Source code for squid.models.select_single_output

from typing import List, Optional, Union

import torch
from torch import Tensor, nn

from .abstract_models import ClassicalModel


[docs]class SelectSingleOutput(ClassicalModel): """ SelectSingleOutput simulates standard technique of looking at a Single Qubit in determining the value of the function. This class allows user to select either a single or multiple outputs from previous model. Then these outputs can be passed through a single Linear layer. Linear Layer in the end can be particularly useful in case of choosing single qubit, which can have only positive values, since bias in a layer, can vary signs of these values. This in turn will SoftMax is depended on a sign of the output. :param qubit_idx: Indices of qubit(-s) to be selected, defaults to the first qubit. :type qubit_idx: Union[List[int], int], optional :param use_linear_after: Whether to use linear layer after selecting the qubits or not, defaults to True If the `use_linear_after` is set to false, then the output will contain binary where first element is sum of chosen indexes, and the other is 1 - that sum (in case those are probabilities, this would correspond to sum of elements on other indexes). :type use_linear_after: bool, optional :param num_classes: Ignored if `use_linear_after` is `False`. Number of classes to classify, defaults to number of chosen qubits. :type num_classes: Optional[int], optional """ def __init__( self, qubit_idx: Union[List[int], int] = 0, use_linear_after: bool = True, num_classes: Optional[int] = None, *args, **kwargs, ): super(ClassicalModel, self).__init__() if isinstance(qubit_idx, int): _qubit_idx = [qubit_idx] else: _qubit_idx = qubit_idx self.qubit_idx = tuple(_qubit_idx) num_classes = len(_qubit_idx) if num_classes is None else num_classes self.use_linear_after = use_linear_after if use_linear_after: self.fc = nn.Linear(len(_qubit_idx), num_classes) def forward(self, x: Tensor, *args) -> Tensor: selected_x = x[:, self.qubit_idx] if self.use_linear_after: result = self.fc(selected_x) else: result = torch.sum(selected_x, dim=1) result = result.repeat((2, 1)).transpose(0, 1) result[:, 1] = 1 - result[:, 1] return result