欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

使用 Pytorch Hub 的详细教程,以及使用 yolov5 hubconf.py 的具体示例

最编程 2024-03-12 13:55:42
...

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

1 Pytorch Hub介绍

1.1 Pytorch Hub的简单介绍

Pytorch Hub是一个预先训练的模型存储库,旨在促进研究重现性。简单来说就是很多人把各种项目训练的预训练模型上传到Pytorch Hub,然后我们直接通过torch.hub.load()接口就可以直接调用,迅速实现项目中的效果!

从Pytorch Hub的Github主页发布的时间,可以看出大概是在2020年4月份左右,是比Tensorflow HubPaddlePaddle Hub时间相对晚!

1.2 Pytorch Hub预训练模型上传简单介绍

Pytorch Hub支持通过添加简单的hubconf.py文件将预训练的模型(模型定义和预训练的权重)发布到github存储库;hubconf.py可以具有多个入口点。 每个入口点都定义为python函数(例如:您要发布的经过预先训练的模型)

def entrypoint_name(*args, **kwargs):
    # args & kwargs are optional, for models which take positional/keyword arguments.
    ...

详细具体内容在后文介绍

1.3 Pytorch Hub相关内容地址

  • Pytorch HUb的Github主页github.com/pytorch/hub
  • Pytorch Hub的官方仓库pytorch.org/hub/researc…
  • Pytorch Hub的官方文档pytorch.org/docs/master…

1.4 Pytorch Hub目前支持的生态(Ecosystem)领域

目前Pytorch Hub生态的模型已经包含了几乎深度学习的所有领域,预训练的模型也非常多,主要领域包括:

  • 语音Audio
  • 生成模型Generative
  • 自然语言处理Nlp
  • 可编写脚本Scriptable
  • 计算机视觉Vision

在这里插入图片描述

2 如何在torch.hub中发布模型

  • 下面的2-3节参考的是pytorch hub的官方文档

通过添加一个简单的hubconf.py文件Pytorch Hub支持将预训练模型模型定义预训练权重)发布到github存储库hubconf.py可以定义多个入口点。每个入口点都被定义为一个python函数(例如:你想要发布的预先训练的模型)。

def entrypoint_name(*args, **kwargs):
    # args & kwargs are optional, for models which take positional/keyword arguments.
    ...

2.1 如何实现入口点(entrypoint)

我们在github上的项目pytorch/vision中创建一个hubconf.py文件,中然后实现一个入口,下面的代码片段为resnet18模型指定了一个入口点。在大多数情况下,在hubcon .py中导入正确的函数就足够了。这里只是为了说明是如何工作的。具体可以到pytorch/vision repo查看完整的脚本

dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18

# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
    """ # This docstring shows up in hub.help()
    Resnet18 model
    pretrained (bool): kwargs, load pretrained weights into the model
    """
    # Call the model, load pretrained weights
    model = _resnet18(pretrained=pretrained, **kwargs)
    return model

1、dependencies变量说明

dependencies变量:是加载模型所需要的列表,注意:这与训练模型所需要的依赖有所不同

2、argskwargs传递给真正可调用参数

3、该函数文档字符串帮助信息。它解释了模型的作用以及允许的位置/关键字参数。强烈建议在此处添加一些示例

4、Entrypoint函数可以返回一个模型(nn .module),也可以返回辅助工具,以使用户的工作流程更顺利,例如分词器(tokenizers)。

5、带有下划线前缀的可调用项被视为辅助函数,不会在中显示torch.hub.list()

6、预训练的权重既可以存储在github存储库中,也可以由torch.hub.load_state_dict_from_url()加载。 如果小于2GB,建议将其附加到项目版本中,并使用该版本中的网址。 在上面的示例中,torchvision.models.resnet.resnet18处理预训练的方法,或者你可以在入口点定义中写入以下逻辑:

if pretrained:
    # For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
    dirname = os.path.dirname(__file__)
    checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)

    #  对于checkpoint模型你可以保存在其他地方
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

注意:

发布的模型(published models)应至少在分支/标签中。 不能是随机提交。

3 从torch.hub中加载模型

Pytorch Hub提供了方便的api,可以通过:

  • torch.hub.list()查看Hub中所有可用的模型

  • torch.hub.help()显示文档字符串和示例

-torch.hub.load()加载预先训练好的模型

