PyTorch基础教程及注意事项-数据和模型篇


数据ETL(提取、转换、加载)

数据集:

  1. 内置数据集:
    定义: PyTorch生态系统(专业库)中内置的常用数据集
    torchvision:用于计算机视觉常用数据集
    torchvision.datasets 内部常用数据集
  2. 典型数据集:
数据集 名称 下载加载命令 数据量 用途
MNIST 手写数字图像数据集 torchvision.datasets.MNIST(root, train, download, transform) 7万张手写数字0-9的灰度图像,其中,6万张用于训练,1万张用于测试。每张图像的大小为28×28像素 图像分类
CIFAR-10 彩色图像数据集 torchvision.datasets.CIFAR10(root, train, download, transform) 10个类别、每个类别有6000张图像,总共有5万张训练图像和1万张测试图像,6万张32×32像素彩色图像 图像分类
COCO 通用物体检测、分割、关键点检测数据集 torchvision.datasets.CocoCaptions(root, anaFile, transform) 33 万张图像、150 万目标实例、80 个目标类、91 个物品类以及 25 万关键点人物 图像分割
ImageNet 经典图像数据集 torchvision.datasets.ImageNet(root, split, transform, loader) 120万张训练图像,5万张验证图像和10万张测试图像 图像分类和物体检测
STL-10 彩色图像数据集 torchvision.datasets.STL10(root, split, download, transform) 10个类组成,总共约6000+张96*96像素图像 图像识别
Cityscapes 城市街道场景图像 torchvision.datasets.Cityscapes(root, split, mode, transform) 50 个不同城市街景中记录的视频序列,其包含 20000 个弱注释帧和 5000 帧的高质量像素级 城市街景语义理解
  1. 自定义数据集:

工具:torch.utils.data.Dataset抽象类,从自己的数据源创建自定义数据集。
用法:需要继承该抽象类,并实现如下方法:len(self)(返回数据集中的样本数量),getitem(self, idx)(通过索引返回样本)。

  1. 外置数据集:一般是外部数据库,可以使用psycopg2等模块。
数据库 模块工具 数据库驱动包
GaussDB(DWS) psycopg2/PyGreSQL/psycopg2-binary 可使用开源驱动JDBC/ODBC
GaussDB psycopg2 Psycopg
Opengauss psycopg2 Psycopg
PostgreSQL psycopg2/psycopg3 Python模块工具安装后,不需要再单独在pg库侧安装驱动包
  1. 使用注意:
  • 规格不同按需适配|使用时需要注意模块、模块接口、Python和数据库版本限制。
  • 使用时需要注意防火墙等限制,防火墙即可以阻止下载,也可以阻止连接,一般会导致连接超时报错。
  • 账号密码和数据库配置等信息一般是单独放置,加密保存,可以使用Python模块Crypto、cryptodome或者pycryptodomex库中的加解密方法,获取使用信息。
  • 除了使用连接模块psycopg2的psycopg2.connect()连接数据库外,还可以使用DBUtils.PooledDB包(管理数据库连接池)连接数据库,DBUtils.PooledDB包使用的时候,需要和psycopg2或者importlib一起配合使用。

数据加载、处理和转换

  1. 数据加载器:torch.utils.data.DataLoader,从数据集中按批次加载数据,支持多线程加载加速和数据打乱。
  • 关键参数:torch.utils.data.DataLoader(dataset,batch_size, shuffle,num_workers,drop_last)
    • batch_size: int类型的数据, 每次加载样本的数量,如默认设置为1,那就是一行一行的喂数据给模型,效率比较慢。
    • shuffle: bool数据类型, 是否需要对数据进行洗牌(通常用于训练时将数据打乱使用),如果数据有规律特征(顺序或倒序),则不应该设置为True。
    • num_workers:int类型的数据,默认为0,表示使用主进程导入数据,非负数表示使用多少子进程导入数据。
    • drop_last:bool数据类型, 默认为False,和batch_size配合使用,可用于数据集中不能被batch_size整除中,确认是否丢弃最后一批数据。
  1. 多数据源加载:torch.utils.data.ConcatDataset,自定义加载多数据源,可以将多数据源合并为一个数据集。

