Go Back
工作总结与学习

五月工作总结

最后编辑: 2026-05-31 16:19

五月工作总结

  • 复现并改进了 CycleGAN-and-pix2pix 的项目,完成自己自定义的风格迁移(浮世绘风格);改进 cyclegan 开源项目,引入自注意力机制来扩大感受野,降低循环一致性损失,并将结果提交至 github repository,学习 git 命令的同时了解如何建立和维护自己的私有库;
  • 复现与部署 xlstm-mixer 项目,拓展到天气预测领域;
  • princess_xuge_project 项目,完成了一套网站从初始代码到部署到自己专属的域名并维护的流程;
  • 学习《动手学深度学习》、《动手学 ROS2》;
  • 完成毕业设计。

一、复现 CycleGAN-and-pix2pix 项目,并进行了改进:

首先贴出初始作者 Junyan Zhu 的工作:
Github 仓库:
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
论文:
"E:\pytorch_application\cyclegan.pdf"

Official resources:

原项目结构详解

该项目基于生成对抗循环神经网络实现了一个通用的图像到图像翻译框架。它用一套统一的训练/测试骨架,同时支持:

  • 无配对训练的 CycleGAN
  • 有配对训练的 pix2pix
  • 基于 pix2pix 的 colorization
  • 单边推理用的 test 模式

这个框架最重要的设计思想是:

  • 训练入口统一
  • 数据集接口统一
  • 模型接口统一
  • 参数系统可扩展
  • 具体模型和具体数据集通过“字符串 + 动态导入”装配

换成人话就是 train.pytest.py 并不知道自己在跑哪一个具体模型,它们只依赖统一接口。先说怎么工作:

先说目录中哪些是源码,哪些是实验产物
核心源码目录
  • train.py
  • test.py
  • options/
  • data/
  • models/
  • util/
  • scripts/
  • docs/

本地实验/运行产物目录

  • checkpoints/
    • 保存训练权重、训练过程网页和日志
  • results/
    • 保存测试推理结果
  • wandb/
    • 保存 Weights & Biases 日志
  • datasets/
    • 保存样例数据集和数据准备脚本
  • my_test_images/
    • 本地测试图片
  • recordings.md
    • 本地补充文档,不属于框架核心源码

核心模型层工作方式:两个生成器两个判别器

•••
CycleGAN 的核心前向是:

- `fake_B = G_A(real_A)`
- `rec_A = G_B(fake_B)`
- `fake_A = G_B(real_B)`
- `rec_B = G_A(fake_A)`

含义:

- `G_A` 把 A 域图翻译成 B 域风格
- `G_B` 再把翻译结果翻回 A 域

为什么要做改进?原始模型有什么不足:

问题维度原始实现主要缺陷
生成器ResNet-9blocks + ConvTranspose2d单尺度局部特征提取,缺乏全局感受野;反卷积易产生棋盘格伪影
判别器70×70 PatchGAN (Markovian)感受野局限于局部窗口,无法捕获全局结构一致性
对抗损失LSGAN (最小二乘)虽比原始 GAN 稳定,但分布对齐问题未根本解决
循环损失L1 逐像素损失忽略图像结构/纹理的感知相似性,与人类审美有偏差

你怎么做改进?
改进点一:自注意力残差块 + U-Net 跳跃连接生成器
核心思想:
在原始 ResNet Generator 的基础上做两个结构性改变:

  • 引入 Self-Attention:让网络在 bottleneck 层获得全局感受野,建模长距离像素依赖关系
  • 引入 U-Net 式跳跃连接:将编码器各层特征直接传递给解码器对应层,保留多尺度空间信息

Self_Attention 模块

•••
class Self_Attention(nn.Module):
"""SAGAN 风格的自注意力模块,为卷积网络提供全局感受野"""