3.1 torch.hub.list()查看指定仓库在hubconf.py中定义的模型

使用torch.hub.list()可以查看指定仓库hubconf.py中定义的入口模型,其实就是查看hubconf.py中定义的函数名!

3.1.1 torch.hub.list() 源码

1、torch.hub.list()函数的源码:

def list(github, force_reload=False):
    r"""
    List all entrypoints available in `github` hubconf.

    Args:
        github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
            tag/branch. The default branch is `master` if not specified.
            Example: 'pytorch/vision[:hub]'
        force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
            Default is `False`.
    Returns:
        entrypoints: a list of available entrypoint names

    Example:
        >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
    """
    repo_dir = _get_cache_or_reload(github, force_reload, True)

    sys.path.insert(0, repo_dir)

    hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)

    sys.path.remove(repo_dir)

    # We take functions starts with '_' as internal helper functions
    entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]

    return entrypoints

2、参数说明

torch.hub.list():列出指定github仓库中对应的hubconf.py文件中定义的入口,返回值是entrypoints

参数:

  • github"repo_owner / repo_name [:tag_name]”,后面的tag_name是版本,是可选的,例如yolov5的github仓库地址为:https://github.com/ultralytics/yolov5,此时这里的参数可以写成:github=ultralytics/yolov5

  • force_reload:默认为False,是否强制重新下载源代码仓库,如果为False,默认是先去缓存目录~/.cache/torch/hub下去找

3.1.2 torch.hub.list() 查看yolov5接口实例

我们以yolo5为例,来说明torch.hub.list()的使用。

  • yolov5的github仓库地址:https://github.com/ultralytics/yolov5
  • yolov5仓库中定义的hubconf.py地址https://github.com/ultralytics/yolov5/blob/develop/hubconf.py

1、查看yolov5在hubconf.py中所有可用的模型

import torch

# yolov5: https://github.com/ultralytics/yolov5
entrypoints = torch.hub.list("ultralytics/yolov5")
print(entrypoints)
# ['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']

输出:

['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']
Using cache found in /home/shl/.cache/torch/hub/ultralytics_yolov5_master

如果把forch_reload=True,你看到的输出则是:

Downloading: "https://github.com/ultralytics/yolov5/archive/master.zip" to /home/shl/.cache/torch/hub/master.zip
['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']

上面列表中都是在yolov5仓库代码的hubconf.py文件中定义的函数,当然_开头的函数不会放到列表中!

2、因此,在调用torch.hub.list()函数的时候,会先把指定的github仓库的代码克隆下载到~/.cache/torch/hub缓存目录下,并把仓库名重新命名为:repoOwner_repoName_master

在这里插入图片描述

3.2 torch.hub.help()显示文档字符串和示例

3.2.1 torch.hub.help()源码

1、torch.hub.help()

def help(github, model, force_reload=False):
    r"""
    Show the docstring of entrypoint `model`.

    Args:
        github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional
            tag/branch. The default branch is `master` if not specified.
            Example: 'pytorch/vision[:hub]'
        model (string): a string of entrypoint name defined in repo's hubconf.py
        force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
            Default is `False`.
    Example:
        >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
    """
    repo_dir = _get_cache_or_reload(github, force_reload, True)

    sys.path.insert(0, repo_dir)

    hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)

    sys.path.remove(repo_dir)

    entry = _load_entry_from_hubconf(hub_module, model)

    return entry.__doc__

2、参数

  • github"repo_owner / repo_name [:tag_name]”,后面的tag_name是版本,是可选的,例如yolov5的github仓库地址为:https://github.com/ultralytics/yolov5,此时这里的参数可以写成:github=ultralytics/yolov5

  • model:就是在hubconf.py文件中定义函数的名字

  • force_reload:默认为False,是否强制重新下载源代码仓库,如果为False,默认是先去缓存目录~/.cache/torch/hub下去找

返回值是字符串,就是指定函数下注释的内容:

3.2.2 torch.hub.help()查看resnet18模型的帮助文档

1、代码

import torch

# yolov5: https://github.com/pytorch/vision
help_doc = torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)
print(help_doc)

2、查看结果如下:

在这里插入图片描述

帮助文档信息就定义在~/.cache/torch/hub/pytorch_vision_master/torchvision/models/resnet.py文件中!

3.3 torch.hub.load()加载预先训练好的模型

