0%

Pytorch训练基本框架

前言

简单介绍一下基于 Pytorch 训练网络的基本框架流程,这部分几乎所有都是类似的,在这基础上可以加入一些其他的内容,例如搭建复杂网络,测试网络性能等,以及增加其他功能,例如 Tensorboard 训练过程可视化,记录训练日志等等


流程

训练一个网络,模型和数据是最基本的

  1. 前期准备,包括载入参数、创建日志,创建 tensorboard 等等
  2. 准备数据,包括载入数据集、(可选: 加入数据增强策略, )将数据转换为 tensor、构建 dataloader
  3. 建立模型, 可能需要加载断点(checkpoint) 或 预训练模型 (pretrain)
  4. (使用 GPU) 将模型和数据迁移到 GPU
  5. 优化器,典型的是梯度下降 SGD, 学习率策略(lr_scheduler, optional 可选项)
  6. 确定损失函数,即模型的优化目标,典型如交叉熵损失 nn.CrossEntropyLoss
  7. 开始训练,从 dataloader 取得数据(迁移到 GPU),前向传播,计算损失,反向传播,优化器更新网络

GPU训练

关于使用 GPU 进行训练的基本框架

单机单卡

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = ...
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=...)

model = ...
optimizer = optim.SGD(model.parameters())

for epoch in range(opt.num_epoch):
for i, (input, target) in enumerate(train_loader):
input= input.to(device)
target = target.to(device)
...
output = model(input)
loss = criterion(output, target)
...
optimizer.zero_grad()
loss.backward()
optimizer.step()

单机多卡

DP模式

在原来的模型上套上 nn.DataParallel 即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch

train_dataset = ...
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=...)

model = ...
model = nn.DataParallel(model.to(device), device_ids=None, output_device=None)
optimizer = optim.SGD(model.parameters())

for epoch in range(opt.num_epoch):
for i, (input, target) in enumerate(train_loader):
input= input.cuda()
target = target.cuda()
...
output = model(input)
loss = criterion(output, target)
...
optimizer.zero_grad()
loss.backward()
optimizer.step()

DDP模式

步骤如下:

  1. 设置local_rank参数,可以把这个参数理解为进程编号。该参数在运行上面这条指令时就会确定,每块GPU上的该参数都会不一样。
  2. 配置初始化方式,一般有tcp方式和env方式。上面是用的env,下面是用tcp方式用法: dist.init_process_group(backend='nccl', init_method='tcp://localhost:23456'
  3. 通过local_rank 来确定该进程的设备:torch.cuda.set_device(opt.local_rank)
  4. 数据加载部分我们在该教程的第一篇里介绍过,主要时通过 torch.utils.data.distributed.DistributedSampler 来获取每个gpu上的数据索引,每个gpu根据索引加载对应的数据,组合成一个batch,与此同时Dataloader里的shuffle必须设置为None
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import argparse
import torch.distributed as dist

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', default=-1, type=int,
help='node rank for distributed training')
opt = parser.parse_args()

# 初始化GPU通信方式(NCCL)和参数的获取方式(env代表通过环境变量)。
dist.init_process_group(backend='nccl', init_method='env://')

torch.cuda.set_device(opt.local_rank)

train_dataset = ...
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)

#使用 DistributedDataParallel 包装模型,
#它能帮助我们为不同 GPU 上求得的梯度进行 all reduce
#(即汇总不同 GPU 计算所得的梯度,并同步计算结果)。
#all reduce 后不同 GPU 中模型的梯度均为 all reduce 之前各 GPU 梯度的均值。
model = ...
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])

optimizer = optim.SGD(model.parameters())

for epoch in range(opt.num_epoch):
for i, (input, target) in enumerate(train_loader):
input= input.cuda()
target = target.cuda()
...
output = model(input)
loss = criterion(output, target)
...
optimizer.zero_grad()
loss.backward()
optimizer.step()

#运行命令
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py

多机多卡

没这方面需求,略

参考资料


简单示例

以 resnet18 在 CIFAR10 上为例

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# -*- coding: utf-8 -*-
# File : base_train.py
# Author : MeteorDream
# Time : 2022/08/01 17:18:38
# language: Python
# Software: Visual Studio Code

import torch
import torch.utils.data as data
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision import datasets
from torchvision.models import resnet18

def build_dataloader(batch_size):

print("Build dataset CIFAR-10")

# CIFAR-10 constants
CIFAR10_MEAN = [0.4914, 0.4822, 0.4465]
CIFAR10_STD = [0.2023, 0.1994, 0.2010]

transform = transforms.Compose([
transforms.Resize((112, 112)),
transforms.ToTensor(),
transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)
])

# CIFAR10 dataset
train_data = datasets.CIFAR10('../data/cifar-10-python', train=True, transform=transform)
test_data = datasets.CIFAR10('../data/cifar-10-python', train=False, transform=transform)


train_dataloader = data.DataLoader(
train_data, sampler=data.RandomSampler(train_data),
batch_size=batch_size,
num_workers=8,
pin_memory=True,
drop_last=False,
)

test_dataloader = data.DataLoader(
test_data, sampler=data.SequentialSampler(test_data),
batch_size=batch_size,
num_workers=8,
pin_memory=True,
drop_last=False
)

print("Finish: Build dataset CIFAR-10")

return train_dataloader, test_dataloader

def main():

use_cuda = True
epochs = 30
lr = 0.01

# Prepare training data
train_data, test_data = build_dataloader(256)

# model
model = resnet18(num_classes=10)

if use_cuda:
model = model.cuda()

# optimizer and lr_scheduler
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.05)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-4)

criterion = nn.CrossEntropyLoss()

best_acc = 0

