| import os | |
| import copy | |
| import random | |
| import argparse | |
| import time | |
| import numpy as np | |
| from PIL import Image | |
| import scipy.io as scio | |
| import scipy.misc | |
| import torch | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from torch.autograd import Variable | |
| import torch.optim as optim | |
| import torch.nn as nn | |
| from torch.backends import cudnn | |
| from data_controller import SegDataset | |
| from loss import Loss | |
| from segnet import SegNet as segnet | |
| import sys | |
| sys.path.append("..") | |
| from lib.utils import setup_logger | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--dataset_root', default='/home/data1/jeremy/YCB_Video_Dataset', help="dataset root dir (''YCB_Video Dataset'')") | |
| parser.add_argument('--batch_size', default=3, help="batch size") | |
| parser.add_argument('--n_epochs', default=600, help="epochs to train") | |
| parser.add_argument('--workers', type=int, default=10, help='number of data loading workers') | |
| parser.add_argument('--lr', default=0.0001, help="learning rate") | |
| parser.add_argument('--logs_path', default='logs/', help="path to save logs") | |
| parser.add_argument('--model_save_path', default='trained_models/', help="path to save models") | |
| parser.add_argument('--log_dir', default='logs/', help="path to save logs") | |
| parser.add_argument('--resume_model', default='', help="resume model name") | |
| opt = parser.parse_args() | |
| if __name__ == '__main__': | |
| opt.manualSeed = random.randint(1, 10000) | |
| random.seed(opt.manualSeed) | |
| torch.manual_seed(opt.manualSeed) | |
| dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)) | |
| test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000) | |
| test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers)) | |
| print(len(dataset), len(test_dataset)) | |
| model = segnet() | |
| model = model.cuda() | |
| if opt.resume_model != '': | |
| checkpoint = torch.load('{0}/{1}'.format(opt.model_save_path, opt.resume_model)) | |
| model.load_state_dict(checkpoint) | |
| for log in os.listdir(opt.log_dir): | |
| os.remove(os.path.join(opt.log_dir, log)) | |
| optimizer = optim.Adam(model.parameters(), lr=opt.lr) | |
| criterion = Loss() | |
| best_val_cost = np.Inf | |
| st_time = time.time() | |
| for epoch in range(1, opt.n_epochs): | |
| model.train() | |
| train_all_cost = 0.0 | |
| train_time = 0 | |
| logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch)) | |
| logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started')) | |
| for i, data in enumerate(dataloader, 0): | |
| rgb, target = data | |
| rgb, target = Variable(rgb).cuda(), Variable(target).cuda() | |
| semantic = model(rgb) | |
| optimizer.zero_grad() | |
| semantic_loss = criterion(semantic, target) | |
| train_all_cost += semantic_loss.item() | |
| semantic_loss.backward() | |
| optimizer.step() | |
| logger.info('Train time {0} Batch {1} CEloss {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), train_time, semantic_loss.item())) | |
| if train_time != 0 and train_time % 1000 == 0: | |
| torch.save(model.state_dict(), os.path.join(opt.model_save_path, 'model_current.pth')) | |
| train_time += 1 | |
| train_all_cost = train_all_cost / train_time | |
| logger.info('Train Finish Avg CEloss: {0}'.format(train_all_cost)) | |
| model.eval() | |
| test_all_cost = 0.0 | |
| test_time = 0 | |
| logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch)) | |
| logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started')) | |
| for j, data in enumerate(test_dataloader, 0): | |
| rgb, target = data | |
| rgb, target = Variable(rgb).cuda(), Variable(target).cuda() | |
| semantic = model(rgb) | |
| semantic_loss = criterion(semantic, target) | |
| test_all_cost += semantic_loss.item() | |
| test_time += 1 | |
| logger.info('Test time {0} Batch {1} CEloss {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_time, semantic_loss.item())) | |
| test_all_cost = test_all_cost / test_time | |
| logger.info('Test Finish Avg CEloss: {0}'.format(test_all_cost)) | |
| if test_all_cost <= best_val_cost: | |
| best_val_cost = test_all_cost | |
| torch.save(model.state_dict(), os.path.join(opt.model_save_path, 'model_{}_{}.pth'.format(epoch, test_all_cost))) | |
| print('----------->BEST SAVED<-----------') | |