关键词
| 术语 | 英文 | 核心概念 |
|---|---|---|
| 生成对抗网络 | GAN | 生成器与判别器的对抗框架 |
| 生成器 | Generator | 从随机噪声生成假样本 |
| 判别器 | Discriminator | 区分真假样本 |
| 模式崩溃 | Mode Collapse | 生成样本多样性不足 |
| JS散度 | Jensen-Shannon Divergence | GAN损失函数的理论基础 |
| Wasserstein距离 | EM Distance | 改进GAN训练稳定性的度量 |
| 条件生成 | Conditional Generation | 基于条件信息的可控生成 |
| 渐进式生长 | Progressive Growing | 从低分辨率逐步训练到高分辨率 |
| 风格迁移 | Style Transfer | 控制生成图像的风格特征 |
| 循环一致性 | Cycle Consistency | 无配对数据转换的核心约束 |
| 一致性损失 | Consistency Loss | 保证转换前后语义保持的损失函数 |
| 谱归一化 | Spectral Normalization | 约束判别器Lipschitz常数的技巧 |
| 梯度惩罚 | Gradient Penalty | 替代权重裁剪的WGAN改进 |
| 小样本学习 | Few-shot Learning | 仅用少量样本训练生成模型 |
| Zero-shot学习 | Zero-shot Learning | 生成未见过的类别样本 |
1. 引言:对抗学习的诞生
生成对抗网络(Generative Adversarial Network, GAN)由 Ian Goodfellow 等人在 2014 年提出,是一种基于博弈论的生成式建模框架。GAN 的核心思想是通过让两个神经网络相互对抗来学习数据分布,这种”左右互搏”的训练范式深刻影响了深度学习的发展方向。
历史意义
GAN 的提出标志着生成式 AI 进入了一个新纪元。在此之前,自编码器和变分自编码器是主流的生成模型,但 GAN 的出现使得生成图像的质量有了质的飞跃。
1.1 GAN的哲学思想
GAN的设计蕴含了深刻的哲学思想。想象一个伪造者(生成器)和一个鉴定师(判别器)之间的博弈:伪造者不断尝试制造以假乱真的赝品,而鉴定师则不断提升自己的鉴别能力。这种动态博弈过程最终会达到一个均衡点,此时伪造者已经掌握了制造几乎完美赝品的技术,而鉴定师也达到了近乎完美的鉴别能力。
从更宏观的视角看,GAN体现了”道高一尺,魔高一丈”的辩证思想。没有绝对的防御,也没有绝对的攻击,一切都是相对的、动态的。这种思想不仅适用于技术领域,也反映了自然界中普遍存在的协同进化现象。
1.2 生成模型家族概览
在深度学习的生成模型家族中,GAN占据着独特的位置。常见的生成模型包括:
自编码器(Autoencoder)系列:
- 标准自编码器:学习数据的压缩表示,用于重构
- 变分自编码器(VAE):学习数据的潜在分布,支持采样生成
- Denoising Autoencoder:学习从噪声数据中恢复原始信号
基于流的模型(Flow-based Models):
- RealNVP:使用可逆变换构建生成模型
- Glow:基于normalize flows的高质量图像生成
- 优点:精确的对数似然计算
- 缺点:计算量大,生成速度慢
扩散模型(Diffusion Models):
- DDPM:逐步添加噪声再逆向去噪
- Score Matching:学习数据分布的梯度场
- 优点:训练稳定,生成质量高
- 缺点:推理速度慢,需要大量计算
GAN的独特优势:
- 生成速度快(单次前向传播)
- 生成质量高(在某些领域领先)
- 损失函数灵活(多种变体)
2. Goodfellow 2014 原始GAN
2.1 核心架构
原始 GAN 包含两个核心组件:生成器 和判别器 。两者通过对抗训练相互提升,最终达到纳什均衡状态。
import torch
import torch.nn as nn
import torch.optim as optim
class Generator(nn.Module):
"""生成器网络:从 latent space 映射到数据空间"""
def __init__(self, latent_dim, img_shape):
super(Generator, self).__init__()
self.img_shape = img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
class Discriminator(nn.Module):
"""判别器网络:判断样本为真或假的二分类器"""
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity2.2 目标函数与博弈论本质
GAN 的训练过程可以形式化为一个二人零和博弈:
这个目标函数体现了博弈论中的零和博弈思想:判别器 试图最大化正确分类的概率,而生成器 试图最小化判别器的准确率。
2.3 训练流程
def train_gan(generator, discriminator, dataloader, latent_dim, epochs, device):
"""GAN 训练循环"""
adversarial_loss = torch.nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
for epoch in range(epochs):
for imgs, _ in dataloader:
batch_size = imgs.size(0)
imgs = imgs.to(device)
# 真实标签与假标签
valid = torch.ones(batch_size, 1).to(device)
fake = torch.zeros(batch_size, 1).to(device)
# 训练生成器
optimizer_G.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
gen_imgs = generator(z)
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# 训练判别器
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()2.4 原始GAN的数学证明
原始论文中给出了GAN收敛性的初步证明。假设在训练的每一步,判别器都能够达到最优,则生成器的优化目标等价于最小化生成分布与真实分布之间的JS散度。
定理(GAN的最优判别器和生成器):
当判别器可训练至最优时,最优判别器为:
此时生成器的目标函数为:
这意味着当且仅当 时,达到全局最优 。
3. JS散度与训练困难
3.1 原始GAN的理论分析
从信息论角度,原始 GAN 的目标函数可以解释为最小化生成分布 与真实分布 之间的 JS 散度:
其中 是混合分布。
3.2 JS散度的致命缺陷
JS 散度存在一个根本性问题:当两个分布完全不重叠时,JS 散度趋近于常数 ,梯度恒为 0。
# 梯度消失的直观演示
import numpy as np
def js_divergence_gradients():
"""
当生成分布与真实分布完全不重叠时,
JS散度的梯度为0,导致训练停滞
"""
# 假设真实分布为均值=0的高斯,生成分布为均值=10的高斯
p_data = np.random.normal(0, 1, 10000)
p_g = np.random.normal(10, 1, 10000)
# 计算JS散度(理论上接近log2 ≈ 0.693)
# 无论参数如何变化,梯度都接近0
pass训练不稳定性
当生成分布 与真实分布 的支撑集(support)不重叠或重叠可忽略时,判别器最优时 ,此时生成器的梯度消失。
3.3 模式崩溃(Mode Collapse)
模式崩溃是 GAN 训练中的另一个核心问题:生成器学会”欺骗”判别器,但只生成少数几种样本,忽略了数据分布的多样性。
真实数据分布: 双峰分布
p_data(x) = 0.5 * N(μ=2, σ=1) + 0.5 * N(μ=8, σ=1)
理想生成: 应该生成两个峰的所有样本
模式崩溃: 只生成 μ=2 附近的样本(或只生成 μ=8 附近的)
3.4 训练不稳定性的深层原因
原始GAN训练不稳定的原因是多方面的:
1. 同步训练问题:
- 生成器和判别器需要达到动态平衡
- 如果判别器过强,生成器梯度消失
- 如果判别器过弱,生成器失去学习目标
2. 模式多样化缺失:
- 损失函数无法捕捉分布的完整结构
- 生成器倾向于”作弊”,只学习部分模式
3. 梯度消失与梯度爆炸:
- JS散度的饱和特性导致梯度消失
- 判别器过深导致梯度不稳定
4. Wasserstein GAN (WGAN)
4.1 理论突破
Wasserstein GAN 由 Martin Arjovsky 等人在 2017 年提出,用 Earth-Mover(EM)距离替代 JS 散度:
EM 距离的核心优势:即使两个分布完全不重叠,EM 距离仍然能提供有意义的梯度信号。
4.2 WGAN实现
class WGAN_Generator(nn.Module):
"""WGAN 生成器"""
def __init__(self, latent_dim, img_shape):
super().__init__()
self.img_shape = img_shape
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
class WGAN_Discriminator(nn.Module):
"""WGAN 判别器( critics,无sigmoid输出层)"""
def __init__(self, img_shape):
super().__init__()
self.model = nn.Sequential(
nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1) # 无激活函数,输出为任意实数
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
def train_wgan(generator, discriminator, dataloader, latent_dim, epochs, device, clip_value=0.01):
"""WGAN 训练:使用权重裁剪替代梯度惩罚"""
optimizer_G = optim.RMSprop(generator.parameters(), lr=0.00005)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=0.00005)
for epoch in range(epochs):
for imgs, _ in dataloader:
batch_size = imgs.size(0)
imgs = imgs.to(device)
# 训练判别器(多个迭代)
for _ in range(5):
optimizer_D.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
gen_imgs = generator(z)
d_loss = -torch.mean(discriminator(imgs)) + torch.mean(discriminator(gen_imgs))
d_loss.backward()
optimizer_D.step()
# 权重裁剪
for p in discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
# 训练生成器
optimizer_G.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
gen_imgs = generator(z)
g_loss = -torch.mean(discriminator(gen_imgs))
g_loss.backward()
optimizer_G.step()4.3 WGAN-GP:梯度惩罚版本
权重裁剪虽然简单,但会导致判别器趋向于最简单的函数。WGAN-GP提出了更优雅的解决方案:
class WGAN_GP_Generator(nn.Module):
"""WGAN-GP 生成器"""
def __init__(self, latent_dim, img_shape):
super().__init__()
self.img_shape = img_shape
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
nn.Tanh()
)
def forward(self, z):
return self.model(z).view(z.size(0), *self.img_shape)
class WGAN_GP_Discriminator(nn.Module):
"""WGAN-GP 判别器,使用残差连接增强表达能力"""
def __init__(self, img_shape):
super().__init__()
self.img_shape = img_shape
self.model = nn.Sequential(
nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
)
def forward(self, img):
return self.model(img.view(img.size(0), -1))
def compute_gradient_penalty(discriminator, real_images, fake_images, device):
"""
计算梯度惩罚
惩罚项:E(||∇_x̂ D(x̂)||_2 - 1)^2
其中 x̂ = ε * x_real + (1-ε) * x_fake, ε ~ U(0,1)
"""
batch_size = real_images.size(0)
alpha = torch.rand(batch_size, 1, 1, 1).to(device)
# 在真实图像和生成图像之间插值
interpolates = alpha * real_images + (1 - alpha) * fake_images
interpolates = interpolates.requires_grad_(True)
d_interpolates = discriminator(interpolates)
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(d_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
return gradient_penalty
def train_wgan_gp(generator, discriminator, dataloader, latent_dim, epochs, device):
"""WGAN-GP 训练循环"""
optimizer_G = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0004, betas=(0.0, 0.9))
lambda_gp = 10 # 梯度惩罚系数
for epoch in range(epochs):
for imgs, _ in dataloader:
batch_size = imgs.size(0)
real_images = imgs.to(device)
# 训练判别器
optimizer_D.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z)
d_real = discriminator(real_images)
d_fake = discriminator(fake_images.detach())
# WGAN损失
d_loss = d_fake.mean() - d_real.mean()
# 梯度惩罚
gp = compute_gradient_penalty(discriminator, real_images, fake_images, device)
total_d_loss = d_loss + lambda_gp * gp
total_d_loss.backward()
optimizer_D.step()
# 训练生成器(每n个判别器步骤训练一次)
optimizer_G.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z)
g_loss = -discriminator(fake_images).mean()
g_loss.backward()
optimizer_G.step()5. 条件GAN(cGAN)
5.1 条件生成的理论框架
条件生成对抗网络(Conditional GAN, cGAN)由 Mehdi Mirza 和 Simon Osindero 在 2014 年提出,通过将条件信息 引入生成器和判别器:
5.2 cGAN架构实现
class ConditionalGenerator(nn.Module):
"""条件生成器:同时接受噪声和条件信息"""
def __init__(self, latent_dim, num_classes, img_shape):
super().__init__()
self.img_shape = img_shape
self.label_emb = nn.Embedding(num_classes, num_classes)
# 将噪声和条件标签拼接
self.model = nn.Sequential(
nn.Linear(latent_dim + num_classes, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
nn.Tanh()
)
def forward(self, z, labels):
# 嵌入标签并与噪声拼接
label_embedding = self.label_emb(labels)
gen_input = torch.cat((z, label_embedding), dim=-1)
img = self.model(gen_input)
img = img.view(img.size(0), *self.img_shape)
return img
class ConditionalDiscriminator(nn.Module):
"""条件判别器:同时接受图像和条件信息"""
def __init__(self, num_classes, img_shape):
super().__init__()
self.img_shape = img_shape
self.label_emb = nn.Embedding(num_classes, int(torch.prod(torch.tensor(img_shape))))
self.model = nn.Sequential(
nn.Linear(int(torch.prod(torch.tensor(img_shape))) + num_classes, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.4),
nn.Linear(512, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.4),
nn.Linear(512, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
label_embedding = self.label_emb(labels)
d_in = torch.cat((img.view(img.size(0), -1), label_embedding), dim=-1)
validity = self.model(d_in)
return validity6. DCGAN:深度卷积GAN
6.1 架构设计原则
DCGAN(Deep Convolutional GAN)由 Radford 等人在 2016 年提出,是将GAN与深度卷积网络结合的里程碑工作。其核心架构设计原则包括:
生成器设计:
- 使用转置卷积(fractionally strided convolution)进行上采样
- 移除所有池化层,使用转置卷积实现空间上采样
- 在生成器中使用批量归一化(BatchNorm)
- 输出层使用 Tanh 激活(归一化到 [-1, 1])
- 隐藏层使用 ReLU/LeakyReLU 激活
判别器设计:
- 使用步长卷积(strided convolution)代替池化
- 同样使用批量归一化(首层除外)
- 使用 LeakyReLU 激活防止梯度稀疏
6.2 DCGAN实现
class DCGenerator(nn.Module):
"""DCGAN 生成器:使用转置卷积"""
def __init__(self, latent_dim, channels, features_g=64):
super(DCGenerator, self).__init__()
# 初始输入: latent_dim x 1 x 1
self.init_size = 4
self.latent_dim = latent_dim
self.channels = channels
# 初始卷积块
self.init = nn.Sequential(
nn.ConvTranspose2d(latent_dim, features_g * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(True)
)
# 逐步上采样
self.upsample1 = nn.Sequential(
nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(True)
) # 8x8
self.upsample2 = nn.Sequential(
nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(True)
) # 16x16
self.upsample3 = nn.Sequential(
nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g),
nn.ReLU(True)
) # 32x32
self.upsample4 = nn.Sequential(
nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False),
nn.Tanh()
) # 64x64
def forward(self, z):
# 重塑为 (batch, channels, 1, 1)
x = z.view(z.size(0), self.latent_dim, 1, 1)
x = self.init(x)
x = self.upsample1(x)
x = self.upsample2(x)
x = self.upsample3(x)
x = self.upsample4(x)
return x
class DCDiscriminator(nn.Module):
"""DCGAN 判别器:使用步长卷积"""
def __init__(self, channels, features_d=64):
super(DCDiscriminator, self).__init__()
# 输入: channels x 64 x 64
self.main = nn.Sequential(
# 第一个卷积块(无BN)
nn.Conv2d(channels, features_d, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 逐步下采样
nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 8),
nn.LeakyReLU(0.2, inplace=True),
# 最后一个卷积块
nn.Conv2d(features_d * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)7. Progressive GAN 与 StyleGAN 系列
7.1 Progressive Growing GAN
Progressive GAN(ProGAN)由 Karras 等人在 2018 年提出,核心思想是从低分辨率开始逐步增加网络深度:
训练阶段:
Level 1: 4×4 分辨率 ─────────────────────────────────────────────> 稳定
Level 2: 8×8 分辨率 ─────────────────────────────────────────────> 稳定
Level 3: 16×16 分辨率 ─────────────────────────────────────────────> 稳定
Level 4: 32×32 分辨率 ─────────────────────────────────────────────> 稳定
Level 5: 64×64 分辨率 ─────────────────────────────────────────────> 稳定
...
Level N: 1024×1024 分辨率 ─────────────────────────────────────> 稳定
Progressive Growing 的优势:
- 训练稳定性:低分辨率时分布更简单,更容易学习
- 计算效率:早期训练快,可快速获得大尺度结构
- 模式覆盖:减少模式崩溃的风险
class ProgressiveGenerator(nn.Module):
"""
渐进式增长生成器
关键组件:
1. 分辨率阶段:每个分辨率有对应的卷积块
2. 平滑过渡:在分辨率切换时使用alpha插值
3. 潜在向量:每层可以接收不同的潜在向量
"""
def __init__(self, latent_dim, max_resolution=1024, base_channels=512):
super().__init__()
self.latent_dim = latent_dim
self.max_resolution = max_resolution
self.base_channels = base_channels
# 计算分辨率级别
self.num_stages = int(np.log2(max_resolution)) - 1 # 4x4 -> 1024x1024
# 初始层:从潜在向量生成4x4特征图
self.initial = nn.Sequential(
nn.Linear(latent_dim, base_channels * 16 * 4 * 4),
nn.Reshape(base_channels * 16, 4, 4),
nn.ReLU()
)
# 各分辨率的上采样模块
self.upsample_blocks = nn.ModuleList()
self.to_rgb_blocks = nn.ModuleList()
channels = base_channels * 16
resolution = 4
for stage in range(1, self.num_stages):
new_resolution = resolution * 2
# 上采样卷积
self.upsample_blocks.append(nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(channels, channels // 2, 3, padding=1),
nn.BatchNorm2d(channels // 2),
nn.ReLU(),
nn.Conv2d(channels // 2, channels // 2, 3, padding=1),
nn.BatchNorm2d(channels // 2),
nn.ReLU()
))
# RGB输出层
self.to_rgb_blocks.append(nn.Conv2d(channels // 2, 3, 1))
channels = channels // 2
resolution = new_resolution
# 最终RGB输出
self.final_rgb = nn.Conv2d(channels, 3, 3, padding=1)
def forward(self, z, alpha=1.0, stage=None):
"""
前向传播
参数:
z: 潜在向量
alpha: 混合系数(用于平滑过渡)
stage: 当前训练的分辨率阶段
"""
x = self.initial(z)
if stage is None:
stage = len(self.upsample_blocks)
# 应用前stage个上采样块
for i in range(stage):
x = self.upsample_blocks[i](x)
# 输出RGB
rgb = self.final_rgb(x)
return torch.tanh(rgb)7.2 StyleGAN:风格控制的革命
StyleGAN 在 ProGAN 基础上引入自适应实例归一化(AdaIN),实现对生成图像风格的控制:
StyleGAN的核心创新:
- 映射网络(Mapping Network):将潜在向量 转换为中间潜在向量
- 风格块(Style Blocks):在每个分辨率级别注入风格信息
- 噪声输入:为每个级别添加随机噪声以增加细节变化
class StyleBlock(nn.Module):
"""StyleGAN 的核心风格控制模块"""
def __init__(self, in_channels, style_dim):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
self.affine = nn.Linear(style_dim, in_channels * 2)
def forward(self, x, style):
"""
x: 特征图
style: 风格向量(来自映射网络)
"""
# 从风格向量预测缩放和偏移参数
style_params = self.affine(style)
scale, bias = style_params.chunk(2, dim=-1)
# 自适应实例归一化
x = (x - x.mean(dim=[2, 3], keepdim=True)) / (x.std(dim=[2, 3], keepdim=True) + 1e-5)
x = x * (scale.view(scale.size(0), -1, 1, 1) + 1) + bias.view(bias.size(0), -1, 1, 1)
return self.conv(x)
class MappingNetwork(nn.Module):
"""
映射网络:将潜在向量z映射到中间潜在空间w
目的:
1. 解耦潜在空间的各个维度
2. 使风格控制更加精确和独立
"""
def __init__(self, latent_dim, hidden_dim=512, num_layers=8):
super().__init__()
layers = []
for i in range(num_layers):
in_dim = latent_dim if i == 0 else hidden_dim
layers.extend([
nn.Linear(in_dim, hidden_dim),
nn.LeakyReLU(0.2)
])
self.mapping = nn.Sequential(*layers)
self.normalize = PixelNorm()
def forward(self, z):
z = self.normalize(z)
w = self.mapping(z)
return w
class StyleGenerator(nn.Module):
"""StyleGAN 生成器"""
def __init__(self, latent_dim=512, channels=3, base_resolution=4, max_resolution=1024):
super().__init__()
self.latent_dim = latent_dim
self.channels = channels
# 映射网络
self.mapping = MappingNetwork(latent_dim)
# 初始层(constant input)
self.constant_input = nn.Parameter(torch.ones(1, 512, 4, 4))
self.conv1 = StyledConv(512, 512, 3, latent_dim)
self.to_rgb1 = ToRGB(512, channels, latent_dim)
# 逐级上采样
resolutions = [8, 16, 32, 64, 128, 256, 512, 1024]
self.upsamples = nn.ModuleList()
self.convs = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
input_channels = 512
for res in resolutions[1:]:
output_channels = input_channels // 2
self.upsamples.append(nn.Upsample(scale_factor=2, mode='bilinear'))
self.convs.append(StyledConv(input_channels, output_channels, 3, latent_dim))
self.to_rgbs.append(ToRGB(output_channels, channels, latent_dim))
input_channels = output_channels
def forward(self, z, styles=None, noise=None):
"""
前向传播
styles: 风格向量列表(如果为None,使用从z映射的w)
"""
if styles is None:
w = self.mapping(z)
# 对所有层使用相同的w
styles = [w] * 14
# 初始特征
x = self.constant_input.repeat(z.size(0), 1, 1, 1)
x = self.conv1(x, styles[0])
# 逐级上采样
rgb = self.to_rgb1(x, styles[1])
for i, (up, conv, to_rgb) in enumerate(zip(
self.upsamples, self.convs, self.to_rgbs
)):
x = up(x)
x = conv(x, styles[i + 2])
rgb = to_rgb(x, styles[i + 3])
return rgb7.3 StyleGAN2:质量与稳定性提升
StyleGAN2 针对原始StyleGAN的几个问题进行了改进:
- 权重 demodulation:替代AdaIN,减少伪影
- 路径长度正则化:使潜在空间插值更加平滑
- 重新设计架构:移除噪声输入的影响(可选)
class StyleGAN2Conv(nn.Module):
"""
StyleGAN2 的卷积模块
使用 weight demodulation 技术
"""
def __init__(self, in_channels, out_channels, style_dim, upsample=False):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') if upsample else None
self.padding = 1
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, 3, 3))
self.bias = nn.Parameter(torch.zeros(out_channels))
self.affine = nn.Linear(style_dim, in_channels)
def forward(self, x, style):
"""
weight demodulation:
1. 计算每通道激活的标准差
2. 对权重进行归一化
3. 添加样式调制
"""
batch, _, height, width = x.shape
# 获取样式调制
style = self.affine(style)
style = style.view(batch, 1, -1, 1, 1)
# 权重归一化
weight = self.weight.unsqueeze(0)
weight_norm = torch.norm(weight, dim=(2, 3, 4), keepdim=True)
demod_weight = weight / (weight_norm + 1e-8)
# 应用样式调制
modulated_weight = demod_weight * style
# 重塑用于卷积
demod_weight = demod_weight.view(-1, *self.weight.shape)
# 卷积
x = x.view(1, -1, height, width)
out = F.conv2d(x, demod_weight, padding=self.padding, groups=batch)
out = out.view(batch, -1, height, width)
# 添加偏置并应用激活
out = out + self.bias.view(1, -1, 1, 1)
return out
class PathLengthPenalty(nn.Module):
"""
路径长度正则化
目的:使潜在空间w中的等距移动对应于图像空间中
的等距变化,从而提高插值质量
"""
def __init__(self, beta=0.99):
super().__init__()
self.beta = beta
self.styles_mean = None
def forward(self, w, generated_images):
"""
计算路径长度惩罚
"""
image_size = generated_images.shape[-1]
n = w.shape[1] # 层数
# 计算图像对随机方向的梯度
random_direction = torch.randn_like(w)
random_direction = random_direction / torch.norm(random_direction, dim=-1, keepdim=True)
# 生成带扰动的图像
w_perturbed = w + 0.01 * random_direction
# 这里需要调用生成器,但简化表示
# 路径长度惩罚
if self.styles_mean is None:
self.styles_mean = torch.zeros(1)
# 更新移动平均
with torch.no_grad():
path_lengths = torch.ones(w.shape[0], device=w.device)
self.styles_mean = self.beta * self.styles_mean + (1 - self.beta) * path_lengths.mean()
# 惩罚项
penalty = (path_lengths - self.styles_mean) ** 2
return penalty.mean()7.4 StyleGAN3:消除混叠伪影
StyleGAN3 在 2021 年提出,通过消除特征图的离散采样效应,实现了图像的连续平移和旋转不变性:
StyleGAN3 的核心改进
- 使用低通滤波器确保信号在网络传播中保持连续性
- 重新设计上采样和下采样操作
- 实现了真正的等变性而非近似等变性
class StyleGAN3Conv(nn.Module):
"""
StyleGAN3 的连续卷积
关键改进:
1. 使用sinc滤波器进行上/下采样
2. 确保信号在变换下保持连续
"""
def __init__(self, in_channels, out_channels, kernel_size, down=False):
super().__init__()
self.down = down
self.kernel_size = kernel_size
# 学习的滤波核(用于可控制的频谱)
if down:
self.filter = nn.Conv2d(
in_channels, in_channels, kernel_size,
stride=2, padding=kernel_size//2, groups=in_channels
)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size//2)
def sinc_filter(self, x, cutoff=0.5):
"""
生成sinc低通滤波器
sinc滤波器具有理想的频率响应,是唯一能够
完美重建的低通滤波器
"""
n = torch.arange(-self.kernel_size//2, self.kernel_size//2 + 1, device=x.device)
window = torch.hann_window(self.kernel_size, device=x.device)
kernel = torch.sinc(2 * cutoff * n) * window
kernel = kernel / kernel.sum()
return kernel
def forward(self, x):
if self.down:
# 使用sinc滤波器进行下采样
kernel = self.sinc_filter(x).view(1, 1, -1, 1)
x = F.conv2d(x, kernel.repeat(x.shape[1], 1, 1, 1),
padding=self.kernel_size//2, groups=x.shape[1])
x = self.conv(x)
return x8. 图像翻译GAN系列
8.1 Pix2Pix:有配对的图像翻译
Pix2Pix是条件GAN在图像翻译任务上的经典应用,适用于有配对数据的图像转换任务,如:
- 边缘图 → 照片
- 白天 → 夜晚
- 航拍图 → 地图
- 素描 → 真实图像
class UNetGenerator(nn.Module):
"""
U-Net 生成器(用于Pix2Pix)
编码器-解码器架构 + 跳跃连接
保留低层次的细节信息
"""
def __init__(self, input_channels, output_channels, num_filters=64):
super().__init__()
# 编码器(下采样)
self.enc1 = self._conv_block(input_channels, num_filters) # 256
self.enc2 = self._conv_block(num_filters, num_filters * 2) # 128
self.enc3 = self._conv_block(num_filters * 2, num_filters * 4) # 64
self.enc4 = self._conv_block(num_filters * 4, num_filters * 8) # 32
self.enc5 = self._conv_block(num_filters * 8, num_filters * 8) # 16
self.enc6 = self._conv_block(num_filters * 8, num_filters * 8) # 8
self.enc7 = self._conv_block(num_filters * 8, num_filters * 8) # 4
self.enc8 = self._conv_block(num_filters * 8, num_filters * 8) # 2
# 解码器(上采样)
self.dec1 = self._upconv_block(num_filters * 8, num_filters * 8) # 4
self.dec2 = self._upconv_block(num_filters * 16, num_filters * 8) # 8
self.dec3 = self._upconv_block(num_filters * 16, num_filters * 8) # 16
self.dec4 = self._upconv_block(num_filters * 16, num_filters * 8) # 32
self.dec5 = self._upconv_block(num_filters * 16, num_filters * 4) # 64
self.dec6 = self._upconv_block(num_filters * 8, num_filters * 2) # 128
self.dec7 = self._upconv_block(num_filters * 4, num_filters) # 256
# 最终输出层
self.final = nn.ConvTranspose2d(num_filters * 2, output_channels, 3, padding=1)
self.tanh = nn.Tanh()
def _conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
def _upconv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x):
# 编码
enc1 = self.enc1(x)
enc2 = self.enc2(enc1)
enc3 = self.enc3(enc2)
enc4 = self.enc4(enc3)
enc5 = self.enc5(enc4)
enc6 = self.enc6(enc5)
enc7 = self.enc7(enc6)
enc8 = self.enc8(enc7)
# 解码(带跳跃连接)
dec1 = self.dec1(enc8)
dec2 = self.dec2(torch.cat([dec1, enc7], dim=1))
dec3 = self.dec3(torch.cat([dec2, enc6], dim=1))
dec4 = self.dec4(torch.cat([dec3, enc5], dim=1))
dec5 = self.dec5(torch.cat([dec4, enc4], dim=1))
dec6 = self.dec6(torch.cat([dec5, enc3], dim=1))
dec7 = self.dec7(torch.cat([dec6, enc2], dim=1))
return self.tanh(self.final(torch.cat([dec7, enc1], dim=1)))
class PatchDiscriminator(nn.Module):
"""
Patch判别器(PatchGAN)
与其判断整个图像是否为真,
不如判断图像的每个patch是否为真
这使得判别器专注于高频细节
"""
def __init__(self, input_channels, num_filters=64, n_layers=3):
super().__init__()
layers = []
# 第一层:无BN
layers.append(nn.Conv2d(input_channels, num_filters, 4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.2))
# 中间层
nf_mult = 1
for i in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** i, 8)
layers.append(nn.Conv2d(num_filters * nf_mult_prev, num_filters * nf_mult,
4, stride=2, padding=1))
layers.append(nn.BatchNorm2d(num_filters * nf_mult))
layers.append(nn.LeakyReLU(0.2))
# 最后一层
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
layers.append(nn.Conv2d(num_filters * nf_mult_prev, num_filters * nf_mult,
4, stride=1, padding=1))
layers.append(nn.BatchNorm2d(num_filters * nf_mult))
layers.append(nn.LeakyReLU(0.2))
# 输出层
layers.append(nn.Conv2d(num_filters * nf_mult, 1, 4, stride=1, padding=1))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)8.2 CycleGAN:无配对的图像翻译
CycleGAN解决了无配对数据的图像翻译问题,核心思想是循环一致性(Cycle Consistency):
class ResidualBlock(nn.Module):
"""残差块:用于CycleGAN的生成器"""
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, 3, stride=1, padding=0),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, 3, stride=1, padding=0),
nn.BatchNorm2d(channels)
)
def forward(self, x):
return x + self.block(x)
class CycleGANGenerator(nn.Module):
"""
CycleGAN 生成器
使用编码器-解码器 + 残差连接
编码器:逐步降低分辨率,提取高层特征
解码器:逐步提高分辨率,恢复空间信息
残差连接:帮助传递细节信息
"""
def __init__(self, input_channels=3, residual_blocks=9):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(input_channels, 64, 7, stride=1, padding=0),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
# 残差块
self.residuals = nn.Sequential(
*[ResidualBlock(256) for _ in range(residual_blocks)]
)
# 解码器
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(3),
nn.Conv2d(64, input_channels, 7, stride=1, padding=0),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = self.residuals(x)
x = self.decoder(x)
return x
class CycleGANLoss:
"""
CycleGAN 损失函数
包括:
1. 对抗损失:让生成的图像骗过判别器
2. 循环一致性损失:X->Y->X 应回到 X
3. 同一性损失:G(Y) 在给定 Y 时应接近 Y
"""
def __init__(self, lambda_cycle=10, lambda_identity=5):
self.lambda_cycle = lambda_cycle
self.lambda_identity = lambda_identity
self.criterionGAN = nn.MSELoss()
self.criterionCycle = nn.L1Loss()
self.criterionIdt = nn.L1Loss()
def compute_loss_D(self, D, real, fake):
"""判别器损失"""
# 真实样本应该被判定为真
real_loss = self.criterionGAN(D(real), torch.ones_like(D(real)))
# 假样本应该被判定为假
fake_loss = self.criterionGAN(D(fake.detach()), torch.zeros_like(D(fake)))
return (real_loss + fake_loss) / 2
def compute_loss_G(self, G, F, D, real_x, real_y):
"""
生成器损失
包括对抗损失、循环一致性损失和同一性损失
"""
# 对抗损失
fake_y = G(real_x)
loss_GAN = self.criterionGAN(D(fake_y), torch.ones_like(D(fake_y)))
fake_x = F(real_y)
loss_F_GAN = self.criterionGAN(D(fake_x), torch.ones_like(D(fake_x)))
# 循环一致性损失
loss_cycle_X = self.criterionCycle(F(fake_y), real_x)
loss_cycle_Y = self.criterionCycle(G(fake_x), real_y)
# 同一性损失
loss_id_X = self.criterionIdt(G(real_y), real_y)
loss_id_Y = self.criterionIdt(F(real_x), real_x)
# 总损失
total_loss = (loss_GAN + loss_F_GAN +
self.lambda_cycle * (loss_cycle_X + loss_cycle_Y) +
self.lambda_identity * (loss_id_X + loss_id_Y))
return total_loss, {
'loss_GAN': loss_GAN.item(),
'loss_F_GAN': loss_F_GAN.item(),
'loss_cycle': (loss_cycle_X + loss_cycle_Y).item(),
'loss_identity': (loss_id_X + loss_id_Y).item()
}8.3 StarGAN:多域图像翻译
StarGAN解决了多域(multi-domain)图像翻译问题,只需一个生成器就能在多个域之间转换:
class StarGANGenerator(nn.Module):
"""
StarGAN 生成器
特点:
1. 单一生成器处理多个域
2. 使用域标签作为条件信息
3. 节省计算资源
"""
def __init__(self, image_channels=3, condition_channels=5, num_filters=64):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Conv2d(image_channels + condition_channels, num_filters, 7, stride=1, padding=3),
nn.InstanceNorm2d(num_filters),
nn.ReLU(inplace=True),
nn.Conv2d(num_filters, num_filters * 2, 4, stride=2, padding=1),
nn.InstanceNorm2d(num_filters * 2),
nn.ReLU(inplace=True),
nn.Conv2d(num_filters * 2, num_filters * 4, 4, stride=2, padding=1),
nn.InstanceNorm2d(num_filters * 4),
nn.ReLU(inplace=True),
nn.Conv2d(num_filters * 4, num_filters * 8, 4, stride=2, padding=1),
nn.InstanceNorm2d(num_filters * 8),
nn.ReLU(inplace=True)
)
# 残差块
self.residual_blocks = nn.Sequential(
*[ResidualBlock(num_filters * 8) for _ in range(6)]
)
# 解码器
self.decoder = nn.Sequential(
nn.ConvTranspose2d(num_filters * 8, num_filters * 4, 4, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(num_filters * 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(num_filters * 4, num_filters * 2, 4, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(num_filters * 2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(num_filters * 2, num_filters, 4, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(num_filters),
nn.ReLU(inplace=True),
nn.Conv2d(num_filters, image_channels, 7, stride=1, padding=3),
nn.Tanh()
)
def forward(self, x, domain_labels):
"""
前向传播
x: 输入图像
domain_labels: 目标域的one-hot编码
"""
# 将域标签扩展到与图像相同的空间尺寸
batch, _, h, w = x.shape
domain_map = domain_labels.view(batch, -1, 1, 1).expand(batch, -1, h, w)
# 拼接图像和域标签
combined = torch.cat([x, domain_map], dim=1)
# 编码
encoded = self.encoder(combined)
# 残差块
residual_out = self.residual_blocks(encoded)
# 解码
output = self.decoder(residual_out)
return output
class StarGANDiscriminator(nn.Module):
"""
StarGAN 判别器
采用分类器结构:
1. 判断图像真假
2. 判断图像属于哪个域
"""
def __init__(self, image_channels=3, num_domains=5, num_filters=64):
super().__init__()
self.num_domains = num_domains
# 共享特征提取
self.feature_extractor = nn.Sequential(
nn.Conv2d(image_channels, num_filters, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_filters, num_filters * 2, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_filters * 2, num_filters * 4, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_filters * 4, num_filters * 8, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True)
)
# 真假判断头
self.adversarial_head = nn.Conv2d(num_filters * 8, 1, 3, stride=1, padding=1)
# 域分类头
self.domain_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(num_filters * 8, num_domains)
)
def forward(self, x):
features = self.feature_extractor(x)
real_fake = self.adversarial_head(features)
domain = self.domain_head(features)
return real_fake, domain9. 大规模GAN
9.1 BigGAN:大规模高保真生成
BigGAN通过大规模训练实现了前所未有的生成质量,核心创新包括:
- 大batch训练:使用2048的batch size
- 类别条件BatchNorm:根据类别动态调整归一化参数
- 截断技巧(Truncation Trick):控制潜在向量的分布
- 正交正则化:稳定训练
class BigGANConfig:
"""BigGAN 配置"""
batch_size = 2048
latent_dim = 120
embedding_dim = 128
channels = {
'G': [16, 16, 16, 8, 8, 8, 4, 4, 2, 2, 1],
'D': [1, 2, 2, 4, 4, 8, 8, 8, 16, 16, 16]
}
class SelfAttention(nn.Module):
"""
自注意力模块
使生成器能够捕获图像中的长距离依赖关系
对于生成复杂场景特别重要
"""
def __init__(self, channels):
super().__init__()
self.channels = channels
# 查询、键、值投影
self.query = nn.Conv2d(channels, channels // 8, 1)
self.key = nn.Conv2d(channels, channels // 8, 1)
self.value = nn.Conv2d(channels, channels, 1)
# 输出投影
self.out = nn.Conv2d(channels, channels, 1)
# 缩放因子
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch_size, C, H, W = x.shape
# 计算 Q, K, V
q = self.query(x).view(batch_size, -1, H * W).permute(0, 2, 1) # (B, HW, C')
k = self.key(x).view(batch_size, -1, H * W) # (B, C', HW)
v = self.value(x).view(batch_size, -1, H * W).permute(0, 2, 1) # (B, HW, C)
# 注意力权重
attention = torch.bmm(q, k) / (self.channels ** 0.5)
attention = F.softmax(attention, dim=-1)
# 应用注意力
out = torch.bmm(attention, v)
out = out.permute(0, 2, 1).view(batch_size, -1, H, W)
out = self.out(out)
# 残差连接
return self.gamma * out + x
class ConditionalBatchNorm2d(nn.Module):
"""
条件BatchNorm
根据类别标签动态调整归一化参数
"""
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
# 初始化嵌入权重
nn.init.zeros_(self.embed.weight[:, :num_features])
nn.init.zeros_(self.embed.weight[:, num_features:])
def forward(self, x, class_labels):
out = self.bn(x)
gamma, beta = self.embed(class_labels).chunk(2, dim=-1)
gamma = gamma.unsqueeze(-1).unsqueeze(-1)
beta = beta.unsqueeze(-1).unsqueeze(-1)
return gamma * out + beta
class BigGANGeneratorBlock(nn.Module):
"""
BigGAN 生成器模块
包含:
1. 上采样
2. 条件BatchNorm
3. 非线性激活
4. 卷积
5. 可选的自注意力
"""
def __init__(self, in_channels, out_channels, latent_dim, num_classes,
upsample=True, attention=False):
super().__init__()
self.attention = SelfAttention(out_channels) if attention else None
# 条件BatchNorm
self.cbn1 = ConditionalBatchNorm2d(in_channels, num_classes)
self.cbn2 = ConditionalBatchNorm2d(out_channels, num_classes)
# 卷积层
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.conv_out = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') if upsample else None
def forward(self, x, latent, class_labels):
# 第一层
x = self.cbn1(x, class_labels)
x = F.relu(x)
if self.upsample:
x = self.upsample(x)
x = self.conv(x)
# 第二层
x = self.cbn2(x, class_labels)
x = F.relu(x)
x = self.conv_out(x)
# 自注意力
if self.attention:
x = self.attention(x)
return x9.2 SAGAN:自注意力GAN
SAGAN将自注意力机制引入GAN,使生成器和判别器都能捕获全局依赖:
class SelfAttentionGAN:
"""
SAGAN (Self-Attention GAN) 完整训练框架
"""
def __init__(self, num_classes=1000, latent_dim=128, image_size=128):
self.num_classes = num_classes
self.latent_dim = latent_dim
self.generator = SAGANGenerator(latent_dim, num_classes, image_size)
self.discriminator = SAGANDiscriminator(num_classes, image_size)
# 谱归一化(用于判别器)
self.apply_spectral_norm(self.discriminator)
@staticmethod
def apply_spectral_norm(model):
"""应用谱归一化到所有卷积层"""
for module in model.modules():
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
nn.utils.spectral_norm(module)
def generate_condition_vector(self, z, class_labels):
"""生成条件向量:潜在向量 + 类别嵌入"""
label_embedding = nn.Embedding(self.num_classes, self.latent_dim)(class_labels)
return z + 0.1 * label_embedding # 残差连接
class SAGANLoss:
"""
SAGAN 的 hinge loss
比原始GAN的logistic loss更稳定
"""
def __init__(self, lambda_gp=10):
self.lambda_gp = lambda_gp
def d_loss(self, real_logit, fake_logit):
"""判别器损失:hinge loss"""
real_loss = F.relu(1 - real_logit).mean()
fake_loss = F.relu(1 + fake_logit).mean()
return (real_loss + fake_loss) / 2
def g_loss(self, fake_logit):
"""生成器损失"""
return -fake_logit.mean()
def gradient_penalty(self, discriminator, real, fake):
"""WGAN-GP风格梯度惩罚"""
alpha = torch.rand(real.size(0), 1, 1, 1)
interpolated = alpha * real + (1 - alpha) * fake
interpolated.requires_grad_(True)
d_interp = discriminator(interpolated)
gradients = torch.autograd.grad(
outputs=d_interp,
inputs=interpolated,
grad_outputs=torch.ones_like(d_interp),
create_graph=True,
retain_graph=True
)[0]
gradients = gradients.view(real.size(0), -1)
grad_norm = gradients.norm(2, dim=1)
penalty = ((grad_norm - 1) ** 2).mean()
return self.lambda_gp * penalty10. GAN的训练技巧与稳定性
10.1 训练不稳定性的根源与解决方案
GAN训练的不稳定性源于多个方面,以下是系统性的解决方案:
问题1:模式崩溃(Mode Collapse)
解决方案:
- 使用Unrolled GAN:让生成器看到判别器的未来更新
- 使用Minibatch Discrimination:在判别器中添加批次判别
- 使用多个判别器或生成器
class MinibatchDiscrimination(nn.Module):
"""
小批次判别
目的:防止生成器只产生相似的样本
方法:在判别器中计算样本之间的相似度
"""
def __init__(self, input_channels, num_kernels, kernel_dim=5):
super().__init__()
self.num_kernels = num_kernels
self.kernel_dim = kernel_dim
# 将输入映射到潜在空间
self.mapping = nn.Linear(input_channels, num_kernels * kernel_dim)
def forward(self, x):
batch_size = x.size(0)
# 映射并重塑
mapped = self.mapping(x)
mapped = mapped.view(batch_size, self.num_kernels, self.kernel_dim)
# 计算每个样本与其他样本的L1距离
x1 = mapped.unsqueeze(2) # (B, num_kernels, 1, kernel_dim)
x2 = mapped.unsqueeze(1) # (B, 1, num_kernels, kernel_dim)
# L1距离
diff = torch.abs(x1 - x2)
distances = torch.sum(diff, dim=3) # (B, num_kernels, num_kernels)
# 取负指数得到相似度
similarities = torch.exp(-distances)
# 对每个样本,计算与其他所有样本的相似度之和(减去自身)
minibatch_features = torch.sum(similarities, dim=2) - 1
return minibatch_features问题2:梯度消失
解决方案:
- 使用Wasserstein距离替代JS散度
- 使用带有动量的优化器
- 避免使用sigmoid交叉熵损失
问题3:平衡困难
解决方案:
- 使用TTUR(Two Timescale Update Rule)
- 判别器多更新几次
- 自适应学习率
class TTUROptimizer:
"""
TTUR:双时间尺度更新规则
生成器和判别器使用不同的学习率
通常判别器用更大的学习率
"""
def __init__(self, generator, discriminator,
g_lr=0.0001, d_lr=0.0004,
b1=0.0, b2=0.999):
self.optimizer_G = optim.Adam(generator.parameters(), lr=g_lr, betas=(b1, b2))
self.optimizer_D = optim.Adam(discriminator.parameters(), lr=d_lr, betas=(b1, b2))
def step(self, generator_loss, discriminator_loss):
# 先更新判别器
self.optimizer_D.zero_grad()
discriminator_loss.backward()
self.optimizer_D.step()
# 再更新生成器
self.optimizer_G.zero_grad()
generator_loss.backward()
self.optimizer_G.step()10.2 谱归一化(Spectral Normalization)
谱归一化是稳定GAN训练的重要技术,它约束判别器的Lipschitz常数:
def apply_spectral_norm_to_conv(conv_layer):
"""
对卷积层应用谱归一化
核心思想:使每一层的Lipschitz常数 ≤ 1
实现:使用权重的最大奇异值进行归一化
"""
return nn.utils.spectral_norm(conv_layer)
class SNDense(nn.Module):
"""谱归一化的全连接层"""
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
nn.utils.spectral_norm(self.linear)
def forward(self, x):
return self.linear(x)
class SNConv2d(nn.Module):
"""谱归一化的卷积层"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding)
nn.utils.spectral_norm(self.conv)
def forward(self, x):
return self.conv(x)
def create_sn_discriminator(base_discriminator):
"""
将任意判别器转换为谱归一化版本
递归遍历所有子模块,将卷积和线性层替换为谱归一化版本
"""
def replace_with_sn(module):
for name, child in module.named_children():
if isinstance(child, nn.Conv2d):
setattr(module, name, SNConv2d(
child.in_channels, child.out_channels,
child.kernel_size, child.stride, child.padding
))
elif isinstance(child, nn.Linear):
setattr(module, name, SNDense(child.in_features, child.out_features))
else:
replace_with_sn(child)
sn_discriminator = copy.deepcopy(base_discriminator)
replace_with_sn(sn_discriminator)
return sn_discriminator10.3 实例归一化与自适应归一化
class AdaptiveInstanceNorm(nn.Module):
"""
自适应实例归一化(AdaIN)
风格迁移的核心操作
"""
def __init__(self, eps=1e-5):
super().__init__()
self.eps = eps
def forward(self, content, style):
"""
content: 内容特征
style: 风格特征
"""
size = content.size()
# 计算content的均值和方差
content_mean = content.view(size[0], size[1], -1).mean(dim=2).view(size[0], size[1], 1, 1)
content_var = content.view(size[0], size[1], -1).var(dim=2, unbiased=False).view(size[0], size[1], 1, 1)
# 计算style的均值和方差
style_mean = style.view(size[0], size[1], -1).mean(dim=2).view(size[0], size[1], 1, 1)
style_var = style.view(size[0], size[1], -1).var(dim=2, unbiased=False).view(size[0], size[1], 1, 1)
# 归一化content,然后用style的统计量进行缩放
content_norm = (content - content_mean) / torch.sqrt(content_var + self.eps)
output = content_norm * torch.sqrt(style_var + self.eps) + style_mean
return output
class LayerNorm2d(nn.Module):
"""
层归一化(针对2D特征图)
与Instance Norm的区别:
- Instance Norm: 对每个样本、每个通道独立归一化
- Layer Norm: 对每个样本、每个位置的所有通道归一化
"""
def __init__(self, eps=1e-5):
super().__init__()
self.eps = eps
def forward(self, x):
# x: (B, C, H, W)
mean = x.mean(dim=(1, 2, 3), keepdim=True)
var = x.var(dim=(1, 2, 3), unbiased=False, keepdim=True)
return (x - mean) / torch.sqrt(var + self.eps)
class GroupNorm(nn.Module):
"""
组归一化(Group Normalization)
不依赖于batch size,适合batch很小时使用
"""
def __init__(self, num_groups, num_channels, eps=1e-5):
super().__init__()
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
def forward(self, x):
# x: (B, C, H, W)
B, C, H, W = x.shape
G = self.num_groups
# 将通道分成G组
x = x.view(B, G, C // G, H, W)
# 在每组内计算均值和方差
mean = x.mean(dim=(2, 3, 4), keepdim=True)
var = x.var(dim=(2, 3, 4), unbiased=False, keepdim=True)
# 归一化
x = (x - mean) / torch.sqrt(var + self.eps)
x = x.view(B, C, H, W)
# 仿射变换
return x * self.weight.view(1, C, 1, 1) + self.bias.view(1, C, 1, 1)11. GAN的评估指标
11.1 Inception Score (IS)
Inception Score是评估GAN生成质量的经典指标:
其中 是Inception网络的类别输出,。
def calculate_inception_score(images, model, num_splits=10):
"""
计算Inception Score
原理:
- 如果生成图像质量高,Inception网络对每张图的预测应该是确定的(低熵)
- 如果生成图像多样性高,各类别的边缘分布应该均匀(高熵)
KL散度 = 条件熵的减少量 = 模型对图像的确定性程度
"""
import torch.nn.functional as F
model.eval()
preds = []
with torch.no_grad():
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size]
batch = F.interpolate(batch, size=(299, 299), mode='bilinear')
pred = model(batch)
preds.append(F.softmax(pred, dim=-1).cpu().numpy())
preds = np.concatenate(preds, axis=0)
# 计算每张图的KL散度
split_scores = []
split_size = preds.shape[0] // num_splits
for k in range(num_splits):
part = preds[k * split_size:(k + 1) * split_size, :]
# 计算 p(y)
py = np.mean(part, axis=0)
# 计算每个样本的KL散度
kl = part * (np.log(part) - np.log(py + 1e-16))
kl = np.sum(kl, axis=1)
split_scores.append(np.exp(np.mean(kl)))
return np.mean(split_scores), np.std(split_scores)11.2 Fréchet Inception Distance (FID)
FID通过比较真实图像和生成图像在特征空间中的分布距离:
def calculate_fid(real_images, fake_images, model):
"""
计算Fréchet Inception Distance
步骤:
1. 用Inception网络提取真实图像和生成图像的特征
2. 假设特征服从多元高斯分布
3. 计算两个高斯分布之间的Fréchet距离
"""
model.eval()
def get_activations(images):
activations = []
with torch.no_grad():
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size]
batch = F.interpolate(batch, size=(299, 299), mode='bilinear')
act = model(batch)
activations.append(act.cpu().numpy())
return np.concatenate(activations, axis=0)
# 获取特征
real_acts = get_activations(real_images)
fake_acts = get_activations(fake_images)
# 计算均值和协方差
mu_real = np.mean(real_acts, axis=0)
mu_fake = np.mean(fake_acts, axis=0)
sigma_real = np.cov(real_acts, rowvar=False)
sigma_fake = np.cov(fake_acts, rowvar=False)
# 计算FID
diff = mu_real - mu_fake
covmean = sqrtm(sigma_real @ sigma_fake)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
return fid
def sqrtm(matrix):
"""矩阵平方根的数值稳定计算"""
eigenvalues, eigenvectors = np.linalg.eigh(matrix)
return eigenvectors @ np.diag(np.sqrt(np.maximum(eigenvalues, 0))) @ eigenvectors.T11.3 Precision, Recall和F1
def calculate_precision_recall(real_features, fake_features, k=3):
"""
计算生成模型的Precision和Recall
Precision:生成样本中有多少可以被视为真实样本的近邻
Recall:真实样本中有多少可以被生成样本覆盖
"""
from sklearn.neighbors import NearestNeighbors
# 使用k近邻
nn_real = NearestNeighbors(n_neighbors=k + 1).fit(real_features)
nn_fake = NearestNeighbors(n_neighbors=k + 1).fit(fake_features)
# 计算fake样本的precision
_, indices = nn_real.kneighbors(fake_features)
precision = np.mean([1 for idx in indices if idx[0] in idx[1:]]) # 简化
# 计算recall
_, indices = nn_fake.kneighbors(real_features)
recall = np.mean([1 for idx in indices if idx[0] in idx[1:]])
# F1分数
f1 = 2 * precision * recall / (precision + recall + 1e-8)
return precision, recall, f1
def manifold_kneighbors(X_train, X_test, n_neighbors=5):
"""
流形k近邻分析
用于评估生成样本是否位于真实数据流形上
"""
from sklearn.neighbors import NearestNeighbors
nn = NearestNeighbors(n_neighbors=n_neighbors)
nn.fit(X_train)
distances, indices = nn.kneighbors(X_test)
return distances, indices12. GAN的应用场景
12.1 图像生成与增强
GAN在图像生成领域的应用包括:
- 人脸合成:StyleGAN系列生成逼真的人脸
- 艺术创作:GAN辅助艺术设计,如DeepDream风格迁移
- 数据增强:生成稀有类别的样本以改善分类器训练
- 超分辨率:ESRGAN等用于图像放大
class DataAugmentationGAN:
"""
用于数据增强的GAN
特别适用于:
1. 医疗影像(数据稀缺)
2. 稀有事件检测
3. 小样本学习
"""
def __init__(self, generator, num_classes):
self.generator = generator
self.num_classes = num_classes
def generate_for_class(self, target_class, num_samples):
"""为特定类别生成样本"""
z = torch.randn(num_samples, self.generator.latent_dim)
labels = torch.full((num_samples,), target_class)
with torch.no_grad():
generated = self.generator(z, labels)
return generated
def augment_dataset(self, dataset, samples_per_class=100):
"""增强整个数据集"""
augmented_images = []
augmented_labels = []
for class_idx in range(self.num_classes):
gen_samples = self.generate_for_class(class_idx, samples_per_class)
augmented_images.append(gen_samples.cpu().numpy())
augmented_labels.extend([class_idx] * samples_per_class)
augmented_images = np.concatenate(augmented_images, axis=0)
augmented_labels = np.array(augmented_labels)
return augmented_images, augmented_labels12.2 图像编辑与操作
GAN使精细的图像编辑成为可能:
- 语义编辑:如GANSpace、InterfaceGAN等
- 局部修改:如Inpainting网络
- 风格混合:不同风格图像的融合
class GANBasedImageEditor:
"""
基于GAN的图像编辑
技术:
1. 潜在空间插值
2. 方向向量操纵
3. 层叠编辑
"""
def __init__(self, generator):
self.G = generator
def interpolate(self, z1, z2, num_steps=10):
"""在两个潜在向量之间插值"""
alphas = np.linspace(0, 1, num_steps)
interpolated = []
for alpha in alphas:
z = alpha * z1 + (1 - alpha) * z2
with torch.no_grad():
img = self.G(z)
interpolated.append(img)
return torch.stack(interpolated)
def find_edit_direction(self, attribute1_images, attribute2_images):
"""
学习属性编辑方向
方法:在两个属性类别的中心点之间建立方向向量
"""
with torch.no_grad():
feats1 = torch.stack([self.extract_feature(img) for img in attribute1_images])
feats2 = torch.stack([self.extract_feature(img) for img in attribute2_images])
center1 = feats1.mean(dim=0)
center2 = feats2.mean(dim=0)
direction = center2 - center1
return direction / (torch.norm(direction) + 1e-8)
def edit_image(self, z, direction, magnitude):
"""
沿方向向量编辑图像
z: 原始潜在向量
direction: 编辑方向
magnitude: 编辑强度
"""
z_edited = z + magnitude * direction
with torch.no_grad():
edited_image = self.G(z_edited)
return edited_image
def layer_wise_edit(self, z, layer_idx, direction):
"""
分层编辑
不同层控制不同级别的特征:
- 低层:颜色、纹理
- 中层:局部结构
- 高层:整体布局、语义
"""
# 实现细节:修改指定层的激活
pass12.3 视频生成
GAN也被扩展到视频生成领域:
- 视频预测:预测未来帧
- 视频风格化:如Vid2Vid
- 舞蹈生成:根据音乐生成舞蹈动作
class VideoGenerator:
"""
视频生成GAN
关键组件:
1. 时序建模:LSTM/Transformer
2. 运动一致性:光流约束
3. 帧间平滑:时序判别器
"""
def __init__(self, image_generator, hidden_dim=256):
self.image_gen = image_generator
self.temporal_model = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
def generate_video(self, num_frames, latent_dim):
"""生成长度为num_frames的视频"""
batch_size = 1
# 初始化隐藏状态
h = torch.zeros(1, batch_size, hidden_dim)
c = torch.zeros(1, batch_size, hidden_dim)
frames = []
z_prev = torch.randn(batch_size, latent_dim)
for t in range(num_frames):
# 时序建模
output, (h, c) = self.temporal_model(z_prev.unsqueeze(1), (h, c))
z_curr = output.squeeze(1) + torch.randn(batch_size, latent_dim) * 0.1
# 生成帧
with torch.no_grad():
frame = self.image_gen(z_curr)
frames.append(frame)
z_prev = z_curr
return torch.stack(frames, dim=1) # (B, T, C, H, W)12.4 文本生成与多模态
GAN与语言模型的结合:
- 文本到图像:如StackGAN、DALL-E系列思想
- 图像到文本:生成描述
- 跨模态转换:如CLIP引导的编辑
class TextToImageGAN:
"""
文本到图像生成
架构:
1. 文本编码器:LSTM/Transformer
2. 层级生成器:从粗到细
3. 匹配判别器:文本-图像匹配判断
"""
def __init__(self, vocab_size, embed_dim, latent_dim):
self.text_encoder = nn.LSTM(
vocab_size, embed_dim,
batch_first=True, bidirectional=True
)
self.stage1_generator = Stage1Generator(embed_dim * 2, latent_dim)
self.stage2_generator = Stage2Generator(embed_dim * 2, latent_dim)
self.matching_discriminator = MatchingDiscriminator()
def encode_text(self, text_tokens):
"""编码文本描述"""
output, (hidden, _) = self.text_encoder(text_tokens)
# 拼接双向隐藏状态
hidden = torch.cat([hidden[-2], hidden[-1]], dim=-1)
return hidden
def generate(self, text, stage=2):
"""从文本生成图像"""
text_features = self.encode_text(text)
if stage == 1:
return self.stage1_generator(text_features)
else:
# 两阶段:从粗到细
coarse = self.stage1_generator(text_features)
refined = self.stage2_generator(text_features, coarse)
return refined13. GAN的最新研究方向
13.1 Diffusion GAN:融合扩散模型
Diffusion Models的兴起促使研究者探索GAN与扩散模型的结合:
class DiffusionGAN:
"""
Diffusion GAN
思想:用GAN加速扩散模型的采样过程
原理:
- 扩散模型需要多步迭代去噪
- GAN可以学习一步到位的去噪
- 两者结合:GAN作为扩散模型的加速器
"""
def __init__(self, latent_dim, diffusion_steps=1000):
self.denoiser = ConditionalGenerator(latent_dim)
self.discriminator = Discriminator()
self.diffusion_steps = diffusion_steps
def forward_diffusion(self, x0, t):
"""前向扩散:添加噪声"""
noise = torch.randn_like(x0)
alpha_bar = self.get_noise_schedule(t)
return torch.sqrt(alpha_bar) * x0 + torch.sqrt(1 - alpha_bar) * noise, noise
def gan_denoise(self, xt, timestep):
"""
用GAN进行去噪
GAN学习从xt到x_{t-1}的映射
比传统去噪网络更高效
"""
t_embedding = self.get_timestep_embedding(timestep)
return self.denoiser(xt, t_embedding)
def train_step(self, real_images, text_features=None):
"""训练Diffusion GAN"""
# 采样时间步
t = torch.randint(0, self.diffusion_steps, (real_images.size(0),))
# 前向扩散
noisy_images, noise = self.forward_diffusion(real_images, t)
# GAN去噪
denoised = self.gan_denoise(noisy_images, t)
# 训练判别器
d_loss = self.compute_discriminator_loss(denoised.detach(), real_images)
# 训练生成器(去噪器)
g_loss = self.compute_generator_loss(denoised, real_images)
return d_loss, g_loss13.2 Transformer GAN
Transformer架构也开始应用于GAN:
class TransformerDiscriminator(nn.Module):
"""
基于Transformer的判别器
使用自注意力捕获图像中的长距离依赖
"""
def __init__(self, patch_size=16, embed_dim=768, num_heads=12, num_layers=12):
super().__init__()
self.patch_embed = nn.Conv2d(3, embed_dim, patch_size, stride=patch_size)
self.pos_embed = nn.Parameter(torch.zeros(1, (256 // patch_size) ** 2, embed_dim))
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
dropout=0.1,
activation='gelu',
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, 1)
def forward(self, x):
# 分patch
x = self.patch_embed(x) # (B, embed_dim, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # (B, num_patches, embed_dim)
# 添加位置编码
x = x + self.pos_embed
# Transformer编码
x = self.transformer(x)
x = self.norm(x)
# 全局池化 + 分类
x = x.mean(dim=1)
return self.head(x)13.3 NeRF-GAN:三维感知生成
class NeRFGAN:
"""
NeRF-GAN:三维感知生成
结合神经辐射场(NeRF)和GAN
支持从任意视角渲染生成的三维内容
"""
def __init__(self):
self.nerf = NeuralRadianceField()
self.rendering_layer = VolumetricRenderer()
def generate_3d_scene(self, z):
"""从潜在向量生成完整的三维场景"""
# 生成场景参数
scene_params = self.decode_latent(z)
# 体积渲染
rgb, depth = self.rendering_layer(
self.nerf,
scene_params,
camera_poses=self.sample_camera_poses()
)
return rgb, depth
def train(self, real_images, camera_poses):
"""训练NeRF-GAN"""
# 渲染伪影
fake_images = self.generate_3d_scene(self.get_latent(real_images))
# GAN损失
d_loss = self.compute_gan_loss(real_images, fake_images)
g_loss = self.compute_generator_loss(fake_images)
# 渲染一致性损失
render_loss = self.compute_render_consistency_loss(real_images, fake_images)
return d_loss, g_loss + render_loss14. 学术引用与参考文献
- Goodfellow, I., et al. (2014). “Generative Adversarial Networks.” NeurIPS.
- Arjovsky, M., Chintala, S., & Bottou, L. (2017). “Wasserstein GAN.” ICML.
- Radford, A., Metz, L., & Chintala, S. (2016). “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.” ICLR.
- Mirza, M., & Osindero, S. (2014). “Conditional Generative Adversarial Nets.” arXiv.
- Karras, T., et al. (2018). “Progressive Growing of GANs for Improved Quality, Stability, and Variation.” ICLR.
- Karras, T., et al. (2019). “A Style-Based Generator Architecture for Generative Adversarial Networks.” CVPR.
- Karras, T., et al. (2021). “Alias-Free Generative Adversarial Networks.” NeurIPS.
- Isola, P., et al. (2017). “Image-to-Image Translation with Conditional Adversarial Networks.” CVPR.
- Zhu, J.-Y., et al. (2017). “Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.” ICCV.
- Choi, Y., et al. (2018). “StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation.” CVPR.
- Brock, A., Donahue, J., & Simonyan, K. (2019). “Large Scale GAN Training for High Fidelity Natural Image Synthesis.” ICLR.
- Zhang, H., et al. (2019). “Self-Attention Generative Adversarial Networks.” ICML.
- Gulrajani, I., et al. (2017). “Improved Training of Wasserstein GANs.” NeurIPS.
- Miyato, T., et al. (2018). “Spectral Normalization for Generative Adversarial Networks.” ICLR.
- Mao, X., et al. (2017). “Least Squares Generative Adversarial Networks.” ICCV.