3.3.1 torch.hub.load()源码

1、torch.hub.load()源码

def load(repo_or_dir, model, *args, **kwargs):
    r"""
    Load a model from a github repo or a local directory.

    Note: Loading a model is the typical use case, but this can also be used to
    for loading other objects such as tokenizers, loss functions, etc.

    If :attr:`source` is ``'github'``, :attr:`repo_or_dir` is expected to be
    of the form ``repo_owner/repo_name[:tag_name]`` with an optional
    tag/branch.

    If :attr:`source` is ``'local'``, :attr:`repo_or_dir` is expected to be a
    path to a local directory.

    Args:
        repo_or_dir (string): repo name (``repo_owner/repo_name[:tag_name]``),
            if ``source = 'github'``; or a path to a local directory, if
            ``source = 'local'``.
        model (string): the name of a callable (entrypoint) defined in the
            repo/dir's ``hubconf.py``.
        *args (optional): the corresponding args for callable :attr:`model`.
        source (string, optional): ``'github'`` | ``'local'``. Specifies how
            ``repo_or_dir`` is to be interpreted. Default is ``'github'``.
        force_reload (bool, optional): whether to force a fresh download of
            the github repo unconditionally. Does not have any effect if
            ``source = 'local'``. Default is ``False``.
        verbose (bool, optional): If ``False``, mute messages about hitting
            local caches. Note that the message about first download cannot be
            muted. Does not have any effect if ``source = 'local'``.
            Default is ``True``.
        **kwargs (optional): the corresponding kwargs for callable
            :attr:`model`.

    Returns:
        The output of the :attr:`model` callable when called with the given
        ``*args`` and ``**kwargs``.

    Example:
        >>> # from a github repo
        >>> repo = 'pytorch/vision'
        >>> model = torch.hub.load(repo, 'resnet50', pretrained=True)
        >>> # from a local directory
        >>> path = '/some/local/path/pytorch/vision'
        >>> model = torch.hub.load(path, 'resnet50', pretrained=True)
    """
    source = kwargs.pop('source', 'github').lower()
    force_reload = kwargs.pop('force_reload', False)
    verbose = kwargs.pop('verbose', True)

    if source not in ('github', 'local'):
        raise ValueError(
            f'Unknown source: "{source}". Allowed values: "github" | "local".')

    if source == 'github':
        repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose)

    model = _load_local(repo_or_dir, model, *args, **kwargs)
    return model

2、参数: load(repo_or_dir, model, *args, **kwargs)

  • repo_or_dir:该参数可以传入两种值:

    • repo_owner/repo_name[:tag_name]:就是和上面的list和help一样,传入仓库的作者名和仓库名

    • 本地源码的路径:如果本地已经克隆下载了源码的路径,你可以直接填写克隆源码的路径即可,就相当于是从缓存中加载一个道理

  • model:就是在hubconf.py文件中定义的模型

  • 还有其他的可选参数

3、源码文档中给出的参考实例:

>>> # from a github repo
>>> repo = 'pytorch/vision'
>>> model = torch.hub.load(repo, 'resnet50', pretrained=True)
>>> # from a local directory
>>> path = '/some/local/path/pytorch/vision'
>>> model = torch.hub.load(path, 'resnet50', pretrained=True)

3.3.2 torch.hub.load()导入yolov5的模型实例

1、在test.py文件定义如下

import torch

# yolov5: https://github.com/ultralytics/yolov5

model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)
print(model)

2、输出结果:

在这里插入图片描述

  • 首先还是先下载源码:默认下载到~/.cache/torch/hub目录

  • 然后下载指定的预训练模型,预训练的模型默认是下载到和当前执行文件的同级目录下。例如我这里就会下载yolov5s.pt预训练模型下载到和test.py同一个目录下!

3.4 torch.hub.download_url_to_file():预训练模型下载到指定本地路径

1、download_url_to_file(url, dst, hash_prefix=None, progress=True)参数

  • url:要下载的预训练模型url链接地址

  • dst:预训练模型下载到本地的目录路径

  • hash_prefix(可选):如果不是None,则SHA256下载的文件应该以哈希前缀开头,默认是None

  • progress(可选):是否显示下载进度条,默认是True

2、实例:把resnet18的预训练模型下载到当前路径下

torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', './')

在这里插入图片描述

