欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

[代码阅读] U-Net++ Pytorch

最编程 2024-03-06 17:30:06
...

代码:
https://github.com/4uiiurz1/pytorch-nested-unet
文件包括:utils.py,preprocess_dsb2018.py,dataset.py,train.py, archs.py,losses.py,metrics.py,test.py

  • utils.py

def str2bool: => tru = 1; false = 0
def count_params(): =>计算训练参数量

  • preprocess_dsb2018.py

数据预处理
创建好image,mask所在目录,规范image channel = 3,resize image & mask,将千奇百态的图片按上述处理后再重命名保存。

  • dataset.py

将image,mask [0,255]的范围归一化为[0,1]。
如果进行数据增强,则将部分图片左右翻转或上下翻转,并将(h,w,c) =>(c,h,w)

  • train.py

用训练集训练模型,并在验证集的约束下优化模型并保存。

  • test.py

加载最优模型对测试集预测,保存分割的mask,并用相关指标评估。


preprocess_dsb2018.py

  • glob.glob()
    返回所有匹配的文件路径列表(list);该方法需要一个参数用来指定匹配的路径字符串(字符串可以为绝对路径也可以为相对路径),其返回的文件名只包括当前目录里的文件名,不包括子文件夹里的文件。

  • tqdm
    Tqdm 是 Python 进度条库,可以在 Python 长循环中添加一个进度提示信息用法:tqdm(iterator)

  • skimage.io.imread
    io.imread读出图片格式是uint8(unsigned int);value是numpy array;图像数据是以RGB的格式进行存储的,通道值默认范围0-255。(height,width, channel)
    skimage图片信息
from skimage import io, data
img = data.chelsea()
io.imshow(img)
print(type(img))  #显示类型
print(img.shape)  #显示尺寸
print(img.shape[0])  #图片高度
print(img.shape[1])  #图片宽度
print(img.shape[2])  #图片通道数
print(img.size)   #显示总像素个数
print(img.max())  #最大像素值
print(img.min())  #最小像素值
print(img.mean()) #像素平均值
print(img[0][0])  #图像的像素值

image.shape[0], 图片垂直尺寸
image.shape[1], 图片水平尺寸
image.shape[2], 图片通道数

  • len()
    返回字符串、列表、字典、元组等长度
    if (len(image.shape) == 2):
    如果图片只有height,width,没有channel

  • np.tile()
    函数形式: tile(A,rep)
    功能:重复A的各个维度
    参数类型:
    A: Array类的都可以
    rep:A沿着各个维度重复的次数

  • skimage.io.imsave
    使用io模块的imsave(fname,arr)函数来实现。第一个参数表示保存的路径和名称,第二个参数表示需要保存的数组变量。
    保存图片的同时也起到了转换格式的作用。如果读取时图片格式为jpg图片,保存为png格式,则将图片从jpg图片转换为png图片并保存。

  • os.path.basename()
    返回path最后的文件名。若path以/或\结尾,那么就会返回空值。

path='D:\****'
os.path.basename(path)=****
path='/root/runoob.txt'
os.path.basename(path)=runoob.txt
  • Numpy的布尔索引与花式索引
for mask_path in glob(path+'/masks/*'):
     mask_ = imread(mask_path) > 127
     mask[mask_] = 1
  • imread(mask_path) 是numpy.ndarray,(h,w,c),范围是[0,255],0为黑色,255为白色,
  • mask_ 数组中 > 127(白色)的元素记为ture,否则记为false.
  • mask[mask_] = 1将mask_ 元素为1的的地方赋值为1。
    详见:
    Numpy的布尔索引与花式索引

  • cv2.resize()
    cv2.resize(src,dsize,dst=None,fx=None,fy=None,interpolation=None)
    • scr:原图
    • dsize:输出图像尺寸
    • fx:沿水平轴的比例因子
    • fy:沿垂直轴的比例因子
    • interpolation:插值方法
      只会resize原图像的水平方向尺寸和垂直方向尺寸,不会对channel有影响。
      详见:
      cv2.resize()

train.py

  • .__ dict __
    通俗的理解:每个参数,变量,对象都是以字典的形式存储,每一个key对应一个value
    __ dict __.png

def main()

  • argparse.ArgumentParser()

    • choices - 设置参数值的范围,如果choices中的类型不是字符串,要指定type。#parser.add_argument(“-y”, choices=[‘a’, ‘b’, ‘d’])
    • metavar - 参数的名字,在显示 帮助信息时才用到. # parser.add_argument(“-o”, metavar=”OOOOOO”)
      更多见--python3中argparse模块详解
  • vars()
    返回对象object的属性和属性值的字典对象。

  • getattr()
    getattr():从名字上看获取属性值.

