Spaces:
Paused
Paused
| import os | |
| import argparse | |
| import json | |
| import multiprocessing | |
| import torch | |
| import time | |
| from tqdm import tqdm | |
| from torch.utils.data import DataLoader | |
| from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score | |
| from utils.util_calc import TemporalConfusionMatrix | |
| from models.classifiers.nn_classifiers.nn_classifier import NNClassifier | |
| from models.classifiers.ground_truth.ground_truth import GroundTruth | |
| from models.classifiers.ground_truth.ground_truth_base import FAIL_FLAG | |
| from utils.data_utils.tsptw_dataset import TSPTWDataloader | |
| from utils.data_utils.pctsp_dataset import PCTSPDataloader | |
| from utils.data_utils.pctsptw_dataset import PCTSPTWDataloader | |
| from utils.data_utils.cvrp_dataset import CVRPDataloader | |
| from utils.utils import set_device | |
| from utils.utils import load_dataset | |
| def load_eval_dataset(dataset_path, problem, model_type, batch_size, num_workers, parallel, num_cpus): | |
| if model_type == "nn": | |
| if problem == "tsptw": | |
| eval_dataset = TSPTWDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus) | |
| elif problem == "pctsp": | |
| eval_dataset = PCTSPDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus) | |
| elif problem == "pctsptw": | |
| eval_dataset = PCTSPTWDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus) | |
| elif problem == "cvrp": | |
| eval_dataset = CVRPDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus) | |
| else: | |
| raise NotImplementedError | |
| #------------ | |
| # dataloader | |
| #------------ | |
| def pad_seq_length(batch): | |
| data = {} | |
| for key in batch[0].keys(): | |
| padding_value = True if key == "mask" else 0.0 | |
| # post-padding | |
| data[key] = torch.nn.utils.rnn.pad_sequence([d[key] for d in batch], batch_first=True, padding_value=padding_value) | |
| pad_mask = torch.nn.utils.rnn.pad_sequence([torch.full((d["mask"].size(0), ), True) for d in batch], batch_first=True, padding_value=False) | |
| data.update({"pad_mask": pad_mask}) | |
| return data | |
| eval_dataloader = DataLoader(eval_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| collate_fn=pad_seq_length, | |
| num_workers=num_workers) | |
| return eval_dataloader | |
| else: | |
| eval_dataset = load_dataset(dataset_path) | |
| return eval_dataset | |
| def eval_classifier(problem: str, | |
| dataset, | |
| model_type: str, | |
| model_dir: str = None, | |
| gpu: int = -1, | |
| num_workers: int = 4, | |
| batch_size: int = 128, | |
| parallel: bool = True, | |
| solver: str = "ortools", | |
| num_cpus: int = 1): | |
| #-------------- | |
| # gpu settings | |
| #-------------- | |
| use_cuda, device = set_device(gpu) | |
| #------- | |
| # model | |
| #------- | |
| num_classes = 3 if problem == "pctsptw" else 2 | |
| if model_type == "nn": | |
| assert model_dir is not None, "please specify model_path when model_type is nn." | |
| params = argparse.ArgumentParser() | |
| # model_dir = os.path.split(args.model_path)[0] | |
| with open(f"{model_dir}/cmd_args.dat", "r") as f: | |
| params.__dict__ = json.load(f) | |
| assert params.problem == problem, "problem of the trained model should match that of the dataset" | |
| model = NNClassifier(problem=params.problem, | |
| node_enc_type=params.node_enc_type, | |
| edge_enc_type=params.edge_enc_type, | |
| dec_type=params.dec_type, | |
| emb_dim=params.emb_dim, | |
| num_enc_mlp_layers=params.num_enc_mlp_layers, | |
| num_dec_mlp_layers=params.num_dec_mlp_layers, | |
| num_classes=num_classes, | |
| dropout=params.dropout, | |
| pos_encoder=params.pos_encoder) | |
| # load trained weights (the best epoch) | |
| with open(f"{model_dir}/best_epoch.dat", "r") as f: | |
| best_epoch = int(f.read()) | |
| print(f"loaded {model_dir}/model_epoch{best_epoch}.pth.") | |
| model.load_state_dict(torch.load(f"{model_dir}/model_epoch{best_epoch}.pth")) | |
| if use_cuda: | |
| model.to(device) | |
| is_sequential = model.is_sequential | |
| elif model_type == "ground_truth": | |
| model = GroundTruth(problem=problem, solver_type=solver) | |
| is_sequential = False | |
| else: | |
| assert False, f"Invalid model type: {model_type}" | |
| #--------- | |
| # Metrics | |
| #--------- | |
| overall_accuracy = MulticlassF1Score(num_classes=num_classes, average="macro").to(device) | |
| eval_accuracy_dict = {} # MulticlassAccuracy(num_classes=num_classes, average="macro") | |
| temp_confmat_dict = {} # TemporalConfusionMatrix(num_classes=num_classes, seq_length=50, device=device) | |
| temporal_accuracy_dict = {} | |
| num_nodes_dist_dict = {} | |
| #------------ | |
| # Evaluation | |
| #------------ | |
| if model_type == "nn": | |
| model.eval() | |
| eval_time = 0.0 | |
| print("Evaluating models ...", end="") | |
| start_time = time.perf_counter() | |
| for data in dataset: | |
| if use_cuda: | |
| data = {key: value.to(device) for key, value in data.items()} | |
| if not is_sequential: | |
| shp = data["curr_node_id"].size() | |
| data = {key: value.flatten(0, 1) for key, value in data.items()} | |
| probs = model(data) # [batch_size x num_classes] or [batch_size x max_seq_length x num_classes] | |
| if not is_sequential: | |
| probs = probs.view(*shp, -1) # [batch_size x max_seq_length x num_classes] | |
| data["labels"] = data["labels"].view(*shp) | |
| data["pad_mask"] = data["pad_mask"].view(*shp) | |
| #------------ | |
| # evaluation | |
| #------------ | |
| start_eval_time = time.perf_counter() | |
| # accuracy | |
| seq_length_list = torch.unique(data["pad_mask"].sum(-1)) | |
| for seq_length_tensor in seq_length_list: | |
| seq_length = seq_length_tensor.item() | |
| if seq_length not in eval_accuracy_dict.keys(): | |
| eval_accuracy_dict[seq_length] = MulticlassF1Score(num_classes=num_classes, average="macro").to(device) | |
| temp_confmat_dict[seq_length] = TemporalConfusionMatrix(num_classes=num_classes, seq_length=seq_length, device=device) | |
| temporal_accuracy_dict[seq_length] = [MulticlassF1Score(num_classes=num_classes, average="macro").to(device) for _ in range(seq_length)] | |
| num_nodes_dist_dict[seq_length] = 0 | |
| seq_length_mask = (data["pad_mask"].sum(-1) == seq_length) # [batch_size] | |
| extracted_labels = data["labels"][seq_length_mask] | |
| extracted_probs = probs[seq_length_mask] | |
| extracted_mask = data["pad_mask"][seq_length_mask].view(-1) # [batch_size x max_seq_length] -> [(batch_size*max_seq_length)] | |
| eval_accuracy_dict[seq_length](extracted_probs.argmax(-1).view(-1)[extracted_mask], extracted_labels.view(-1)[extracted_mask]) | |
| mask = data["pad_mask"].view(-1) | |
| overall_accuracy(probs.argmax(-1).view(-1)[mask], data["labels"].view(-1)[mask]) | |
| # confusion matrix | |
| temp_confmat_dict[seq_length].update(probs.argmax(-1), data["labels"], data["pad_mask"]) | |
| # temporal accuracy | |
| for step in range(seq_length): | |
| temporal_accuracy_dict[seq_length][step](extracted_probs[:, step, :], extracted_labels[:, step]) | |
| # number of samples whose sequence length is seq_length | |
| num_nodes_dist_dict[seq_length] += len(extracted_labels) | |
| eval_time += time.perf_counter() - start_eval_time | |
| calc_time = time.perf_counter() - start_time - eval_time | |
| total_eval_accuracy = {key: value.compute().item() for key, value in eval_accuracy_dict.items()} | |
| overall_accuracy = overall_accuracy.compute() #.item() | |
| temporal_confmat = {key: value.compute() for key, value in temp_confmat_dict.items()} | |
| temporal_accuracy = {key: [value.compute().item() for value in values] for key, values in temporal_accuracy_dict.items()} | |
| print("done") | |
| return overall_accuracy, total_eval_accuracy, temporal_accuracy, calc_time, temporal_confmat, num_nodes_dist_dict | |
| else: | |
| eval_accuracy = MulticlassF1Score(num_classes=num_classes, average="macro").to(device) | |
| print("Loading data ...", end=" ") | |
| with multiprocessing.Pool(num_cpus) as pool: | |
| input_list = list(pool.starmap(model.get_inputs, [(instance["tour"], 0, instance) for instance in dataset])) | |
| print("done") | |
| print("Infering labels ...", end="") | |
| pool = multiprocessing.Pool(num_cpus) | |
| start_time = time.perf_counter() | |
| prob_list = list(pool.starmap(model, tqdm([(inputs, False, False) for inputs in input_list]))) | |
| calc_time = time.perf_counter() - start_time | |
| pool.close() | |
| print("done") | |
| print("Evaluating models ...", end="") | |
| for i, instance in enumerate(dataset): | |
| labels = instance["labels"] | |
| for vehicle_id in range(len(labels)): | |
| for step, label in labels[vehicle_id]: | |
| pred_label = prob_list[i][vehicle_id][step-1] # [num_classes] | |
| if pred_label == FAIL_FLAG: | |
| pred_label = label - 1 if label != 0 else label + 1 | |
| eval_accuracy(torch.LongTensor([pred_label]).view(1, -1), torch.LongTensor([label]).view(1, -1)) | |
| total_eval_accuracy = eval_accuracy.compute() | |
| print("done") | |
| return total_eval_accuracy.item(), calc_time | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| #----------------- | |
| # general settings | |
| #----------------- | |
| parser.add_argument("--gpu", default=-1, type=int, help="Used GPU Number: gpu=-1 indicates using cpu") | |
| parser.add_argument("--num_workers", default=4, type=int, help="Number of workers in dataloader") | |
| parser.add_argument("--parallel", ) | |
| #------------- | |
| # data setting | |
| #------------- | |
| parser.add_argument("--dataset_path", type=str, help="Path to a dataset", required=True) | |
| #------------------ | |
| # Metrics settings | |
| #------------------ | |
| #---------------- | |
| # model settings | |
| #---------------- | |
| parser.add_argument("--model_type", type=str, default="nn", help="Select from [nn, ground_truth]") | |
| # nn classifier | |
| parser.add_argument("--model_dir", type=str, default=None) | |
| parser.add_argument("--batch_size", type=int, default=256) | |
| parser.add_argument("--parallel", action="store_true") | |
| # ground truth | |
| parser.add_argument("--solver", type=str, default="ortools") | |
| parser.add_argument("--num_cpus", type=int, default=os.cpu_count()) | |
| args = parser.parse_args() | |
| problem = str(os.path.basename(os.path.dirname(args.dataset_path))) | |
| dataset = load_eval_dataset(args.dataset_path, problem, args.model_type, args.batch_size, args.num_workers, args.parallel, args.num_cpus) | |
| eval_classifier(problem=problem, | |
| dataset=dataset, | |
| model_type=args.model_type, | |
| model_dir=args.model_dir, | |
| gpu=args.gpu, | |
| num_workers=args.num_workers, | |
| batch_size=args.batch_size, | |
| parallel=args.parallel, | |
| solver=args.solver, | |
| num_cpus=args.num_cpus) |