关键词

术语英文核心概念
生成对抗网络GAN生成器与判别器的对抗框架
生成器Generator从随机噪声生成假样本
判别器Discriminator区分真假样本
模式崩溃Mode Collapse生成样本多样性不足
JS散度Jensen-Shannon DivergenceGAN损失函数的理论基础
Wasserstein距离EM Distance改进GAN训练稳定性的度量
条件生成Conditional Generation基于条件信息的可控生成
渐进式生长Progressive Growing从低分辨率逐步训练到高分辨率
风格迁移Style Transfer控制生成图像的风格特征
循环一致性Cycle Consistency无配对数据转换的核心约束
一致性损失Consistency Loss保证转换前后语义保持的损失函数
谱归一化Spectral Normalization约束判别器Lipschitz常数的技巧
梯度惩罚Gradient Penalty替代权重裁剪的WGAN改进
小样本学习Few-shot Learning仅用少量样本训练生成模型
Zero-shot学习Zero-shot Learning生成未见过的类别样本

0. 先聊聊:GAN到底是什么?

说实话,第一次看到GAN这个词的时候,我也是懵的。什么对抗网络?什么生成模型?听着高大上,但到底是个什么东西?

后来我看到Ian Goodfellow本人用一句话解释GAN,瞬间就懂了:

GAN就是让两个神经网络”左右互搏”,一个负责造假,一个负责打假,在对抗中共同变强。

就这么简单。让我给你讲个故事你就更清楚了。

造假者和鉴定师的故事

想象有个造假的,专门生产假画。他刚开始入行,水平很差,画的画一看就是假的,连我这种外行都能分辨出来。

然后有个鉴定师,专门鉴定画作真假。鉴定师也很菜,分不清真假,把假画当真画的概率跟扔硬币差不多。

有一天,造假者突发奇想:与其闭门造车,不如请鉴定师来”指导”一下。于是他把自己的假画给鉴定师看,鉴定师说”这是假的”。造假者就问:哪里假了?颜色?线条?还是整体感觉?

鉴定师虽然说不出所以然,但他能给出判断。造假者就根据这个反馈不断调整自己的技巧。慢慢地,他的假画越来越像那么回事了。

鉴定师也不是吃素的,他发现自己越来越难分辨真假了,于是也开始研究假画有哪些共同特征,努力提高自己的鉴别能力。

这个过程不断循环:造假者让假画更逼真,鉴定师让鉴别更精准。最后,造假者造出来的画,几乎达到了以假乱真的程度——这时候的鉴定师也成了顶级专家。

GAN就是这样工作的。只不过造假者是生成器(Generator),鉴定师是判别器(Discriminator),两者通过不断对抗,最终都能变得很强。

这就是GAN的核心思想:不是单独训练一个模型,而是让两个模型互相学习、互相提升。

为什么要”对抗”?不能直接学吗?

好问题。你可能会想:为什么不直接让模型学真实数据的分布?非要搞两个网络对抗干嘛?

因为这事儿真的很难。

假设你要生成一张猫的图片。真实世界里有无数种猫,各种姿势、各种颜色、各种背景……你想用一个函数直接描述”猫长什么样”,这个函数的复杂度超乎想象。

GAN的聪明之处在于,它把这个问题转化成了一个分类问题:判别器负责判断真假,生成器负责骗过判别器。通过对抗,生成器”被迫”学会了真实数据的分布。

打个比方:你想让一个人学会辨别红酒的好坏,正确的做法不是给他一本《红酒大全》让他死记硬背,而是让他不断品尝、不断对比、不断总结。对抗训练就是这种”在战斗中学习”的思想。


1. 对抗学习的诞生:改变游戏规则的工作

1.1 GAN的诞生背景

2014年,Ian Goodfellow还是蒙特利尔大学的一个博士生。他在酒吧里跟朋友讨论”怎么让机器自动生成图片”,突然灵光一现,想出了GAN的基本框架。

当晚他就回家写代码实现了原型,结果效果出奇的好——虽然还很简单,但已经证明了这个思路是可行的。

2014年6月,Goodfellow发表了那篇著名的论文《Generative Adversarial Networks》。当时没多少人关注,但很快就引起了轰动。这篇论文被引用了将近10万次,成为深度学习领域最具影响力的论文之一。

1.2 GAN的哲学思想

GAN的设计蕴含了深刻的哲学思想。道高一尺,魔高一丈——没有绝对的防御,也没有绝对的攻击,一切都是相对的、动态的。

这种思想在自然界中也很常见。比如猎豹和瞪羚的军备竞赛:猎豹跑得越来越快,瞪羚也必须跑得更快才能生存。两者互相促进,共同进化。

GAN就是这种思想的体现:生成器和判别器不是零和博弈,而是共同进化。当生成器变强时,判别器也必须变强;当判别器变强时,生成器又被迫继续进步。这种动态平衡最终会达到一个纳什均衡点。

1.3 生成模型三大流派对比

在深度学习的生成模型家族里,主要有三类选手:VAE、Diffusion Model和GAN。它们各有特点,也各有优缺点。

**VAE(变分自编码器)**是最早上场的选手。简单来说,VAE先 encoder 把图片压缩成一个低维向量,然后再 decoder 从这个向量还原出图片。它的好处是训练稳定,坏处是生成的图片通常比较模糊,细节不够。

想象一下:VAE就像一个画家,先把看到的东西记在脑子里(压缩),然后再画出来(生成)。这个过程难免会丢失细节。

**Diffusion Model(扩散模型)**是最近几年的大热门。它的工作方式是:先给图片不断加噪声,直到完全变成随机噪声,然后再学一个逆向过程——从噪声中逐步还原出图片。