3、你也可以下载图片什么的,都可以(参考

torch.hub.download_url_to_file('https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg', 'zidane.jpg')
torch.hub.download_url_to_file('https://github.com/ultralytics/yolov5/raw/master/data/images/bus.jpg', 'bus.jpg')

3.5 torch.hub.load_state_dict_from_url()

torch.hub.load_state_dict_from_url()从给定的url中加载torch序列化的模型,如果下载的文件是压缩文件,则会自动解压缩!

1、torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None)参数

  • url:下载的url链接地址

  • model_dir(可选):下载模型保存路径,如果不指定,默认是下载到os.path.join(torch.hub.get_dir, "checkpoints")目录下,torch.hub.get_dir()的返回值为:'~/.cache/torch/hub',所以默认的下载模型的保存路径为:~/.cache/torch/hub/checkpoints

  • map_location(可选):一个函数或dict指定如何重新映射存储位置

  • progress(可选):是否显示下载进度条,默认是True

  • check_hash(可选)

  • file_name(可选):是否重命名下载的文件名,如果不设置则默认用url中给定的文件名

2、实例

>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

3.6 模型下载和保存的路径

3.6.1 torch.hub.get_dir() 获取模型下载的路径

1、torch.hub.get_dir()

>>> torch.hub.get_dir()
'/home/shl/.cache/torch/hub'
>>> 

2、默认是下载到os.path.join(torch.hub.get_dir, "checkpoints")目录下,torch.hub.get_dir()的返回值为:'~/.cache/torch/hub',所以默认的下载模型的保存路径为:~/.cache/torch/hub/checkpoints

3.6.2 torch.hub.set_dir(path)

1、设置模型下载的路径

>>> torch.hub.get_dir()
'/home/shl/.cache/torch/hub'
>>> torch.hub.set_dir("~/projects")
>>> torch.hub.get_dir()
'~/projects'
>>> 

4 yolov5 中的hubconf.py实例

1、下面是yolov5中的定义的hubconf.py,定义内容如下:

"""YOLOv5 PyTorch Hub models https://pytorch.org/hub/ultralytics_yolov5/
Usage:
    import torch
    model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
"""

import torch


def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
    """Creates a specified YOLOv5 model
    Arguments:
        name (str): name of model, i.e. 'yolov5s'
        pretrained (bool): load pretrained weights into the model
        channels (int): number of input channels
        classes (int): number of model classes
        autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
        verbose (bool): print all information to screen
        device (str, torch.device, None): device to use for model parameters
    Returns:
        YOLOv5 pytorch model
    """
    from pathlib import Path

    from models.yolo import Model, attempt_load
    from utils.general import check_requirements, set_logging
    from utils.google_utils import attempt_download
    from utils.torch_utils import select_device

    check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('tensorboard', 'pycocotools', 'thop'))
    set_logging(verbose=verbose)

    fname = Path(name).with_suffix('.pt')  # checkpoint filename
    try:
        if pretrained and channels == 3 and classes == 80:
            model = attempt_load(fname, map_location=torch.device('cpu'))  # download/load FP32 model
        else:
            cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0]  # model.yaml path
            model = Model(cfg, channels, classes)  # create model
            if pretrained:
                ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu'))  # load
                msd = model.state_dict()  # model state_dict
                csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
                csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape}  # filter
                model.load_state_dict(csd, strict=False)  # load
                if len(ckpt['model'].names) == classes:
                    model.names = ckpt['model'].names  # set class names attribute
        if autoshape:
            model = model.autoshape()  # for file/URI/PIL/cv2/np inputs and NMS
        device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
        return model.to(device)

    except Exception as e:
        help_url = 'https://github.com/ultralytics/yolov5/issues/36'
        s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
        raise Exception(s) from e


def custom(path='path/to/model.pt', autoshape=True, verbose=True, device=None):
    # YOLOv5 custom or local model
    return _create(path, autoshape=autoshape, verbose=verbose, device=device)


def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
    # YOLOv5-small model https://github.com/ultralytics/yolov5
    return _create('yolov5s', pretrained, channels, classes, autoshape, verbose, device)


def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
    # YOLOv5-medium model https://github.com/ultralytics/yolov5
    return _create('yolov5m', pretrained, channels, classes, autoshape, verbose, device)


def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
    # YOLOv5-large model https://github.com/ultralytics/yolov5
    return _create('yolov5l', pretrained, channels, classes, autoshape, verbose, device)


