from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
from .abstract_models import ClassicalModel
from .classical_utils import str_to_activation
[docs]class ConvFeedForward(ClassicalModel):
"""
ConvFeedForward is a general convolutional and feed forward network.
Any (valid) sequence of nn.Conv2d, nn.Linear and
non-linearities (from nn.functional) can be represented by it.
:param in_shape: Shape of incoming data (one example, without batch dimension, channels first).
This is needed to create the first layer of this model.
:type in_shape: List[int]
:param conv_layers: A list of either:
* `Tuple` - A tuple description of convolution,
where first element specifies type of convolution
(either C1-3, CT1-3, M1-3, for
convolution, convolution transpose, maxpool respectively,
through 1 to 3 dimensions).
For example ("C2", 6, 6, 3) is a `Conv2d` layer
from 6 in channels to 6 out channels with kernel size of 3.
* str - an activation function in a given layer.
These are then converted to actual nn classes and functions,
which PyTorch can run through and optimize.
:type conv_layers: List[Union[Tuple, str]]
:param layers: A list of either:
* int - which specifies number of neurons in a given layer.
* str - an activation function in a given layer.
These are then converted to actual nn.Linear and nn functions,
which PyTorch can run through and optimize.
:type layers: List[Union[int, str]]
:param num_classes: Number of classes to classify.
This can be used to create last layer of this model.
Note that it does this model to be the last one in the
stack of models, if this model is part of larger one (like MainModel).
If set to None (default), outgoing shape of this model is equal to layers[-1].
:type num_classes: int, optional
"""
def __init__(
self,
in_shape: List[int],
conv_layers: List[Union[Tuple, str]],
layers: List[Union[int, str]],
num_classes: int = None,
*args,
**kwargs
) -> None:
super(ConvFeedForward, self).__init__()
# Figure out number of channels
if len(in_shape) == 2:
in_shape = [1] + list(in_shape)
in_channels = in_shape[0]
# Convert conv_layers from integers to nn.Conv2D's
_conv_layers = []
if conv_layers is not None and len(conv_layers) > 0:
for layer in conv_layers:
if isinstance(layer, str):
_conv_layers.append(str_to_activation(layer))
elif isinstance(layer, tuple):
__layer, in_channels = ConvFeedForward._tuple_to_conv(
in_channels, layer
)
_conv_layers.append(__layer)
# Cast to sequential
self.conv_layers = nn.Sequential(*_conv_layers)
# Figure out a shape out of the conv part.
in_features = np.prod(
self.conv_layers(
torch.randn([2] + list(in_shape), requires_grad=True)
).shape[1:]
)
_layers: List[nn.Module] = []
# Populate Linear Layers
if layers is not None and len(layers) > 0:
# Convert layers from integers to nn.Linear's
if layers is not None:
for _layer in layers:
if isinstance(_layer, int):
_layers.append(nn.Linear(in_features, _layer))
in_features = _layer
elif isinstance(_layer, str):
_layers.append(str_to_activation(_layer))
# Last Layer (if applicable)
if num_classes is not None:
_layers.append(nn.Linear(in_features, num_classes))
elif num_classes is not None:
_layers.append(nn.Linear(in_features, num_classes))
# Cast to Sequential
self.layers: Optional[nn.Module] = None
if _layers is not None:
self.layers = nn.Sequential(*_layers)
def forward(self, x: torch.Tensor, *args) -> torch.Tensor:
"""
Forward function as required by nn.Module.
:param x: Input
:type x: Tensor
:return: Output
:rtype: Tensor
"""
x = self.conv_layers(x)
if len(x.shape) > 2:
x = x.view(len(x), -1)
if self.layers is not None:
x = self.layers(x)
return x
[docs] @staticmethod
def _tuple_to_conv(
prev_channels: int, tup: Tuple[Any, ...]
) -> Tuple[nn.Module, int]:
"""
Converts tuple to actual nn.Conv, nn.ConvTranspose or nn.MaxPool layer.
:param prev_channels: Channels incoming to the layers
:type prev_channels: int
:param tup: Tuple defining layer
(first element defining type of layer, remaining elements defining arguments to that layer).
:type tup: Tuple[str, Any]
:return: Tuple, with first element being PyTorch layer,
and second element being number of channels after this layer is applied.
:rtype: Tuple[nn.Module, int]
"""
args_to_layer = tup[1:]
# Normal Convolutions
layer: nn.Module
if str(tup[0]) == "C1":
layer = nn.Conv1d(prev_channels, *args_to_layer)
new_channels = args_to_layer[0]
elif str(tup[0]) == "C2":
layer = nn.Conv2d(prev_channels, *args_to_layer)
new_channels = args_to_layer[0]
elif str(tup[0]) == "C3":
layer = nn.Conv3d(prev_channels, *args_to_layer)
new_channels = args_to_layer[0]
# Transpose Convolutions
elif str(tup[0]) == "CT1":
layer = nn.ConvTranspose1d(prev_channels, *args_to_layer)
new_channels = args_to_layer[0]
elif str(tup[0]) == "CT2":
layer = nn.ConvTranspose2d(prev_channels, *args_to_layer)
new_channels = args_to_layer[0]
elif str(tup[0]) == "CT3":
layer = nn.ConvTranspose3d(prev_channels, *args_to_layer)
new_channels = args_to_layer[0]
# MaxPool
elif str(tup[0]) == "M1":
layer = nn.MaxPool1d(*args_to_layer)
new_channels = prev_channels
elif str(tup[0]) == "M2":
layer = nn.MaxPool2d(*args_to_layer)
new_channels = prev_channels
elif str(tup[0]) == "M3":
layer = nn.MaxPool3d(*args_to_layer)
new_channels = prev_channels
return layer, new_channels