欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

在一篇文章中掌握智能抠像深度图像匹配(pytorch 实现

最编程 2024-06-20 15:07:04
...

目录

一. 抠图概述

1. 抠图

2. 语义分割和抠图的协同处理

二.  MNet算法和Pytorch实现

1. MNet算法原理

2. 环境配置

3.  数据集

4.  加载数据集

5.  模型构建

6.  训练

三. 总结

参考文献


 

一. 抠图概述

语义分割模型能够从复杂的背景中准确的分割出前景边界,但是在提取边界细节信息时效果并不是很理想。本文将详细阐述如何利用深度学习来构造一个精抠模型,使得最终的图像边界更加自然。本文实现时主要参考阿里巴巴推出的SHM算法中的MNet模块。

1. 抠图

语义分割是端到端的对像素按照语义进行多分类,而抠图则更进一步,抠图将图片分成前景和背景两部分然后提取前景部分。在抠图中会有一个alpha通道的概念,可以理解为透明度,大部分抠图模型将抠图建模为下面的一个公式:

                                                                                       I= \alpha F+(1-\alpha)B

其中即为 I 为观测图像,F为前景,B为背景。\alpha即为alpha,透明度的意思。要理解抠图,必须先理解上述这个看似简单的公式,按照公式理解:观测图像等价于前景和背景的线性混合,各部分的混合度由\alpha控制。如果读者有PhotoShop的操作经验,那么完全可以将FB理解为PhotoShop中的图层,I 是两个图层的叠加。抠图的最终目标其实是为了得到\alpha,而这个\alpha并不是一个简单的0、1二值分类或者是简单的多分类,在实际操作中\alpha也看作一个通道(alpha通道),它的取值范围和RGB颜色空间一样是[0,255]的整数。从这个角度来理解,如果语义分割任务中只定义两个语义(前景和背景),那么语义分割的每个像素它的预测范围是0或1,即二值分类。而抠图的话,每个像素的预测范围是0到255的整数,其精度要求更高。因此,抠图可以看作是高版本的语义分割任务,其难度更大。

下面举一个例子来观察语义分割和抠图的不同:

                                                                                          图1.1 原始示例图

假设在暗室摄影棚中拍摄了一张器皿图,如上图所示。很明显,这个器皿具有明显的透明属性。由于背景几乎纯黑,我们可以采用语义分割技术直接将器皿分割出来,如下图所示:
 

                                                                                图1.2 语义分割掩码(mask)图

很明显,语义分割将整个器皿前景按照非0即255的方式(二值方式)分割了出来,分割出来的前景如下:

                                                                                      图1.3 语义分割前景图

由于在边界处直接硬分割,因此,器皿内部透明部分带着原始的灰色背景。此时,如果我们继续将该图与其它图像进行合成,那么会合成出下面的图像:

                                                                                               图1.4 语义分割合成图

可以看到,带着“灰色底纹”的器皿合成到新图上后具有明显的缺陷,主观感受较差。

下面再看一下抠图的效果,抠图后我们会得到一张代表图像前景透明度的alpha图,如下所示:

                                                                            图1.5 抠图得到的alpha通道图

可以看到,相比图1.4得到的二值掩码图,利用抠图技术得到的alpha通道图其取值范围更大。根据抠图合成公式,与新图进行合成可以得到下面的图像:

                                                                                             图1.6 抠图合成示例
此时新的合成图像其透明区域自带了一部分原始图像背景,其合成效果更加自然。以上就是语义分割和抠图的不同之处。

对于抠图任务来说,我们最需要的就是通过算法来计算其alpha通道图。

2. 语义分割和抠图的协同处理

前面的内容简单总结一下,语义分割采用的是在前景图像边界处硬分割的方式(0或255),抠图采用的是在前景图像边界处软分割的方式(0到255)。那么在实际的处理过程中语义分割和抠图是完全没有联系的吗?其实不是。在解释这个问题前需要先理解一个抠图领域经常使用的概念:Trimap。

由于抠图难度非常大,所以很多研究学者通过某些方式来获取一些先验值从而能够方便的获取alpha通道图。Trimap就是一种先验图的表示方式。在Trimap图中,可以完全肯定是前景区域的那么用纯255的白色标注,可以完全肯定是背景区域的那么用值为0的黑色标注。对于边界不确定部分,用值为128的灰色标注。例如图1.1,它的一种Trimap图如下所示:

                                                                                       图1.7 Trimap图示例

也就是说,Trimap图的作用是告诉后面的抠图算法,已经预先知道了哪些地方是前景,哪些地方是背景,而且这些区域基本上“板上钉钉”了,基本不需要再靠抠图算法去辨识了。抠图算法需要做的就是对那些不确定区域(灰色区域)进行精细的抠图即可。在这种情况下,抠图算法可以根据提前知道的前景和背景区域去探知区分前景和背景的特征,然后通过这种先验知识,再对灰色区域进行精细抠图。

Trimap的存在是为了简化整体的抠图难度,这也是目前大部分抠图算法采用的策略,例如典型的Closed-form抠图和KNN抠图算法,它们都需要提前提供一张用于辅助的Trimap图用来完成整个的抠图任务。也就是说抠图一般分为两个步骤:粗分割和细分割。粗分割是为了得到Trimap图,细分割是为了进一步细化抠图结果。

到这里再重新审视下Trimap,可以发现,如果将前景看作一种语义,背景看作一种语义,不确定区域(灰度区域)看作一种语义,那么Trimap每个像素只属于三种语义信息,这跟语义分割的内容完全一致,因此,完全可以采用语义分割的方法来得到Trimap图,然后再进行精细化抠图。

                                                                                                图1.9 完整的抠图流程

对于一般应用场景来说,使用语义分割就可以满足精度要求了,但是对于一些图像编辑、摄影剪辑、后期合成等任务,那么就需要高精度的抠图算法来实现了。

本文拟解决肖像照抠图合成任务,主要用来复现论文Semantic Human Matting(SHM)中的MNet模块,该算法由阿里巴巴团队提出,是一种基于深度学习的人像抠图技术。SHM是第一种全自动抠图算法,可以学习将语义信息和高质量细节与深层网络联合起来,实现端到端训练。整个算法不需要用户提供任何交互式输入,可以直接对原图进行人像高精度(发丝级别)提取。SHM核心算法采用了两个模型:TNet和MNet。其中,TNet就是一个PSPNet语义分割模型用来生成人像Trimap,而MNet则借鉴了Deep Image Matting (DIM)的思路采用了编解码结构网络进行精细抠图。该论文详细讲解了整个训练的实现细节,整体来说,SHM算法更加的工程化,综合性更强。因此,如果能够完成SHM的论文复现,那么相信读者在语义分割和抠图领域能够有更加深入的认识。

下面将以MNet模型为主线详细讲解实际的抠图处理流程。

二.  MNet算法和Pytorch实现

1. MNet算法原理

MNet算法主要参考DIM,其实现方式与DIM基本相同,都采用编解码结构进行精细抠图。不同之处在于DIM算法的输入是原图+Trimap图,而MNet算法的输入是原图+前景掩码+背景掩码+不确定区域掩码,本质上是一样的,因为Trimap就是由前景掩码、背景掩码、不确定区域掩码结合而成。另一个不同之处在于DIM算法在编解码结构后面还增加了几个卷积层用于调整,而MNet算法则摒弃了最后几个卷积层。

MNet模型结构如下图所示:

上图中TNet网络就是一个语义分割网络,用于进行图像的Trimap生成,这里生成的是背景、前景、不确定区域的掩码。可以看到,MNet将背景、前景、不确定区域三部分(3通道数据)和原图的R通道、G通道、B通道共组成6通道数据作为输入,然后通过一个编解码结构的网络来获取精细的alpha。

从实现效果上来总结一下,TNet网络得到的结果是个精准的粗预判结果,“准而粗”。MNet得到的结果是个精细的结果,但是不准,即“不准但细”。很自然的,将两者进行结合就可以得到一个“准而细”的抠图结果。整个网络设计最精彩的部分就是将两个模型连接在一起,可以形成一个端到端的模型,TNet和MNet网络可以在整体的一个训练过程中优势互补,真正的实现1+1>2。当然在实际实现的过程中,为了避免训练不收敛,可以分别对TNet和MNet模型进行预训练,然后再进行端到端的整体训练。

最后再详解一下编解码模型的概念。

分割任务中的编码器encode与解码器decode就像是玩“你来比划我来猜”的双方:比划的人想把看到的东西用一种方式描述出来,猜的人根据比划的人提供的信息猜出答案。其中,“比划的人”叫做编码器,“猜的人”就是解码器。具体来说,编码器的任务是在给定输入图像后,通过神经网络学习得到输入图像的特征图;而解码器则在编码器提供特征图后,逐步实现每个像素的类别标注,也就是分割。通常,分割任务中的编码器结构比较类似,大多来源于用于分类任务的网络结构,比如VGG。这样做有一个好处,就是可以借用在大数据库下训练得到的分类网络的权重参数,通过迁移学习实现更好的效果。因此,解码器的不同在很大程度上决定了一个基于编解码结构的分割网络的效果。

在图像分割领域,编解码网络结构的典型代表就是UNet网络,相比其它语义分割模型,使用UNet往往可以得到更加清晰的分割边界。如果进一步加强网络的编解码性能,那么自然的,可以精确提取出发丝级别的alpha抠图通道特征。值得一说的是MNet的前身即DIM算法在 alphamatting.com抠图挑战赛中获得第一名,性能超过了一众传统的KNN等抠图算法,这足以证明编解码网络模型的威力。

2. 环境配置

本文采用Pytorch进行算法建模,Pytorch版本为1.4,cuda版本为10.1,Python版本为3.6.1。详细的环境安装教程请参考另一篇博客:https://blog.****.net/qianbin3200896/article/details/104244538。本文使用Windows10操作系统,4块NVIDIA TITAN RTX显卡进行运算。另外,为了实施观看训练的中间结果,本文使用TenosorboardX这个查看工具。相关介绍和使用请参考博客https://blog.****.net/qianbin3200896/article/details/104181552

3.  数据集

相比语义分割数据集来说,用于抠图的训练数据往往难以获得,因为其抠图精度非常高,需要手工PS逐张精细的操作和挑选。既然如此,那么怎么获得数量众多的高精度抠图数据集用于模型训练呢?DIM算法的作者提出了一种非常实用的方法:人工合成。首先手工标注少量部分样本,得到每个样本的alpha通道图,然后根据每个样本的alpha通道图将前景抠出然后再和其它背景图像合成,由于每张前景图可以和多张背景图合成(DIM算法中每张前景与100张背景合成)从而可以大幅扩充数据集,并且每张训练图像都有现成的高精度alpha通道图。

这里会存在一个问题,这种人工合成的非自然产生的图像是否能够真实的提高模型的抠图精度呢?DIM算法已经为我们证明了这一点。从模型本身角度来分析,MNet算法更多的是将注意力关注于图像的局部,而不是图像的全局语义部分,因此,尽管新合成的图像很多在语义上存在明显的不合理(例如人不大可能出现在水面上等),但是,在局部细节上却可以教会模型如何精细的抠出局部细节。

本文使用DIM提供的数据集来预训练MNet网络。该数据集包含431张前景和对应的alpha图,每张alpha图均为人工PS精细抠图的结果,精度非常高。部分样例如下图所示:

  

  

可以看到DIM数据集的标注精度非常高,不管是发丝还是透明物体均具有非常高的抠图精度。

在实现时首先将DIM数据集在COCO2014上进行合成,每张图像与coco的100张背景图像进行合成,从而生成  431x 100=43100张训练图像,部分合成样例如下图所示:

   

在生成Trimap时可以在alpha图上采用膨胀腐蚀操作获得,这里参照SHM的实现策略,为了加强MNet模型的鲁棒性,采用随机核进行膨胀腐蚀操作,这样就会生成不同宽度未知区域的Trimap图。

DIM数据集需要自行向论文读者发邮件获取,受到协议保护的要求本文仅提供代码。

完整的数据集处理代码如下:

def rand_trimap(mask, smooth=False):
    """
    随机生成trimap
    输入mask:单通道掩码图
    """
    h, w = mask.shape
    #scale_up, scale_down = 0.022, 0.006  
    scale_up, scale_down = 0.03, 0.006 
    dmin = 0        
    emax = 255 - dmin   

    if smooth:
        # 用于腐蚀和膨胀结果的阈值处理
        scale_up, scale_down = 0.02, 0.006
        dmin = 5
        emax = 255 - dmin

        # 高斯模糊平滑
        if h < 1000:
            gau_ker = round(h*0.01)  
            gau_ker = gau_ker if gau_ker % 2 ==1 else gau_ker-1 # 确保核大小是奇数
            if h<500:
                gau_ker = max(3, gau_ker)
            mask = cv2.GaussianBlur(mask, (gau_ker, gau_ker), 0)

    kernel_size_high = max(10, round((h + w) / 2 * scale_up))
    kernel_size_low  = max(1, round((h + w) /2 * scale_down))
    erode_kernel_size  = np.random.randint(kernel_size_low, kernel_size_high)
    dilate_kernel_size = np.random.randint(kernel_size_low, kernel_size_high)

    erode_kernel  = cv2.getStructuringElement(cv2.MORPH_RECT, (erode_kernel_size, erode_kernel_size))
    dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (dilate_kernel_size, dilate_kernel_size))
    eroded_alpha = cv2.erode(mask, erode_kernel)
    dilated_alpha = cv2.dilate(mask, dilate_kernel)

    dilated_alpha = np.where(dilated_alpha > dmin, 255, 0)
    eroded_alpha = np.where(eroded_alpha < emax, 0, 255)

    res = dilated_alpha.copy()
    res[((dilated_alpha == 255) & (eroded_alpha == 0))] = 128

    return res