def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
    # YOLOv5-xlarge model https://github.com/ultralytics/yolov5
    return _create('yolov5x', pretrained, channels, classes, autoshape, verbose, device)


def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
    # YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
    return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose, device)


def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
    # YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
    return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose, device)


def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
    # YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
    return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose, device)


def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
    # YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
    return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose, device)


if __name__ == '__main__':
    model = _create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True)  # pretrained
    # model = custom(path='path/to/model.pt')  # custom

    # Verify inference
    import cv2
    import numpy as np
    from PIL import Image

    imgs = ['data/images/zidane.jpg',  # filename
            'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg',  # URI
            cv2.imread('data/images/bus.jpg')[:, :, ::-1],  # OpenCV
            Image.open('data/images/bus.jpg'),  # PIL
            np.zeros((320, 640, 3))]  # numpy

    results = model(imgs)  # batched inference
    results.print()
    results.save()

2、查看yolov5在hubconf.py中所有可用的模型

import torch

# yolov5: https://github.com/ultralytics/yolov5
entrypoints = torch.hub.list("ultralytics/yolov5")
print(entrypoints)
# ['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']

输出:

['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']
Using cache found in /home/shl/.cache/torch/hub/ultralytics_yolov5_master

3、通过如下的形式就可以导入yolov5的一个yolov5s模型

import torch
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')

5 通过torch.hub.load()和gradio在浏览器中测试模型

5.1 yolov5 hub 相关的资源链接

  • yolov5的github项目主页https://github.com/ultralytics/yolov5

  • yolov5的github关于通过torch.hub.load()导入模型的文档那个https://github.com/ultralytics/yolov5/issues/36

  • yolov5在pytorch hub主页介绍使用https://pytorch.org/hub/ultralytics_yolov5/

  • gradio的github主页https://github.com/gradio-app/gradio

5.2 在浏览器中测试yolov5目标检测

5.2.1 通过torch.hub.load()和gradio在浏览器中测试目标检测代码

  • 测试代码来源: Demo Model Output

  • 你也可以在google colab中直接运行测试代码

1、代码

# yolov5: https://github.com/ultralytics/yolov5

import gradio as gr
import torch
from PIL import Image


# 下载测试图片,默认是下载到当前目录
# 解释图片已经下载了,还是会重新下载,这个智障,代码需要优化
torch.hub.download_url_to_file('https://github.com.cnpmjs.org/ultralytics/yolov5/raw/master/data/images/zidane.jpg', 'zidane.jpg')
torch.hub.download_url_to_file('https://github.com.cnpmjs.org/ultralytics/yolov5/raw/master/data/images/bus.jpg', 'bus.jpg')

# 加载模型和下载预训练的模型,预训练的模型下载到当前目录
# 源码被克隆下载到:~/.cache/torch/hub 目录下
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')  # force_reload=True to update


# 使用yolov5s模型进行推理测试
def yolo(img, size=640):
    # 例如输入图片zidane.jpg的尺寸是: (1280x720)   ====> scale = 640 / max(1280, 720) = 640/1280=0.5
    scale = (size / max(img.size))
    # 根据计算的scale,把输入图片进行resize : resize_shape = (1280,720)*0.5 = (640, 360)
    img = img.resize((int(x * scale) for x in img.size), Image.ANTIALIAS)  # resize

    # yolov5s推理模型接口 inference
    results = model(img)
    # 更新推理结果,就是把检测到目标的boxes和labels绘制到图像上
    results.render()  # updates results.imgs with boxes and labels
    return Image.fromarray(results.imgs[0])


# 定义输入和输出图像的类型
inputs = gr.inputs.Image(type='pil', label="Original Image")
outputs = gr.outputs.Image(type="pil", label="Output Image")

title = "YOLOv5"
description = "YOLOv5 Gradio demo for object detection. Upload an image or click an example image to use."
article = "<p style='text-align: center'>YOLOv5 is a family of compound-scaled object detection models trained on the COCO dataset, and includes " \
          "simple functionality for Test Time Augmentation (TTA), model ensembling, hyperparameter evolution, " \
          "and export to ONNX, CoreML and TFLite. <a href='https://github.com/ultralytics/yolov5'>Source code</a> |" \
          "<a href='https://apps.apple.com/app/id1452689527'>iOS App</a> | <a href='https://pytorch.org/hub/ultralytics_yolov5'>PyTorch Hub</a></p>"

