from typing import List, Optional, Union
import torch
from torch import Tensor, nn
from .abstract_models import ClassicalModel
[docs]class SelectSingleOutput(ClassicalModel):
"""
SelectSingleOutput simulates standard technique of looking at a Single Qubit in determining the value of the function.
This class allows user to select either a single or multiple outputs from previous model.
Then these outputs can be passed through a single Linear layer.
Linear Layer in the end can be particularly useful in case of choosing single qubit,
which can have only positive values, since bias in a layer, can vary signs of these values.
This in turn will SoftMax is depended on a sign of the output.
:param qubit_idx: Indices of qubit(-s) to be selected, defaults to the first qubit.
:type qubit_idx: Union[List[int], int], optional
:param use_linear_after: Whether to use linear layer after selecting the qubits or not, defaults to True
If the `use_linear_after` is set to false,
then the output will contain binary where first element is sum of chosen indexes,
and the other is 1 - that sum (in case those are probabilities,
this would correspond to sum of elements on other indexes).
:type use_linear_after: bool, optional
:param num_classes: Ignored if `use_linear_after` is `False`.
Number of classes to classify, defaults to number of chosen qubits.
:type num_classes: Optional[int], optional
"""
def __init__(
self,
qubit_idx: Union[List[int], int] = 0,
use_linear_after: bool = True,
num_classes: Optional[int] = None,
*args,
**kwargs,
):
super(ClassicalModel, self).__init__()
if isinstance(qubit_idx, int):
_qubit_idx = [qubit_idx]
else:
_qubit_idx = qubit_idx
self.qubit_idx = tuple(_qubit_idx)
num_classes = len(_qubit_idx) if num_classes is None else num_classes
self.use_linear_after = use_linear_after
if use_linear_after:
self.fc = nn.Linear(len(_qubit_idx), num_classes)
def forward(self, x: Tensor, *args) -> Tensor:
selected_x = x[:, self.qubit_idx]
if self.use_linear_after:
result = self.fc(selected_x)
else:
result = torch.sum(selected_x, dim=1)
result = result.repeat((2, 1)).transpose(0, 1)
result[:, 1] = 1 - result[:, 1]
return result