Vision Transformer ViT CIFAR-10 图像分类 PyTorch
ViT-PyTorch — Vision Transformer 复现探索总结
最后编辑: 2026-05-31 16:57
一、项目概述
本项目基于 vit-pytorch 库,使用 SimpleViT(Vision Transformer 的简化实现)在 CIFAR-10 数据集上完成图像分类的快速验证。目标是理解 ViT 架构的核心组件和训练流程,为后续深入探索 ViT 变体(如分段/突变)打下基础。
| 维度 | 详情 |
|---|---|
| 框架 | vit-pytorch(lucidrains 实现) |
| 模型 | SimpleViT |
| 数据集 | CIFAR-10(10 类,60,000 张 32×32 图像) |
| 硬件 | NVIDIA RTX 3060 Laptop(6GB VRAM) |
| 任务 | 图像分类 |
二、ViT 架构理解
Vision Transformer(ViT)是 Transformer 架构在计算机视觉领域的直接应用,核心思想是将图像分割成 固定大小的 patch,然后将每个 patch 线性投影为 token,送入标准 Transformer Encoder 处理。
2.1 SimpleViT 的关键组件
•••
输入图像 (256×256)
↓
Patch Embedding (patch_size=32 → 8×8=64 patches)
↓
Position Embedding + Token 拼接
↓
Transformer Encoder × depth=6 层
├── LayerNorm → Multi-Head Self-Attention (heads=16)
└── LayerNorm → MLP (dim=1024 → mlp_dim=2048)
↓
LayerNorm + 分类头
↓
输出 (num_classes=10)2.2 与传统 CNN 的区别
| 维度 | CNN | ViT |
|---|---|---|
| 感受野 | 局部(卷积核限制) | 全局(自注意力) |
| 归纳偏置 | 强(平移不变性、局部性) | 弱(需更多数据学习) |
| 计算复杂度 | 线性于输入尺寸 | 二次于 patch 数量 |
| 缩放性 | 深度/宽度缩放 | 可扩展到超大算力 |
SimpleViT 相较于原始 ViT 的简化:去掉了 [CLS] token 的蒸馏机制,使用全局平均池化替代。
三、实验配置
3.1 模型参数
| 参数 | 值 | 说明 |
|---|---|---|
| image_size | 256 | 输入图像尺寸 |
| patch_size | 32 | 每个 patch 的像素大小 |
| num_patches | 64 | 256/32 × 256/32 = 64 |
| dim | 1024 | Transformer 隐层维度 |
| depth | 6 | Transformer 层数 |
| heads | 16 | 多头自注意力头数 |
| mlp_dim | 2048 | MLP 前馈网络隐层维度 |
| num_classes | 10 | CIFAR-10 类别数 |
3.2 训练配置
| 参数 | 值 |
|---|---|
| 优化器 | Adam(lr=3e-4) |
| 损失函数 | CrossEntropyLoss |
| Batch size | 64 |
| 训练轮数 | 1 epoch(快速验证) |
| 数据集 | CIFAR-10(50,000 训练 + 10,000 测试) |
| 预处理 | Resize(256) → ToTensor |
四、复现过程
4.1 关键调整
复现时遇到的主要问题与修正:
| 问题 | 原值 | 修正值 |
|---|---|---|
| 类别数不匹配 | 1000(ImageNet 默认) | 10(CIFAR-10) |
| 图像尺寸 | 无限制 | 统一 Resize 到 256 |
| 训练轮数 | 默认过多 | 1 epoch 快速验证 |
4.2 运行结果
•••
step 0, loss: 2.312
step 50, loss: 2.006
step 100, loss: 1.874
step 150, loss: 1.801
...
Simple_ViT 在 3060 Laptop 上训练成功!损失从 2.31 稳定下降至 ~1.7,模型在单 epoch 内有效收敛。在 RTX 3060 Laptop(6GB)上运行流畅,未出现显存不足。
五、总结
| 维度 | 成果 |
|---|---|
| 环境搭建 | vit-pytorch 库正确安装,CUDA 可用 |
| 代码适配 | CIFAR-10 类别数修正(1000→10),图像预处理适配 |
| 训练验证 | 1 epoch 快速训练成功,loss 正常下降 |
| 硬件兼容 | RTX 3060 6GB 运行流畅 |
后续探索方向:
- 完整训练 50-100 epoch,评估测试集准确率
- 对比不同 patch_size(16, 32, 64)对精度的影响
- 加入数据增强(RandomCrop, HorizontalFlip, ColorJitter)
- 尝试 ViT 变体(CvT, Swin Transformer, PiT)
- ViT 突变的系统性消融实验(注意力头数、深度、维度的影响)