# 下载的两个测试图片
examples = [['zidane.jpg'], ['bus.jpg']]
gr.Interface(yolo, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch(
    debug=True)

2、程序运行后会先下载测试图片,然后把下面的http://127.0.0.1:7860/链接在浏览器中打开,既可以开始测试了

在这里插入图片描述

5.2.2 在浏览器中选择图片进行目标检测的具体操作实例

1、链接在浏览器中打开如下

  • 在label Original Image的位置就传入测试的输入图片,有三种方式传入出入图片:

    • Drop Image Here :把输入图片通过拖拽的方式,拖到下面的虚线框中
    • Click to Upload点击虚线框中,手动选择要输入的图片
    • 选择Example图片:在左下角有两张测试图片,可以通过点击选择Example中的测试图片进行测试
  • 在label Ouptut Image是通过检测后图片的输出结果

在这里插入图片描述

2、下面说明我的操作步骤:

  • 1)先在左下角的一张测试图片(我选择zidane.jpg),然后图片就会出现在Input Image的框中

  • 2)点击Submit按钮,就会开始检测检测的结果会显示在Output Image的框中

  • 3)点击Flag会把输入和输出图片保存下来

在这里插入图片描述

3、如下是我点击Flag按钮之后,生成一个flagged目录

flagged
├── Original Image
│   └── 0.jpeg   # 1280x720    输入图片(Input Image)
└── Output Image
    └── 0.png   # 640x360   输出的结果图片 (Output Image)

在这里插入图片描述

6 yolov5 hub的推理接口

该部分参考内容

6.1使用torch.hub.load()加载yolov5模型

# 模型 Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

6.2 模型输入数据进行推理预测

1、模型的输入数据类型有很多

# 测试图片 Images
imgs = ['https://ultralytics.com/images/zidane.jpg']  # batch of images

# 推理测试 Inference
# 可以传入图片链接、图片路径 或 读取好的图片 img1 = Image.open('zidane.jpg')  # PIL image

# 传入链接的时候,图片被下载到当前目录下,如果图片已经存在,下次运行还是会下载
# results = model('https://ultralytics.com/images/zidane.jpg')

# 直接传入图片路径
results = model("./zidane.jpg")

# 传入PIL读取图片  参考:https://github.com/ultralytics/yolov5/issues/36
# img = Image.open("./zidane.jpg")
# results = model(img)

# 传入opencv读取图片
# img = cv2.imread("./zidane.jpg")[:, :, ::-1]   # opencv的 BGR2RGB
# results = model(img)


# 当然,如果你有一个batch图片,可以传入一个列表,同样列表中可以是上面的图片路径、图片链接等类型
# images = ["./bus.jpg", "./zidane.jpg"]
# results = model(images)

print(results)  # <models.common.Detections object at 0x7f3d6962ce80>

