GAN、Pix2pix 与 CycleGAN 论文阅读笔记
这是实习留下的笔记。
GAN 原理
GAN 简介
生成对抗网络(GAN)架构通过对抗性方法评估生成性模型,要同时训练两个模型
- 生成器G:根据真数据的分布制作假数据
- 判别器D:评估样本来自训练数据(而不是G的假数据)的概率
作者这样形容:生成器类似于一群造假者,试图制造假币,并在不被发现的情况下使用;而判别器D则类似于警察,试图检测假币。这个游戏竞赛驱使双方改进他们的方法,直到仿冒品与正品无法辨别。
使用反向传播和丢弃算法来训练这两个模型,并且只使用前向传播来训练生成器的样本
论文代码仓库:https://github.com/goodfeli/adversarial
GAN公式
参数有:
噪声变量:
生成器G:
判别器D:
极小极大对策,G试图最小化,D试图最大化
一开始,G训练次数不足,总是被D拒绝,所以
优缺点
优点:
- 计算效率高:不需要马尔可夫链,只需使用反向传播来获得梯度,在学习过程中不需要推理
- 统计上,无需直接使用数据样本更新,仅通过判别器的梯度进行学习,避免了将输入数据直接复制到生成器参数中(可以融合多种多样的函数)
- 可以表示非常尖锐,甚至退化的分布,而基于马尔可夫链的方法要求分布有点模糊,以便链能够在模式之间混合(生成结果清晰逼真)
缺点:
- 没有明确的生成数据分布
表示 - 在训练期间
必须与 很好地同步。G在不更新D的情况下不能被训练得太多, 将多个不同的输入 映射到相同的输出 ,从而丧失生成数据的多样性(容易出现不稳定、不收敛、模式崩溃等问题)
Pix2pix原理
Pix2pix简介
Pix2pix是一种用于将输入图像“转换”成相应的输出图像的条件对抗网络(cGAN)架构,即“像素到像素”的映射。
相比于卷积神经网络(CNN),Pix2pix而不需要大量专家来手动设计损失函数,而是自动学习适合于实现该目标的损失函数。
论文代码仓库:https://github.com/phillipi/pix2pix
案例演示:https://phillipi.github.io/pix2pix/
代码也可以看后面的CycleGAN代码的Pix2pix部分
Pix2pix公式
条件对抗网络的目标:
为了测试条件化判别器的重要性,作者还比较了无条件变量,其中判别器去掉了原图像
为了防止模糊,作者将L2距离改成了L1距离。
L1范数是向量元素绝对值的和,也称为曼哈顿范数。它在二维空间中的单位球是一个菱形。
L2范数是向量元素平方和的平方根,也称为欧几里得范数。它在二维空间中的单位球是一个圆。
Pix2pix网络架构
生成器和判别器都用了convolution-BatchNorm-ReLu。
生成器:跳跃连接
对于许多图像转换问题,在输入和输出之间共享了大量的低级信息,因此将这可以些信息直接通过网络传输。为了绕过这些信息瓶颈,添加了跳跃连接,遵循“U网”的一般形状。
判别器:马尔可夫判别器(PatchGAN)
L1和L2损失会导致模糊。传统GAN在遇到L1和L2这类损失时,高频清晰度不太清楚。
作者设计了一种判别器体系结构PatchGAN。该判别器尝试对图像中的切成一个一个N×N块,进行真假分类,在图像上卷积地运行这个判别器并平均化,以提供
常规GAN从
图像映射到单个标量输出,这表示“真”或“伪”,而 PatchGAN 从 映射到输出 的 数组
判别器有效地将图像建模为马尔可夫随机场。
马尔可夫随机场(MRF):一种著名的无向图模型,每个结点表示一个或者一组变量,结点之间的边表示两个变量的依赖关系。
马尔可夫随机场有一组势函数,主要用于定义概率分布函数
CycleGAN 原理
CycleGAN 简介
CycleGAN是一种用于图像↔图像转换的生成对抗网络(GAN)架构,特别适用于没有成对训练数据的情况。它通过引入循环一致性损失,实现了从一个域到另一个域的映射,同时保持输入图像的内容不变。
核心在于循环一致性:作者的举例,就像一个句子从英语翻译到法语,再从法语翻译回英语,应该得到原来的英语句子。图像处理同理。(不是简单的向前-向后一致性)
该方法建立在“Pix2pix”框架之上,Pix2pix也是使用了条件生成对抗性网络。但是区别是,Pix2pix测试数据中的输入-输出必须是成对的,CycleGAN不需要是成对的。
关于CycleGAN的用途,作者给了几个案例,包括
- 艺术画风格转换
- 物品变形
- 季节转换
- 照片与画作转换
- 照片增强
CycleGAN 公式
目标:学习给定训练样本的两个域X和Y之间的映射函数
两个映射:
两个对抗性判别器:
重要术语:对抗性损失、循环一致性损失
例:
对抗性损失
对抗性损失公式:
这个公式和GAN论文里是一样的
循环一致性损失
对抗性损失仅约束了整体分布一致性(因为用了期望值),但无法保证个体输入与期望输出的对应关系。所以引入循环一致性损失。
循环一致性公式:
补充:这里“双竖杠”右下角一个1,是用到了L1范数,用于衡量两个图像之间的差异
整体公式
也就是,目标函数 = X域的对抗性损失 + Y域的对抗性损失 + 重要性常参数 × 循环一致性
随后,作者通过仅保留对抗性损失或者仅保留循环一致性,说明两个公式都缺一不可,而且必须要循环双向。
训练细节
- 整体公式中,设置
- 使用批大小为1的Adam优化器
- 所有网络都是从头开始训练的,学习率为0.0002。在前100个批次保持相同的学习速率0.0002,并在接下来的100个批次中线性衰减到零。
- 生成器结构:对于128×128训练图像,使用6个残差块;对于256×256或更高分辨率的训练图像,使用9个残差块。作者详细地分别列出了两种情况的剩余块使用类型。
- 判别器结构:70×70PatchGAN
论文其它信息
与其它模型比较
有包括和其它相似目标模型的比较,略
局限性
- 几何变换收效甚微,更适合颜色和纹理的变化
- 对训练数据分布敏感,遇到非常规情况可能会混淆
- 与有监督方法有性能差距,可能需要某种形式的弱语义监督
代码
参数
基础参数
在文件options/base_options.py里写了一些参数。
| 参数名 | 默认值 | 类型 | 说明 | |||
|---|---|---|---|---|---|---|
--dataroot |
(必填) | str |
图像路径,包含子目录如 trainA, trainB, valA, valB 等 |
|||
--name |
experiment_name |
str |
实验名称,决定样本与模型保存位置 | |||
--checkpoints_dir |
./checkpoints |
str |
模型保存根目录 | |||
--model |
cycle_gan |
str |
选择模型类型 `[cycle_gan \ | pix2pix \ | test \ | colorization]` |
--input_nc |
3 |
int |
输入图像通道数(RGB=3,灰度=1) | |||
--output_nc |
3 |
int |
输出图像通道数(RGB=3,灰度=1) | |||
--ngf |
64 |
int |
生成器最后一层的滤波器数量 | |||
--ndf |
64 |
int |
判别器第一层的滤波器数量 | |||
--netD |
basic |
str |
判别器架构 `[basic \ | n_layers \ | pixel]`(basic 为 70x70 PatchGAN) | |
--netG |
resnet_9blocks |
str |
生成器架构 `[resnet_9blocks \ | resnet_6blocks \ | unet_256 \ | unet_128]` |
--n_layers_D |
3 |
int |
仅当 netD==n_layers 时使用的层数 |
|||
--norm |
instance |
str |
归一化方式 `[instance \ | batch \ | none \ | syncbatch]` |
--init_type |
normal |
str |
网络权重初始化方式 `[normal \ | xavier \ | kaiming \ | orthogonal]` |
--init_gain |
0.02 |
float |
初始化缩放因子 | |||
--no_dropout |
False |
bool(store_true) |
若指定则不使用生成器 dropout | |||
--dataset_mode |
unaligned |
str |
数据集加载方式 `[unaligned不对齐 \ | aligned对齐 \ | single单个 \ | colorization上色]` |
--direction |
AtoB |
str |
转换方向 AtoB 或 BtoA |
|||
--serial_batches |
False |
bool(store_true) |
若指定则按顺序取图像,否则随机 | |||
--num_threads |
4 |
int |
数据加载线程数 | |||
--batch_size |
1 |
int |
输入批次大小 | |||
--load_size |
286 |
int |
缩放到该尺寸 | |||
--crop_size |
256 |
int |
再裁剪到该尺寸 | |||
--max_dataset_size |
inf |
int(使用 float("inf") 表示不限制) |
数据集最大样本数(超过则截断) | |||
--preprocess |
resize_and_crop |
str |
加载时的缩放与裁剪方式 | |||
--no_flip |
False |
bool(store_true) |
若指定则不进行水平翻转数据增强 | |||
--display_winsize |
256 |
int |
可视化/HTML 显示窗口大小 | |||
--epoch |
latest |
str |
加载哪个 epoch,latest 使用最新缓存模型 |
|||
--load_iter |
0 |
int |
加载指定迭代(>0 时按 iter_[load_iter] 加载) |
|||
--verbose |
False |
bool(store_true) |
若指定则打印更多调试信息 | |||
--suffix |
(空字符串) | str |
自定义后缀,应用于 opt.name(例如 {model}_{netG}_size{load_size}) |
|||
--use_wandb |
False |
bool(store_true) |
若指定则初始化 wandb 日志 | |||
--wandb_project_name |
CycleGAN-and-pix2pix |
str |
wandb 项目名 |
训练参数
在文件options/train_options.py里写了一些参数。
| 参数名 | 默认值 | 类型 | 说明 |
|---|---|---|---|
--display_freq |
400 |
int |
在屏幕上显示训练结果的频率(步数) |
--update_html_freq |
1000 |
int |
保存训练结果到 HTML 的频率(步数) |
--print_freq |
100 |
int |
在控制台打印训练信息的频率(步数) |
--no_html |
False |
bool(store_true) |
不将中间训练结果保存到 [opt.checkpoints_dir]/[opt.name]/web/ |
--save_latest_freq |
5000 |
int |
保存最新模型结果的频率(步数) |
--save_epoch_freq |
5 |
int |
每隔多少个 epoch 保存一次检查点 |
--save_by_iter |
False |
bool(store_true) |
是否按迭代次数保存模型而非按 epoch |
--continue_train |
False |
bool(store_true) |
是否继续训练:加载最新模型并接着训练 |
--epoch_count |
1 |
int |
起始的 epoch 计数(用于保存命名等) |
--phase |
train |
str |
运行阶段,例如 train, val, test |
--n_epochs |
100 |
int |
使用初始学习率训练的 epoch 数 |
--n_epochs_decay |
100 |
int |
线性衰减学习率到 0 所需的 epoch 数 |
--beta1 |
0.5 |
float |
Adam 优化器的 momentum 项(beta1) |
--lr |
0.0002 |
float |
Adam 优化器的初始学习率 |
--gan_mode |
lsgan |
str |
GAN 损失类型,示例:vanilla / lsgan / wgangp |
--pool_size |
50 |
int |
存储之前生成图像的缓冲区大小(用于判别器) |
--lr_policy |
linear |
str |
学习率策略:linear线性衰减 / step阶梯衰减 / plateau监控指标自适应下降 / cosine余弦退火 |
--lr_decay_iters |
50 |
int |
每隔多少次迭代按因子乘以 gamma(仅在某些策略中使用) |
测试参数
在文件options/test_options.py里写了一些参数。
| 参数名 | 默认值 | 类型 | 说明 |
|---|---|---|---|
--results_dir |
./results/ |
str |
保存生成结果的目录 |
--aspect_ratio |
1.0 |
float |
生成结果图片的宽高比 |
--phase |
test |
str |
运行阶段,示例:train, val, test |
--eval |
False |
bool(flag,action='store_true') |
在测试时使用 eval 模式(调用 model.eval()) |
--num_test |
50 |
int |
要运行的测试图片数量 |