```
def __init__(self, in_dim, activation):
    super(Self_Attention, self).__init__()
    self.channel_in = in_dim
    self.activation = activation

    # 1×1 卷积降维:query, key 通道数缩小 8 倍以降低计算量
    self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
    self.key_conv   = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
    self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)

    # gamma 初始化为 0:训练初期让网络先学局部特征,再逐步引入全局注意力
    self.gamma = nn.Parameter(torch.zeros(1))
    self.softmax = nn.Softmax(dim=-1)

def forward(self, x):
    m_batchsize, C, width, height = x.size()

    # Q: B × N × C'   (N = W*H)
    proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)

    # K: B × C' × N
    proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)

    # 注意力图: B × N × N  (每个像素对每个像素的关注度)
    energy = torch.bmm(proj_query, proj_key)
    attention = self.softmax(energy)

    # V: B × C × N
    proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)

    # 加权聚合: B × C × N
    out = torch.bmm(proj_value, attention.permute(0, 2, 1))
    out = out.view(m_batchsize, C, width, height)

    # 残差连接,gamma 从 0 开始学习
    out = self.gamma * out + x
    return out

设计要点

  • gamma 初始化为 0:训练初期自注意力分支贡献为 0,等效于原始卷积,让网络先建立局部特征基础,随后逐步学习全局依赖
  • query/key 通道压缩 8 倍:对于 256×256 特征图,注意力矩阵为 65536×65536,通道压缩可将计算量从 O(C²) 降至 O(C²/64)

SEA_ResnetBlock (自注意力增强残差块)

•••
class SEA_ResnetBlock(nn.Module):
    """融合自注意力机制的残差块:局部卷积 + 全局注意力 + 恒等跳跃连接"""

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        super(SEA_ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
        self.self_attention = Self_Attention(dim, 'relu')

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError(f'padding [{padding_type}] is not implemented')

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim), nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim)]
        return nn.Sequential(*conv_block)

    def forward(self, x):
        # 局部特征 (conv_block) + 全局特征 (self_attention) + 恒等映射 (x)
        out = self.conv_block(x) + self.self_attention(x) + x
        return out

设计要点

  • 三条路径并行:conv_block(x) 提取局部纹理 → self_attention(x) 建模全局依赖 → x 恒等映射防止退化
  • 残差嵌套:自注意力本身也是残差形式 (gamma * attention_out + x),形成双层残差保护

2.4 U-Net + 自注意力生成器 (Unet_SEA_ResnetGenerator)

•••
class Unet_SEA_ResnetGenerator(nn.Module):
    """U-Net 架构 + 自注意力残差块的生成器

    结构:
        Encoder (下采样 ×3):  256→128→64
        Bottleneck (SEA_resnet): 64×64 特征提取 + 全局自注意力
        Decoder (上采样 ×2):  64→128→256 (U-Net 跳跃连接)
    """

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.InstanceNorm2d,
                 use_dropout=False, n_blocks=9, padding_type='reflect'):
        assert n_blocks >= 0
        super(Unet_SEA_ResnetGenerator, self).__init__()

        use_bias = True  # InstanceNorm2d 不需要 bias

        # ── Encoder ──
        self.pad = nn.ReflectionPad2d(3)
        self.enc1 = nn.Sequential(
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
            norm_layer(ngf), nn.ReLU(True)
        )                                                           # 256×256 → 256×256

        self.enc2 = nn.Sequential(
            nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
            norm_layer(ngf * 2), nn.ReLU(True)
        )                                                           # 256×256 → 128×128

        self.enc3 = nn.Sequential(
            nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1, bias=use_bias),
            norm_layer(ngf * 4), nn.ReLU(True)
        )                                                           # 128×128 → 64×64

        # ── Bottleneck: SEA_ResnetBlocks + 原始 ResnetBlocks ──
        bottleneck = []
        # 前 3 个使用 SEA_ResnetBlock (带自注意力)
        for i in range(3):
            bottleneck += [SEA_ResnetBlock(ngf * 4, padding_type=padding_type,
                                           norm_layer=norm_layer, use_dropout=use_dropout,
                                           use_bias=use_bias)]
        # 剩余使用原始 ResnetBlock (节省计算量)
        for i in range(n_blocks - 3):
            bottleneck += [ResnetBlock(ngf * 4, padding_type=padding_type,
                                       norm_layer=norm_layer, use_dropout=use_dropout,
                                       use_bias=use_bias)]
        self.bottleneck = nn.Sequential(*bottleneck)

        # ── Decoder (U-Net skip connections) ──
        # 上采样1: 输入通道 = bottleneck(ngf*4) + enc3 skip(ngf*4) = ngf*8
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 8, ngf * 2, kernel_size=3, stride=2,
                               padding=1, output_padding=1, bias=use_bias),
            norm_layer(ngf * 2), nn.ReLU(True)
        )                                                           # 64×64 → 128×128

        # 上采样2: 输入通道 = dec1(ngf*2) + enc2 skip(ngf*2) = ngf*4
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 4, ngf, kernel_size=3, stride=2,
                               padding=1, output_padding=1, bias=use_bias),
            norm_layer(ngf), nn.ReLU(True)
        )                                                           # 128×128 → 256×256

        # 输出层: 输入通道 = dec2(ngf) + enc1 skip(ngf) = ngf*2
        self.out_conv = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf * 2, output_nc, kernel_size=7, padding=0),
            nn.Tanh()
        )

    def forward(self, x):
        # Encoder (保存中间特征用于 U-Net skip connections)
        e1 = self.enc1(self.pad(x))          # ngf   × 256 × 256
        e2 = self.enc2(e1)                    # ngf*2 × 128 × 128
        e3 = self.enc3(e2)                    # ngf*4 × 64  × 64

        # Bottleneck with self-attention
        b = self.bottleneck(e3)               # ngf*4 × 64  × 64

        # Decoder with U-Net skip connections
        d1 = self.dec1(torch.cat([b, e3], 1))  # ngf*2 × 128 × 128
        d2 = self.dec2(torch.cat([d1, e2], 1)) # ngf   × 256 × 256
        out = self.out_conv(torch.cat([d2, e1], 1))

        return out

与原方案的关键区别(Bug 修正)

问题原代码修正
归一化层通道不匹配norm_layer(input_nc)复用于所有层每层独立的 norm_layer(ch)
SA 块定义但未使用self.SA,self.Sa_block_3 定义了但 forward 未调用整合进 bottleneck Sequential
Self_Attention_no_connect 未定义引用了不存在的类移除,用标准 Self_Attention
SEA_Block_3 未定义引用了不存在的类SEA_ResnetBlock 替代
上采样通道计算未考虑 U-Net concat 导致的通道翻倍正确计算 ngf*8,ngf*4,ngf*2
只用了 1 个 ResnetBlockself.resnet(x3)只有 1 块用 9 块(3 SEA + 6 原始)

3. 改进点二:Auto-Encoder 判别器 (EBGAN 风格)

3.1 核心思想

传统 CycleGAN 判别器输出一个标量(真/假概率),新判别器采用 Auto-Encoder 结构:输入图像 → 编码 → 压缩嵌入 → 解码 → 重建图像。

工作原理

  • 判别器学习"重建真实图像"的能力
  • 真实图片 → D 能较好地重建(低 MSE)
  • 生成图片 → D 重建质量差(高 MSE)
  • 生成器目标:让 D 也能重建 G 的输出 → 迫使 G 生成"像真实数据分布"的图像

理论优势

  • 判别器可以 独立预训练(不需要生成器),先学会重建真实数据分布
  • 避免了 GAN 训练中的模式坍塌和训练不平衡
  • 能量基模型(EBGAN)框架,训练更稳定

3.2 辅助模块

•••
class conv_block(nn.Module):
    """判别器编码器下采样块:Conv(3×3, stride=2) → ELU → Conv(3×3, stride=1)"""
    def __init__(self, in_ch, out_ch):
        super(conv_block, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1),
            nn.ELU(True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.ELU(True)
        )
​
    def forward(self, x):
        return self.block(x)
​
​
class deconv_block(nn.Module):
    """判别器解码器上采样块:Upsample(×2) → Conv(3×3) → ELU → Conv(3×3) → ELU"""
    def __init__(self, in_ch, out_ch):
        super(deconv_block, self).__init__()
        self.block = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.ELU(True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.ELU(True)
        )
​
    def forward(self, x):
        return self.block(x)

设计要点:解码器使用 Upsample + Conv 而非 ConvTranspose2d,避免棋盘格伪影。

3.3 Discriminator (Auto-Encoder)

•••
class AER_Discriminator(nn.Module):
    """Auto-Encoder 风格的能量基判别器
​
    Encoder: 256×256 → 128 → 64 → 32 → 16 → 8×8 (特征图)
             → 展平 → Linear(64) → Linear(ndf*8*8) → 重塑
    Decoder: 8×8 → 16 → 32 → 64 → 128 → 256×256 (重建图像)
​
    输出: 重建的图像 (与输入同尺寸),MSE 作为"能量"(真实图像能量低,生成图像能量高)
    """
​
    def __init__(self, input_nc, ndf=64):
        super(AER_Discriminator, self).__init__()
​
        # ── Encoder: 256 → 128 → 64 → 32 → 16 → 8 ──
        self.enc1 = nn.Sequential(
            nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1),
            nn.ELU(True),
            conv_block(ndf, ndf)
        )                                                           # 256×256 → 128×128
​
        self.enc2 = conv_block(ndf, ndf * 2)                        # 128×128 → 64×64
        self.enc3 = conv_block(ndf * 2, ndf * 3)                    # 64×64  → 32×32
        self.enc4 = conv_block(ndf * 3, ndf * 4)                    # 32×32  → 16×16
        self.enc5 = conv_block(ndf * 4, ndf * 5)                    # 16×16  → 8×8
​
        self.enc6 = nn.Sequential(
            nn.Conv2d(ndf * 5, ndf * 5, kernel_size=3, stride=1, padding=1),
            nn.ELU(True),
            nn.Conv2d(ndf * 5, ndf * 5, kernel_size=3, stride=1, padding=1),
            nn.ELU(True)
        )                                                           # 8×8 → 8×8
​
        # 压缩嵌入: 将特征压缩到 64 维再解压(类似 auto-encoder bottleneck)
        self.embed1 = nn.Linear(ndf * 5 * 8 * 8, 64)
        self.embed2 = nn.Linear(64, ndf * 8 * 8)
​
        # ── Decoder: 8 → 16 → 32 → 64 → 128 → 256 ──
        self.dec1 = deconv_block(ndf, ndf)                          # 8×8   → 16×16
        self.dec2 = deconv_block(ndf, ndf)                          # 16×16 → 32×32
        self.dec3 = deconv_block(ndf, ndf)                          # 32×32 → 64×64
        self.dec4 = deconv_block(ndf, ndf)                          # 64×64 → 128×128
        self.dec5 = deconv_block(ndf, ndf)                          # 128×128 → 256×256
​
        self.dec6 = nn.Sequential(
            nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1),
            nn.ELU(True),
            nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1),
            nn.ELU(True),
            nn.Conv2d(ndf, input_nc, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )                                                           # 256×256 → 256×256
​
        self.ndf = ndf
​
    def forward(self, x):
        # Encoder
        out = self.enc1(x)
        out = self.enc2(out)
        out = self.enc3(out)
        out = self.enc4(out)
        out = self.enc5(out)
        out = self.enc6(out)                         # B × (ndf*5) × 8 × 8
​
        # Bottleneck compression
        out = out.view(out.size(0), self.ndf * 5 * 8 * 8)
        out = self.embed1(out)                       # B × 64
        out = self.embed2(out)                       # B × (ndf*8*8)
        out = out.view(out.size(0), self.ndf, 8, 8) # B × ndf × 8 × 8
​
        # Decoder
        out = self.dec1(out)
        out = self.dec2(out)
        out = self.dec3(out)
        out = self.dec4(out)
        out = self.dec5(out)
        out = self.dec6(out)                         # B × input_nc × 256 × 256
​
        return out

3.4 与 WGAN-GP 的兼容性问题

重要:原方案同时提出了 EBGAN 式 AE 判别器和 WGAN-GP 梯度惩罚,但这两者是 互不兼容 的。

判别器类型输出损失函数梯度惩罚兼容
PatchGAN (原始)标量场 (N×N)LSGAN / BCEWGAN-GP 需改为标量输出
Auto-Encoder (EBGAN)重建图像MSE(输入, 重建)❌ 不兼容
WGAN Critic标量Wasserstein 距离✅ GP 原生支持

建议:使用 AE 判别器时,判别器损失应使用 MSE 重建损失,而非 WGAN-GP。如果坚持要用 WGAN-GP,需要将判别器改回卷积分类器结构(原始 PatchGAN + 输出标量)。

本方案推荐 AE 判别器 + MSE 损失 的 EBGAN 路线,理由:

  • AE 判别器可以独立预训练,更稳定
  • 对风格迁移任务(浮世绘),重建约束比对抗约束更温和
  • 梯度惩罚的正则化效果在 AE 框架中自然存在(MSE 天然平滑)

4. 改进点三:LPIPS 感知损失替代 L1 循环一致性损失

4.1 为什么 L1 不够好

原始 CycleGAN 使用 nn.L1Loss() 作为循环一致性损失:

•••
# models/cycle_gan_model.py 第93行
self.criterionCycle = torch.nn.L1Loss()

L1 逐像素比较两个图像,问题在于:

  • 两张感知上相似的图像可能像素值差异很大(如轻微平移 1px,人眼无感但 L1 很大)
  • 两张像素值接近的图像可能感知差异很大(如纹理替换,L1 小但风格完全不同)
  • L1 忽略了图像的结构、纹理、高层语义信息

4.2 LPIPS (Learned Perceptual Image Patch Similarity)

LPIPS 使用预训练的 AlexNet/VGG 提取多层特征,在 特征空间 而非像素空间计算距离:

•••
import lpips
​
# 在 CycleGANModel.__init__ 中替换
self.criterionCycle = lpips.LPIPS(net='alex').to(self.device)
​
# 在 backward_G 中使用 (LPIPS 值域约 0~1,乘以系数调整)
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A).mean() * 2.0 * lambda_A
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B).mean() * 2.0 * lambda_B

注意事项

  • LPIPS 输出是 per-image 向量(B × 1 × 1 × 1),需要 .mean() 归约为标量
  • 系数 2.0 根据实际训练效果调整,也可用原 lambda_A=10.0 直接乘
  • 安装依赖:pip install lpips
  • AlexNet backbone 比 VGG 更轻量,适合 6GB VRAM

4.3 对 identity loss 的影响

identity 损失也可以同步升级为 LPIPS(可选):

•••
# 可选改进
self.criterionIdt = lpips.LPIPS(net='alex').to(self.device)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B).mean() * lambda_idt

但 identity 损失本身权重较小(0.5),使用 LPIPS 的收益可能不如 cycle loss 显著。建议先只改 cycle loss,观察效果后再决定。

5. 改进点四:训练技巧增强稳定性

5.1 Label Smoothing (标签平滑/噪声标签)

原理:不让判别器对"真/假"过于自信,防止判别器过拟合和梯度消失。

在 ​GANLoss.__init__ 中修改

重要:此技巧仅适用于原始 PatchGAN 判别器(输出标量标签)。AE 判别器使用 MSE 重建损失,不涉及真/假标签,因此 不需要 此技巧。

对于仍使用 PatchGAN 的训练分支:

•••
class GANLoss(nn.Module):
    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
                 use_label_noise=True):
        super(GANLoss, self).__init__()
        self.use_label_noise = use_label_noise
        # ... 其余不变 ...
​
    def get_target_tensor(self, prediction, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        target = target_tensor.expand_as(prediction)
​
        # 标签噪声: Real → [0.7, 1.2], Fake → [0.0, 0.3]
        if self.training and self.use_label_noise:
            if target_is_real:
                noise = torch.rand_like(target) * 0.5 + 0.7  # U(0.7, 1.2)
            else:
                noise = torch.rand_like(target) * 0.3          # U(0.0, 0.3)
            return noise
        return target

5.2 输入加噪

原理:向判别器输入添加高斯噪声,作为正则化手段,防止判别器记住训练样本。

在 ​backward_D_A 和 ​backward_D_B 中修改

•••
def backward_D_A(self):
    """对 D_A 的输入添加噪声后再计算损失"""
    fake_B = self.fake_B_pool.query(self.fake_B)
​
    # 添加高斯噪声 (std=0.05 为经验值)
    if self.isTrain:
        real_B_noisy = self.real_B + torch.randn_like(self.real_B) * 0.05
        fake_B_noisy = fake_B + torch.randn_like(fake_B) * 0.05
    else:
        real_B_noisy = self.real_B
        fake_B_noisy = fake_B
​
    self.loss_D_A = self.backward_D_basic(self.netD_A, real_B_noisy, fake_B_noisy)

同样修改 backward_D_B

AE 判别器适配:对于 AE 判别器,加噪后的图像仍然需要 D 能够重建出 原始无噪图像。即 loss_D = MSE(D(real+noise), real),以此增强判别器去噪/重建能力。

5.3 判别器训练频率加倍

原理:每轮迭代训练 G 一次,但训练 D 三次。让判别器保持领先优势,迫使生成器更努力地生成逼真图像。

在 ​optimize_parameters 中修改

•••
def optimize_parameters(self):
    """每轮迭代: 更新 G 一次, 更新 D n_critic 次"""
    n_critic = 3  # 判别器更新次数
​
    # ── 更新 Generator ──
    self.forward()
    self.set_requires_grad([self.netD_A, self.netD_B], False)
    self.optimizer_G.zero_grad()
    self.backward_G()
    self.optimizer_G.step()
​
    # ── 更新 Discriminator (n_critic 次) ──
    self.set_requires_grad([self.netD_A, self.netD_B], True)
    for _ in range(n_critic):
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()

注意:对于 AE 判别器,额外训练 D 特别重要——AE 判别器需要更多步数来学习数据分布的重建。

6. 关键问题诊断与修正

6.1 原方案中的代码 Bug 总结

#严重程度位置问题影响
1CriticalUnet_SEA_ResnetGeneratornorm_layer(input_nc)被所有层复用,通道数不匹配运行时报错
2CriticalUnet_SEA_ResnetGenerator.forward()SA,Sa_block_3,Sa_resnetblock_1 定义了但从未调用自注意力无效
3HighDiscriminator.__init__conv_block,deconv_block 使用但未定义运行时报错
4Highgradient_penaltyreal*epsilon*fake*(1-epsilon)应为 real*epsilon + fake*(1-epsilon)WGAN-GP 计算错误
5Medium整体架构WGAN-GP 与 AE 判别器架构不兼容损失函数设计冲突
6MediumUnet_SEA_ResnetGeneratorSelf_Attention_no_connectSEA_Block_3 未定义无法实例化
7LowUnet_SEA_ResnetGeneratorself.resnet 只用 1 个 ResnetBlock,bottleneck 深度不足特征提取能力弱

6.2 修正后的完整模型集成

models/networks.py 中注册新网络:

•••
# 在 define_G 函数中添加新选项
def define_G(input_nc, output_nc, ngf, netG, norm='instance', use_dropout=False,
             init_type='normal', init_gain=0.02):
    # ... 现有代码 ...
​
    elif netG == 'unet_sea_resnet_9blocks':
        net = Unet_SEA_ResnetGenerator(input_nc, output_nc, ngf,
                                        norm_layer=norm_layer,
                                        use_dropout=use_dropout,
                                        n_blocks=9)
​
    return init_net(net, init_type, init_gain)
​
# 在 define_D 函数中添加新选项
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='instance',
             init_type='normal', init_gain=0.02):
    # ... 现有代码 ...
​
    elif netD == 'ae_res':
        net = AER_Discriminator(input_nc, ndf)
​
    return init_net(net, init_type, init_gain)