Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import os | |
| import multiprocessing | |
| from models.solvers.general_solver import GeneralSolver | |
| from utils.utils import calc_tour_length | |
| def get_visited_mask(tour, step, node_feats, dist_matrix=None): | |
| """ | |
| Visited nodes -> feasible, Unvisited nodes -> infeasible. | |
| When solving a problem with visited_paths fixed, they should be included to the solution. | |
| Therefore, visited nodes are set to feasible nodes. | |
| """ | |
| if dist_matrix is not None: | |
| num_nodes = len(dist_matrix) | |
| else: | |
| num_nodes = len(node_feats["coords"]) | |
| visited = np.isin(np.arange(num_nodes), tour[:step]) | |
| return visited | |
| def get_tw_mask(tour, step, node_feats, dist_matrix=None): | |
| """ | |
| Nodes whose tw exceeds current_time -> infeasible, otherwise -> feasible. | |
| Parameters | |
| ---------- | |
| tour: list [seq_length] | |
| step: int | |
| node_feats: dict of np.array | |
| Returns | |
| ------- | |
| mask_tw: np.array [num_nodes] | |
| """ | |
| node_feats = node_feats.copy() | |
| time_window = node_feats["time_window"] | |
| if dist_matrix is not None: | |
| num_nodes = len(dist_matrix) | |
| curr_time = 0.0 | |
| not_exceed_tw = np.ones(num_nodes).astype(np.int32) | |
| for i in range(1, step): | |
| prev_id = tour[i - 1] | |
| curr_id = tour[i] | |
| travel_time = dist_matrix[prev_id, curr_id] | |
| # assert curr_time + travel_time < time_window[curr_id, 1], f"Invalid tour! arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}" | |
| if curr_time + travel_time < time_window[curr_id, 0]: | |
| curr_time = time_window[curr_id, 0].copy() | |
| else: | |
| curr_time += travel_time | |
| curr_time = curr_time + dist_matrix[tour[step-1]] # [num_nodes] TODO: check | |
| else: | |
| coords = node_feats["coords"] | |
| num_nodes = len(coords) | |
| curr_time = 0.0 | |
| not_exceed_tw = np.ones(num_nodes).astype(np.int32) | |
| for i in range(1, step): | |
| prev_id = tour[i - 1] | |
| curr_id = tour[i] | |
| travel_time = np.linalg.norm(coords[prev_id] - coords[curr_id]) | |
| # assert curr_time + travel_time < time_window[curr_id, 1], f"Invalid tour! arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}" | |
| if curr_time + travel_time < time_window[curr_id, 0]: | |
| curr_time = time_window[curr_id, 0].copy() | |
| else: | |
| curr_time += travel_time | |
| curr_time = curr_time + np.linalg.norm(coords[tour[step-1]][None, :] - coords, axis=-1) # [num_nodes] TODO: check | |
| not_exceed_tw[curr_time > time_window[:, 1]] = 0 | |
| not_exceed_tw = not_exceed_tw > 0 | |
| return not_exceed_tw | |
| def get_cap_mask(tour, step, node_feats): | |
| num_nodes = len(node_feats["coords"]) | |
| demands = node_feats["demand"] | |
| remaining_cap = node_feats["capacity"].copy() | |
| less_than_cap = np.ones(num_nodes).astype(np.int32) | |
| for i in range(step): | |
| remaining_cap -= demands[tour[i]] | |
| less_than_cap[remaining_cap < demands] = 0 | |
| less_than_cap = less_than_cap > 0 | |
| return less_than_cap | |
| def get_pc_mask(tour, step, node_feats): | |
| """ | |
| Mask for Price collecting problems (e.g., PCTSP, PCTSPTW, PCCVRP, PCCVRPTW, ...) | |
| Returns | |
| ------- | |
| not_exceed_max_length | |
| """ | |
| large_value = 1e+5 | |
| coords = node_feats["coords"] | |
| max_length = (node_feats["max_length"] * large_value).astype(np.int64) | |
| tour_length = 0 | |
| for i in range(1, step): | |
| prev_id = tour[i - 1] | |
| curr_id = tour[i] | |
| tour_length += (np.linalg.norm(coords[prev_id] - coords[curr_id]) * large_value).astype(np.int64) | |
| curr_to_next = (np.linalg.norm(coords[tour[step-1]][None, :] - coords, axis=-1) * large_value).astype(np.int64) # [num_nodes] | |
| next_to_depot = (np.linalg.norm(coords[tour[0]][None, :] - coords, axis=-1) * large_value).astype(np.int64) # [num_nodes] | |
| not_exceed_max_length = (tour_length + curr_to_next + next_to_depot) <= max_length # [num_nodes] | |
| return not_exceed_max_length | |
| def analyze_tour(tour, node_feats): | |
| coords = node_feats["coords"] | |
| time_window = node_feats["time_window"] | |
| curr_time = 0 | |
| for i in range(1, len(tour)): | |
| prev_id = tour[i - 1] | |
| curr_id = tour[i] | |
| travel_time = np.linalg.norm(coords[prev_id] - coords[curr_id]) | |
| valid = curr_time + travel_time < time_window[curr_id, 1] | |
| print(f"visit #{i}: {prev_id} -> {curr_id}, travel_time: {travel_time}, arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}, valid: {valid}") | |
| if curr_time + travel_time < time_window[curr_id, 0]: | |
| curr_time = time_window[curr_id, 0] | |
| else: | |
| curr_time += travel_time | |
| FAIL_FLAG = -1 | |
| class GroundTruthBase(nn.Module): | |
| def __init__(self, problem, compared_problems, solver_type): | |
| """ | |
| Parameters | |
| ---------- | |
| """ | |
| super().__init__() | |
| self.problem = problem | |
| self.compared_problems = compared_problems | |
| self.num_compared_problems = len(compared_problems) | |
| self.solver_type = solver_type | |
| self.solvers = [] | |
| for i in range(self.num_compared_problems): | |
| # TODO: | |
| self.solvers.append(GeneralSolver(self.compared_problems[i], self.solver_type, scaling=False)) | |
| def forward(self, inputs, annotation=False, parallel=True): | |
| """ | |
| Parameters | |
| ---------- | |
| inputs: dict | |
| tour: 2d list [num_vehicles x seq_length] | |
| first_explained_step: int | |
| node_feats: dict of np.array | |
| annotation: bool | |
| please set it True when annotating data | |
| Returns | |
| ------- | |
| labels: | |
| probs: torch.tensor [batch_size (num_vehicles) x max_seq_length x num_classes] | |
| """ | |
| input_tours = inputs["tour"] | |
| node_feats = inputs["node_feats"] | |
| dist_matrix = inputs["dist_matrix"] | |
| first_explained_step = inputs["first_explained_step"] | |
| num_vehicles = len(input_tours) | |
| if annotation: | |
| labels = [[] for _ in range(num_vehicles)] | |
| for vehicle_id in range(num_vehicles): | |
| input_tour = input_tours[vehicle_id] | |
| # analyze_tour(input_tour, node_feats) | |
| for step in range(first_explained_step + 1, len(input_tour)): | |
| _, __, label = self.label_path(vehicle_id, step, input_tour, node_feats) | |
| if label == FAIL_FLAG: | |
| return | |
| labels[vehicle_id].append((step, label)) | |
| return labels | |
| else: | |
| if parallel: | |
| labels = [[-1] * (len(range(first_explained_step+1, len(input_tours[vehicle_id])))) for vehicle_id in range(num_vehicles)] | |
| num_cpus = os.cpu_count() | |
| with multiprocessing.Pool(num_cpus) as pool: | |
| for vehicle_id, step, label in pool.starmap(self.label_path, [(vehicle_id, step, input_tours[vehicle_id], node_feats, dist_matrix) | |
| for vehicle_id in range(num_vehicles) | |
| for step in range(first_explained_step+1, len(input_tours[vehicle_id]))]): | |
| labels[vehicle_id][step-(first_explained_step+1)] = label | |
| else: | |
| labels = [[-1] * (len(range(first_explained_step+1, len(input_tours[vehicle_id])))) for vehicle_id in range(num_vehicles)] | |
| for vehicle_id in range(num_vehicles): | |
| for step in range(first_explained_step+1, len(input_tours[vehicle_id])): | |
| vehicle_id, step, label = self.label_path(vehicle_id, step, input_tours[vehicle_id], node_feats, dist_matrix) | |
| labels[vehicle_id][step-(first_explained_step+1)] = label | |
| # validate labels | |
| for vehicle_id in range(num_vehicles): | |
| assert (len(input_tours[vehicle_id]) - 1) == len(labels[vehicle_id]), f"vehicle_id={vehicle_id}, {input_tours}, {labels}" | |
| return labels | |
| # labels = [torch.LongTensor(label) for label in labels] # [num_vehicles x seq_length] | |
| # labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) # [num_vehicles x max_seq_length] | |
| # probs = torch.zeros((labels.size(0), labels.size(1), self.num_compared_problems+1)) # [num_vehicles x max_seq_length x (num_compared_problems+1)] | |
| # probs.scatter_(-1, labels.unsqueeze(-1).expand_as(probs), 1.0) | |
| # return probs | |
| def label_path(self, vehicle_id, step, input_tour, node_feats, dist_matrix=None): | |
| compared_tour_list = [[] for _ in range(self.num_compared_problems)] | |
| visited_path = input_tour[:step].copy() | |
| new_node_id, new_node_feats, new_dist_matrix = self.get_feasible_nodes(input_tour, step, node_feats, dist_matrix) | |
| new_visited_path = np.array(list(map(lambda x: np.where(new_node_id==x)[0].item(), visited_path))) | |
| for i in range(self.num_compared_problems): | |
| # TODO: in CVRPTW / PCCVRPTW, need to modify classification of the first and last paths | |
| compared_tours = self.solvers[i].solve(new_node_feats, new_visited_path, new_dist_matrix) | |
| if compared_tours is None: | |
| return vehicle_id, step, FAIL_FLAG | |
| compared_tour = None | |
| for compared_tour_tmp in compared_tours: | |
| if new_visited_path[-1] in compared_tour_tmp: | |
| compared_tour = compared_tour_tmp | |
| break | |
| assert compared_tour is not None, f"Found no appropriate vhiecle. {compared_tours}, {new_visited_path}" | |
| compared_tour = np.array(list(map(lambda x: new_node_id[x], compared_tour))) | |
| if (step > 0) and (compared_tour[1] != input_tour[1]): | |
| compared_tour = np.flipud(compared_tour) # make direction of the cf tour the same as factual one | |
| compared_tour_list[i] = compared_tour | |
| # print("fixed_paths :", visited_path) | |
| # print("input_tour :", input_tour) | |
| # print("compared_tour:", compared_tour) | |
| # print() | |
| # annotation | |
| label = self.get_label(input_tour, compared_tour_list, step) | |
| return vehicle_id, step, label | |
| def solve(self, step, input_tour, node_feats, instance_name=None): | |
| compared_tours = {} | |
| visited_path = input_tour[:step].copy() | |
| new_node_id, new_node_feats = self.get_feasible_nodes(input_tour, step, node_feats) | |
| new_visited_path = np.array(list(map(lambda x: np.where(new_node_id==x)[0].item(), visited_path))) | |
| for i, compared_problem in enumerate(self.compared_problems): | |
| compared_tours[compared_problem] = self.solvers[i].solve(new_node_feats, new_visited_path, instance_name) | |
| compared_tours[compared_problem] = list(map(lambda compared_tour: list(map(lambda x: new_node_id[x], compared_tour)), compared_tours[compared_problem])) | |
| compared_tours[compared_problem] = list(map(lambda compared_tour: calc_tour_length(compared_tour, node_feats["coords"]), compared_tours[compared_problem])) | |
| return compared_tours | |
| def get_label(self, input_tour, compared_tours, step): | |
| for i in range(self.num_compared_problems): | |
| compared_tour = compared_tours[i] | |
| if input_tour[step] == compared_tour[step]: | |
| return i | |
| return self.num_compared_problems | |
| def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None): | |
| input_features = { | |
| "tour": tour, | |
| "first_explained_step": first_explained_step, | |
| "node_feats": node_feats, | |
| "dist_matrix": dist_matrix | |
| } | |
| return input_features | |
| def get_feasible_nodes(self, tour, step, node_feats, dist_matrix=None): | |
| """ | |
| Parameters | |
| ---------- | |
| tour: np.array [seq_length] | |
| step: int | |
| node_feats: np.array [num_nodes x node_dim] | |
| Returns | |
| ------- | |
| new_node_id: np.array [num_feasible_nodes] | |
| new_node_feats: dict of np.array [num_feasible_nodes x coord_dim] | |
| """ | |
| if dist_matrix is not None: | |
| num_nodes = len(dist_matrix) | |
| else: | |
| num_nodes = len(node_feats["coords"]) | |
| mask = self.get_mask(tour, step, node_feats, dist_matrix) | |
| node_id = np.arange(num_nodes) | |
| new_node_id = node_id[mask].copy() | |
| new_node_feats = { | |
| key: node_feat[mask].copy() | |
| if key in ["coords", "time_window", "demand", "penalties", "prizes"] else | |
| node_feat.copy() | |
| for key, node_feat in node_feats.items() | |
| } | |
| if dist_matrix is not None: | |
| delete_id = node_id[~mask] | |
| new_dist_matrix = np.delete(np.delete(dist_matrix, delete_id, 0), delete_id, 1) | |
| else: | |
| new_dist_matrix = None | |
| return new_node_id, new_node_feats, new_dist_matrix | |
| def get_mask(self, tour, step, node_feats, dist_matrix=None): | |
| raise NotImplementedError | |
| def check_feasibility(self, tour, node_feats): | |
| raise NotImplementedError |