for epoch in range(1, epochs + 1):
# training
train_acc, train_bs = 0, 0
for img, labels in train_data:
if use_cuda:
img = img.cuda()
labels = labels.cuda()

optimizer.zero_grad()
out = model(img)
loss = criterion(out, labels)
loss.backward()
optimizer.step()

train_acc += out.argmax(dim=1).eq(labels).sum()
train_bs += labels.size(0)

lr_scheduler.step()

# test
test_acc, test_bs = 0, 0
with torch.no_grad():
for img, labels in test_data:
if use_cuda:
img = img.cuda()
labels = labels.cuda()

out = model(img)
test_acc += out.argmax(dim=1).eq(labels).sum()
test_bs += labels.size(0)

if test_acc > best_acc:
best_acc = test_acc

print("epoch {:^3}: train accuracy {:.2f}% \t test accuracy {:.2f}%".format(epoch, train_acc / train_bs * 100, test_acc / test_bs * 100))

print("Finish training, best accuracy is {:.2f}".format(best_acc / test_bs * 100))

if __name__ == '__main__':
main()

多卡并行仅需加一行代码:

1
2
3
if use_cuda:
model = nn.DataParallel(model)
model = model.cuda()

运行日志

运行

1
CUDA_VISIBLE_DEVICES=0 python base_train.py

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
Build dataset CIFAR-10
Finish: Build dataset CIFAR-10
epoch 1 : train accuracy 42.42% test accuracy 51.48%
epoch 2 : train accuracy 59.89% test accuracy 61.70%
epoch 3 : train accuracy 66.25% test accuracy 66.88%
epoch 4 : train accuracy 70.40% test accuracy 68.71%
epoch 5 : train accuracy 72.80% test accuracy 67.77%
epoch 6 : train accuracy 74.11% test accuracy 71.70%
epoch 7 : train accuracy 74.91% test accuracy 74.28%
epoch 8 : train accuracy 76.14% test accuracy 72.74%
epoch 9 : train accuracy 77.14% test accuracy 76.21%
epoch 10 : train accuracy 77.88% test accuracy 76.52%
epoch 11 : train accuracy 78.33% test accuracy 75.22%
epoch 12 : train accuracy 79.57% test accuracy 76.96%
epoch 13 : train accuracy 80.43% test accuracy 78.86%
epoch 14 : train accuracy 81.07% test accuracy 78.41%
epoch 15 : train accuracy 82.38% test accuracy 78.82%
epoch 16 : train accuracy 83.43% test accuracy 79.09%
epoch 17 : train accuracy 84.36% test accuracy 80.42%
epoch 18 : train accuracy 85.67% test accuracy 80.70%
epoch 19 : train accuracy 87.30% test accuracy 81.80%
epoch 20 : train accuracy 88.73% test accuracy 82.55%
epoch 21 : train accuracy 90.69% test accuracy 83.60%
epoch 22 : train accuracy 92.77% test accuracy 83.92%
epoch 23 : train accuracy 94.76% test accuracy 84.79%
epoch 24 : train accuracy 97.13% test accuracy 85.81%
epoch 25 : train accuracy 98.78% test accuracy 87.30%
epoch 26 : train accuracy 99.61% test accuracy 88.14%
epoch 27 : train accuracy 99.85% test accuracy 88.44%
epoch 28 : train accuracy 99.89% test accuracy 88.36%
epoch 29 : train accuracy 99.91% test accuracy 88.36%
epoch 30 : train accuracy 99.92% test accuracy 88.40%
Finish training, best accuracy is 88.44

多卡并行输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
Build dataset CIFAR-10
Finish: Build dataset CIFAR-10
epoch 1 : train accuracy 28.95% test accuracy 40.68%
epoch 2 : train accuracy 46.05% test accuracy 49.88%
epoch 3 : train accuracy 54.92% test accuracy 57.35%
epoch 4 : train accuracy 61.06% test accuracy 62.65%
epoch 5 : train accuracy 66.66% test accuracy 67.10%
epoch 6 : train accuracy 70.87% test accuracy 69.61%
epoch 7 : train accuracy 75.06% test accuracy 73.85%
epoch 8 : train accuracy 78.62% test accuracy 74.51%
epoch 9 : train accuracy 81.11% test accuracy 77.26%
epoch 10 : train accuracy 83.26% test accuracy 79.15%
epoch 11 : train accuracy 85.32% test accuracy 79.54%
epoch 12 : train accuracy 86.53% test accuracy 79.82%
epoch 13 : train accuracy 88.45% test accuracy 81.70%
epoch 14 : train accuracy 90.27% test accuracy 82.01%
epoch 15 : train accuracy 91.94% test accuracy 82.84%
epoch 16 : train accuracy 93.86% test accuracy 82.54%
epoch 17 : train accuracy 95.61% test accuracy 83.89%
epoch 18 : train accuracy 97.23% test accuracy 84.05%
epoch 19 : train accuracy 98.45% test accuracy 85.64%
epoch 20 : train accuracy 99.31% test accuracy 86.10%
epoch 21 : train accuracy 99.78% test accuracy 86.65%
epoch 22 : train accuracy 99.91% test accuracy 87.39%
epoch 23 : train accuracy 99.94% test accuracy 87.34%
epoch 24 : train accuracy 99.96% test accuracy 87.56%
epoch 25 : train accuracy 99.97% test accuracy 87.56%
epoch 26 : train accuracy 99.97% test accuracy 87.67%
epoch 27 : train accuracy 99.97% test accuracy 87.74%
epoch 28 : train accuracy 99.97% test accuracy 87.77%
epoch 29 : train accuracy 99.97% test accuracy 87.71%
epoch 30 : train accuracy 99.97% test accuracy 87.71%
Finish training, best accuracy is 87.77
--- ♥ end ♥ ---

欢迎关注我呀~