def genDIM():
    """
    生成标准化的DIM数据集,同时生成JSON文件列表
    """
    # 设置拷贝路径
    fg_folder='E:/deeplearn/DIM/fg' 
    alpha_folder='E:/deeplearn/DIM/alpha'
    bg_folder='E:/deeplearn/SRGAN/data/COCO2014/train2014'
    des_img_folder='./data/DIM/img' 
    des_alpha_folder='./data/DIM/alpha' 
    des_trimap_folder='./data/DIM/trimap' 

    # 检索文件
    fglist = getFileList(fg_folder, [], 'jpg')
    alphalist = getFileList(alpha_folder, [], 'jpg')
    bglist = getFileList(bg_folder, [], 'jpg')

    print('检索到 '+str(len(fglist))+' 个前景图像')
    print('检索到 '+str(len(alphalist))+ ' 个alpha通道图')
    print('检索到 '+str(len(bglist))+ ' 个背景图')

    num_bgs = 100

    # 逐张检查
    index=0
    save_img_list=list()
    save_alpha_list=list()
    save_trimap_list=list()

    bg_iter = iter(bglist)
    for imgpath in fglist:
        # 读取
        img = cv2.imread(imgpath)
        (path, filename) = os.path.split(imgpath)
        alphapath = os.path.join(alpha_folder,filename)
        alpha = cv2.imread(alphapath)
    
        bbox = img.shape
        w = bbox[1]
        h = bbox[0]
        
        bcount = 0 
        for i in range(num_bgs):
            bgpath = next(bg_iter)        
            bg = cv2.imread(bgpath)
            bg_bbox = bg.shape
            bw = bg_bbox[1]
            bh = bg_bbox[0]
            wratio = w / bw
            hratio = h / bh
            ratio = wratio if wratio > hratio else hratio     
            if ratio > 1:        
                dim = (math.ceil(bw*ratio), math.ceil(bh*ratio))
                bg = cv2.resize(bg, dim, interpolation = cv2.INTER_CUBIC)
            
            # 合成
            out = composite(img, bg, alpha, w, h) 
            cv2.imwrite(des_img_folder+('/%d.png' % (index)),out)  

            # 生成alpha
            alpha1 = cv2.cvtColor(alpha, cv2.COLOR_RGB2GRAY)
            cv2.imwrite(des_alpha_folder+('/%d.png' % (index)),alpha1)

            # 生成随机Trimap
            trimap = rand_trimap(alpha1)
            cv2.imwrite(des_trimap_folder+('/%d.png' % (index)),trimap)

            # 记录
            save_img_list.append(des_img_folder+('/%d.png' % (index)))
            save_alpha_list.append(des_alpha_folder+('/%d.png' % (index)))
            save_trimap_list.append(des_trimap_folder+('/%d.png' % (index)))

            index += 1          
            print('当前写入第 %d 张图片' % (index))
            bcount += 1
            if bcount==num_bgs:
                break

    # 写入json文件
    with open('./data/dim_img.json', 'w') as jsonfile1:
        json.dump(save_img_list, jsonfile1)

    with open('./data/dim_alpha.json', 'w') as jsonfile2:
        json.dump(save_alpha_list, jsonfile2)

    with open('./data/dim_trimap.json', 'w') as jsonfile3:
        json.dump(save_trimap_list, jsonfile3)

    print('共写入 %d 张图片' % (index))    