数据转换

  1. 目的:将原始数据转换为适合模型训练的数据格式
  2. 组成:
  • 数据预处理:数据归一化,调整数据格式、大小和数据范围,使其适合模型输入。
  • 数据增强:在训练时对数据做变换,例如随机剪裁/翻转等,提高模型泛化能力,避免过拟合。
  1. 一般工具:torchvision视觉库,torchvision.transfroms工具,以下是典型数据转换和增强操作:
    基础数据变换操作
函数 用途
torchvision.transfroms.Compose() 多变换操作组合,按照顺序依次执行
torchvision.transfroms.Resize() 调整图像大小,保证输入到网络的图像大小一致
torchvision.transfroms.ToTensor() 图像转化为Tensor张量,像素数值归一化为[0,1]范围
torchvision.transfroms.Normalize() 图像数据标准化,使数据符合特定均值和标准差
torchvision.transfroms.CenterCrop() 从图像中心剪裁指定大小区域

数据增强操作

函数 用途
torchvision.transfroms.RandomHorizontalFlip() 随机水平翻转图像
torchvision.transfroms.RandomRotation() 随机旋转图像一定角度
torchvision.transfroms.ColorJitter() 调整图像亮度、对比度、饱和度和色调
torchvision.transfroms.RandomCrop() 随机裁剪指定大小的区域
torchvision.transfroms.RandomResizeCrop() 随机裁剪图像并调整到指定大小

模型保存,加载和部署

  1. 功能:训练中断时恢复训练;在不同的训练阶段比较模型性能;方便模型部署和成员之间的模型共享;可用于迁移学习。
  2. 类型:
  • 保存整个模型:torch.save(model, path), 保存模型架构和所有参数。
    优势:完整保留模型结构。
    劣势:文件体积大,对模型类的定义有依赖。
  • 保存模型参数:torch.save(model.dict(), path), 只保留模型的状态字典信息。
    优势:文件小,可加载到不同模型架构中,兼容性好。
    缺点:过程繁琐,使用前需要每次先创建相同架构的模型。
  1. 注意事项:
  • 保存命名:模型和参数的保存要按照架构要求和命名规范进行,要有意义。
  • 设置定期保存命令,最好是每隔几轮训练迭代就保存一次检查点,防止训练出现问题。
  • 在保存完成后,需要测试加载能力,确保保存的模型能正常加载使用。
  • 训练和保存过程中,由于比较耗时,可以在此期间将模型架构、训练参数等信息做文档保存,以备查阅。
  • 模型文件等材料放入Git/svn/Dbox/代码仓的版本管理系统中,做版本管理。
  • 如果无法加载保存旧的版本模型,可以查阅文档,确定当时使用的PyTorch版本后再加载,或者转换模型的格式。

模型加载

跨设备模型加载:

  1. CPU或GPU加载:先保存模型参数到指定路径中,设定模型使用设备(torch.device(“cpu)),再将模型参数加载到模型中(torch.load(参数路径,map_location = ‘cpu)),最后将模型转移到指定的模型使用设备上。
  2. 多GPU模型加载:
  • 目的:调度GPU,实现计算加速, GPU 并行计算(DataParallel 或 torch.distributed)
  • GPU方法和函数清单:
方法函数 说明
torch.cuda.is_available 判断GPU是否可用
torch.device() 创建设备对象,可用于张量设置在GPU上
torch.to(device) 将张量移动到指定设备上
torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’) 优先使用GPU计算,没有GPU情况下用CPU计算
注意事项:
  1. 在设定本次计算使用的工具时,使用torch.device(‘cpu’),cpu需要小写。
  2. 在使用GPU运算时,要求模型和数据都要在GPU或者CPU同一个设备上。
  3. 模型和输入数据使用的设备确定时,可以在模型之前设定,也可以在使用过程中设置,张量和模型必须在同一个设备上可以使用torch.to()工具来转移。
  4. GPU中专用GPU内存和共享GPU内存的差别:专用GPU内存我们通常称为“显存”,就是显卡上独立专门的物理内存。而共享GPU内存是指从系统内存中单独划出,供GPU使用的内存,是显存的补充。可以在CUDA中显式管理共享内存。