这听起来很复杂,训练也慢得离谱——可能要几周时间。但它生成的图片质量非常高,而且训练非常稳定,不太会出现GAN那些幺蛾子问题。

打个比方:Diffusion就像一个人在练习素描,先把一张完美的画揉成纸团,再学着怎么把纸团还原成原来的画。这个过程虽然慢,但最后还原出来的画质量很高。

GAN 的思路完全不一样。它的核心是”对抗”,不需要重建损失函数。GAN生成图片只需要一次前向传播,速度飞快。但它的训练过程出了名的难搞定——容易训练崩溃、容易模式崩溃、需要小心平衡生成器和判别器的节奏。

打个比方:GAN就像两个人在下棋,一个人出招一个人应对,在你来我往中棋艺都提高了。速度快,但需要技巧和经验。

三种模型的对比总结:

特性VAEDiffusionGAN
生成速度快(单次)慢(几十到几百步)快(单次)
生成质量一般很高很高
训练稳定性稳定稳定不稳定
训练难度简单复杂很难
模式覆盖较好很好容易崩溃
典型应用潜在空间插值高质量图像生成实时生成、风格控制

现在的趋势是:Diffusion在图像生成质量上已经超越了GAN,比如DALL-E 3、Stable Diffusion都是Diffusion模型。但GAN因为生成速度快、可以进行细粒度的风格控制,在某些场景下仍然不可替代。而且GAN的思想影响了整个生成式AI领域。


2. 理解GAN的数学直觉

2.1 从生成器到判别器

GAN包含两个核心组件:

生成器(Generator, G):输入是一个随机向量 (通常服从正态分布或均匀分布),输出是一张”假”图片。生成器的任务是让自己的输出尽可能接近真实数据分布。

判别器(Discriminator, D):输入是一张图片,输出是一个概率值,表示这张图片是”真”的概率。判别器的任务是准确区分真实图片和生成图片。

训练过程如下:

  1. 判别器看真实图片,告诉它”这是真的”(期望输出=1)
  2. 判别器看生成图片,告诉它”这是假的”(期望输出=0)
  3. 生成器看了判别器的判断后,开始”进化”,生成更逼真的图片来骗过判别器
  4. 重复以上步骤

2.2 目标函数:对抗的数学表达

GAN的目标函数可以写成这样:

别被公式吓到,让我用大白话解释:

  • 是我们要优化的目标
  • :判别器想要最大化这个值,也就是尽可能正确区分真假
  • :生成器想要最小化这个值,也就是尽可能骗过判别器

具体来说:

  • 第一项 :对于真实图片 ,判别器 越接近1越好(对数越大)
  • 第二项 :对于生成图片 ,我们希望 越接近0越好,这样 就越大

2.3 为什么对抗训练能work?

这是GAN最核心的问题:为什么让两个网络互相对抗,就能学到数据分布?

答案在于纳什均衡

假设在理想的纳什均衡状态下,生成器学到了真实的数据分布 ,判别器只能随机猜测(因为真假分布完全一样),此时

可以证明,当 时,目标函数达到全局最优值 。所以生成器的目标就是让 尽可能接近

但现实没那么理想。原始GAN使用的是JS散度来衡量分布距离,而JS散度有个致命问题——当两个分布完全不重叠时,梯度会变成0。这就好比你考试得了0分,老师说”你还需要继续努力”,但不给任何具体反馈。


3. 从零实现GAN:PyTorch完整教程

3.1 最小可运行的GAN代码

让我先给你一个最简单版本的GAN,让你直观感受一下GAN是怎么工作的。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
 
# 超参数
latent_dim = 100        # 随机向量的维度
batch_size = 64
epochs = 50
learning_rate = 0.0002
image_size = 28 * 28    # MNIST图片是28x28
 
# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # 归一化到[-1, 1]
])
dataset = datasets.MNIST('./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

3.2 定义生成器和判别器

class Generator(nn.Module):
    """生成器:把随机向量变成图片"""
    def __init__(self, latent_dim, output_dim):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, output_dim),
            nn.Tanh()  # 输出范围[-1, 1]
        )
    
    def forward(self, z):
        return self.net(z)
 
class Discriminator(nn.Module):
    """判别器:判断图片是真是假"""
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出概率[0, 1]
        )
    
    def forward(self, x):
        return self.net(x)

逐行解析

  1. 生成器的输入:一个长度为 latent_dim 的随机向量。比如输入 [0.1, -0.5, 0.3, ...] 这样100维的向量
  2. 生成器的输出:一张展平后的图片(28×28=784维),用 Tanh 激活把值映射到 [-1, 1]
  3. 判别器的输入:一张展平后的图片
  4. 判别器的输出:一个0到1之间的概率值,越接近1表示越像真图

3.3 训练循环

# 初始化网络
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator(latent_dim, image_size).to(device)
discriminator = Discriminator(image_size).to(device)
 
# 优化器
opt_g = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
opt_d = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
 
# 损失函数
criterion = nn.BCELoss()
 
