Spaces:
Paused
Paused
| import torch.nn as nn | |
| from models.solvers.ortools.ortools_tsp import ORToolsTSP | |
| from models.solvers.ortools.ortools_tsptw import ORToolsTSPTW | |
| from models.solvers.ortools.ortools_pctsp import ORToolsPCTSP | |
| from models.solvers.ortools.ortools_pctsptw import ORToolsPCTSPTW | |
| from models.solvers.ortools.ortools_cvrp import ORToolsCVRP | |
| from models.solvers.ortools.ortools_cvrptw import ORToolsCVRPTW | |
| class ORTools(nn.Module): | |
| def __init__(self, problem, large_value=1e+6, scaling=False): | |
| super().__init__() | |
| self.coord_dim = 2 | |
| self.problem = problem | |
| self.large_value = large_value | |
| self.scaling = scaling | |
| self.ortools = self.get_ortools(problem) | |
| def get_ortools(self, problem): | |
| """ | |
| Parameters | |
| ---------- | |
| problem: str | |
| problem type | |
| Returns | |
| ------- | |
| ortools: ortools for the specified problem | |
| """ | |
| if problem == "tsp": | |
| return ORToolsTSP(self.large_value, self.scaling) | |
| elif problem == "tsptw": | |
| return ORToolsTSPTW(self.large_value, self.scaling) | |
| elif problem == "pctsp": | |
| return ORToolsPCTSP(self.large_value, self.scaling) | |
| elif problem == "pctsptw": | |
| return ORToolsPCTSPTW(self.large_value, self.scaling) | |
| elif problem == "cvrp": | |
| return ORToolsCVRP(self.large_value, self.scaling) | |
| elif problem == "cvrptw": | |
| return ORToolsCVRPTW(self.large_value, self.scaling) | |
| else: | |
| raise NotImplementedError | |
| def solve(self, node_feats, fixed_paths=None, dist_martix=None, instance_name=None): | |
| """ | |
| Parameters | |
| ---------- | |
| node_feats: np.array [num_nodes x node_dim] | |
| fixed_paths: np.array [cf_step] | |
| scaling: bool | |
| whether or not coords are muliplied by a large value | |
| to convert float-coods into int-coords | |
| Returns | |
| ------- | |
| tour: np.array [seq_length] | |
| """ | |
| return self.ortools.solve(node_feats, fixed_paths, dist_martix, instance_name) |