from typing import Any, Dict, Union
import torch
from torch import nn
from .abstract_models import ClassicalModel, QuantumModel
[docs]class MainModel(nn.Module):
"""Wrapper for a sequential model, where data flows:
`model1 -> model2 -> model3.`
It allows model2 to be basically anything
(under some constraints: given data outputs data, and calculates gradient with respect to input.)
:param model1: A torch nn.Module. Network which outputs shape compatible with input to model 2.
:type model1: ClassicalModel
:param model2: Either a torch nn.Module, or any algorithm operating on numpy arrays.
:type model2: Union[ClassicalModel, QuantumModel]
:param model3: A torch nn.Module. Network which input shape is compatible with output to model 2.
:type model3: ClassicalModel
:param config: Configuration specifying various aspects of a model,
for example whether certain model is trainable or not. Defaults to None.
:type config: Dict[str, Any], optional
"""
def __init__(
self,
model1: ClassicalModel,
model2: Union[ClassicalModel, QuantumModel],
model3: ClassicalModel,
config: Dict[str, Any] = None,
):
"""Constructor method
"""
super().__init__()
self.model1 = model1
self.model2 = model2
self.model3 = model3
self.config = config
# Make the models not trainable if specified in a config
model1_train = True
model2_train = True
model3_train = True
if self.config is not None:
model1_train = not self.config["model 1"].get("fixed", False)
model2_train = not (
self.config["model 2"].get("fixed", False)
and isinstance(self.model2, nn.Module)
)
model3_train = not self.config["model 3"].get("fixed", False)
self.model1_train = model1_train
self.model2_train = model2_train
self.model3_train = model3_train
def parameters(self, trainable_only: bool = False):
"""
Parameters function similar to `nn.Module`.
:param trainable_only: Whether to only return trainable parameters, defaults to False
:type trainable_only: bool, optional
:return: List of parameters across all 3 models.
:rtype: List
"""
result = []
if not trainable_only or self.model1_train:
result.extend(list(self.model1.parameters()))
if not trainable_only or self.model2_train:
if isinstance(self.model2, nn.Module):
result.extend(list(self.model2.parameters()))
if not trainable_only or self.model3_train:
result.extend(list(self.model3.parameters()))
if trainable_only:
result = [p for p in result if p.requires_grad]
return result
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward function as required by nn.Module.
It deals with specific case of model2 being quantum.
"""
self.input_1 = x
if not isinstance(self.model2, QuantumModel):
x = self.model1(x)
x = self.model2(x)
return self.model3(x)
else:
device_x = x.device
# Encoder
x = self.model1(x)
# Process through QM clipping function
x = self.model2.clip(x)
self.processed_input_2 = x
# Torch -> Numpy
angle_np = x.cpu().detach().numpy()
# Middle, quantum part
# Due to implementation
fwq = self.model2.forward(angle_np)
# Numpy -> Torch
self.input_3 = torch.tensor(
fwq, requires_grad=True, dtype=torch.float32, device=device_x
)
# Decoder
self.pred = self.model3(self.input_3)
return self.pred
def backward(self) -> torch.Tensor:
"""
Propagates gradient backward through the model.
Note that typical PyTorch `loss.backward()` call still has to be made!
This function does nothing if all models are classical.
If model 2 is quantum than *fake loss* trick is used.
:return: Gradient with respect to the input.
:rtype: torch.Tensor
"""
if isinstance(self.model2, QuantumModel):
# Get gradient with respect to input to model 3
# Torch -> Numpy
device_x = self.input_3.device
input_3_grad = self.input_3.grad.cpu().numpy()
# Perform gradient propagation through middle
quantum_grad = self.model2.backward(input_3_grad) # type: ignore
# Convert gradient returned from quantum part for:
# dL/d(input to model 2/output of model 1)
# Numpy -> Torch
output_1_grad = torch.tensor(
quantum_grad, requires_grad=False, dtype=torch.float32, device=device_x
)
# Make a fake loss, which allows us to pass gradients in PyTorch.
# This is because d(gx)/dx = g, for outputs of model 1,
# and everything is propagated from there by PyTorch.
fake_angle_loss = torch.sum(output_1_grad * self.processed_input_2)
# Update encoder gradients from both grads
fake_angle_loss.backward()
# Return gradient with respect to input (so we can nest these?)
return self.input_1.grad
def __call__(self, *args):
return self.forward(*args)
def to(self, device):
"""
Copy of `nn.Module` `to` function but adjusting for the fact that self.model2 doesn't have to be a Module.
"""
self.model1 = self.model1.to(device)
if isinstance(self.model2, nn.Module):
self.model2 = self.model2.to(device)
self.model3 = self.model3.to(device)
return self
def num_parameters(self, trainable_only: bool = True):
"""
Assumes no shared parameters, and gives an option to include all or only trainable parameters.
"""
return sum(p.numel() for p in self.parameters(trainable_only))