关键词
| 术语 | 英文 | 核心概念 |
|---|---|---|
| 生成对抗网络 | GAN | 生成器与判别器的对抗框架 |
| 生成器 | Generator | 从随机噪声生成假样本 |
| 判别器 | Discriminator | 区分真假样本 |
| 模式崩溃 | Mode Collapse | 生成样本多样性不足 |
| JS散度 | Jensen-Shannon Divergence | GAN损失函数的理论基础 |
| Wasserstein距离 | EM Distance | 改进GAN训练稳定性的度量 |
| 条件生成 | Conditional Generation | 基于条件信息的可控生成 |
| 渐进式生长 | Progressive Growing | 从低分辨率逐步训练到高分辨率 |
| 风格迁移 | Style Transfer | 控制生成图像的风格特征 |
| 循环一致性 | Cycle Consistency | 无配对数据转换的核心约束 |
| 一致性损失 | Consistency Loss | 保证转换前后语义保持的损失函数 |
| 谱归一化 | Spectral Normalization | 约束判别器Lipschitz常数的技巧 |
| 梯度惩罚 | Gradient Penalty | 替代权重裁剪的WGAN改进 |
| 小样本学习 | Few-shot Learning | 仅用少量样本训练生成模型 |
| Zero-shot学习 | Zero-shot Learning | 生成未见过的类别样本 |
0. 先聊聊:GAN到底是什么?
说实话,第一次看到GAN这个词的时候,我也是懵的。什么对抗网络?什么生成模型?听着高大上,但到底是个什么东西?
后来我看到Ian Goodfellow本人用一句话解释GAN,瞬间就懂了:
GAN就是让两个神经网络”左右互搏”,一个负责造假,一个负责打假,在对抗中共同变强。
就这么简单。让我给你讲个故事你就更清楚了。
造假者和鉴定师的故事
想象有个造假的,专门生产假画。他刚开始入行,水平很差,画的画一看就是假的,连我这种外行都能分辨出来。
然后有个鉴定师,专门鉴定画作真假。鉴定师也很菜,分不清真假,把假画当真画的概率跟扔硬币差不多。
有一天,造假者突发奇想:与其闭门造车,不如请鉴定师来”指导”一下。于是他把自己的假画给鉴定师看,鉴定师说”这是假的”。造假者就问:哪里假了?颜色?线条?还是整体感觉?
鉴定师虽然说不出所以然,但他能给出判断。造假者就根据这个反馈不断调整自己的技巧。慢慢地,他的假画越来越像那么回事了。
鉴定师也不是吃素的,他发现自己越来越难分辨真假了,于是也开始研究假画有哪些共同特征,努力提高自己的鉴别能力。
这个过程不断循环:造假者让假画更逼真,鉴定师让鉴别更精准。最后,造假者造出来的画,几乎达到了以假乱真的程度——这时候的鉴定师也成了顶级专家。
GAN就是这样工作的。只不过造假者是生成器(Generator),鉴定师是判别器(Discriminator),两者通过不断对抗,最终都能变得很强。
这就是GAN的核心思想:不是单独训练一个模型,而是让两个模型互相学习、互相提升。
为什么要”对抗”?不能直接学吗?
好问题。你可能会想:为什么不直接让模型学真实数据的分布?非要搞两个网络对抗干嘛?
因为这事儿真的很难。
假设你要生成一张猫的图片。真实世界里有无数种猫,各种姿势、各种颜色、各种背景……你想用一个函数直接描述”猫长什么样”,这个函数的复杂度超乎想象。
GAN的聪明之处在于,它把这个问题转化成了一个分类问题:判别器负责判断真假,生成器负责骗过判别器。通过对抗,生成器”被迫”学会了真实数据的分布。
打个比方:你想让一个人学会辨别红酒的好坏,正确的做法不是给他一本《红酒大全》让他死记硬背,而是让他不断品尝、不断对比、不断总结。对抗训练就是这种”在战斗中学习”的思想。
1. 对抗学习的诞生:改变游戏规则的工作
1.1 GAN的诞生背景
2014年,Ian Goodfellow还是蒙特利尔大学的一个博士生。他在酒吧里跟朋友讨论”怎么让机器自动生成图片”,突然灵光一现,想出了GAN的基本框架。
当晚他就回家写代码实现了原型,结果效果出奇的好——虽然还很简单,但已经证明了这个思路是可行的。
2014年6月,Goodfellow发表了那篇著名的论文《Generative Adversarial Networks》。当时没多少人关注,但很快就引起了轰动。这篇论文被引用了将近10万次,成为深度学习领域最具影响力的论文之一。
1.2 GAN的哲学思想
GAN的设计蕴含了深刻的哲学思想。道高一尺,魔高一丈——没有绝对的防御,也没有绝对的攻击,一切都是相对的、动态的。
这种思想在自然界中也很常见。比如猎豹和瞪羚的军备竞赛:猎豹跑得越来越快,瞪羚也必须跑得更快才能生存。两者互相促进,共同进化。
GAN就是这种思想的体现:生成器和判别器不是零和博弈,而是共同进化。当生成器变强时,判别器也必须变强;当判别器变强时,生成器又被迫继续进步。这种动态平衡最终会达到一个纳什均衡点。
1.3 生成模型三大流派对比
在深度学习的生成模型家族里,主要有三类选手:VAE、Diffusion Model和GAN。它们各有特点,也各有优缺点。
**VAE(变分自编码器)**是最早上场的选手。简单来说,VAE先 encoder 把图片压缩成一个低维向量,然后再 decoder 从这个向量还原出图片。它的好处是训练稳定,坏处是生成的图片通常比较模糊,细节不够。
想象一下:VAE就像一个画家,先把看到的东西记在脑子里(压缩),然后再画出来(生成)。这个过程难免会丢失细节。
**Diffusion Model(扩散模型)**是最近几年的大热门。它的工作方式是:先给图片不断加噪声,直到完全变成随机噪声,然后再学一个逆向过程——从噪声中逐步还原出图片。
这听起来很复杂,训练也慢得离谱——可能要几周时间。但它生成的图片质量非常高,而且训练非常稳定,不太会出现GAN那些幺蛾子问题。
打个比方:Diffusion就像一个人在练习素描,先把一张完美的画揉成纸团,再学着怎么把纸团还原成原来的画。这个过程虽然慢,但最后还原出来的画质量很高。
GAN 的思路完全不一样。它的核心是”对抗”,不需要重建损失函数。GAN生成图片只需要一次前向传播,速度飞快。但它的训练过程出了名的难搞定——容易训练崩溃、容易模式崩溃、需要小心平衡生成器和判别器的节奏。
打个比方:GAN就像两个人在下棋,一个人出招一个人应对,在你来我往中棋艺都提高了。速度快,但需要技巧和经验。
三种模型的对比总结:
| 特性 | VAE | Diffusion | GAN |
|---|---|---|---|
| 生成速度 | 快(单次) | 慢(几十到几百步) | 快(单次) |
| 生成质量 | 一般 | 很高 | 很高 |
| 训练稳定性 | 稳定 | 稳定 | 不稳定 |
| 训练难度 | 简单 | 复杂 | 很难 |
| 模式覆盖 | 较好 | 很好 | 容易崩溃 |
| 典型应用 | 潜在空间插值 | 高质量图像生成 | 实时生成、风格控制 |
现在的趋势是:Diffusion在图像生成质量上已经超越了GAN,比如DALL-E 3、Stable Diffusion都是Diffusion模型。但GAN因为生成速度快、可以进行细粒度的风格控制,在某些场景下仍然不可替代。而且GAN的思想影响了整个生成式AI领域。
2. 理解GAN的数学直觉
2.1 从生成器到判别器
GAN包含两个核心组件:
生成器(Generator, G):输入是一个随机向量 (通常服从正态分布或均匀分布),输出是一张”假”图片。生成器的任务是让自己的输出尽可能接近真实数据分布。
判别器(Discriminator, D):输入是一张图片,输出是一个概率值,表示这张图片是”真”的概率。判别器的任务是准确区分真实图片和生成图片。
训练过程如下:
- 判别器看真实图片,告诉它”这是真的”(期望输出=1)
- 判别器看生成图片,告诉它”这是假的”(期望输出=0)
- 生成器看了判别器的判断后,开始”进化”,生成更逼真的图片来骗过判别器
- 重复以上步骤
2.2 目标函数:对抗的数学表达
GAN的目标函数可以写成这样:
别被公式吓到,让我用大白话解释:
- 是我们要优化的目标
- :判别器想要最大化这个值,也就是尽可能正确区分真假
- :生成器想要最小化这个值,也就是尽可能骗过判别器
具体来说:
- 第一项 :对于真实图片 ,判别器 越接近1越好(对数越大)
- 第二项 :对于生成图片 ,我们希望 越接近0越好,这样 就越大
2.3 为什么对抗训练能work?
这是GAN最核心的问题:为什么让两个网络互相对抗,就能学到数据分布?
答案在于纳什均衡。
假设在理想的纳什均衡状态下,生成器学到了真实的数据分布 ,判别器只能随机猜测(因为真假分布完全一样),此时 。
可以证明,当 时,目标函数达到全局最优值 。所以生成器的目标就是让 尽可能接近 。
但现实没那么理想。原始GAN使用的是JS散度来衡量分布距离,而JS散度有个致命问题——当两个分布完全不重叠时,梯度会变成0。这就好比你考试得了0分,老师说”你还需要继续努力”,但不给任何具体反馈。
3. 从零实现GAN:PyTorch完整教程
3.1 最小可运行的GAN代码
让我先给你一个最简单版本的GAN,让你直观感受一下GAN是怎么工作的。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 超参数
latent_dim = 100 # 随机向量的维度
batch_size = 64
epochs = 50
learning_rate = 0.0002
image_size = 28 * 28 # MNIST图片是28x28
# 数据加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # 归一化到[-1, 1]
])
dataset = datasets.MNIST('./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)3.2 定义生成器和判别器
class Generator(nn.Module):
"""生成器:把随机向量变成图片"""
def __init__(self, latent_dim, output_dim):
super(Generator, self).__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, output_dim),
nn.Tanh() # 输出范围[-1, 1]
)
def forward(self, z):
return self.net(z)
class Discriminator(nn.Module):
"""判别器:判断图片是真是假"""
def __init__(self, input_dim):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid() # 输出概率[0, 1]
)
def forward(self, x):
return self.net(x)逐行解析:
- 生成器的输入:一个长度为
latent_dim的随机向量。比如输入[0.1, -0.5, 0.3, ...]这样100维的向量 - 生成器的输出:一张展平后的图片(28×28=784维),用
Tanh激活把值映射到 [-1, 1] - 判别器的输入:一张展平后的图片
- 判别器的输出:一个0到1之间的概率值,越接近1表示越像真图
3.3 训练循环
# 初始化网络
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator(latent_dim, image_size).to(device)
discriminator = Discriminator(image_size).to(device)
# 优化器
opt_g = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
opt_d = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
# 损失函数
criterion = nn.BCELoss()
for epoch in range(epochs):
for batch_idx, (real_images, _) in enumerate(dataloader):
batch_size = real_images.size(0)
real_images = real_images.view(batch_size, -1).to(device)
# 创建标签
real_labels = torch.ones(batch_size, 1).to(device) # 真图标签=1
fake_labels = torch.zeros(batch_size, 1).to(device) # 假图标签=0
# ========== 训练判别器 ==========
discriminator.zero_grad()
# 判别器看真图,应该输出1
real_output = discriminator(real_images)
d_loss_real = criterion(real_output, real_labels)
# 判别器看假图,应该输出0
noise = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(noise)
fake_output = discriminator(fake_images.detach()) # detach避免梯度传到G
d_loss_fake = criterion(fake_output, fake_labels)
# 总判别器损失
d_loss = (d_loss_real + d_loss_fake) / 2
d_loss.backward()
opt_d.step()
# ========== 训练生成器 ==========
generator.zero_grad()
# 生成器生成假图,希望判别器输出1(骗过判别器)
noise = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(noise)
fake_output = discriminator(fake_images)
g_loss = criterion(fake_output, real_labels) # 标签是1
g_loss.backward()
opt_g.step()
if batch_idx % 200 == 0:
print(f"Epoch [{epoch}/{epochs}] Batch [{batch_idx}] "
f"D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}")关键点解释:
-
为什么
fake_images.detach()? 因为我们更新判别器时,不应该让生成器的权重跟着变。detach() 就是切断梯度流。 -
为什么判别器损失要除以2? 这不是必须的,只是为了让判别器的学习速度和生成器更平衡。
-
生成器的标签为什么用1? 生成器希望骗过判别器,所以它希望判别器把自己生成的图片判断为真(标签=1)。
-
betas=(0.5, 0.999) 是什么? 这是Adam优化器的参数,控制梯度的指数加权平均。0.5是个比较小的值,让优化器更”激进”地适应最近的梯度变化。
3.4 查看生成效果
import matplotlib.pyplot as plt
def generate_and_plot(generator, latent_dim, num_images=16):
generator.eval()
with torch.no_grad():
noise = torch.randn(num_images, latent_dim).to(device)
fake_images = generator(noise)
fake_images = fake_images.view(-1, 28, 28).cpu().numpy()
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
ax.imshow(fake_images[i], cmap='gray')
ax.axis('off')
plt.show()
# 训练完后生成图片
generate_and_plot(generator, latent_dim)4. DCGAN实战:用深度卷积GAN生成图片
4.1 为什么需要DCGAN?
上面那个简单的全连接GAN生成MNIST数字还行,但如果想生成更复杂的图片,比如真人脸、CIFAR图片,就力不从心了。
问题在于全连接层的参数太多了,而且无法捕捉图片的空间结构特征。
DCGAN(Deep Convolutional GAN)就是来解决这个问题的。它把卷积神经网络引入GAN:
- 生成器用转置卷积(Transposed Convolution)来上采样
- 判别器用普通卷积来提取特征和下采样
- 通过卷积的局部连接特性,能更好地捕捉图像的空间结构
4.2 DCGAN的核心设计原则
Radford等人在2016年的论文中总结了DCGAN的设计经验:
生成器设计:
- 用转置卷积(也叫反卷积或fractionally strided convolution)替代全连接层
- 转置卷积能让特征图的空间尺寸翻倍
- 每个卷积层后加 BatchNorm(批归一化),稳定训练
- 激活函数用 ReLU,最后一层用 Tanh
判别器设计:
- 用步长卷积(strided convolution)替代池化层
- 同样加 BatchNorm(第一层除外)
- 激活函数用 LeakyReLU(负斜率0.2),避免梯度稀疏
- 最后用 Sigmoid 输出概率
4.3 DCGAN代码实现
class DCGenerator(nn.Module):
"""DCGAN生成器:使用转置卷积"""
def __init__(self, latent_dim, channels, features_g=64):
super(DCGenerator, self).__init__()
# 网络结构:latent_dim -> 4x4x(512*8) -> 8x8x(512*4) -> 16x16x(512*2) -> 32x32x512 -> 64x64xchannels
self.latent_dim = latent_dim
self.channels = channels
# 初始块:1x1 -> 4x4
self.initial = nn.Sequential(
nn.ConvTranspose2d(latent_dim, features_g * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(True)
)
# 上采样阶段
self.upsample1 = nn.Sequential(
nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(True)
)
self.upsample2 = nn.Sequential(
nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(True)
)
self.upsample3 = nn.Sequential(
nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g),
nn.ReLU(True)
)
self.final = nn.Sequential(
nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False),
nn.Tanh() # 输出范围[-1, 1]
)
def forward(self, z):
# z: (batch_size, latent_dim)
x = z.view(z.size(0), self.latent_dim, 1, 1) # reshape到4D
x = self.initial(x) # 4x4
x = self.upsample1(x) # 8x8
x = self.upsample2(x) # 16x16
x = self.upsample3(x) # 32x32
x = self.final(x) # 64x64
return x
class DCDiscriminator(nn.Module):
"""DCGAN判别器:使用步长卷积"""
def __init__(self, channels, features_d=64):
super(DCDiscriminator, self).__init__()
self.main = nn.Sequential(
# 输入: channels x 64 x 64
nn.Conv2d(channels, features_d, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# features_d -> features_d*2
nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 2),
nn.LeakyReLU(0.2, inplace=True),
# features_d*2 -> features_d*4
nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 4),
nn.LeakyReLU(0.2, inplace=True),
# features_d*4 -> features_d*8
nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 8),
nn.LeakyReLU(0.2, inplace=True),
# 最终输出: 1x1
nn.Conv2d(features_d * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x).view(-1, 1).squeeze(1)4.4 DCGAN训练技巧
DCGAN虽然比原始GAN稳定多了,但也有一些坑需要注意:
1. BatchNorm的坑
如果batch size太小(小于32),BatchNorm的效果会变差。所以建议使用较大的batch size,比如64或128。
2. 学习率的设置
原始论文建议使用 Adam 优化器,学习率0.0002,beta1=0.5。这个配置被广泛验证有效。
3. 标签平滑(Label Smoothing)
不要用硬标签0和1,稍微平滑一下效果更好:
real_labels = torch.ones_like(labels) * (0.9 + torch.rand_like(labels) * 0.1) # [0.9, 1.0]
fake_labels = torch.zeros_like(labels) + torch.rand_like(labels) * 0.1 # [0.0, 0.1]4. 平衡训练
如果判别器loss一直很低,说明它太强了,生成器得不到有效的梯度反馈。可以:
- 减少判别器的训练频率(比如每2步训练一次)
- 降低判别器的学习率
- 在判别器里多加几层
5. Mode Collapse问题:为什么生成器只会那几招?
5.1 问题描述
Mode Collapse(模式崩溃)是GAN训练中最让人头疼的问题之一。
举个例子:真实数据分布是一个双峰分布(比如MNIST里的数字4和数字9),但训练好的生成器只生成数字4,或者只生成数字9,忽略了另一个峰。
更严重的情况:生成器学会了”作弊”——它找到了一种骗过判别器的方法,但这种方法极其单一。比如不管输入什么随机向量,输出的都是同一张”万能脸”。
5.2 为什么会发生Mode Collapse?
根本原因:JS散度的局限性。
当生成器生成的分布 只覆盖了真实分布 的一部分时,判别器可以轻松地把这部分识别出来。但对于生成器来说,继续优化这一小部分比去学习其他模式更容易。
就好比一个学生发现背答案是最省力的考试方法,于是他就一直背答案,完全不学习真正理解知识。
5.3 解决方案
方案1:Minibatch Discrimination
让判别器不仅看单张图片,还看一整批图片的相似度。这样如果生成器只生成相似的图片,判别器会发现这批图片太”千篇一律”了。
class MinibatchDiscrimination(nn.Module):
"""批次判别模块"""
def __init__(self, input_dim, num_kernels, kernel_dim=5):
super().__init__()
self.num_kernels = num_kernels
self.kernel_dim = kernel_dim
self.mapping = nn.Linear(input_dim, num_kernels * kernel_dim)
def forward(self, x):
batch_size = x.size(0)
# 映射
mapped = self.mapping(x)
mapped = mapped.view(batch_size, self.num_kernels, self.kernel_dim)
# 计算样本间的L1距离
x1 = mapped.unsqueeze(2)
x2 = mapped.unsqueeze(1)
diff = torch.abs(x1 - x2)
distances = torch.sum(diff, dim=3)
# 转成相似度
similarities = torch.exp(-distances)
# 每个样本和其他样本的相似度之和
minibatch_features = torch.sum(similarities, dim=2) - 1
return torch.cat([x, minibatch_features], dim=1)方案2:Unrolled GAN
让生成器在更新时,能”看到”判别器未来会怎么更新。这样生成器就不会针对当前的判别器优化,而是针对未来的判别器优化。
def unrolled_loss(generator, discriminator, real_batch, latent_dim, unroll_steps=5):
"""Unrolled GAN损失"""
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
# 保存原始状态
saved_state = copy.deepcopy(discriminator.state_dict())
# 更新判别器若干步
for _ in range(unroll_steps):
optimizer_d.zero_grad()
fake_batch = generator(torch.randn(len(real_batch), latent_dim))
d_loss = -torch.mean(discriminator(real_batch)) + torch.mean(discriminator(fake_batch))
d_loss.backward()
optimizer_d.step()
# 用未来的判别器计算生成器损失
fake_batch = generator(torch.randn(len(real_batch), latent_dim))
g_loss = -torch.mean(discriminator(fake_batch))
# 恢复判别器
discriminator.load_state_dict(saved_state)
return g_loss方案3:WGAN/WGAN-GP
换用Wasserstein距离,从根本上解决JS散度的梯度消失问题。WGAN能提供更稳定的梯度,让生成器有动力去探索更多的模式。
6. Wasserstein GAN:更稳定的训练
6.1 Wasserstein距离的直观理解
WGAN的核心创新是用Wasserstein距离(也叫Earth-Mover距离)替代JS散度。
Wasserstein距离的物理含义是:要把一堆土从形状A变成形状B,需要搬运的土方量。分布越接近,需要搬运的越少。
关键优势是:即使两个分布完全不重叠,Wasserstein距离仍然能反映它们的远近。
这就好比:
- JS散度 = 只告诉你”对不对”,不说”差多少”
- Wasserstein距离 = 告诉你”差多少”,还告诉你”怎么补”
6.2 WGAN的实现
WGAN相比原始GAN,有几个关键变化:
- 判别器不使用Sigmoid激活:直接输出任意实数
- 损失函数变化:不再是log似然,而是直接用判别器的输出
- 权重裁剪(Weight Clipping):把判别器的参数限制在一个小范围内
class WGAN_Discriminator(nn.Module):
"""WGAN判别器(也叫Critic)"""
def __init__(self, img_shape):
super().__init__()
self.img_shape = img_shape
self.model = nn.Sequential(
nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1) # 无激活,输出任意实数
)
def forward(self, img):
return self.model(img.view(img.size(0), -1))
def train_wgan(generator, discriminator, dataloader, latent_dim, epochs, device):
"""WGAN训练"""
optimizer_G = optim.RMSprop(generator.parameters(), lr=0.00005)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=0.00005)
clip_value = 0.01 # 权重裁剪范围
for epoch in range(epochs):
for imgs, _ in dataloader:
batch_size = imgs.size(0)
imgs = imgs.view(batch_size, -1).to(device)
# 训练判别器(多更新几次)
for _ in range(5):
optimizer_D.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z)
# WGAN损失:真图片的得分高,假图片的得分低
d_loss = -torch.mean(discriminator(imgs)) + torch.mean(discriminator(fake_imgs))
d_loss.backward()
optimizer_D.step()
# 权重裁剪
for p in discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
# 训练生成器
optimizer_G.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z)
g_loss = -torch.mean(discriminator(fake_imgs))
g_loss.backward()
optimizer_G.step()6.3 WGAN-GP:梯度惩罚版本
权重裁剪有个问题:它会强迫判别器学习简单的函数,限制了表达能力。
WGAN-GP(Gradient Penalty WGAN)用梯度惩罚替代权重裁剪,效果更好:
def compute_gradient_penalty(discriminator, real_images, fake_images, device):
"""计算梯度惩罚"""
batch_size = real_images.size(0)
# 随机插值系数
alpha = torch.rand(batch_size, 1, 1, 1).to(device)
# 在真实和虚假之间插值
interpolates = alpha * real_images + (1 - alpha) * fake_images
interpolates = interpolates.requires_grad_(True)
d_interpolates = discriminator(interpolates)
# 计算梯度
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(d_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
# 惩罚项:梯度范数应该接近1
penalty = ((gradient_norm - 1) ** 2).mean()
return penalty
def train_wgan_gp(generator, discriminator, dataloader, latent_dim, epochs, device):
"""WGAN-GP训练"""
optimizer_G = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0004, betas=(0.0, 0.9))
lambda_gp = 10 # 梯度惩罚系数
for epoch in range(epochs):
for imgs, _ in dataloader:
batch_size = imgs.size(0)
real_images = imgs.to(device)
# 训练判别器
optimizer_D.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z)
d_real = discriminator(real_images)
d_fake = discriminator(fake_images.detach())
# WGAN损失 + 梯度惩罚
d_loss = d_fake.mean() - d_real.mean()
gp = compute_gradient_penalty(discriminator, real_images, fake_images, device)
total_d_loss = d_loss + lambda_gp * gp
total_d_loss.backward()
optimizer_D.step()
# 训练生成器(可以每两步训练一次)
optimizer_G.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z)
g_loss = -discriminator(fake_images).mean()
g_loss.backward()
optimizer_G.step()7. 条件生成CGAN:想生成什么就生成什么
7.1 什么是条件生成?
普通的GAN输入只有一个随机向量,输出完全是”随机”的。你可能生成一只猫,也可能生成一只狗,或者完全四不像。
条件GAN(Conditional GAN, cGAN)让你能控制生成的内容。比如输入”生成一只猫”,就真的生成猫;输入”生成数字7”,就真的生成7。
7.2 cGAN的实现
class ConditionalGenerator(nn.Module):
"""条件生成器"""
def __init__(self, latent_dim, num_classes, img_shape, embed_dim=50):
super().__init__()
self.img_shape = img_shape
# 类别嵌入层
self.label_emb = nn.Embedding(num_classes, embed_dim)
# 输入: 随机向量 + 类别嵌入
self.model = nn.Sequential(
nn.Linear(latent_dim + embed_dim, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
nn.Tanh()
)
def forward(self, z, labels):
label_emb = self.label_emb(labels)
combined = torch.cat([z, label_emb], dim=-1)
img = self.model(combined)
return img.view(img.size(0), *self.img_shape)
class ConditionalDiscriminator(nn.Module):
"""条件判别器"""
def __init__(self, num_classes, img_shape, embed_dim=50):
super().__init__()
self.img_shape = img_shape
self.label_emb = nn.Embedding(num_classes, int(torch.prod(torch.tensor(img_shape))))
self.model = nn.Sequential(
nn.Linear(int(torch.prod(torch.tensor(img_shape))) + embed_dim, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.4),
nn.Linear(512, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.4),
nn.Linear(512, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
label_emb = self.label_emb(labels)
combined = torch.cat([img.view(img.size(0), -1), label_emb], dim=-1)
return self.model(combined)使用示例:
# 生成数字5的图片
z = torch.randn(1, latent_dim).to(device)
target_label = torch.tensor([5]).to(device)
generated_img = generator(z, target_label)7.3 类别信息的其他注入方式
除了Embedding,还可以:
1. 类别条件BatchNorm:
class ConditionalBatchNorm(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.bn = nn.BatchNorm2d(num_features, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
def forward(self, x, class_id):
out = self.bn(x)
gamma, beta = self.embed(class_id).chunk(2, dim=-1)
return out * (gamma.view(-1, 1, 1) + 1) + beta.view(-1, 1, 1)2. 类别调制(Class Modulation):
把类别信息当作”风格向量”,对特征进行调制。
3. 辅助分类器(ACGAN):
判别器不仅判断真假,还额外预测类别:
class ACGAN_Discriminator(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = DCDiscriminator()
self.adv_head = nn.Linear(512, 1) # 真假判断
self.cls_head = nn.Linear(512, num_classes) # 类别分类
def forward(self, x):
features = self.features(x)
return self.adv_head(features), self.cls_head(features)8. Progressive GAN:从模糊到清晰
8.1 渐进式训练的思路
想象你学画画的过程:不是一开始就画细节,而是先画轮廓,再慢慢加细节。
Progressive GAN(ProGAN)就是用这个思路:
- 先用4×4分辨率训练,生成器判别器都很简单
- 等4×4学好了,加上8×8的训练
- 等8×8学好了,加上16×16
- 以此类推,直到1024×1024
这样每个阶段要学的东西都很少,训练稳定,而且低分辨率阶段学到的”大结构”知识会被保留。
8.2 平滑过渡
从一个分辨率切换到下一个分辨率时,不能突然切换,否则会导致训练崩溃。
ProGAN用的是alpha混合:在过渡期间,新层的输出和旧层的输出会按照alpha系数加权混合:
def forward(self, z, alpha=1.0, stage=1):
"""
alpha: 混合系数
stage: 当前阶段(从1开始)
"""
x = self.initial(z) # 4x4
for i in range(stage - 1):
x = self.upsample_blocks[i](x)
if alpha < 1.0:
# 过渡阶段:混合新旧输出
new_output = self.upsample_blocks[stage - 1](x)
old_output = self.to_rgb_old(x)
return alpha * new_output + (1 - alpha) * old_output
else:
# 稳定阶段
x = self.upsample_blocks[stage - 1](x)
return self.to_rgb_new(x)8.3 渐进式训练的优势
- 训练更稳定:每一步只需要学习小幅提升
- 速度更快:低分辨率阶段计算量小
- 减少模式崩溃:低分辨率已经覆盖了大尺度模式
- 可以生成高分辨率:1024×1024甚至更高
9. StyleGAN:控制生成的每一个细节
9.1 StyleGAN的核心创新
StyleGAN是GAN领域的里程碑工作,它实现了对生成图片细粒度的控制。
核心创新有三个:
1. 映射网络(Mapping Network)
原始的随机向量 可能各维度之间有相关性,比如”狗的大小”和”背景的亮度”可能绑在一起。映射网络把 转换成解耦的中间向量 ,让每个维度控制独立的特征。
2. 自适应实例归一化(AdaIN)
在生成器的每个分辨率级别,都注入一次风格信息。这个风格信息来自 ,通过调整均值和方差来控制该层生成的特征。
3. 噪声输入
每个分辨率级别还可以输入独立的噪声,用来增加细节变化。
9.2 为什么StyleGAN能控制风格?
不同分辨率级别,控制不同层级的特征:
- 4×4 - 8×8:最粗糙的尺度,控制姿态、脸型等大致轮廓
- 16×16 - 32×32:中等尺度,控制发型、肤色等
- 64×64 - 256×256:细节尺度,控制皮肤纹理、眼睛细节等
- 512×512 - 1024×1024:最精细的尺度,控制头发丝、背景等微观细节
所以如果你只修改高层(低分辨率)的 ,就会改变整体风格;只修改低层(高分辨率)的 ,只会改变细节。
9.3 StyleGAN2的改进
StyleGAN2主要改进了两点:
1. 移除伪影
StyleGAN1的AdaIN会产生”水滴”伪影。StyleGAN2用Weight Demodulation替代:
def forward(self, x, style):
# 样式调制
style = self.affine(style)
scale, bias = style.chunk(2, dim=-1)
# 权重归一化
weight = self.weight * scale.view(1, -1, 1, 1)
demod = weight / (weight.norm(dim=(2,3), keepdim=True) + 1e-8)
# 卷积
x = F.conv2d(x, demod, padding=1)
return x + bias.view(1, -1, 1, 1)2. 路径长度正则化
让潜在空间的等距移动对应图像空间的等距变化,使得插值更平滑。
9.4 StyleGAN3:消除锯齿
StyleGAN3发现了一个问题:生成的图片在平移时会出现”跳动”。原因是离散采样导致的aliasing(混叠)。
StyleGAN3通过使用sinc滤波器重新设计所有上采样和下采样操作,实现了真正的连续等变性——图片可以平滑平移而不会出现跳变。
10. GAN评估指标:怎么衡量生成的图片好不好?
10.1 Inception Score(IS)
Inception Score是最早的GAN评估指标:
直观理解:
- 如果图片质量高,Inception网络对它的分类应该很确定—— 的熵很低
- 如果图片多样性高,各类别的边缘分布应该均匀—— 的熵很高
- KL散度 = 高条件确定性 - 高边缘均匀性
计算方法:
def calculate_inception_score(images, model, num_splits=10):
model.eval()
preds = []
with torch.no_grad():
for img in images:
# 调整大小到299x299(Inception的输入尺寸)
img = F.interpolate(img.unsqueeze(0), size=(299, 299), mode='bilinear')
pred = F.softmax(model(img), dim=-1)
preds.append(pred.cpu().numpy())
preds = np.concatenate(preds)
# 计算每个split的IS
split_scores = []
split_size = len(preds) // num_splits
for i in range(num_splits):
part = preds[i * split_size:(i + 1) * split_size]
py = np.mean(part, axis=0)
kl = part * (np.log(part + 1e-8) - np.log(py + 1e-8))
kl = np.sum(kl, axis=1)
split_scores.append(np.exp(np.mean(kl)))
return np.mean(split_scores), np.std(split_scores)IS的局限:
- 只看生成图片,不和真实图片比较
- 可以通过”生成ImageNet里的某一类”来刷分,但这不代表真的学到了分布
10.2 Fréchet Inception Distance(FID)
FID是目前最流行的GAN评估指标。它比较真实图片和生成图片在特征空间中的分布差异。
思想:用Inception网络提取特征,假设这些特征服从高斯分布,然后计算两个高斯分布之间的Fréchet距离。
- :均值向量
- :协方差矩阵
- FID越小越好(0表示完美匹配)
计算方法:
def calculate_fid(real_images, fake_images, inception_model):
"""计算FID"""
real_acts = extract_features(real_images, inception_model)
fake_acts = extract_features(fake_images, inception_model)
mu_real = np.mean(real_acts, axis=0)
mu_fake = np.mean(fake_acts, axis=0)
sigma_real = np.cov(real_acts, rowvar=False)
sigma_fake = np.cov(fake_acts, rowvar=False)
# Fréchet距离
diff = mu_real - mu_fake
covmean = sqrtm(sigma_real @ sigma_fake)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
return fid
def sqrtm(matrix):
"""矩阵平方根的数值稳定计算"""
eigenvalues, eigenvectors = np.linalg.eigh(matrix)
return eigenvectors @ np.diag(np.sqrt(np.maximum(eigenvalues, 0))) @ eigenvectors.TFID的优势:
- 同时考虑真实和生成图片
- 对模式崩溃敏感(生成的图片单一会导致FID变高)
- 与人类感知相关性较好
10.3 Precision和Recall
FID只看分布距离,但无法区分”生成的图片质量高但覆盖不全”和”生成的图片质量低但覆盖全”。
Precision-Recall分析可以区分这两种情况:
- Precision(精确度):生成的图片里,有多少是”合格”的
- Recall(召回率):真实数据分布被覆盖了多少
def calculate_precision_recall(real_features, fake_features, k=3):
"""计算Precision和Recall"""
from sklearn.neighbors import NearestNeighbors
# k近邻
nn_real = NearestNeighbors(n_neighbors=k + 1).fit(real_features)
nn_fake = NearestNeighbors(n_neighbors=k + 1).fit(fake_features)
# 计算每个fake样本的precision
distances, indices = nn_real.kneighbors(fake_features)
# 如果fake样本的最近邻大多是real样本,则precision高
precision = np.mean([np.mean(idx[1:] < len(real_features)) for idx in indices])
# 计算recall
distances, indices = nn_fake.kneighbors(real_features)
recall = np.mean([np.mean(idx[1:] < len(fake_features)) for idx in indices])
return precision, recall11. GAN的应用:这些都在用GAN
11.1 艺术创作与设计
GAN在艺术领域的应用五花八门:
人脸生成与编辑:StyleGAN能生成极其逼真的人脸,还支持各种编辑操作——换发型、加眼镜、改年龄、调表情。Midjourney、DALL-E这些工具的背后,都有GAN的思想。
风格迁移:CycleGAN可以把马变成斑马、把夏天变成冬天。艺术家们用它来创作独特的视觉效果。
服装设计:时尚品牌用GAN来生成新款式,或者根据用户的喜好推荐设计。
音乐生成:虽然图片GAN更成熟,但AudioGAN也开始用于音乐创作,比如生成特定风格的音乐片段。
11.2 数据增强与医学影像
医疗影像有个老大难问题:数据太少了。GAN可以生成逼真的医学影像来扩充数据集:
class MedicalImageAugmentation:
"""医学影像数据增强"""
def __init__(self, generator, classifier):
self.generator = generator
self.classifier = classifier
def generate_samples(self, target_label, num_samples=100):
"""为特定类别生成样本"""
z = torch.randn(num_samples, self.generator.latent_dim)
labels = torch.full((num_samples,), target_label)
with torch.no_grad():
generated = self.generator(z, labels)
# 过滤:只保留分类器认为正确的
probs = self.classifier(generated)
valid_indices = (probs.argmax(dim=1) == target_label).nonzero().squeeze()
return generated[valid_indices]
def balance_dataset(self, dataset, target_count_per_class=500):
"""平衡数据集"""
# 对于样本不足的类别,用GAN补充
pass11.3 图像修复与超分辨率
图像修复(Inpainting):移除图片中的不需要的物体(比如路人、水印),并用合理的背景填充。GAN能生成语义正确、视觉自然的结果。
超分辨率(Super Resolution):把低分辨率图片放大并增强细节。ESRGAN(Enhanced SRGAN)是这个领域的经典工作。
去噪与恢复:去除老照片的噪点、划痕,恢复破损的图片。GAN能学习真实的纹理,生成自然的结果。
11.4 游戏与电影制作
角色生成:游戏开发者用GAN来生成NPC(非玩家角色)的脸型、皮肤纹理等,大幅减少美术工作量。
场景构建:用GAN生成游戏场景的纹理、背景,或者根据文字描述生成概念图。
特效生成:电影后期用GAN来生成火焰、烟雾、爆炸等特效的自然变化。
11.5 3D内容生成
NeRF + GAN:神经辐射场(NeRF)能生成高质量的3D视图,但训练慢。GAN可以用来加速NeRF的渲染,或者让NeRF生成更多样化的内容。
纹理生成:用GAN为3D模型生成贴图、材质。
可控角色生成:给定一个骨骼姿态,生成穿着对应衣服的角色。
12. 调试GAN的实用技巧
12.1 监控训练过程的指标
GAN的训练过程很复杂,需要监控多个指标:
class GANTrainingMonitor:
def __init__(self):
self.history = {
'd_loss_real': [],
'd_loss_fake': [],
'g_loss': [],
'd_accuracy': [],
'fid_scores': [] # 定期计算
}
def update(self, d_loss_real, d_loss_fake, g_loss, d_preds_real, d_preds_fake):
self.history['d_loss_real'].append(d_loss_real)
self.history['d_loss_fake'].append(d_loss_fake)
self.history['g_loss'].append(g_loss)
# 判别器准确率
acc_real = (d_preds_real > 0.5).float().mean().item()
acc_fake = (d_preds_fake < 0.5).float().mean().item()
self.history['d_accuracy'].append((acc_real + acc_fake) / 2)
def check_mode_collapse(self, fake_images):
"""检测模式崩溃"""
if len(fake_images) < 10:
return False
with torch.no_grad():
fake_flat = fake_images.view(len(fake_images), -1)
# 计算样本间的相似度
similarity = torch.mm(fake_flat, fake_flat.t())
# 如果样本太相似,相似度矩阵会很"极端"
std = similarity.std().item()
return std < 0.1 # 阈值可调12.2 常见问题排查
问题:判别器loss一直很低
- 说明判别器太强了,生成器得不到有效梯度
- 解决:降低判别器学习率,或减少判别器训练步数
问题:生成器loss一直很高且不下降
- 可能梯度消失或判别器太弱
- 解决:检查权重初始化,增加判别器深度,使用WGAN
问题:生成的图片出现明显的”模式崩溃”
- 生成器只生成少数几种图片
- 解决:使用Minibatch Discrimination,减小学习率,增加随机性
问题:训练刚开始就崩溃
- 学习率太高
- 解决:降低学习率,使用学习率预热(warmup)
问题:图片出现伪影或棋盘格效应
- 转置卷积的”棋盘格”问题
- 解决:使用PixelShuffle或上采样+卷积替代转置卷积
12.3 实用tricks清单
- 标签平滑:真实标签用0.9而不是1,假标签用0.1而不是0
- LeakyReLU:判别器用LeakyReLU(0.2)而非ReLU
- 批量归一化的玄学:如果batch size太小,试试InstanceNorm
- 特征匹配:生成器的损失可以加上一项,让假图片的特征统计量接近真图片
- 历史平均:对生成器的参数做历史平均,可以稳定训练
- TTUR:生成器和判别器用不同的学习率(通常是1:4)
13. 学术论文与参考文献
-
Goodfellow, I., et al. (2014). “Generative Adversarial Networks.” NeurIPS. — GAN的开山之作
-
Arjovsky, M., Chintala, S., & Bottou, L. (2017). “Wasserstein GAN.” ICML. — Wasserstein距离的里程碑
-
Radford, A., Metz, L., & Chintala, S. (2016). “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.” ICLR. — DCGAN
-
Mirza, M., & Osindero, S. (2014). “Conditional Generative Adversarial Nets.” arXiv. — 条件GAN
-
Karras, T., et al. (2018). “Progressive Growing of GANs for Improved Quality, Stability, and Variation.” ICLR. — ProGAN
-
Karras, T., et al. (2019). “A Style-Based Generator Architecture for Generative Adversarial Networks.” CVPR. — StyleGAN
-
Karras, T., et al. (2021). “Alias-Free Generative Adversarial Networks.” NeurIPS. — StyleGAN3
-
Isola, P., et al. (2017). “Image-to-Image Translation with Conditional Adversarial Networks.” CVPR. — Pix2Pix
-
Zhu, J.-Y., et al. (2017). “Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.” ICCV. — CycleGAN
-
Choi, Y., et al. (2018). “StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation.” CVPR. — StarGAN
-
Brock, A., Donahue, J., & Simonyan, K. (2019). “Large Scale GAN Training for High Fidelity Natural Image Synthesis.” ICLR. — BigGAN
-
Zhang, H., et al. (2019). “Self-Attention Generative Adversarial Networks.” ICML. — SAGAN
-
Gulrajani, I., et al. (2017). “Improved Training of Wasserstein GANs.” NeurIPS. — WGAN-GP
-
Mao, X., et al. (2017). “Least Squares Generative Adversarial Networks.” ICCV. — LSGAN
-
Miyato, T., et al. (2018). “Spectral Normalization for Generative Adversarial Networks.” ICLR. — 谱归一化GAN
14. 结语
GAN是深度学习领域最有趣的思想之一。它用”对抗”的框架,把生成模型这个难题转化成了一个可学习的博弈问题。
从2014年到现在,GAN已经走过了很长的路。从最初的简单全连接网络,到StyleGAN的超写实人脸,GAN的能力有了质的飞跃。
但GAN的难点——训练不稳定、模式崩溃、难以平衡——至今仍然存在。Diffusion模型在生成质量上已经超越GAN,但GAN因为生成速度快、可以进行细粒度控制,在很多场景下仍然是首选。
学习GAN,不只是学习一个算法,更是学习一种思维方式:通过对抗来学习,通过竞争来进步。这种思想在强化学习、多智能体系统、对抗样本等领域都有广泛应用。
希望这篇指南能帮助你理解GAN的核心思想,并有能力动手实现自己的GAN项目。如果有任何问题,欢迎在评论区讨论!