YOLOv5源代码深度剖析(3):train.py模块逐行详解与解析
最编程
2024-02-21 21:25:46
...
# YOLOv5 ???? by Ultralytics, GPL-3.0 license
"""
Train a YOLOv5 model on a custom dataset
在数据集上训练 yolo v5 模型
Usage:
$ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
训练数据为coco128 coco128数据集中有128张图片 80个类别,是规模较小的数据集
"""
'''===============================================一、导入包==================================================='''
'''======================1.导入安装好的python库====================='''
import argparse # 解析命令行参数模块
import math # 数学公式模块
import os # 与操作系统进行交互的模块 包含文件路径操作和解析
import random # 生成随机数模块
import sys # sys系统模块 包含了与Python解释器和它的环境有关的函数
import time # 时间模块 更底层
from copy import deepcopy # 深度拷贝模块
from datetime import datetime # datetime模块能以更方便的格式显示日期或对日期进行运算。
from pathlib import Path # Path将str转换为Path对象 使字符串路径易于操作的模块
import numpy as np # numpy数组操作模块
import torch # 引入torch
import torch.distributed as dist # 分布式训练模块
import torch.nn as nn # 对torch.nn.functional的类的封装 有很多和torch.nn.functional相同的函数
import yaml # yaml是一种直观的能够被电脑识别的的数据序列化格式,容易被人类阅读,并且容易和脚本语言交互。一般用于存储配置文件。
from torch.cuda import amp # PyTorch amp自动混合精度训练模块
from torch.nn.parallel import DistributedDataParallel as DDP # 多卡训练模块
from torch.optim import SGD, Adam, lr_scheduler # tensorboard模块
from tqdm import tqdm # 进度条模块
'''===================2.获取当前文件的绝对路径========================'''
FILE = Path(__file__).resolve() # __file__指的是当前文件(即train.py),FILE最终保存着当前文件的绝对路径,比如D://yolov5/train.py
ROOT = FILE.parents[0] # YOLOv5 root directory ROOT保存着当前项目的父目录,比如 D://yolov5
if str(ROOT) not in sys.path: # sys.path即当前python环境可以运行的路径,假如当前项目不在该路径中,就无法运行其中的模块,所以就需要加载路径
sys.path.append(str(ROOT)) # add ROOT to PATH 把ROOT添加到运行路径上
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative ROOT设置为相对路径
'''===================3..加载自定义模块============================'''
import val # for end-of-epoch mAP
from models.experimental import attempt_load
from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.autobatch import check_train_batch_size
from utils.callbacks import Callbacks
from utils.datasets import create_dataloader
from utils.downloads import attempt_download
from utils.general import (LOGGER, NCOLS, check_dataset, check_file, check_git_status, check_img_size,
check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
'''================4. 分布式训练初始化==========================='''
# https://pytorch.org/docs/stable/elastic/run.html该网址有详细介绍
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # -本地序号。这个 Worker 是这台机器上的第几个 Worker
RANK = int(os.getenv('RANK', -1)) # -进程序号。这个 Worker 是全局第几个 Worker
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) # 总共有几个 Worker
'''
查找名为LOCAL_RANK,RANK,WORLD_SIZE的环境变量,
若存在则返回环境变量的值,若不存在则返回第二个参数(-1,默认None)
rank和local_rank的区别: 两者的区别在于前者用于进程间通讯,后者用于本地设备分配。
'''
'''===============================================二、train()函数:训练过程==================================================='''
''' =====================1.载入参数和初始化配置信息========================== '''
def train(hyp, # 超参数 可以是超参数配置文件的路径或超参数字典 path/to/hyp.yaml or hyp
opt, # main中opt参数
device, # 当前设备
callbacks # 用于存储Loggers日志记录器中的函数,方便在每个训练阶段控制日志的记录情况
):
# 从opt获取参数。日志保存路径,轮次、批次、权重、进程序号(主要用于分布式训练)等
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
'''
1.1创建目录,设置模型、txt等保存的路径
'''
# Directories 获取记录训练日志的保存路径
# 设置保存权重路径 如runs/train/exp1/weights
w = save_dir / 'weights' # weights dir
# 新建文件夹 weights train evolve
(w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
# 保存训练结果的目录,如last.pt和best.pt
last, best = w / 'last.pt', w / 'best.pt'
'''
1.2 读取hyp(超参数)配置文件
'''
# Hyperparameters 加载超参数
if isinstance(hyp, str): # isinstance()是否是已知类型。 判断hyp是字典还是字符串
# 若hyp是字符串,即认定为路径,则加载超参数为字典
with open(hyp, errors='ignore') as f:
# 加载yaml文件
hyp = yaml.safe_load(f) # load hyps dict 加载超参信息
# 打印超参数 彩色字体
LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
'''
1.3 将本次运行的超参数(hyp),和选项操作(opt)给保存成yaml格式,
保存在了每次训练得到的exp文件中,这两个yaml显示了我们本次训练所选择的超参数和opt参数,opt参数是train代码下面那一堆参数选择
'''
# Save run settings 保存训练中的参数hyp和opt
with open(save_dir / 'hyp.yaml', 'w') as f:
# 保存超参数为yaml配置文件
yaml.safe_dump(hyp, f, sort_keys=False)
with open(save_dir / 'opt.yaml', 'w') as f:
# 保存命令行参数为yaml配置文件
yaml.safe_dump(vars(opt), f, sort_keys=False)
# 定义数据集字典
data_dict = None
'''
1.4 加载相关日志功能:如tensorboard,logger,wandb
'''
# Loggers 设置wandb和tb两种日志, wandb和tensorboard都是模型信息,指标可视化工具
if RANK in [-1, 0]: # 如果进程编号为-1或0
# 初始化日志记录器实例
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
# W&B # wandb为可视化参数工具
if loggers.wandb:
data_dict = loggers.wandb.data_dict
# 如果使用中断训练 再读取一次参数
if resume:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
# Register actions
for k in methods(loggers):
# 将日志记录器中的方法与字符串进行绑定
callbacks.register_action(k, callback=getattr(loggers, k))
'''
1.5 配置:画图开关,cuda,种子,读取数据集相关的yaml文件
'''
# Config 画图
# 是否绘制训练、测试图片、指标图等,使用进化算法则不绘制
plots = not evolve # create plots
cuda = device.type != 'cpu'
# 设置随机种子
init_seeds(1 + RANK)
# 加载数据配置信息
with torch_distributed_zero_first(LOCAL_RANK): # torch_distributed_zero_first 同步所有进程
data_dict = data_dict or check_dataset(data) # check if None check_dataset 检查数据集,如果没找到数据集则下载数据集(仅适用于项目中自带的yaml文件数据集)
# 获取训练集、测试集图片路径
train_path, val_path = data_dict['train'], data_dict['val']
# nc:数据集有多少种类别
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
# names: 数据集所有类别的名字,如果设置了single_cls则为一类
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
# 判断类别长度和文件是否对应
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
# 当前数据集是否是coco数据集(80个类别)
is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
''' =====================2.model:加载网络模型========================== '''
# Model 载入模型
# 检查文件后缀是否是.pt
check_suffix(weights, '.pt') # check weights
# 加载预训练权重 yolov5提供了5个不同的预训练权重,可以根据自己的模型选择预训练权重
pretrained = weights.endswith('.pt')
'''
2.1预训练模型加载
'''
if pretrained:
# 使用预训练的话:
# torch_distributed_zero_first(RANK): 用于同步不同进程对数据读取的上下文管理器
with torch_distributed_zero_first(LOCAL_RANK):
# 如果本地不存在就从google云盘中自动下载模型
# 通常会下载失败,建议提前下载下来放进weights目录
weights = attempt_download(weights) # download if not found locally
# ============加载模型以及参数================= #
ckpt = torch.load(weights, map_location=device) # load checkpoint
"""
两种加载模型的方式: opt.cfg / ckpt['model'].yaml
这两种方式的区别:区别在于是否是使用resume
如果使用resume-断点训练:
将opt.cfg设为空,选择ckpt['model']yaml创建模型, 且不加载anchor。
这也影响了下面是否除去anchor的key(也就是不加载anchor), 如果resume则不加载anchor
原因:
使用断点训练时,保存的模型会保存anchor,所以不需要加载,
主要是预训练权重里面保存了默认coco数据集对应的anchor,
如果用户自定义了anchor,再加载预训练权重进行训练,会覆盖掉用户自定义的anchor。
"""
# ***加载模型*** #
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
# ***以下三行是获得anchor*** #
# 若cfg 或 hyp.get('anchors')不为空且不使用中断训练 exclude=['anchor'] 否则 exclude=[]
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
# 将预训练模型中的所有参数保存下来,赋值给csd
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
# 判断预训练参数和新创建的模型参数有多少是相同的
# 筛选字典中的键值对,把exclude删除
csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
# ***模型创建*** #
model.load_state_dict(csd, strict=False) # load
# 显示加载预训练权重的的键值对和创建模型的键值对
# 如果pretrained为ture 则会少加载两个键对(anchors, anchor_grid)
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
else:
# #直接加载模型,ch为输入图片通道
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
'''
2.2设置模型输入
'''
# Freeze 冻结训练的网络层
"""
冻结模型层,设置冻结层名字即可
作用:冰冻一些层,就使得这些层在反向传播的时候不再更新权重,需要冻结的层,可以写在freeze列表中
freeze为命令行参数,默认为0,表示不冻结
"""
freeze = [f'model.{x}.' for x in range(freeze)] # layers to freeze
# 首先遍历所有层
for k, v in model.named_parameters():
# 为所有层的参数设置梯度
v.requires_grad = True # train all layers
# 判断是否需要冻结
if any(x in k for x in freeze):
LOGGER.info(f'freezing {k}')
# 冻结训练的层梯度不更新
v.requires_grad = False
# Image size 设置训练和测试图片尺寸
# 获取模型总步长和模型输入图片分辨率
gs = max(int(model.stride.max()), 32) # grid size (max stride)
# 检查输入图片分辨率是否能被32整除
imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
# Batch size 设置一次训练所选取的样本数
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
# 确保batch size满足要求
batch_size = check_train_batch_size(model, imgsz)
'''
2.3 优化器设置
'''
# Optimizer 优化器
nbs = 64 # nominal batch size
"""
nbs = 64
batchsize = 16
accumulate = 64 / 16 = 4
模型梯度累计accumulate次之后就更新一次模型 相当于使用更大batch_size
"""
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
# 根据accumulate设置权重衰减参数,防止过拟合
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
# 打印缩放后的权重衰减超参数
LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
# 将模型分成三组(BN层的weight,卷积层的weights,biases)进行优化
g0, g1, g2 = [], [], [] # optimizer parameter groups
# 遍历网络中的所有层,每遍历完一层向更深的层遍历
for v in model.modules():
# hasattr: 测试指定的对象是否具有给定的属性,返回一个布尔值
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
# 将层的bias添加至g2
g2.append(v.bias)
# YOLO v5的模型架构中只有卷积层和BN层
if isinstance(v, nn.BatchNorm2d): # weight (no decay)
# 将BN层的权重添加至g0 未经过权重衰减
g0.append(v.weight)
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
# 将层的weight添加至g1 经过了权重衰减
# 这里指的是卷积层的weight
g1.append(v.weight)
# 选用优化器,并设置g0(bn参数)组的优化方式
if opt.adam:
optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
else:
optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
# 将卷积层的参数添加至优化器 并做权重衰减
# add_param_group()函数为添加一个参数组,同一个优化器可以更新很多个参数组,不同的参数组可以设置不同的超参数
optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
# 将所有的bias添加至优化器
optimizer.add_param_group({'params': g2}) # add g2 (biases)
# 打印优化信息
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
# 在内存中删除g0 g1 g2 目的是节省空间
del g0, g1, g2
'''
2.4 学习率设置
'''
# Scheduler 设置学习率策略:两者可供选择,线性学习率和余弦退火学习率
if opt.linear_lr:
# 使用线性学习率
lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
else:
# 使用余弦退火学习率
lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
# 可视化 scheduler
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
'''
2.5 训练前最后准备
'''
# EMA 设置ema(指数移动平均),考虑历史值对参数的影响,目的
上一篇: 如何将百联OK卡兑换成微信红包教程
下一篇: 全面指南:如何轻松回收百联OK卡