Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from models.classifiers.ground_truth.ground_truth_tsptw import GroundTruthTSPTW | |
| from models.classifiers.ground_truth.ground_truth_pctsp import GroundTruthPCTSP | |
| from models.classifiers.ground_truth.ground_truth_pctsptw import GroundTruthPCTSPTW | |
| from models.classifiers.ground_truth.ground_truth_cvrp import GroundTruthCVRP | |
| from models.classifiers.ground_truth.ground_truth_cvrptw import GroundTruthCVRPTW | |
| class GroundTruth(nn.Module): | |
| def __init__(self, problem, solver_type): | |
| super().__init__() | |
| self.problem = problem | |
| self.solver_type = solver_type | |
| if problem == "tsptw": | |
| self.ground_truth = GroundTruthTSPTW(solver_type) | |
| elif problem == "pctsp": | |
| self.ground_truth = GroundTruthPCTSP(solver_type) | |
| elif problem == "pctsptw": | |
| self.ground_truth = GroundTruthPCTSPTW(solver_type) | |
| elif problem == "cvrp": | |
| self.ground_truth = GroundTruthCVRP(solver_type) | |
| elif problem == "cvrptw": | |
| self.ground_truth = GroundTruthCVRPTW(solver_type) | |
| else: | |
| raise NotImplementedError | |
| def forward(self, inputs, annotation=False, parallel=False): | |
| return self.ground_truth(inputs, annotation, parallel) | |
| def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None): | |
| return self.ground_truth.get_inputs(tour, first_explained_step, node_feats, dist_matrix) | |
| def solve(self, step, input_tour, node_feats, instance_name=None): | |
| return self.ground_truth.solve(step, input_tour, node_feats, instance_name) |