分布式模型训练框架和工具:

  1. PyTorch原生分布式能力:torch.DistributedDataParalleltorch.DataParalleltorch.distributed,torch.DistributedDataParallel效率更高。
  • 数据并行模式支持:PyTorch支持数据并行模式,不支持流水线、Tensor、混合和自动并行模式,其1.9.0以上版本支持ZeRO,可以用于torch.DistributedDataParallel的communicationhook调用。torch.DistributedDataParallel支持通讯优化。
  • 计算加速支持:可以使用torch.amp实现低精度训练(1.6以上版本,FP16精度)
  • 内存优化支持:可以使用torch.utils.checkpointing进行重计算
  1. NeMo- Megatron:英伟达开发的分布式训练模型开源框架,支持数据并行和模型并行。
  2. DeepSpeed:微软开发的分布式训练模型开源框架,可以和NeMo- Megatron兼容,是目前主要使用的分布式训练工具。
    大模型分布式训练并行技术:https://zhuanlan.zhihu.com/p/598714869

模型转换

  1. 版本兼容性:在torch.save()中使用_use_new_zipfile_serialization来确保好的兼容性,可以确保仍然能使用1.6版本之前的旧格式。
  2. 格式转换:可以使用torch.jit.scipt()将模型转换为TorchScript格式,再使用torch.jit.save()和torch.jit.load()加载模型。 可以解决旧版本模型加载失败的问题。
  3. 导出格式:
格式 使用命令 特点 适用场景
TorchScript torch.jit.trace() torch.jit.script() Pytorch原生格式,保持动态图特性 Pytorch生态内部
ONNX torch.onnx.export() 开发标准,跨框架兼容 多框架协同环境
Torch-TensorRT import torch_tensorrt nvidia优化格式 GPU推理加速
  1. 注意事项:
  • TorchScript导出前要使用model.eval()调用模型,以转换模型模式状态为评估。

模型部署

  1. 目的:使用AI平台(ModelArts),通过API调用的方式,将AI能力整合到Web等系统中,供调用使用。
  2. 一般部署流程:训练模型-》模型优化-》格式转换-》部署环境选择-》服务封装-》性能监控
  3. 部署方式:
  • 本地部署:ONNX部署,使用onnxruntime模块(ONNX Runtime 是由微软维护的一个跨平台机器学习推理加速器),使用onnxruntime.InferenceSession()加载模型,创建推理会话,再使用onnxruntime.InferenceSession().run()执行推理
  • 云端部署(优先):使用fastapi模块构建REST API,通过API接口调用。 fastapi工具使用。

参考文献

[1] 深度学习与PyTorch入门实战
[2] Zhang, A., Lipton, Z. C., Li, M., & Smola, A. J. (2023). Dive into Deep Learning. Cambridge University Press. URL: https://D2L.ai
[3] PyTorch深度学习
[4] 菜鸟教程
[5] 深入浅出PyTorch


 上一篇
PyTorch基础教程及注意事项-基础篇 PyTorch基础教程及注意事项-基础篇
PyTorch基础教程及注意事项-基础篇,主要介绍PyTorch的安装、环境准备、Anaconda环境配置、PyTorch结构和核心数据结构张量(Tensor)及其使用方法。
2025-08-25
下一篇 
PyTorch基础教程及注意事项-神经网络篇 PyTorch基础教程及注意事项-神经网络篇
PyTorch基础教程及注意事项-神经网络篇,主要介绍神经网络模型、典型神经网络(FNN/CNN/RNN)、神经网络模型训练过程和PyTorch神经网络工具。
2025-08-25
  目录