Go Back
Vision Transformer ViT CIFAR-10 图像分类 PyTorch

ViT-PyTorch — Vision Transformer 复现探索总结

最后编辑: 2026-05-31 16:57

一、项目概述

本项目基于 vit-pytorch 库,使用 SimpleViT(Vision Transformer 的简化实现)在 CIFAR-10 数据集上完成图像分类的快速验证。目标是理解 ViT 架构的核心组件和训练流程,为后续深入探索 ViT 变体(如分段/突变)打下基础。

维度详情
框架vit-pytorch(lucidrains 实现)
模型SimpleViT
数据集CIFAR-10(10 类,60,000 张 32×32 图像)
硬件NVIDIA RTX 3060 Laptop(6GB VRAM)
任务图像分类

二、ViT 架构理解

Vision Transformer(ViT)是 Transformer 架构在计算机视觉领域的直接应用,核心思想是将图像分割成 固定大小的 patch,然后将每个 patch 线性投影为 token,送入标准 Transformer Encoder 处理。

2.1 SimpleViT 的关键组件

•••
输入图像 (256×256)
    ↓
Patch Embedding (patch_size=32 → 8×8=64 patches)
    ↓
Position Embedding + Token 拼接
    ↓
Transformer Encoder × depth=6 层
  ├── LayerNorm → Multi-Head Self-Attention (heads=16)
  └── LayerNorm → MLP (dim=1024 → mlp_dim=2048)
    ↓
LayerNorm + 分类头
    ↓
输出 (num_classes=10)

2.2 与传统 CNN 的区别

维度CNNViT
感受野局部(卷积核限制)全局(自注意力)
归纳偏置强(平移不变性、局部性)弱(需更多数据学习)
计算复杂度线性于输入尺寸二次于 patch 数量
缩放性深度/宽度缩放可扩展到超大算力

SimpleViT 相较于原始 ViT 的简化:去掉了 [CLS] token 的蒸馏机制,使用全局平均池化替代。


三、实验配置

3.1 模型参数

参数说明
image_size256输入图像尺寸
patch_size32每个 patch 的像素大小
num_patches64256/32 × 256/32 = 64
dim1024Transformer 隐层维度
depth6Transformer 层数
heads16多头自注意力头数
mlp_dim2048MLP 前馈网络隐层维度
num_classes10CIFAR-10 类别数

3.2 训练配置

参数
优化器Adam(lr=3e-4)
损失函数CrossEntropyLoss
Batch size64
训练轮数1 epoch(快速验证)
数据集CIFAR-10(50,000 训练 + 10,000 测试)
预处理Resize(256) → ToTensor

四、复现过程

4.1 关键调整

复现时遇到的主要问题与修正:

问题原值修正值
类别数不匹配1000(ImageNet 默认)10(CIFAR-10)
图像尺寸无限制统一 Resize 到 256
训练轮数默认过多1 epoch 快速验证

4.2 运行结果

•••
step 0, loss: 2.312
step 50, loss: 2.006
step 100, loss: 1.874
step 150, loss: 1.801
...
Simple_ViT 在 3060 Laptop 上训练成功!

损失从 2.31 稳定下降至 ~1.7,模型在单 epoch 内有效收敛。在 RTX 3060 Laptop(6GB)上运行流畅,未出现显存不足。


五、总结

维度成果
环境搭建vit-pytorch 库正确安装,CUDA 可用
代码适配CIFAR-10 类别数修正(1000→10),图像预处理适配
训练验证1 epoch 快速训练成功,loss 正常下降
硬件兼容RTX 3060 6GB 运行流畅

后续探索方向

  • 完整训练 50-100 epoch,评估测试集准确率
  • 对比不同 patch_size(16, 32, 64)对精度的影响
  • 加入数据增强(RandomCrop, HorizontalFlip, ColorJitter)
  • 尝试 ViT 变体(CvT, Swin Transformer, PiT)
  • ViT 突变的系统性消融实验(注意力头数、深度、维度的影响)