class Person():
    age = 14
Tom = Person()
print(getattr(Tom,'age'))

此时的结果为14,
若,该属性不存在

getattr(Tom,'name')
AttributeError: 'Person' object has no attribute 'name'
  • train_test_split()
    • 所在包:sklearn.model_selection
    • 功能:划分数据的训练集与测试集
    • 参数解读:train_test_split (*arrays,test_size, train_size, rondom_state=None, shuffle=True, stratify=None)
    • arrays:特征数据和标签数据(array,list,dataframe等类型),要求所有数据长度相同。
    • test_size / train_size: 测试集/训练集的大小,若输入小数表示比例,若输入整数表示数据个数。
    • rondom_state:随机种子(一个整数),其实就是一个划分标记,对于同一个数据集,如果rondom_state相同,则划分结果也相同。
    • shuffle:是否打乱数据的顺序,再划分,默认True。
    • stratify:none或者array/series类型的数据,表示按这列进行分层采样。
xtrain,xtest,ytrain,ytest=train_test_split(data,label,test_size=0.2,stratify=data['a'],random_state=1)
  • pytorch固定部分参数进行网络训练
if args.optimizer == 'Adam':
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
elif args.optimizer == 'SGD':
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
                momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)

详见:
pytorch固定部分参数进行网络训练
pytorch固定部分参数进行网络训练

class AverageMeter

计算并存储平均值和当前值

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def train

  • .cuda()
input = input.cuda()

x.cuda()操作将 x 变为Tensor类型


cuda = >Tensor
  • python 默认参数
def enroll(name, gender, age=6, city='Beijing'):
    print('name:', name)
    print('gender:', gender)
    print('age:', age)
    print('city:', city)

只有与默认参数不符的学生才需要提供额外的信息:
enroll('Bob', 'M', 7)
enroll('Adam', 'M', city='Tianjin')

  • losses = AverageMeter()
    实例化AverageMeter(),调用init初始化函数

  • losses.update(loss.item(), input.size(0))
    input.size()覆盖默认参数n=1
    input经过.cuda后为Tensor类型,input [bacth_size, c, w, h]


dataset.py

from skimage.io import imread

image = imread(img_path)
mask = imread(mask_path)

image = image.astype('float32') / 255
mask = mask.astype('float32') / 255

if self.aug:
    if random.uniform(0, 1) > 0:
        image = image[:, ::-1, :].copy()
        mask = mask[:, ::-1].copy()
    if random.uniform(0, 1) > 0.5:
        image = image[::-1, :, :].copy()
        mask = mask[::-1, :].copy()

image = image.transpose((2, 0, 1))
mask = mask[:,:,np.newaxis]
mask = mask.transpose((2, 0, 1))
  • image = imread(img_path)

    • imread的读取的image结果为numpy.ndarray, type为('uint8'),image里都是整数uint8,范围[0-255]
    • 假设image为(300, 400, 3) (h,w,c),channel顺序为RGB
      array([
      [ [143, 198, 201 (dim=3)],[143, 198, 201],... (w=200)],
      [ [143, 198, 201],[143, 198, 201],... ], ...(h=100) ], dtype=uint8)
  • image = image.astype('float32') / 255
    将image的type改为float32,并把数据范围缩小到[0,1]

  • ** image = image[:, ::-1, :].copy()**
    image = image[:, ::-1, :] 表示将图像向右翻转180°
    image = image[::-1,: , :]表示将图像向下翻转180°

  • image = image.transpose((2, 0, 1))
    image (h,w,c) transpose为(c,w,h)


utils.py

def str2bool(v):
    if v.lower() in ['true', 1]:
        return True
    elif v.lower() in ['false', 0]:
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
  • count_params()
    计算模型训练参数量

  • pytorch 获取模型参数量

# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')

total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

arch.py

有Unet和Unet++


losses.py

一些计算losses的函数


metrics.py

评估函数


参考链接:
image.shap--yournevermore
skimage.io.imread与cv2.imread的区别
python 处理图像的常见操作-- 平移缩放裁剪
skimage图像处理
skimage.io.imread API
numpy模块的tile()方法简单说明
Python os.path() 模块
os.path.basename()作用
python之getattr()函数
train_test_split数据集分割
OpenCV、Skimage、PIL图像处理的细节差异
Numpy的布尔索引与花式索引
image = image[:, ::-1, :]的含义是什么
python 矩阵转置transpose
pytorch 获取模型参数量