Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| class RandomPredictor(nn.Module): | |
| def __init__(self, num_classes): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| def forward(self, inputs): | |
| """ | |
| Parameters | |
| ---------- | |
| inputs: int or dict | |
| batch_size or dict of input features | |
| Returns | |
| ------- | |
| probs: torch.tensor [batch_size x num_classes] | |
| """ | |
| batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0) | |
| ranom_index = torch.randint(self.num_classes, (batch_size, self.num_classes)) | |
| probs = torch.zeros(batch_size, self.num_classes).to(torch.float) | |
| probs.scatter_(-1, ranom_index, 1.0) | |
| return probs | |
| def get_inputs(self, tour, first_explained_step, node_feats): | |
| return len(tour[first_explained_step:-1]) | |
| class FixedClassPredictor(nn.Module): | |
| def __init__(self, predicted_class, num_classes): | |
| """ | |
| Paramters | |
| --------- | |
| predicted_class: int | |
| a class that this predictor always predicts | |
| num_classes: int | |
| number of classes | |
| """ | |
| super().__init__() | |
| self.predicted_class = predicted_class | |
| self.num_classes = num_classes | |
| assert predicted_class < num_classes, f"predicted_class should be 0 - {num_classes}." | |
| def forward(self, inputs): | |
| """ | |
| Parameters | |
| ---------- | |
| inputs: int or dict | |
| batch_size or dict of input features | |
| Returns | |
| ------- | |
| probs: torch.tensor [batch_size x num_classes] | |
| """ | |
| batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0) | |
| index = torch.full((batch_size, self.num_classes), self.predicted_class) | |
| probs = torch.zeros(batch_size, self.num_classes).to(torch.float) | |
| probs.scatter_(-1, index, 1.0) | |
| return probs | |
| def get_inputs(self, tour, first_explained_step, node_feats): | |
| return len(tour[first_explained_step:-1]) |