Source code for squid.models.feed_forward

from typing import List, Union

from torch import Tensor, nn

from .abstract_models import ClassicalModel
from .classical_utils import str_to_activation


[docs]class FeedForward(ClassicalModel): """FeedForward is a general feed forward network wrapper. Any sequence of nn.Linear and non-linearities (from nn.functional) can be represented by it. :param layers: A list of either: * int - which specifies number of neurons in a given layer. * str - an activation function in a given layer. These are then converted to actual nn.Linear and nn functions, which PyTorch can run through and optimize. :type layers: List[Union[int, str]] :param in_features: Number of incoming features to this network. This is needed to create the first layer of this model. :type in_features: int :param num_classes: Number of classes to classify. This can be used to create last layer of this model. Note that it does this model to be the last one in the stack of models, if this model is part of larger one (like MainModel). If set to None (Default), outgoing shape of this model is equal to layers[-1]. :type num_classes: int, optional """ def __init__( self, layers: List[Union[int, str]], in_features: int, num_classes: int = None, *args, **kwargs ) -> None: super().__init__() # Convert layers from integers to nn.Linear's _layers: List[nn.Module] = [] if layers is not None: for layer in layers: if isinstance(layer, int): _layers.append(nn.Linear(in_features, layer)) in_features = layer elif isinstance(layer, str): _layers.append(str_to_activation(layer)) # Last Layer (if applicable) if num_classes is not None: layer_added = False if len(layers) > 0: for layer in layers[::-1]: if isinstance(layer, int): _layers.append(nn.Linear(layer, num_classes)) layer_added = True break if not layer_added and in_features is not None: _layers.append(nn.Linear(in_features, num_classes)) # Cast to Sequential self.layers = nn.Sequential(*_layers) def forward(self, x: Tensor, *args) -> Tensor: """ Forward function as required by nn.Module. :param x: Input :type x: Tensor :return: Output :rtype: Tensor """ if len(self.layers) == 0: return x x = x.view(len(x), -1).float() return self.layers(x)