2、更多推理预测参数可以参考:yolov5/models/common.py

    def forward(self, imgs, size=640, augment=False, profile=False):Glenn Jocher, 7 months ago: • PyTorch Hub and autoShape update (#1415)
        # Inference from various sources. For height=720, width=1280, RGB images example inputs are:
        #   filename:   imgs = 'data/samples/zidane.jpg'
        #   URI:             = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
        #   OpenCV:          = cv2.imread('image.jpg')[:,:,::-1]  # HWC BGR to RGB x(720,1280,3)
        #   PIL:             = Image.open('image.jpg')  # HWC x(720,1280,3)
        #   numpy:           = np.zeros((720,1280,3))  # HWC
        #   torch:           = torch.zeros(16,3,720,1280)  # BCHW
        #   multiple:        = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...]  # list of images

6.3 完整代码

__Author__ = "Shliang"
__Email__ = "shliang0603@gmail.com"

import torch
from PIL import Image
import cv2

# 模型 Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

# 测试图片 Images
imgs = ['https://ultralytics.com/images/zidane.jpg']  # batch of images

# 推理测试 Inference
# 可以传入图片链接、图片路径 或 读取好的图片 img1 = Image.open('zidane.jpg')  # PIL image

# 传入链接的时候,图片被下载到当前目录下,如果图片已经存在,下次运行还是会下载
# results = model('https://ultralytics.com/images/zidane.jpg')

# 直接传入图片路径
results = model("./zidane.jpg")

# 传入PIL读取图片  参考:https://github.com/ultralytics/yolov5/issues/36
# img = Image.open("./zidane.jpg")
# results = model(img)

# 传入opencv读取图片
# img = cv2.imread("./zidane.jpg")[:, :, ::-1]   # opencv的 BGR2RGB
# results = model(img)


# 当然,如果你有一个batch图片,可以传入一个列表,同样列表中可以是上面的图片路径、图片链接等类型
# images = ["./bus.jpg", "./zidane.jpg"]
# results = model(images)

print(results)  # <models.common.Detections object at 0x7f3d6962ce80>


# 关于Inference的更多参数可以参考:[yolov5/models/common.py](https://github.com/ultralytics/yolov5/blob/3551b072b3/models/common.py#L182)
# model.conf = 0.25  # confidence threshold (0-1)
# model.iou = 0.45  # NMS IoU threshold (0-1)
# model.classes = None  # (optional list) filter by class, i.e. = [0, 15, 16] for persons, cats and dogs
# 可以自定义设置推理输入大小
# results = model(imgs, size=320)  # custom inference size

# 打印检测的结果
results.print()
# image 1/1: 720x1280 2 persons, 2 ties
# Speed: 14.8ms pre-process, 14.7ms inference, 15.2ms NMS per image at shape (1, 3, 384, 640)

# 保存检测的结果图片
results.save()
# Saved zidane.jpg to runs/hub/exp
# 如果在运行一次,就重新保存到 runs/hub/exp2目录下

# 显示结果
results.show()

# 图片预测的结果,是一个tensor, 结果包括检测框坐标: xmin ymin xmax ymax confidence class name
print(results.xyxy[0])  # img1 predictions (tensor)
# tensor([[7.50000e+02, 4.30000e+01, 1.14800e+03, 7.09000e+02, 8.76953e-01, 0.00000e+00],
#         [4.33500e+02, 4.34000e+02, 5.18000e+02, 7.15000e+02, 6.58203e-01, 2.70000e+01],
#         [1.13500e+02, 1.96250e+02, 1.09200e+03, 7.10000e+02, 5.95215e-01, 0.00000e+00],
#         [9.86000e+02, 3.04500e+02, 1.02800e+03, 4.20000e+02, 2.84424e-01, 2.70000e+01]], device='cuda:0')

# 把检测结果转换成pandas的类型, 就是把results.xyxy[0]的tensor类型结果值转换成如下的pandas数据格式
print(results.pandas().xyxy[0])  # img1 predictions (pandas)
#      xmin    ymin    xmax   ymax  confidence  class    name
# 0  749.50   43.50  1148.0  704.5    0.874023      0  person
# 1  433.50  433.50   517.5  714.5    0.687988     27     tie
# 2  114.75  195.75  1095.0  708.0    0.624512      0  person
# 3  986.00  304.00  1028.0  420.0    0.286865     27     tie


# 把结果保存成json文件
results.pandas().xyxy[0].to_json(orient="records")  # JSON img1 predictions
print(results.pandas().xyxy[0].to_json(orient="records"))
'''
[{"xmin":750.0,"ymin":43.0,"xmax":1148.0,"ymax":709.0,"confidence":0.876953125,"class":0,"name":"person"},
{"xmin":433.5,"ymin":434.0,"xmax":518.0,"ymax":715.0,"confidence":0.658203125,"class":27,"name":"tie"},
{"xmin":113.5,"ymin":196.25,"xmax":1092.0,"ymax":710.0,"confidence":0.5952148438,"class":0,"name":"person"},
{"xmin":986.0,"ymin":304.5,"xmax":1028.0,"ymax":420.0,"confidence":0.2844238281,"class":27,"name":"tie"}]
'''

7 更多pytorch hub的项目

1、Pytorch hub就是为了对研究快速进行复现,更多Pytorch hub的项目可以参考Pytorch hub search主页:

  • Pytorch Hub search models的官方仓库pytorch.org/hub/researc…

在这里插入图片描述

2、可以看出,Pytorch Hub包含了很多方向的项目: All Audio Generative Nlp Scriptable Vision

  • Audio:语音方向

  • Generative:生成网络方向

  • NLP:自然语言方向

  • Scriptable:可变写脚本

  • Vision:视觉方向