for epoch in range(epochs):
    for batch_idx, (real_images, _) in enumerate(dataloader):
        batch_size = real_images.size(0)
        real_images = real_images.view(batch_size, -1).to(device)
        
        # 创建标签
        real_labels = torch.ones(batch_size, 1).to(device)   # 真图标签=1
        fake_labels = torch.zeros(batch_size, 1).to(device)  # 假图标签=0
        
        # ========== 训练判别器 ==========
        discriminator.zero_grad()
        
        # 判别器看真图,应该输出1
        real_output = discriminator(real_images)
        d_loss_real = criterion(real_output, real_labels)
        
        # 判别器看假图,应该输出0
        noise = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(noise)
        fake_output = discriminator(fake_images.detach())  # detach避免梯度传到G
        d_loss_fake = criterion(fake_output, fake_labels)
        
        # 总判别器损失
        d_loss = (d_loss_real + d_loss_fake) / 2
        d_loss.backward()
        opt_d.step()
        
        # ========== 训练生成器 ==========
        generator.zero_grad()
        
        # 生成器生成假图,希望判别器输出1(骗过判别器)
        noise = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(noise)
        fake_output = discriminator(fake_images)
        g_loss = criterion(fake_output, real_labels)  # 标签是1
        
        g_loss.backward()
        opt_g.step()
        
        if batch_idx % 200 == 0:
            print(f"Epoch [{epoch}/{epochs}] Batch [{batch_idx}] "
                  f"D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}")

关键点解释

  1. 为什么 fake_images.detach() 因为我们更新判别器时,不应该让生成器的权重跟着变。detach() 就是切断梯度流。

  2. 为什么判别器损失要除以2? 这不是必须的,只是为了让判别器的学习速度和生成器更平衡。

  3. 生成器的标签为什么用1? 生成器希望骗过判别器,所以它希望判别器把自己生成的图片判断为真(标签=1)。

  4. betas=(0.5, 0.999) 是什么? 这是Adam优化器的参数,控制梯度的指数加权平均。0.5是个比较小的值,让优化器更”激进”地适应最近的梯度变化。

3.4 查看生成效果

import matplotlib.pyplot as plt
 
def generate_and_plot(generator, latent_dim, num_images=16):
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_images, latent_dim).to(device)
        fake_images = generator(noise)
        fake_images = fake_images.view(-1, 28, 28).cpu().numpy()
        
        fig, axes = plt.subplots(4, 4, figsize=(8, 8))
        for i, ax in enumerate(axes.flat):
            ax.imshow(fake_images[i], cmap='gray')
            ax.axis('off')
        plt.show()
 
# 训练完后生成图片
generate_and_plot(generator, latent_dim)

4. DCGAN实战:用深度卷积GAN生成图片

4.1 为什么需要DCGAN?

上面那个简单的全连接GAN生成MNIST数字还行,但如果想生成更复杂的图片,比如真人脸、CIFAR图片,就力不从心了。

问题在于全连接层的参数太多了,而且无法捕捉图片的空间结构特征。

DCGAN(Deep Convolutional GAN)就是来解决这个问题的。它把卷积神经网络引入GAN:

  • 生成器用转置卷积(Transposed Convolution)来上采样
  • 判别器用普通卷积来提取特征和下采样
  • 通过卷积的局部连接特性,能更好地捕捉图像的空间结构

4.2 DCGAN的核心设计原则

Radford等人在2016年的论文中总结了DCGAN的设计经验:

生成器设计

  • 用转置卷积(也叫反卷积或fractionally strided convolution)替代全连接层
  • 转置卷积能让特征图的空间尺寸翻倍
  • 每个卷积层后加 BatchNorm(批归一化),稳定训练
  • 激活函数用 ReLU,最后一层用 Tanh

判别器设计

  • 用步长卷积(strided convolution)替代池化层
  • 同样加 BatchNorm(第一层除外)
  • 激活函数用 LeakyReLU(负斜率0.2),避免梯度稀疏
  • 最后用 Sigmoid 输出概率

4.3 DCGAN代码实现

class DCGenerator(nn.Module):
    """DCGAN生成器:使用转置卷积"""
    def __init__(self, latent_dim, channels, features_g=64):
        super(DCGenerator, self).__init__()
        
        # 网络结构:latent_dim -> 4x4x(512*8) -> 8x8x(512*4) -> 16x16x(512*2) -> 32x32x512 -> 64x64xchannels
        
        self.latent_dim = latent_dim
        self.channels = channels
        
        # 初始块:1x1 -> 4x4
        self.initial = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, features_g * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(True)
        )
        
        # 上采样阶段
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(True)
        )
        
        self.upsample2 = nn.Sequential(
            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(True)
        )
        
        self.upsample3 = nn.Sequential(
            nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g),
            nn.ReLU(True)
        )
        
        self.final = nn.Sequential(
            nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False),
            nn.Tanh()  # 输出范围[-1, 1]
        )
    
    def forward(self, z):
        # z: (batch_size, latent_dim)
        x = z.view(z.size(0), self.latent_dim, 1, 1)  # reshape到4D
        x = self.initial(x)    # 4x4
        x = self.upsample1(x)  # 8x8
        x = self.upsample2(x)  # 16x16
        x = self.upsample3(x)  # 32x32
        x = self.final(x)      # 64x64
        return x
 
