创建 pytorch 分类检测分割数据集和数据加载器
最编程
2024-04-04 16:45:52
...
1.前言
在PyTorch中,Dataset
和DataLoader
是两个重要的工具,用于构建输入数据的管道。
(1)Dataset
是一个抽象类,表示数据集,需要实现__len__
和__getitem__
方法。
(2)DataLoader
是一个可迭代的数据加载器,它封装了数据集的加载、批处理、打乱和并行加载等功能。
2.分类任务创建Dataset
和DataLoader
(1)对于分类任务,Dataset
需要返回图像和对应的标签
from torch.utils.data import Dataset
from PIL import Image
import os
import torch
class ClassificationDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.transform = transform
self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
self.labels = [...] # 这里应该是与图像对应的标签列表
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
(2)DataLoader
加载数据
from torch.utils.data import DataLoader
transform = ... # 这里定义你的数据预处理流程
dataset = ClassificationDataset(root_dir='path_to_your_data', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
3.检测任务创建Dataset
和DataLoader
(1)Dataset
需要返回图像和对应的边界框信息
class DetectionDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.transform = transform
self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
self.annotations = [...] # 这里应该是与图像对应的边界框信息列表
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert('RGB')
boxes = self.annotations[idx] # 这些是边界框信息
if self.transform:
image, boxes = self.transform(image, boxes)
return image, boxes
(2)DataLoader
加载数据
dataloader = DataLoader(DetectionDataset(root_dir='path_to_your_data', transform=transform), batch_size=2, shuffle=True)
4.分割任务创建Dataset
和DataLoader
(1)Dataset
需要返回图像和对应的分割掩码
class SegmentationDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.transform = transform
self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
self.masks = [...] # 这里应该是与图像对应的分割掩码列表
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
mask_path = self.masks[idx]
image = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path).convert('L') # 假设掩码是灰度图
if self.transform:
image, mask = self.transform(image, mask)
return image, mask
(2)DataLoader
加载数据
dataloader = DataLoader(SegmentationDataset(root_dir='path_to_your_data', transform=transform), batch_size=4, shuffle=True)
在PyTorch的
Dataset
和DataLoader
框架中,idx
(或称为索引)是通过迭代DataLoader
时自动生成的。当你创建一个DataLoader
实例,并在训练循环中迭代它时,DataLoader
会内部调用Dataset
的__getitem__
方法,并自动为你提供索引idx
。
上一篇: mac/win 使用 pyinstaller 对应用程序/exe 文件进行打包,执行存活脚本,双击运行 -???? 坑爹记录
下一篇: [STM32 嵌入式系统设计与开发] - 16InputCapture(输入捕捉应用程序) - II.任务执行
推荐阅读
-
[姿势估计] 实践记录:使用 Dlib 和 mediapipe 进行人脸姿势估计 - 本文重点介绍方法 2):方法 1:基于深度学习的方法:。 基于深度学习的方法:基于深度学习的方法利用深度学习模型,如卷积神经网络(CNN)或递归神经网络(RNN),直接从人脸图像中学习姿势估计。这些方法能够学习更复杂的特征表征,并在大规模数据集上取得优异的性能。方法二:基于二维校准信息估计三维姿态信息(计算机视觉 PnP 问题)。 特征点定位:人脸姿态估计的第一步是通过特征点定位来检测和定位人脸的关键点,如眼睛、鼻子和嘴巴。这些关键点提供了人脸的局部结构信息,可用于后续的姿势估计。 旋转表示:常见的旋转表示方法包括欧拉角和旋转矩阵。欧拉角通过三个旋转角度(通常是俯仰、偏航和滚动)描述头部的旋转姿态。旋转矩阵是一个 3x3 矩阵,表示头部从一个坐标系到另一个坐标系的变换。 三维模型重建:根据特征点的定位结果,三维人脸模型可用于姿势估计。通过将人脸的二维图像映射到三维模型上,可以估算出人脸的旋转和平移信息。这就需要建立人脸的三维模型,然后通过优化方法将模型与特征点对齐,从而获得姿势估计结果。 特征点定位 特征点定位是用于检测人脸关键部位的五官基础部分,还有其他更多的特征点表示方法,大家可以参考我上一篇文章中介绍的特征点检测方案实践:人脸校正二次定位操作来解决人脸校正的问题,客户在检测关键点的代码上略有修改,坐标转换部分客户见上图 def get_face_info(image). img_copy = image.copy image.flags.writeable = False image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = face_detection.process(image) # 在图像上绘制人脸检测注释。 image.flags.writeable = True image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) box_info, facial = None, None if results.detections: for detection in results. for detection in results.detections: mp_drawing.Drawing.detection = 无 mp_drawing.draw_detection(image, detection) 面部 = detection.location_data.relative_keypoints 返回面部 在上述代码中,返回的数据是五官(6 个关键点的坐标),这是用 mediapipe 库实现的,下面我们可以尝试用另一个库:dlib 来实现。 使用 dlib 使用 Dlib 库在 Python 中实现人脸关键点检测的步骤如下: 确保已安装 Dlib 库,可使用以下命令: pip install dlib 导入必要的库: 加载 Dlib 的人脸检测器和关键点检测器模型: 读取图像并将其灰度化: 使用人脸检测器检测图像中的人脸: 对检测到的人脸进行遍历,并使用关键点检测器检测人脸关键点: 显示绘制了关键点的图像: 以下代码将参数 landmarks_part 添加到要返回的关键点坐标中。
-
创建 pytorch 分类检测分割数据集和数据加载器
-
基于机器学习的网络入侵检测与特征选择和随机森林分类器性能评估(NSL-KDD 数据集)--代码实现