使用 Pytorch Hub 的详细教程,以及使用 yolov5 hubconf.py 的具体示例
本文已参与「新人创作礼」活动,一起开启掘金创作之路。
1 Pytorch Hub介绍
1.1 Pytorch Hub的简单介绍
Pytorch Hub
是一个预先训练的模型存储库
,旨在促进研究
的重现性
。简单来说就是很多人把各种项目训练的预训练模型上传到Pytorch Hub,然后我们直接通过torch.hub.load()
接口就可以直接调用,迅速实现项目中的效果!
从Pytorch Hub的Github主页发布的时间,可以看出大概是在2020年4月份左右
,是比Tensorflow Hub
和PaddlePaddle Hub
时间相对晚!
1.2 Pytorch Hub预训练模型上传简单介绍
Pytorch Hub
支持通过添加简单的hubconf.py文件将预训练的模型(模型定义和预训练的权重)发布到github存储库;hubconf.py可以具有多个入口点。 每个入口点都定义为python函数(例如:您要发布的经过预先训练的模型)
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
详细具体内容在后文介绍
1.3 Pytorch Hub相关内容地址
- Pytorch HUb的Github主页:github.com/pytorch/hub
- Pytorch Hub的官方仓库:pytorch.org/hub/researc…
- Pytorch Hub的官方文档:pytorch.org/docs/master…
1.4 Pytorch Hub目前支持的生态(Ecosystem)领域
目前Pytorch Hub生态的模型已经包含了几乎深度学习的所有领域,预训练的模型也非常多,主要领域包括:
-
语音
:Audio
-
生成模型
:Generative
-
自然语言处理
:Nlp
-
可编写脚本
:Scriptable
-
计算机视觉
:Vision
2 如何在torch.hub中发布模型
- 下面的2-3节参考的是pytorch hub的官方文档
通过添加一个简单的hubconf.py文件
,Pytorch Hub
支持将预训练模型
(模型定义
和预训练权重
)发布到github存储库
;hubconf.py
可以定义多个入口点
。每个入口点
都被定义为一个python函数
(例如:你想要发布的预先训练的模型
)。
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
2.1 如何实现入口点(entrypoint)
我们在github上的项目pytorch/vision
中创建一个hubconf.py文件
,中然后实现一个入口,下面的代码片段为resnet18模型
指定了一个入口点
。在大多数情况下,在hubcon .py
中导入正确的函数
就足够了。这里只是为了说明是如何工作的。具体可以到pytorch/vision repo查看完整的脚本
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
1、dependencies
变量说明
dependencies
变量:是加载模型所需要的列表
,注意:这与训练模型所需要的依赖有所不同
2、args
和kwargs
传递给真正可调用参数
3、该函数
的文档字符串
是帮助信息
。它解释了模型的作用
以及允许的位置/关键字参数
。强烈建议在此处添加一些示例
4、Entrypoint函数
可以返回一个模型
(nn .module
),也可以返回辅助工具
,以使用户的工作流程更顺利,例如分词器(tokenizers)。
5、带有下划线前缀的可调用项被视为辅助函数
,不会在中显示torch.hub.list()
6、预训练的权重
既可以存储在github存储库
中,也可以由torch.hub.load_state_dict_from_url()加载
。 如果小于2GB,建议将其附加到项目版本中,并使用该版本中的网址。 在上面的示例中,torchvision.models.resnet.resnet18
处理预训练的方法,或者你可以在入口点定义中写入以下逻辑:
if pretrained:
# For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# 对于checkpoint模型你可以保存在其他地方
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
注意:
发布的模型(published models)应至少在分支/标签中。 不能是随机提交。
3 从torch.hub中加载模型
Pytorch Hub
提供了方便的api
,可以通过:
-
torch.hub.list()
查看Hub中所有可用的模型 -
torch.hub.help()
显示文档字符串和示例
-torch.hub.load()
加载预先训练好的模型
3.1 torch.hub.list()查看指定仓库在hubconf.py中定义的模型
使用torch.hub.list()
可以查看指定仓库
在hubconf.py
中定义的入口模型
,其实就是查看hubconf.py
中定义的函数名!
3.1.1 torch.hub.list() 源码
1、torch.hub.list()
函数的源码:
def list(github, force_reload=False):
r"""
List all entrypoints available in `github` hubconf.
Args:
github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
tag/branch. The default branch is `master` if not specified.
Example: 'pytorch/vision[:hub]'
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
Default is `False`.
Returns:
entrypoints: a list of available entrypoint names
Example:
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
"""
repo_dir = _get_cache_or_reload(github, force_reload, True)
sys.path.insert(0, repo_dir)
hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
sys.path.remove(repo_dir)
# We take functions starts with '_' as internal helper functions
entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
return entrypoints
2、参数说明
torch.hub.list()
:列出指定github仓库中对应的hubconf.py文件
中定义的入口
,返回值是entrypoints
参数:
-
github
:"repo_owner / repo_name [:tag_name]”
,后面的tag_name是版本,是可选的,例如yolov5的github仓库地址为:https://github.com/ultralytics/yolov5
,此时这里的参数可以写成:github=ultralytics/yolov5
-
force_reload
:默认为False
,是否强制重新下载源代码仓库,如果为False,默认是先去缓存目录~/.cache/torch/hub
下去找
3.1.2 torch.hub.list() 查看yolov5接口实例
我们以yolo5
为例,来说明torch.hub.list()
的使用。
-
yolov5的github仓库地址:
https://github.com/ultralytics/yolov5
-
yolov5仓库中定义的hubconf.py地址:
https://github.com/ultralytics/yolov5/blob/develop/hubconf.py
1、查看yolov5在hubconf.py
中所有可用的模型
import torch
# yolov5: https://github.com/ultralytics/yolov5
entrypoints = torch.hub.list("ultralytics/yolov5")
print(entrypoints)
# ['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']
输出:
['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']
Using cache found in /home/shl/.cache/torch/hub/ultralytics_yolov5_master
如果把forch_reload=True
,你看到的输出则是:
Downloading: "https://github.com/ultralytics/yolov5/archive/master.zip" to /home/shl/.cache/torch/hub/master.zip
['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']
上面列表中都是在yolov5仓库代码的hubconf.py
文件中定义的函数
,当然_
开头的函数不会放到列表中!
2、因此,在调用torch.hub.list()函数
的时候,会先把指定的github仓库的代码
克隆下载到~/.cache/torch/hub
缓存目录下,并把仓库名重新命名为:repoOwner_repoName_master
3.2 torch.hub.help()显示文档字符串和示例
3.2.1 torch.hub.help()源码
1、torch.hub.help()
def help(github, model, force_reload=False):
r"""
Show the docstring of entrypoint `model`.
Args:
github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional
tag/branch. The default branch is `master` if not specified.
Example: 'pytorch/vision[:hub]'
model (string): a string of entrypoint name defined in repo's hubconf.py
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
Default is `False`.
Example:
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
"""
repo_dir = _get_cache_or_reload(github, force_reload, True)
sys.path.insert(0, repo_dir)
hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
sys.path.remove(repo_dir)
entry = _load_entry_from_hubconf(hub_module, model)
return entry.__doc__
2、参数
-
github
:"repo_owner / repo_name [:tag_name]”
,后面的tag_name是版本,是可选的,例如yolov5的github仓库地址为:https://github.com/ultralytics/yolov5
,此时这里的参数可以写成:github=ultralytics/yolov5
-
model
:就是在hubconf.py
文件中定义函数的名字 -
force_reload
:默认为False
,是否强制重新下载源代码仓库,如果为False,默认是先去缓存目录~/.cache/torch/hub
下去找
返回值是字符串
,就是指定函数
下注释的内容:
3.2.2 torch.hub.help()查看resnet18模型的帮助文档
1、代码
import torch
# yolov5: https://github.com/pytorch/vision
help_doc = torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)
print(help_doc)
2、查看结果如下:
帮助文档信息就定义在~/.cache/torch/hub/pytorch_vision_master/torchvision/models/resnet.py
文件中!
3.3 torch.hub.load()加载预先训练好的模型
3.3.1 torch.hub.load()源码
1、torch.hub.load()
源码
def load(repo_or_dir, model, *args, **kwargs):
r"""
Load a model from a github repo or a local directory.
Note: Loading a model is the typical use case, but this can also be used to
for loading other objects such as tokenizers, loss functions, etc.
If :attr:`source` is ``'github'``, :attr:`repo_or_dir` is expected to be
of the form ``repo_owner/repo_name[:tag_name]`` with an optional
tag/branch.
If :attr:`source` is ``'local'``, :attr:`repo_or_dir` is expected to be a
path to a local directory.
Args:
repo_or_dir (string): repo name (``repo_owner/repo_name[:tag_name]``),
if ``source = 'github'``; or a path to a local directory, if
``source = 'local'``.
model (string): the name of a callable (entrypoint) defined in the
repo/dir's ``hubconf.py``.
*args (optional): the corresponding args for callable :attr:`model`.
source (string, optional): ``'github'`` | ``'local'``. Specifies how
``repo_or_dir`` is to be interpreted. Default is ``'github'``.
force_reload (bool, optional): whether to force a fresh download of
the github repo unconditionally. Does not have any effect if
``source = 'local'``. Default is ``False``.
verbose (bool, optional): If ``False``, mute messages about hitting
local caches. Note that the message about first download cannot be
muted. Does not have any effect if ``source = 'local'``.
Default is ``True``.
**kwargs (optional): the corresponding kwargs for callable
:attr:`model`.
Returns:
The output of the :attr:`model` callable when called with the given
``*args`` and ``**kwargs``.
Example:
>>> # from a github repo
>>> repo = 'pytorch/vision'
>>> model = torch.hub.load(repo, 'resnet50', pretrained=True)
>>> # from a local directory
>>> path = '/some/local/path/pytorch/vision'
>>> model = torch.hub.load(path, 'resnet50', pretrained=True)
"""
source = kwargs.pop('source', 'github').lower()
force_reload = kwargs.pop('force_reload', False)
verbose = kwargs.pop('verbose', True)
if source not in ('github', 'local'):
raise ValueError(
f'Unknown source: "{source}". Allowed values: "github" | "local".')
if source == 'github':
repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose)
model = _load_local(repo_or_dir, model, *args, **kwargs)
return model
2、参数: load(repo_or_dir, model, *args, **kwargs)
-
repo_or_dir
:该参数可以传入两种值:-
repo_owner/repo_name[:tag_name]
:就是和上面的list和help一样,传入仓库的作者名和仓库名 -
本地源码的路径
:如果本地已经克隆下载了源码的路径,你可以直接填写克隆源码的路径即可,就相当于是从缓存中加载一个道理
-
-
model
:就是在hubconf.py
文件中定义的模型 -
还有其他的可选参数
3、源码文档中给出的参考实例:
>>> # from a github repo
>>> repo = 'pytorch/vision'
>>> model = torch.hub.load(repo, 'resnet50', pretrained=True)
>>> # from a local directory
>>> path = '/some/local/path/pytorch/vision'
>>> model = torch.hub.load(path, 'resnet50', pretrained=True)
3.3.2 torch.hub.load()导入yolov5的模型实例
1、在test.py文件定义如下
import torch
# yolov5: https://github.com/ultralytics/yolov5
model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)
print(model)
2、输出结果:
-
首先还是先下载源码:默认下载到
~/.cache/torch/hub目录
下 -
然后下载指定的
预训练模型
,预训练的模型默认是下载到和当前执行文件的同级目录下。例如我这里就会下载yolov5s.pt预训练模型
下载到和test.py
同一个目录下!
3.4 torch.hub.download_url_to_file():预训练模型下载到指定本地路径
1、download_url_to_file(url, dst, hash_prefix=None, progress=True)
参数
-
url
:要下载的预训练模型
的url链接地址
-
dst
:预训练模型下载到本地的目录路径
-
hash_prefix(可选)
:如果不是None,则SHA256下载的文件应该以哈希前缀
开头,默认是None
-
progress(可选)
:是否显示下载进度条
,默认是True
2、实例:把resnet18的预训练模型下载到当前路径下
torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', './')
3、你也可以下载图片什么的,都可以(参考)
torch.hub.download_url_to_file('https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg', 'zidane.jpg')
torch.hub.download_url_to_file('https://github.com/ultralytics/yolov5/raw/master/data/images/bus.jpg', 'bus.jpg')
3.5 torch.hub.load_state_dict_from_url()
torch.hub.load_state_dict_from_url()
从给定的url
中加载torch序列化的模型
,如果下载的文件是压缩文件,则会自动解压缩!
1、torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None)
参数
-
url
:下载的url链接地址 -
model_dir(可选)
:下载模型保存路径,如果不指定,默认是下载到os.path.join(torch.hub.get_dir, "checkpoints")
目录下,torch.hub.get_dir()
的返回值为:'~/.cache/torch/hub'
,所以默认的下载模型的保存路径为:~/.cache/torch/hub/checkpoints
-
map_location(可选)
:一个函数或dict指定如何重新映射存储位置 -
progress(可选)
:是否显示下载进度条,默认是True
-
check_hash(可选)
: -
file_name(可选)
:是否重命名下载的文件名,如果不设置则默认用url中给定的文件名
2、实例
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
3.6 模型下载和保存的路径
3.6.1 torch.hub.get_dir() 获取模型下载的路径
1、torch.hub.get_dir()
>>> torch.hub.get_dir()
'/home/shl/.cache/torch/hub'
>>>
2、默认是下载到os.path.join(torch.hub.get_dir, "checkpoints")
目录下,torch.hub.get_dir()
的返回值为:'~/.cache/torch/hub'
,所以默认的下载模型的保存路径为:~/.cache/torch/hub/checkpoints
3.6.2 torch.hub.set_dir(path)
1、设置模型下载的路径
>>> torch.hub.get_dir()
'/home/shl/.cache/torch/hub'
>>> torch.hub.set_dir("~/projects")
>>> torch.hub.get_dir()
'~/projects'
>>>
4 yolov5 中的hubconf.py实例
1、下面是yolov5中的定义的hubconf.py
,定义内容如下:
"""YOLOv5 PyTorch Hub models https://pytorch.org/hub/ultralytics_yolov5/
Usage:
import torch
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
"""
import torch
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
"""Creates a specified YOLOv5 model
Arguments:
name (str): name of model, i.e. 'yolov5s'
pretrained (bool): load pretrained weights into the model
channels (int): number of input channels
classes (int): number of model classes
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
verbose (bool): print all information to screen
device (str, torch.device, None): device to use for model parameters
Returns:
YOLOv5 pytorch model
"""
from pathlib import Path
from models.yolo import Model, attempt_load
from utils.general import check_requirements, set_logging
from utils.google_utils import attempt_download
from utils.torch_utils import select_device
check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('tensorboard', 'pycocotools', 'thop'))
set_logging(verbose=verbose)
fname = Path(name).with_suffix('.pt') # checkpoint filename
try:
if pretrained and channels == 3 and classes == 80:
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if pretrained:
ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
model.load_state_dict(csd, strict=False) # load
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
return model.to(device)
except Exception as e:
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
raise Exception(s) from e
def custom(path='path/to/model.pt', autoshape=True, verbose=True, device=None):
# YOLOv5 custom or local model
return _create(path, autoshape=autoshape, verbose=verbose, device=device)
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-small model https://github.com/ultralytics/yolov5
return _create('yolov5s', pretrained, channels, classes, autoshape, verbose, device)
def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-medium model https://github.com/ultralytics/yolov5
return _create('yolov5m', pretrained, channels, classes, autoshape, verbose, device)
def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-large model https://github.com/ultralytics/yolov5
return _create('yolov5l', pretrained, channels, classes, autoshape, verbose, device)
def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-xlarge model https://github.com/ultralytics/yolov5
return _create('yolov5x', pretrained, channels, classes, autoshape, verbose, device)
def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose, device)
def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose, device)
def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose, device)
def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose, device)
if __name__ == '__main__':
model = _create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True) # pretrained
# model = custom(path='path/to/model.pt') # custom
# Verify inference
import cv2
import numpy as np
from PIL import Image
imgs = ['data/images/zidane.jpg', # filename
'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg', # URI
cv2.imread('data/images/bus.jpg')[:, :, ::-1], # OpenCV
Image.open('data/images/bus.jpg'), # PIL
np.zeros((320, 640, 3))] # numpy
results = model(imgs) # batched inference
results.print()
results.save()
2、查看yolov5在hubconf.py
中所有可用的模型
import torch
# yolov5: https://github.com/ultralytics/yolov5
entrypoints = torch.hub.list("ultralytics/yolov5")
print(entrypoints)
# ['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']
输出:
['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']
Using cache found in /home/shl/.cache/torch/hub/ultralytics_yolov5_master
3、通过如下的形式就可以导入yolov5的一个yolov5s
模型
import torch
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
5 通过torch.hub.load()和gradio在浏览器中测试模型
5.1 yolov5 hub 相关的资源链接
-
yolov5的github项目主页:
https://github.com/ultralytics/yolov5
-
yolov5的github关于通过torch.hub.load()导入模型的文档那个:
https://github.com/ultralytics/yolov5/issues/36
-
yolov5在pytorch hub主页介绍使用:
https://pytorch.org/hub/ultralytics_yolov5/
-
gradio的github主页:
https://github.com/gradio-app/gradio
5.2 在浏览器中测试yolov5目标检测
5.2.1 通过torch.hub.load()和gradio在浏览器中测试目标检测代码
-
测试代码来源: Demo Model Output
-
你也可以在google colab中直接运行测试代码
1、代码
# yolov5: https://github.com/ultralytics/yolov5
import gradio as gr
import torch
from PIL import Image
# 下载测试图片,默认是下载到当前目录
# 解释图片已经下载了,还是会重新下载,这个智障,代码需要优化
torch.hub.download_url_to_file('https://github.com.cnpmjs.org/ultralytics/yolov5/raw/master/data/images/zidane.jpg', 'zidane.jpg')
torch.hub.download_url_to_file('https://github.com.cnpmjs.org/ultralytics/yolov5/raw/master/data/images/bus.jpg', 'bus.jpg')
# 加载模型和下载预训练的模型,预训练的模型下载到当前目录
# 源码被克隆下载到:~/.cache/torch/hub 目录下
model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # force_reload=True to update
# 使用yolov5s模型进行推理测试
def yolo(img, size=640):
# 例如输入图片zidane.jpg的尺寸是: (1280x720) ====> scale = 640 / max(1280, 720) = 640/1280=0.5
scale = (size / max(img.size))
# 根据计算的scale,把输入图片进行resize : resize_shape = (1280,720)*0.5 = (640, 360)
img = img.resize((int(x * scale) for x in img.size), Image.ANTIALIAS) # resize
# yolov5s推理模型接口 inference
results = model(img)
# 更新推理结果,就是把检测到目标的boxes和labels绘制到图像上
results.render() # updates results.imgs with boxes and labels
return Image.fromarray(results.imgs[0])
# 定义输入和输出图像的类型
inputs = gr.inputs.Image(type='pil', label="Original Image")
outputs = gr.outputs.Image(type="pil", label="Output Image")
title = "YOLOv5"
description = "YOLOv5 Gradio demo for object detection. Upload an image or click an example image to use."
article = "<p style='text-align: center'>YOLOv5 is a family of compound-scaled object detection models trained on the COCO dataset, and includes " \
"simple functionality for Test Time Augmentation (TTA), model ensembling, hyperparameter evolution, " \
"and export to ONNX, CoreML and TFLite. <a href='https://github.com/ultralytics/yolov5'>Source code</a> |" \
"<a href='https://apps.apple.com/app/id1452689527'>iOS App</a> | <a href='https://pytorch.org/hub/ultralytics_yolov5'>PyTorch Hub</a></p>"
# 下载的两个测试图片
examples = [['zidane.jpg'], ['bus.jpg']]
gr.Interface(yolo, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch(
debug=True)
2、程序运行后会先下载测试图片,然后把下面的http://127.0.0.1:7860/
链接在浏览器中打开,既可以开始测试了
5.2.2 在浏览器中选择图片进行目标检测的具体操作实例
1、链接在浏览器
中打开如下
-
在label
Original Image
的位置就传入测试的输入图片
,有三种方式传入出入图片:-
Drop Image Here
:把输入图片通过拖拽
的方式,拖到下面的虚线框中 -
Click to Upload
:点击
虚线框中,手动选择要输入的图片 -
选择Example图片
:在左下角
有两张测试图片,可以通过点击选择Example中的测试图片
进行测试
-
-
在label
Ouptut Image
是通过检测后图片的输出结果
2、下面说明我的操作步骤:
-
1)先在
左下角
的一张测试图片(我选择zidane.jpg),然后图片就会出现在Input Image的框中
-
2)点击
Submit按钮
,就会开始检测
,检测的结果
会显示在Output Image的框中
-
3)点击
Flag
会把输入和输出图片保存下来
3、如下是我点击Flag按钮
之后,生成一个flagged目录
:
flagged
├── Original Image
│ └── 0.jpeg # 1280x720 输入图片(Input Image)
└── Output Image
└── 0.png # 640x360 输出的结果图片 (Output Image)
6 yolov5 hub的推理接口
该部分参考内容
6.1使用torch.hub.load()加载yolov5模型
# 模型 Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
6.2 模型输入数据进行推理预测
1、模型的输入数据类型
有很多
# 测试图片 Images
imgs = ['https://ultralytics.com/images/zidane.jpg'] # batch of images
# 推理测试 Inference
# 可以传入图片链接、图片路径 或 读取好的图片 img1 = Image.open('zidane.jpg') # PIL image
# 传入链接的时候,图片被下载到当前目录下,如果图片已经存在,下次运行还是会下载
# results = model('https://ultralytics.com/images/zidane.jpg')
# 直接传入图片路径
results = model("./zidane.jpg")
# 传入PIL读取图片 参考:https://github.com/ultralytics/yolov5/issues/36
# img = Image.open("./zidane.jpg")
# results = model(img)
# 传入opencv读取图片
# img = cv2.imread("./zidane.jpg")[:, :, ::-1] # opencv的 BGR2RGB
# results = model(img)
# 当然,如果你有一个batch图片,可以传入一个列表,同样列表中可以是上面的图片路径、图片链接等类型
# images = ["./bus.jpg", "./zidane.jpg"]
# results = model(images)
print(results) # <models.common.Detections object at 0x7f3d6962ce80>
2、更多推理预测参数可以参考:yolov5/models/common.py
def forward(self, imgs, size=640, augment=False, profile=False):Glenn Jocher, 7 months ago: • PyTorch Hub and autoShape update (#1415)
# Inference from various sources. For height=720, width=1280, RGB images example inputs are:
# filename: imgs = 'data/samples/zidane.jpg'
# URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: = np.zeros((720,1280,3)) # HWC
# torch: = torch.zeros(16,3,720,1280) # BCHW
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
6.3 完整代码
__Author__ = "Shliang"
__Email__ = "shliang0603@gmail.com"
import torch
from PIL import Image
import cv2
# 模型 Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
# 测试图片 Images
imgs = ['https://ultralytics.com/images/zidane.jpg'] # batch of images
# 推理测试 Inference
# 可以传入图片链接、图片路径 或 读取好的图片 img1 = Image.open('zidane.jpg') # PIL image
# 传入链接的时候,图片被下载到当前目录下,如果图片已经存在,下次运行还是会下载
# results = model('https://ultralytics.com/images/zidane.jpg')
# 直接传入图片路径
results = model("./zidane.jpg")
# 传入PIL读取图片 参考:https://github.com/ultralytics/yolov5/issues/36
# img = Image.open("./zidane.jpg")
# results = model(img)
# 传入opencv读取图片
# img = cv2.imread("./zidane.jpg")[:, :, ::-1] # opencv的 BGR2RGB
# results = model(img)
# 当然,如果你有一个batch图片,可以传入一个列表,同样列表中可以是上面的图片路径、图片链接等类型
# images = ["./bus.jpg", "./zidane.jpg"]
# results = model(images)
print(results) # <models.common.Detections object at 0x7f3d6962ce80>
# 关于Inference的更多参数可以参考:[yolov5/models/common.py](https://github.com/ultralytics/yolov5/blob/3551b072b3/models/common.py#L182)
# model.conf = 0.25 # confidence threshold (0-1)
# model.iou = 0.45 # NMS IoU threshold (0-1)
# model.classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for persons, cats and dogs
# 可以自定义设置推理输入大小
# results = model(imgs, size=320) # custom inference size
# 打印检测的结果
results.print()
# image 1/1: 720x1280 2 persons, 2 ties
# Speed: 14.8ms pre-process, 14.7ms inference, 15.2ms NMS per image at shape (1, 3, 384, 640)
# 保存检测的结果图片
results.save()
# Saved zidane.jpg to runs/hub/exp
# 如果在运行一次,就重新保存到 runs/hub/exp2目录下
# 显示结果
results.show()
# 图片预测的结果,是一个tensor, 结果包括检测框坐标: xmin ymin xmax ymax confidence class name
print(results.xyxy[0]) # img1 predictions (tensor)
# tensor([[7.50000e+02, 4.30000e+01, 1.14800e+03, 7.09000e+02, 8.76953e-01, 0.00000e+00],
# [4.33500e+02, 4.34000e+02, 5.18000e+02, 7.15000e+02, 6.58203e-01, 2.70000e+01],
# [1.13500e+02, 1.96250e+02, 1.09200e+03, 7.10000e+02, 5.95215e-01, 0.00000e+00],
# [9.86000e+02, 3.04500e+02, 1.02800e+03, 4.20000e+02, 2.84424e-01, 2.70000e+01]], device='cuda:0')
# 把检测结果转换成pandas的类型, 就是把results.xyxy[0]的tensor类型结果值转换成如下的pandas数据格式
print(results.pandas().xyxy[0]) # img1 predictions (pandas)
# xmin ymin xmax ymax confidence class name
# 0 749.50 43.50 1148.0 704.5 0.874023 0 person
# 1 433.50 433.50 517.5 714.5 0.687988 27 tie
# 2 114.75 195.75 1095.0 708.0 0.624512 0 person
# 3 986.00 304.00 1028.0 420.0 0.286865 27 tie
# 把结果保存成json文件
results.pandas().xyxy[0].to_json(orient="records") # JSON img1 predictions
print(results.pandas().xyxy[0].to_json(orient="records"))
'''
[{"xmin":750.0,"ymin":43.0,"xmax":1148.0,"ymax":709.0,"confidence":0.876953125,"class":0,"name":"person"},
{"xmin":433.5,"ymin":434.0,"xmax":518.0,"ymax":715.0,"confidence":0.658203125,"class":27,"name":"tie"},
{"xmin":113.5,"ymin":196.25,"xmax":1092.0,"ymax":710.0,"confidence":0.5952148438,"class":0,"name":"person"},
{"xmin":986.0,"ymin":304.5,"xmax":1028.0,"ymax":420.0,"confidence":0.2844238281,"class":27,"name":"tie"}]
'''
7 更多pytorch hub的项目
1、Pytorch hub
就是为了对研究快速进行复现
,更多Pytorch hub的项目
可以参考Pytorch hub search
主页:
- Pytorch Hub search models的官方仓库:pytorch.org/hub/researc…
2、可以看出,Pytorch Hub包含了很多方向的项目
:
All Audio Generative Nlp Scriptable Vision
-
Audio
:语音方向 -
Generative
:生成网络方向 -
NLP
:自然语言方向 -
Scriptable
:可变写脚本 -
Vision
:视觉方向