class DCDiscriminator(nn.Module):
    """DCGAN判别器:使用步长卷积"""
    def __init__(self, channels, features_d=64):
        super(DCDiscriminator, self).__init__()
        
        self.main = nn.Sequential(
            # 输入: channels x 64 x 64
            nn.Conv2d(channels, features_d, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # features_d -> features_d*2
            nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # features_d*2 -> features_d*4
            nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # features_d*4 -> features_d*8
            nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 最终输出: 1x1
            nn.Conv2d(features_d * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.main(x).view(-1, 1).squeeze(1)

4.4 DCGAN训练技巧

DCGAN虽然比原始GAN稳定多了,但也有一些坑需要注意:

1. BatchNorm的坑

如果batch size太小(小于32),BatchNorm的效果会变差。所以建议使用较大的batch size,比如64或128。

2. 学习率的设置

原始论文建议使用 Adam 优化器,学习率0.0002,beta1=0.5。这个配置被广泛验证有效。

3. 标签平滑(Label Smoothing)

不要用硬标签0和1,稍微平滑一下效果更好:

real_labels = torch.ones_like(labels) * (0.9 + torch.rand_like(labels) * 0.1)  # [0.9, 1.0]
fake_labels = torch.zeros_like(labels) + torch.rand_like(labels) * 0.1  # [0.0, 0.1]

4. 平衡训练

如果判别器loss一直很低,说明它太强了,生成器得不到有效的梯度反馈。可以:

  • 减少判别器的训练频率(比如每2步训练一次)
  • 降低判别器的学习率
  • 在判别器里多加几层

5. Mode Collapse问题:为什么生成器只会那几招?

5.1 问题描述

Mode Collapse(模式崩溃)是GAN训练中最让人头疼的问题之一。

举个例子:真实数据分布是一个双峰分布(比如MNIST里的数字4和数字9),但训练好的生成器只生成数字4,或者只生成数字9,忽略了另一个峰。

更严重的情况:生成器学会了”作弊”——它找到了一种骗过判别器的方法,但这种方法极其单一。比如不管输入什么随机向量,输出的都是同一张”万能脸”。

5.2 为什么会发生Mode Collapse?

根本原因:JS散度的局限性。

当生成器生成的分布 只覆盖了真实分布 的一部分时,判别器可以轻松地把这部分识别出来。但对于生成器来说,继续优化这一小部分比去学习其他模式更容易

就好比一个学生发现背答案是最省力的考试方法,于是他就一直背答案,完全不学习真正理解知识。

5.3 解决方案

方案1:Minibatch Discrimination

让判别器不仅看单张图片,还看一整批图片的相似度。这样如果生成器只生成相似的图片,判别器会发现这批图片太”千篇一律”了。

class MinibatchDiscrimination(nn.Module):
    """批次判别模块"""
    def __init__(self, input_dim, num_kernels, kernel_dim=5):
        super().__init__()
        self.num_kernels = num_kernels
        self.kernel_dim = kernel_dim
        self.mapping = nn.Linear(input_dim, num_kernels * kernel_dim)
    
    def forward(self, x):
        batch_size = x.size(0)
        # 映射
        mapped = self.mapping(x)
        mapped = mapped.view(batch_size, self.num_kernels, self.kernel_dim)
        
        # 计算样本间的L1距离
        x1 = mapped.unsqueeze(2)
        x2 = mapped.unsqueeze(1)
        diff = torch.abs(x1 - x2)
        distances = torch.sum(diff, dim=3)
        
        # 转成相似度
        similarities = torch.exp(-distances)
        
        # 每个样本和其他样本的相似度之和
        minibatch_features = torch.sum(similarities, dim=2) - 1
        
        return torch.cat([x, minibatch_features], dim=1)

方案2:Unrolled GAN

让生成器在更新时,能”看到”判别器未来会怎么更新。这样生成器就不会针对当前的判别器优化,而是针对未来的判别器优化。

def unrolled_loss(generator, discriminator, real_batch, latent_dim, unroll_steps=5):
    """Unrolled GAN损失"""
    optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
    
    # 保存原始状态
    saved_state = copy.deepcopy(discriminator.state_dict())
    
    # 更新判别器若干步
    for _ in range(unroll_steps):
        optimizer_d.zero_grad()
        fake_batch = generator(torch.randn(len(real_batch), latent_dim))
        d_loss = -torch.mean(discriminator(real_batch)) + torch.mean(discriminator(fake_batch))
        d_loss.backward()
        optimizer_d.step()
    
    # 用未来的判别器计算生成器损失
    fake_batch = generator(torch.randn(len(real_batch), latent_dim))
    g_loss = -torch.mean(discriminator(fake_batch))
    
    # 恢复判别器
    discriminator.load_state_dict(saved_state)
    
    return g_loss

方案3:WGAN/WGAN-GP

换用Wasserstein距离,从根本上解决JS散度的梯度消失问题。WGAN能提供更稳定的梯度,让生成器有动力去探索更多的模式。


6. Wasserstein GAN:更稳定的训练

6.1 Wasserstein距离的直观理解

WGAN的核心创新是用Wasserstein距离(也叫Earth-Mover距离)替代JS散度。

Wasserstein距离的物理含义是:要把一堆土从形状A变成形状B,需要搬运的土方量。分布越接近,需要搬运的越少。

关键优势是:即使两个分布完全不重叠,Wasserstein距离仍然能反映它们的远近

这就好比:

  • JS散度 = 只告诉你”对不对”,不说”差多少”
  • Wasserstein距离 = 告诉你”差多少”,还告诉你”怎么补”

6.2 WGAN的实现

WGAN相比原始GAN,有几个关键变化:

  1. 判别器不使用Sigmoid激活:直接输出任意实数
  2. 损失函数变化:不再是log似然,而是直接用判别器的输出
  3. 权重裁剪(Weight Clipping):把判别器的参数限制在一个小范围内
class WGAN_Discriminator(nn.Module):
    """WGAN判别器(也叫Critic)"""
    def __init__(self, img_shape):
        super().__init__()
        self.img_shape = img_shape
        
        self.model = nn.Sequential(
            nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)  # 无激活,输出任意实数
        )
    
    def forward(self, img):
        return self.model(img.view(img.size(0), -1))
 
def train_wgan(generator, discriminator, dataloader, latent_dim, epochs, device):
    """WGAN训练"""
    optimizer_G = optim.RMSprop(generator.parameters(), lr=0.00005)
    optimizer_D = optim.RMSprop(discriminator.parameters(), lr=0.00005)
    
    clip_value = 0.01  # 权重裁剪范围
    
    for epoch in range(epochs):
        for imgs, _ in dataloader:
            batch_size = imgs.size(0)
            imgs = imgs.view(batch_size, -1).to(device)
            
            # 训练判别器(多更新几次)
            for _ in range(5):
                optimizer_D.zero_grad()
                z = torch.randn(batch_size, latent_dim).to(device)
                fake_imgs = generator(z)
                
                # WGAN损失:真图片的得分高,假图片的得分低
                d_loss = -torch.mean(discriminator(imgs)) + torch.mean(discriminator(fake_imgs))
                d_loss.backward()
                optimizer_D.step()
                
                # 权重裁剪
                for p in discriminator.parameters():
                    p.data.clamp_(-clip_value, clip_value)
            
            # 训练生成器
            optimizer_G.zero_grad()
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = generator(z)
            g_loss = -torch.mean(discriminator(fake_imgs))
            g_loss.backward()
            optimizer_G.step()

6.3 WGAN-GP:梯度惩罚版本

权重裁剪有个问题:它会强迫判别器学习简单的函数,限制了表达能力。

WGAN-GP(Gradient Penalty WGAN)用梯度惩罚替代权重裁剪,效果更好:

def compute_gradient_penalty(discriminator, real_images, fake_images, device):
    """计算梯度惩罚"""
    batch_size = real_images.size(0)
    
    # 随机插值系数
    alpha = torch.rand(batch_size, 1, 1, 1).to(device)
    
    # 在真实和虚假之间插值
    interpolates = alpha * real_images + (1 - alpha) * fake_images
    interpolates = interpolates.requires_grad_(True)
    
    d_interpolates = discriminator(interpolates)
    
    # 计算梯度
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    
    # 惩罚项:梯度范数应该接近1
    penalty = ((gradient_norm - 1) ** 2).mean()
    
    return penalty
 
def train_wgan_gp(generator, discriminator, dataloader, latent_dim, epochs, device):
    """WGAN-GP训练"""
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.0, 0.9))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0004, betas=(0.0, 0.9))
    
    lambda_gp = 10  # 梯度惩罚系数
    
    for epoch in range(epochs):
        for imgs, _ in dataloader:
            batch_size = imgs.size(0)
            real_images = imgs.to(device)
            
            # 训练判别器
            optimizer_D.zero_grad()
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_images = generator(z)
            
            d_real = discriminator(real_images)
            d_fake = discriminator(fake_images.detach())
            
            # WGAN损失 + 梯度惩罚
            d_loss = d_fake.mean() - d_real.mean()
            gp = compute_gradient_penalty(discriminator, real_images, fake_images, device)
            total_d_loss = d_loss + lambda_gp * gp
            
            total_d_loss.backward()
            optimizer_D.step()
            
            # 训练生成器(可以每两步训练一次)
            optimizer_G.zero_grad()
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_images = generator(z)
            g_loss = -discriminator(fake_images).mean()
            g_loss.backward()
            optimizer_G.step()

