from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision.transforms import RandomHorizontalFlip, RandomVerticalFlip, RandomAffine
from torchvision.datasets.utils import download_url
import os
import os.path as osp
import sys
import pickle
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
# 定义数据集类
class AugmentedCIFAR10(CIFAR10):
"""`CIFAR10` 数据集的扩充版本."""
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
super(AugmentedCIFAR10, self).__init__(root, train, transform, target_transform, download)
self.transform = transforms.Compose([
RandomHorizontalFlip(),
RandomVerticalFlip(),
RandomAffine(degrees=0, translate=(0.1, 0.1)),
ToTensor()
])
# 使用示例
data_dir = 'path/to/data'
train_set = AugmentedCIFAR10(data_dir, train=True, download=True)
test_set = AugmentedCIFAR10(data_dir, train=False, download=True)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=True)
# 开始训练或测试循环
for inputs, labels in train_loader:
# 使用 inputs 和 labels 进行训练
# ...
pass
这段代码定义了一个名为AugmentedCIFAR10
的扩充版本的CIFAR10数据集类,它在初始化时接收和CIFAR10相同的参数,并使用RandomHorizontalFlip
,RandomVerticalFlip
和RandomAffine
进行数据增强。在使用时,只需传入数据集的路径以及是否进行下载,就可以像使用标准CIFAR10数据集一样使用它,包括创建数据加载器和在训练循环中使用。