| import torch | |
| from torchvision import datasets | |
| import torchvision.transforms as transforms | |
| batch_size = 128 | |
| def data_transform(): | |
| transform_train = transforms.Compose([ | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomRotation(10), | |
| transforms.RandomCrop(32, padding=4), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| transform_test = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| return transform_train, transform_test | |
| def data_loader(transform_train, transform_test): | |
| train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) | |
| test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) | |
| train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) | |
| test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2) | |
| return train_loader, test_loader | |