7. 条件生成CGAN:想生成什么就生成什么

7.1 什么是条件生成?

普通的GAN输入只有一个随机向量,输出完全是”随机”的。你可能生成一只猫,也可能生成一只狗,或者完全四不像。

条件GAN(Conditional GAN, cGAN)让你能控制生成的内容。比如输入”生成一只猫”,就真的生成猫;输入”生成数字7”,就真的生成7。

7.2 cGAN的实现

class ConditionalGenerator(nn.Module):
    """条件生成器"""
    def __init__(self, latent_dim, num_classes, img_shape, embed_dim=50):
        super().__init__()
        self.img_shape = img_shape
        
        # 类别嵌入层
        self.label_emb = nn.Embedding(num_classes, embed_dim)
        
        # 输入: 随机向量 + 类别嵌入
        self.model = nn.Sequential(
            nn.Linear(latent_dim + embed_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        label_emb = self.label_emb(labels)
        combined = torch.cat([z, label_emb], dim=-1)
        img = self.model(combined)
        return img.view(img.size(0), *self.img_shape)
 
class ConditionalDiscriminator(nn.Module):
    """条件判别器"""
    def __init__(self, num_classes, img_shape, embed_dim=50):
        super().__init__()
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, int(torch.prod(torch.tensor(img_shape))))
        
        self.model = nn.Sequential(
            nn.Linear(int(torch.prod(torch.tensor(img_shape))) + embed_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.Linear(512, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        label_emb = self.label_emb(labels)
        combined = torch.cat([img.view(img.size(0), -1), label_emb], dim=-1)
        return self.model(combined)

使用示例

# 生成数字5的图片
z = torch.randn(1, latent_dim).to(device)
target_label = torch.tensor([5]).to(device)
generated_img = generator(z, target_label)

7.3 类别信息的其他注入方式

除了Embedding,还可以:

1. 类别条件BatchNorm

class ConditionalBatchNorm(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.bn = nn.BatchNorm2d(num_features, affine=False)
        self.embed = nn.Embedding(num_classes, num_features * 2)
        
    def forward(self, x, class_id):
        out = self.bn(x)
        gamma, beta = self.embed(class_id).chunk(2, dim=-1)
        return out * (gamma.view(-1, 1, 1) + 1) + beta.view(-1, 1, 1)

2. 类别调制(Class Modulation)

把类别信息当作”风格向量”,对特征进行调制。

3. 辅助分类器(ACGAN)

判别器不仅判断真假,还额外预测类别:

class ACGAN_Discriminator(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = DCDiscriminator()
        self.adv_head = nn.Linear(512, 1)      # 真假判断
        self.cls_head = nn.Linear(512, num_classes)  # 类别分类
    
    def forward(self, x):
        features = self.features(x)
        return self.adv_head(features), self.cls_head(features)

8. Progressive GAN:从模糊到清晰

8.1 渐进式训练的思路

想象你学画画的过程:不是一开始就画细节,而是先画轮廓,再慢慢加细节。

Progressive GAN(ProGAN)就是用这个思路:

  1. 先用4×4分辨率训练,生成器判别器都很简单
  2. 等4×4学好了,加上8×8的训练
  3. 等8×8学好了,加上16×16
  4. 以此类推,直到1024×1024

这样每个阶段要学的东西都很少,训练稳定,而且低分辨率阶段学到的”大结构”知识会被保留。

8.2 平滑过渡

从一个分辨率切换到下一个分辨率时,不能突然切换,否则会导致训练崩溃。

ProGAN用的是alpha混合:在过渡期间,新层的输出和旧层的输出会按照alpha系数加权混合:

def forward(self, z, alpha=1.0, stage=1):
    """
    alpha: 混合系数
    stage: 当前阶段(从1开始)
    """
    x = self.initial(z)  # 4x4
    
    for i in range(stage - 1):
        x = self.upsample_blocks[i](x)
    
    if alpha < 1.0:
        # 过渡阶段:混合新旧输出
        new_output = self.upsample_blocks[stage - 1](x)
        old_output = self.to_rgb_old(x)
        return alpha * new_output + (1 - alpha) * old_output
    else:
        # 稳定阶段
        x = self.upsample_blocks[stage - 1](x)
        return self.to_rgb_new(x)

8.3 渐进式训练的优势

  1. 训练更稳定:每一步只需要学习小幅提升
  2. 速度更快:低分辨率阶段计算量小
  3. 减少模式崩溃:低分辨率已经覆盖了大尺度模式
  4. 可以生成高分辨率:1024×1024甚至更高

9. StyleGAN:控制生成的每一个细节

9.1 StyleGAN的核心创新

StyleGAN是GAN领域的里程碑工作,它实现了对生成图片细粒度的控制。

核心创新有三个:

1. 映射网络(Mapping Network)

原始的随机向量 可能各维度之间有相关性,比如”狗的大小”和”背景的亮度”可能绑在一起。映射网络把 转换成解耦的中间向量 ,让每个维度控制独立的特征。

2. 自适应实例归一化(AdaIN)

在生成器的每个分辨率级别,都注入一次风格信息。这个风格信息来自 ,通过调整均值和方差来控制该层生成的特征。

3. 噪声输入

每个分辨率级别还可以输入独立的噪声,用来增加细节变化。

9.2 为什么StyleGAN能控制风格?

不同分辨率级别,控制不同层级的特征:

  • 4×4 - 8×8:最粗糙的尺度,控制姿态、脸型等大致轮廓
  • 16×16 - 32×32:中等尺度,控制发型、肤色等
  • 64×64 - 256×256:细节尺度,控制皮肤纹理、眼睛细节等
  • 512×512 - 1024×1024:最精细的尺度,控制头发丝、背景等微观细节

所以如果你只修改高层(低分辨率)的 ,就会改变整体风格;只修改低层(高分辨率)的 ,只会改变细节。

9.3 StyleGAN2的改进

StyleGAN2主要改进了两点:

1. 移除伪影

StyleGAN1的AdaIN会产生”水滴”伪影。StyleGAN2用Weight Demodulation替代:

def forward(self, x, style):
    # 样式调制
    style = self.affine(style)
    scale, bias = style.chunk(2, dim=-1)
    
    # 权重归一化
    weight = self.weight * scale.view(1, -1, 1, 1)
    demod = weight / (weight.norm(dim=(2,3), keepdim=True) + 1e-8)
    
    # 卷积
    x = F.conv2d(x, demod, padding=1)
    return x + bias.view(1, -1, 1, 1)

2. 路径长度正则化

让潜在空间的等距移动对应图像空间的等距变化,使得插值更平滑。

9.4 StyleGAN3:消除锯齿

StyleGAN3发现了一个问题:生成的图片在平移时会出现”跳动”。原因是离散采样导致的aliasing(混叠)。

StyleGAN3通过使用sinc滤波器重新设计所有上采样和下采样操作,实现了真正的连续等变性——图片可以平滑平移而不会出现跳变。


10. GAN评估指标:怎么衡量生成的图片好不好?

10.1 Inception Score(IS)

Inception Score是最早的GAN评估指标:

直观理解

  1. 如果图片质量高,Inception网络对它的分类应该很确定—— 的熵很低
  2. 如果图片多样性高,各类别的边缘分布应该均匀—— 的熵很高
  3. KL散度 = 高条件确定性 - 高边缘均匀性

计算方法

def calculate_inception_score(images, model, num_splits=10):
    model.eval()
    preds = []
    
    with torch.no_grad():
        for img in images:
            # 调整大小到299x299(Inception的输入尺寸)
            img = F.interpolate(img.unsqueeze(0), size=(299, 299), mode='bilinear')
            pred = F.softmax(model(img), dim=-1)
            preds.append(pred.cpu().numpy())
    
    preds = np.concatenate(preds)
    
    # 计算每个split的IS
    split_scores = []
    split_size = len(preds) // num_splits
    
    for i in range(num_splits):
        part = preds[i * split_size:(i + 1) * split_size]
        py = np.mean(part, axis=0)
        kl = part * (np.log(part + 1e-8) - np.log(py + 1e-8))
        kl = np.sum(kl, axis=1)
        split_scores.append(np.exp(np.mean(kl)))
    
    return np.mean(split_scores), np.std(split_scores)

IS的局限

  • 只看生成图片,不和真实图片比较
  • 可以通过”生成ImageNet里的某一类”来刷分,但这不代表真的学到了分布

10.2 Fréchet Inception Distance(FID)

FID是目前最流行的GAN评估指标。它比较真实图片和生成图片在特征空间中的分布差异。

思想:用Inception网络提取特征,假设这些特征服从高斯分布,然后计算两个高斯分布之间的Fréchet距离。

  • :均值向量
  • :协方差矩阵
  • FID越小越好(0表示完美匹配)

计算方法

def calculate_fid(real_images, fake_images, inception_model):
    """计算FID"""
    real_acts = extract_features(real_images, inception_model)
    fake_acts = extract_features(fake_images, inception_model)
    
    mu_real = np.mean(real_acts, axis=0)
    mu_fake = np.mean(fake_acts, axis=0)
    
    sigma_real = np.cov(real_acts, rowvar=False)
    sigma_fake = np.cov(fake_acts, rowvar=False)
    
    # Fréchet距离
    diff = mu_real - mu_fake
    covmean = sqrtm(sigma_real @ sigma_fake)
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
    return fid
 
def sqrtm(matrix):
    """矩阵平方根的数值稳定计算"""
    eigenvalues, eigenvectors = np.linalg.eigh(matrix)
    return eigenvectors @ np.diag(np.sqrt(np.maximum(eigenvalues, 0))) @ eigenvectors.T

FID的优势

  • 同时考虑真实和生成图片
  • 对模式崩溃敏感(生成的图片单一会导致FID变高)
  • 与人类感知相关性较好

10.3 Precision和Recall

FID只看分布距离,但无法区分”生成的图片质量高但覆盖不全”和”生成的图片质量低但覆盖全”。

Precision-Recall分析可以区分这两种情况:

  • Precision(精确度):生成的图片里,有多少是”合格”的
  • Recall(召回率):真实数据分布被覆盖了多少
def calculate_precision_recall(real_features, fake_features, k=3):
    """计算Precision和Recall"""
    from sklearn.neighbors import NearestNeighbors
    
    # k近邻
    nn_real = NearestNeighbors(n_neighbors=k + 1).fit(real_features)
    nn_fake = NearestNeighbors(n_neighbors=k + 1).fit(fake_features)
    
    # 计算每个fake样本的precision
    distances, indices = nn_real.kneighbors(fake_features)
    # 如果fake样本的最近邻大多是real样本,则precision高
    precision = np.mean([np.mean(idx[1:] < len(real_features)) for idx in indices])
    
    # 计算recall
    distances, indices = nn_fake.kneighbors(real_features)
    recall = np.mean([np.mean(idx[1:] < len(fake_features)) for idx in indices])
    
    return precision, recall

11. GAN的应用:这些都在用GAN

11.1 艺术创作与设计

GAN在艺术领域的应用五花八门:

人脸生成与编辑:StyleGAN能生成极其逼真的人脸,还支持各种编辑操作——换发型、加眼镜、改年龄、调表情。Midjourney、DALL-E这些工具的背后,都有GAN的思想。

风格迁移:CycleGAN可以把马变成斑马、把夏天变成冬天。艺术家们用它来创作独特的视觉效果。

服装设计:时尚品牌用GAN来生成新款式,或者根据用户的喜好推荐设计。

音乐生成:虽然图片GAN更成熟,但AudioGAN也开始用于音乐创作,比如生成特定风格的音乐片段。

11.2 数据增强与医学影像

医疗影像有个老大难问题:数据太少了。GAN可以生成逼真的医学影像来扩充数据集:

class MedicalImageAugmentation:
    """医学影像数据增强"""
    def __init__(self, generator, classifier):
        self.generator = generator
        self.classifier = classifier
    
    def generate_samples(self, target_label, num_samples=100):
        """为特定类别生成样本"""
        z = torch.randn(num_samples, self.generator.latent_dim)
        labels = torch.full((num_samples,), target_label)
        
        with torch.no_grad():
            generated = self.generator(z, labels)
        
        # 过滤:只保留分类器认为正确的
        probs = self.classifier(generated)
        valid_indices = (probs.argmax(dim=1) == target_label).nonzero().squeeze()
        
        return generated[valid_indices]
    
    def balance_dataset(self, dataset, target_count_per_class=500):
        """平衡数据集"""
        # 对于样本不足的类别,用GAN补充
        pass

11.3 图像修复与超分辨率

图像修复(Inpainting):移除图片中的不需要的物体(比如路人、水印),并用合理的背景填充。GAN能生成语义正确、视觉自然的结果。

超分辨率(Super Resolution):把低分辨率图片放大并增强细节。ESRGAN(Enhanced SRGAN)是这个领域的经典工作。

去噪与恢复:去除老照片的噪点、划痕,恢复破损的图片。GAN能学习真实的纹理,生成自然的结果。

11.4 游戏与电影制作

角色生成:游戏开发者用GAN来生成NPC(非玩家角色)的脸型、皮肤纹理等,大幅减少美术工作量。

场景构建:用GAN生成游戏场景的纹理、背景,或者根据文字描述生成概念图。

特效生成:电影后期用GAN来生成火焰、烟雾、爆炸等特效的自然变化。

11.5 3D内容生成

NeRF + GAN:神经辐射场(NeRF)能生成高质量的3D视图,但训练慢。GAN可以用来加速NeRF的渲染,或者让NeRF生成更多样化的内容。

纹理生成:用GAN为3D模型生成贴图、材质。

可控角色生成:给定一个骨骼姿态,生成穿着对应衣服的角色。


12. 调试GAN的实用技巧

12.1 监控训练过程的指标

GAN的训练过程很复杂,需要监控多个指标:

class GANTrainingMonitor:
    def __init__(self):
        self.history = {
            'd_loss_real': [],
            'd_loss_fake': [],
            'g_loss': [],
            'd_accuracy': [],
            'fid_scores': []  # 定期计算
        }
    
    def update(self, d_loss_real, d_loss_fake, g_loss, d_preds_real, d_preds_fake):
        self.history['d_loss_real'].append(d_loss_real)
        self.history['d_loss_fake'].append(d_loss_fake)
        self.history['g_loss'].append(g_loss)
        
        # 判别器准确率
        acc_real = (d_preds_real > 0.5).float().mean().item()
        acc_fake = (d_preds_fake < 0.5).float().mean().item()
        self.history['d_accuracy'].append((acc_real + acc_fake) / 2)
    
    def check_mode_collapse(self, fake_images):
        """检测模式崩溃"""
        if len(fake_images) < 10:
            return False
        
        with torch.no_grad():
            fake_flat = fake_images.view(len(fake_images), -1)
            # 计算样本间的相似度
            similarity = torch.mm(fake_flat, fake_flat.t())
            
            # 如果样本太相似,相似度矩阵会很"极端"
            std = similarity.std().item()
            
            return std < 0.1  # 阈值可调

12.2 常见问题排查

问题:判别器loss一直很低

  • 说明判别器太强了,生成器得不到有效梯度
  • 解决:降低判别器学习率,或减少判别器训练步数

问题:生成器loss一直很高且不下降

  • 可能梯度消失或判别器太弱
  • 解决:检查权重初始化,增加判别器深度,使用WGAN

问题:生成的图片出现明显的”模式崩溃”

  • 生成器只生成少数几种图片
  • 解决:使用Minibatch Discrimination,减小学习率,增加随机性

问题:训练刚开始就崩溃

  • 学习率太高
  • 解决:降低学习率,使用学习率预热(warmup)

问题:图片出现伪影或棋盘格效应

  • 转置卷积的”棋盘格”问题
  • 解决:使用PixelShuffle或上采样+卷积替代转置卷积

12.3 实用tricks清单

  1. 标签平滑:真实标签用0.9而不是1,假标签用0.1而不是0
  2. LeakyReLU:判别器用LeakyReLU(0.2)而非ReLU
  3. 批量归一化的玄学:如果batch size太小,试试InstanceNorm
  4. 特征匹配:生成器的损失可以加上一项,让假图片的特征统计量接近真图片
  5. 历史平均:对生成器的参数做历史平均,可以稳定训练
  6. TTUR:生成器和判别器用不同的学习率(通常是1:4)

13. 学术论文与参考文献

  1. Goodfellow, I., et al. (2014). “Generative Adversarial Networks.” NeurIPS. — GAN的开山之作

  2. Arjovsky, M., Chintala, S., & Bottou, L. (2017). “Wasserstein GAN.” ICML. — Wasserstein距离的里程碑

  3. Radford, A., Metz, L., & Chintala, S. (2016). “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.” ICLR. — DCGAN

  4. Mirza, M., & Osindero, S. (2014). “Conditional Generative Adversarial Nets.” arXiv. — 条件GAN

  5. Karras, T., et al. (2018). “Progressive Growing of GANs for Improved Quality, Stability, and Variation.” ICLR. — ProGAN

  6. Karras, T., et al. (2019). “A Style-Based Generator Architecture for Generative Adversarial Networks.” CVPR. — StyleGAN

  7. Karras, T., et al. (2021). “Alias-Free Generative Adversarial Networks.” NeurIPS. — StyleGAN3

  8. Isola, P., et al. (2017). “Image-to-Image Translation with Conditional Adversarial Networks.” CVPR. — Pix2Pix

  9. Zhu, J.-Y., et al. (2017). “Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.” ICCV. — CycleGAN

  10. Choi, Y., et al. (2018). “StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation.” CVPR. — StarGAN

  11. Brock, A., Donahue, J., & Simonyan, K. (2019). “Large Scale GAN Training for High Fidelity Natural Image Synthesis.” ICLR. — BigGAN

  12. Zhang, H., et al. (2019). “Self-Attention Generative Adversarial Networks.” ICML. — SAGAN

  13. Gulrajani, I., et al. (2017). “Improved Training of Wasserstein GANs.” NeurIPS. — WGAN-GP

  14. Mao, X., et al. (2017). “Least Squares Generative Adversarial Networks.” ICCV. — LSGAN

  15. Miyato, T., et al. (2018). “Spectral Normalization for Generative Adversarial Networks.” ICLR. — 谱归一化GAN


14. 结语

GAN是深度学习领域最有趣的思想之一。它用”对抗”的框架,把生成模型这个难题转化成了一个可学习的博弈问题。

从2014年到现在,GAN已经走过了很长的路。从最初的简单全连接网络,到StyleGAN的超写实人脸,GAN的能力有了质的飞跃。

但GAN的难点——训练不稳定、模式崩溃、难以平衡——至今仍然存在。Diffusion模型在生成质量上已经超越GAN,但GAN因为生成速度快、可以进行细粒度控制,在很多场景下仍然是首选。

学习GAN,不只是学习一个算法,更是学习一种思维方式:通过对抗来学习,通过竞争来进步。这种思想在强化学习、多智能体系统、对抗样本等领域都有广泛应用。

希望这篇指南能帮助你理解GAN的核心思想,并有能力动手实现自己的GAN项目。如果有任何问题,欢迎在评论区讨论!


相关文档