最终生成3个文件夹数据,分别为img、alpha和trimap文件夹,用于存放合成图、alpha通道图和trimap图。

4.  加载数据集

通过自定义的HumanDataset进行数据加载,一次加载3张,分别为合成图、alpha通道图和trimap图:

class HumanDataset(Dataset):
    """
    人像数据集
    """
    def __init__(self, dataname, transforms=None):

        items = []
        img_path = './data/'+ dataname + '_img.json'
        trimap_path = './data/'+ dataname + '_trimap.json'
        alpha_path = './data/'+ dataname + '_alpha.json'

        with open(img_path, 'r') as j:
            imglist = json.load(j)
        with open(trimap_path, 'r') as j:
            trimaplist = json.load(j)
        with open(alpha_path, 'r') as j:
            alphalist = json.load(j)

        for i in range(len(imglist)):
            items.append((imglist[i], trimaplist[i], alphalist[i]))

        self.items = items
        self.transforms = transforms

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index):
        image_name, trimap_name, alpha_name = self.items[index]
        image = cv2.imread(image_name, cv2.IMREAD_COLOR)
        trimap = cv2.imread(trimap_name, cv2.IMREAD_GRAYSCALE)
        alpha = cv2.imread(alpha_name, cv2.IMREAD_GRAYSCALE)

        if self.transforms is not None:
            for transform in self.transforms:
                image, trimap, alpha = transform(image, trimap, alpha)

        return image, trimap, alpha

 

