pytorch使用多显卡并行加速训练模型(nn.DataParallel)
2022/8/7 23:23:07
本文主要是介绍pytorch使用多显卡并行加速训练模型(nn.DataParallel),对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
torch.nn.DataParallel是一种能够将数据分散到多张显卡上从而加快模型训练的方法。
它的原理是首先在指定的每张显卡上拷贝一份模型,然后将输入的数据分散到各张显卡上,计算梯度,回传到第一张显卡上,然后再对模型进行参数优化。
所以,第一张显卡的负载往往更高,但由于该方法集成度高,书写简便,使用仍十分广泛。
示例:
import torch import torch.nn as nn ... gpu_num = x # 可用的gpu数量 model = Model() if gpu_num == 1: # 单卡 model = model.cuda(0) else: # 多卡 device_ids = list(range(gpu_num)) model = nn.DataParallel(model, device_ids=device_ids).cuda(device=device_ids[0]) ... # 所有数据都需要先放到指定的第一张显卡上才能进行多卡训练 data = data.cuda(0) ... # train ...
***注意使用nn.DataParellel时,模型后会自动添加一个.module的属性,在save的时候会将其保存下来,所以在load该模型时需要去掉字典key中的'.module'字符串
***在使用nn.DataParellel时,由于自动添加了module模型,因此需要分块训练模型的时候,也需要将模型块名更改。
例如:
# 原optimizer定义 optimizer = optim.Adam(params=model.part.parameters(), lr=0.00001) # 使用多卡训练后 optimizer = optim.Adam(params=model.module.part.parameters(), lr=0.00001)
这篇关于pytorch使用多显卡并行加速训练模型(nn.DataParallel)的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-05-04安装 VPrix Desktop 的系统要求-icode9专业技术文章分享
- 2024-05-01巧用 TiCDC Syncpoint 构建银行实时交易和准实时计算一体化架构
- 2024-05-01银行核心背后的落地工程体系丨Oracle - TiDB 数据迁移详解
- 2024-04-26高性能表格工具VTable总体构成-icode9专业技术文章分享
- 2024-04-16软路由代理问题, tg 无法代理问题-icode9专业技术文章分享
- 2024-04-16程序猿用什么锅-icode9专业技术文章分享
- 2024-04-16自建 NAS 的方案-icode9专业技术文章分享
- 2024-04-14ansible 在远程主机上执行脚本,并传入参数-icode9专业技术文章分享
- 2024-04-14ansible 在远程主机上执行脚本,并传入参数, 加上remote_src: yes 配置-icode9专业技术文章分享
- 2024-04-14ansible 检测远程主机的8080端口,如果关闭,则echo 进程已关闭-icode9专业技术文章分享