Source code for squid.models.main_model

from typing import Any, Dict, Union

import torch
from torch import nn

from .abstract_models import ClassicalModel, QuantumModel


[docs]class MainModel(nn.Module): """Wrapper for a sequential model, where data flows: `model1 -> model2 -> model3.` It allows model2 to be basically anything (under some constraints: given data outputs data, and calculates gradient with respect to input.) :param model1: A torch nn.Module. Network which outputs shape compatible with input to model 2. :type model1: ClassicalModel :param model2: Either a torch nn.Module, or any algorithm operating on numpy arrays. :type model2: Union[ClassicalModel, QuantumModel] :param model3: A torch nn.Module. Network which input shape is compatible with output to model 2. :type model3: ClassicalModel :param config: Configuration specifying various aspects of a model, for example whether certain model is trainable or not. Defaults to None. :type config: Dict[str, Any], optional """ def __init__( self, model1: ClassicalModel, model2: Union[ClassicalModel, QuantumModel], model3: ClassicalModel, config: Dict[str, Any] = None, ): """Constructor method """ super().__init__() self.model1 = model1 self.model2 = model2 self.model3 = model3 self.config = config # Make the models not trainable if specified in a config model1_train = True model2_train = True model3_train = True if self.config is not None: model1_train = not self.config["model 1"].get("fixed", False) model2_train = not ( self.config["model 2"].get("fixed", False) and isinstance(self.model2, nn.Module) ) model3_train = not self.config["model 3"].get("fixed", False) self.model1_train = model1_train self.model2_train = model2_train self.model3_train = model3_train def parameters(self, trainable_only: bool = False): """ Parameters function similar to `nn.Module`. :param trainable_only: Whether to only return trainable parameters, defaults to False :type trainable_only: bool, optional :return: List of parameters across all 3 models. :rtype: List """ result = [] if not trainable_only or self.model1_train: result.extend(list(self.model1.parameters())) if not trainable_only or self.model2_train: if isinstance(self.model2, nn.Module): result.extend(list(self.model2.parameters())) if not trainable_only or self.model3_train: result.extend(list(self.model3.parameters())) if trainable_only: result = [p for p in result if p.requires_grad] return result def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward function as required by nn.Module. It deals with specific case of model2 being quantum. """ self.input_1 = x if not isinstance(self.model2, QuantumModel): x = self.model1(x) x = self.model2(x) return self.model3(x) else: device_x = x.device # Encoder x = self.model1(x) # Process through QM clipping function x = self.model2.clip(x) self.processed_input_2 = x # Torch -> Numpy angle_np = x.cpu().detach().numpy() # Middle, quantum part # Due to implementation fwq = self.model2.forward(angle_np) # Numpy -> Torch self.input_3 = torch.tensor( fwq, requires_grad=True, dtype=torch.float32, device=device_x ) # Decoder self.pred = self.model3(self.input_3) return self.pred def backward(self) -> torch.Tensor: """ Propagates gradient backward through the model. Note that typical PyTorch `loss.backward()` call still has to be made! This function does nothing if all models are classical. If model 2 is quantum than *fake loss* trick is used. :return: Gradient with respect to the input. :rtype: torch.Tensor """ if isinstance(self.model2, QuantumModel): # Get gradient with respect to input to model 3 # Torch -> Numpy device_x = self.input_3.device input_3_grad = self.input_3.grad.cpu().numpy() # Perform gradient propagation through middle quantum_grad = self.model2.backward(input_3_grad) # type: ignore # Convert gradient returned from quantum part for: # dL/d(input to model 2/output of model 1) # Numpy -> Torch output_1_grad = torch.tensor( quantum_grad, requires_grad=False, dtype=torch.float32, device=device_x ) # Make a fake loss, which allows us to pass gradients in PyTorch. # This is because d(gx)/dx = g, for outputs of model 1, # and everything is propagated from there by PyTorch. fake_angle_loss = torch.sum(output_1_grad * self.processed_input_2) # Update encoder gradients from both grads fake_angle_loss.backward() # Return gradient with respect to input (so we can nest these?) return self.input_1.grad def __call__(self, *args): return self.forward(*args) def to(self, device): """ Copy of `nn.Module` `to` function but adjusting for the fact that self.model2 doesn't have to be a Module. """ self.model1 = self.model1.to(device) if isinstance(self.model2, nn.Module): self.model2 = self.model2.to(device) self.model3 = self.model3.to(device) return self def num_parameters(self, trainable_only: bool = True): """ Assumes no shared parameters, and gives an option to include all or only trainable parameters. """ return sum(p.numel() for p in self.parameters(trainable_only))