5.  模型构建

模型结构采用编解码形式构建,完整代码如下:

class MNet(nn.Module):
    """
    人像精细抠图模型
    """
    def __init__(self):
        super(MNet, self).__init__()
        # 编码
        # stage-1
        self.conv_1_1 = nn.Sequential(nn.Conv2d(6, 64, 3, 1, 1, bias=True), nn.BatchNorm2d(64), nn.ReLU())
        self.conv_1_2 = nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1, bias=True), nn.BatchNorm2d(64), nn.ReLU())
        self.max_pooling_1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)

        # stage-2
        self.conv_2_1 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1, bias=True), nn.BatchNorm2d(128), nn.ReLU())
        self.conv_2_2 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1, bias=True), nn.BatchNorm2d(128), nn.ReLU())
        self.max_pooling_2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)

        # stage-3
        self.conv_3_1 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1, bias=True), nn.BatchNorm2d(256), nn.ReLU())
        self.conv_3_2 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1, bias=True), nn.BatchNorm2d(256), nn.ReLU())
        self.conv_3_3 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1, bias=True), nn.BatchNorm2d(256), nn.ReLU())
        self.max_pooling_3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)

        # stage-4
        self.conv_4_1 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())
        self.conv_4_2 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())
        self.conv_4_3 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())
        self.max_pooling_4 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)

        # stage-5
        self.conv_5_1 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())
        self.conv_5_2 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())
        self.conv_5_3 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())

        # 解码
        # stage-5
        self.deconv_5 = nn.Sequential(nn.Conv2d(512, 512, 5, 1, 2, bias=True), nn.BatchNorm2d(512), nn.ReLU())

        # stage-4
        self.up_pool_4 = nn.MaxUnpool2d(2, stride=2)
        self.deconv_4 = nn.Sequential(nn.Conv2d(512, 256, 5, 1, 2, bias=True), nn.BatchNorm2d(256), nn.ReLU())

        # stage-3
        self.up_pool_3 = nn.MaxUnpool2d(2, stride=2)
        self.deconv_3 = nn.Sequential(nn.Conv2d(256, 128, 5, 1, 2, bias=True), nn.BatchNorm2d(128), nn.ReLU())

        # stage-2
        self.up_pool_2 = nn.MaxUnpool2d(2, stride=2)
        self.deconv_2 = nn.Sequential(nn.Conv2d(128, 64, 5, 1, 2, bias=True), nn.BatchNorm2d(64), nn.ReLU())

        # stage-1
        self.up_pool_1 = nn.MaxUnpool2d(2, stride=2)
        self.deconv_1 = nn.Sequential(nn.Conv2d(64, 64, 5, 1, 2, bias=True), nn.BatchNorm2d(64), nn.ReLU())

        # stage-0
        self.conv_0 = nn.Conv2d(64, 1, 5, 1, 2, bias=True)

    def forward(self, input):
        # encoder
        x11 = self.conv_1_1(input)
        x12 = self.conv_1_2(x11)
        x1p, id1 = self.max_pooling_1(x12)

        x21 = self.conv_2_1(x1p)
        x22 = self.conv_2_2(x21)
        x2p, id2 = self.max_pooling_2(x22)

        x31 = self.conv_3_1(x2p)
        x32 = self.conv_3_2(x31)
        x33 = self.conv_3_3(x32)
        x3p, id3 = self.max_pooling_3(x33)

        x41 = self.conv_4_1(x3p)
        x42 = self.conv_4_2(x41)
        x43 = self.conv_4_3(x42)
        x4p, id4 = self.max_pooling_4(x43)

        x51 = self.conv_5_1(x4p)
        x52 = self.conv_5_2(x51)
        x53 = self.conv_5_3(x52)

        # decoder
        x5d = self.deconv_5(x53)

        x4u = self.up_pool_4(x5d, id4)
        x4d = self.deconv_4(x4u)

        x3u = self.up_pool_3(x4d, id3)
        x3d = self.deconv_3(x3u)

        x2u = self.up_pool_2(x3d, id2)
        x2d = self.deconv_2(x2u)

        x1u = self.up_pool_1(x2d, id1)
        x1d = self.deconv_1(x1u)

        raw_alpha = self.conv_0(x1d)
        return raw_alpha

 

