import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
# 超参数
batch_size = 64
epochs = 200
lr = 0.02
# MNIST数据集的转换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,)),
])
# 准备数据集
train_dataset = datasets.MNIST(root='./data',
train=True,
download=True,
transform=transform)
test_dataset = datasets.MNIST(root='./data',
train=False,
download=True,
transform=transform)
# 加载数据
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# GAN的生成器和判别器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, z):
img = self.fc(z)
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.fc(img_flat)
return validity
# 实例化生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 损失函数和优化器
optimizer_g = optim.Adam(generator.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)
# 训练函数
def train(epoch):
generator.train()
discriminator.train()
for batch_idx, (data, _) in enumerate(train_loader):
# 生成假样本
z = torch.randn(batch_size, 100)
评论已关闭