6.  训练

整个训练使用4个NVIDIA TITAN RTX进行训练,训练耗时24小时左右,训练代码如下:

import torch.backends.cudnn as cudnn
import torch
from torch import nn
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from models import MNet
from datasets import HumanDataset,RandomPatch,TrimapToCategorical,Normalize,NumpyToTensor
from utils import *
from loss import PredictionL1Loss


# 数据集参数
data_folder = './data/'   # 数据存放路径
dataname = 'dim'      # 数据集名称

# 学习参数
checkpoint = None     # 预训练模型路径,如果不存在则为None
batch_size = 20        # 批大小
start_epoch = 1       # 轮数起始位置
epochs = 100           # 迭代轮数
workers = 4           # 工作线程数
lr = 0.0001           # 学习率             
weight_decay = 0.0005 # 权重延迟

# 设备参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ngpu = 2           # 用来运行的gpu数量

cudnn.benchmark = True # 对卷积进行加速

writer = SummaryWriter() # 实时监控     使用命令 tensorboard --logdir runs  进行查看

def main():
    """
    训练.
    """
    global checkpoint,start_epoch,writer

    # 初始化
    model = MNet()
    # 初始化优化器
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                    lr=lr, betas=(0.9, 0.999),
                                    weight_decay=weight_decay)

    # 迁移至默认设备进行训练
    model = model.to(device)
    criterion = PredictionL1Loss()
    criterion.to(device)

    # 加载预训练模型
    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    if torch.cuda.is_available() and ngpu > 1:
        model = nn.DataParallel(model, device_ids=list(range(ngpu)))

    # 定制化的dataloaders
    transforms = [
                RandomPatch(320),
                Normalize(),
                TrimapToCategorical(),
                NumpyToTensor()
            ]
    train_dataset = HumanDataset(dataname,transforms)
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 

    # 开始逐轮训练
    preloss = 10000000
    for epoch in range(start_epoch, epochs+1):
        # if epoch == 30:  # 适当降低学习率
        #     adjust_learning_rate(optimizer, 0.1)

        model.train()  # 训练模式:允许使用批样本归一化

        loss_epoch = AverageMeter()  # 统计损失函数

        n_iter = len(train_loader)

        # 按批处理
        for i, (imgs, trimaps_gt, alphas_gt) in enumerate(train_loader):

            # 数据移至默认设备进行训练
            imgs = imgs.to(device)  
            alphas_gt = alphas_gt.to(device) 
            trimaps_gt = trimaps_gt.float().to(device) 
            input = torch.cat([imgs, trimaps_gt], dim=1)  

            # 前向传播
            alphas_pre = model(input)
            
            # 计算损失
            loss, loss_alpha, loss_comps = criterion(imgs, alphas_pre, alphas_gt)

            # 后向传播
            optimizer.zero_grad()
            loss.backward()

            # 更新模型
            optimizer.step()

            # 记录损失值
            loss_epoch.update(loss.item(), imgs.size(0))

            # 监控图像变化
            if i == n_iter-2:
                alphas_pre_temp = alphas_pre[:4,:3:,:,:]                
                writer.add_image('MNet/epoch_'+str(epoch)+'_1', make_grid(imgs[:4,:,:,:].cpu(), nrow=4,normalize=True),epoch)
                writer.add_image('MNet/epoch_'+str(epoch)+'_2', make_grid(alphas_pre_temp.cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('MNet/epoch_'+str(epoch)+'_3', make_grid(alphas_gt[:4,:,:,:].float().cpu(), nrow=4, normalize=True),epoch)

            # 打印结果
            print("第 "+str(i)+ " 个batch训练结束")
 
        # 手动释放内存              
        del imgs,trimaps_gt, alphas_gt, alphas_pre_temp

        # 监控损失值变化
        writer.add_scalar('PreTrainMNet/Loss', loss_epoch.val, epoch)    

        # 保存预训练模型
        if loss_epoch.val < preloss:
            preloss = loss_epoch.val
            torch.save({
                'epoch': epoch,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict()
            }, 'results/checkpoint_mnet.pth')
    
    # 训练结束关闭监控
    writer.close()


if __name__ == '__main__':
    main()

训练曲线如下图所示:

训练部分结果如下所示:

上图中中间一行为预测结果,最后一行为真值(GT),可以看到,整体的训练效果还是较好的,人像发丝部分也已经达到较好的处理效果,另外,从第三张透明玻璃杯子的效果看上去,已经达到非常高的抠图精度。

下图是接近训练结束时的效果:

 

三. 总结

本文实现了抠图算法DIM,整体效果尚佳。如果需要继续提高模型精度,可以扩大训练集规模和多样性,提升算法鲁棒性和精度。当然,读者也可以阅读最新的文献,尝试新的算法来提升性能。

阿里巴巴团队提出的Semantic Human Matting是一篇非常经典且实用的端到端人像抠图方法,其效果相比于其它算法性能更加突出,其中用于精细抠图的模型就基于DIM算法。推荐读者可以精读此论文并尝试加以实现。本文重点实现SHM算法中的MNet模型,具体操作步骤可以按照本文脉络实现。

最后说明一下,本文所用的数据来源于DIM,该数据集需要联系作者获取。考虑到数据保护协议,请各位读者自行向论文作者获取。

参考文献

【1】Long J, Shelhamer E, Darrell T, et al. Fully convolutional networks for semantic segmentation[C]. computer vision and pattern recognition, 2014: 3431-3440.

【2】Olaf Ronneberger, Philipp Fischer, Thomas Brox. U-Net: Convolutional Networks for Biomedical Image Segmentation[C]// International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer International Publishing, 2015.

【3】Iglovikov, Vladimir, Shvets, Alexey. TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation[C]// International Conference on Compute Vision and Pattern Recognition. 2018.

【4】Zhao H, Shi J, Qi X, et al. Pyramid Scene Parsing Network[C]. computer vision and pattern recognition, 2017: 6230-6239.

【5】Chen Q, Ge T, Xu Y, et al. Semantic Human Matting[C]. acm multimedia, 2018: 618-626.

【6】Xu N, Price B, Cohen S, et al. Deep Image Matting[C]. computer vision and pattern recognition, 2017: 311-320.

【7】Shen X, Tao X, Gao H, et al. Deep Automatic Portrait Matting[C]. european conference on computer vision, 2016: 92-107.