欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

MATLAB 代码分析:使用 DCGAN 实现图像数据的生成 - 经典代码:使用 DCGAN 生成花朵

最编程 2024-10-15 16:36:05
...

MATLAB官方其实给出了DCGAN生成花朵的示范代码,原文地址:训练生成对抗网络 (GAN) - MATLAB & Simulink - MathWorks 中国

先看看训练效果

训练1周期

训练11周期

训练56个周期

脚本文件 

为了能让各位更好的复现,该代码已打包,下载后解压运行用MATLAB运行"gan.mlx"即可
链接: https://pan.baidu.com/s/1hNYLw1xku2AdKf5CanoFzA?pwd=fb7n 提取码: fb7n 
 

代码详解:

首先是脚本gan:

数据获取
clear all
clc
imageFolder = fullfile("flower_photos");
imds = imageDatastore(imageFolder,IncludeSubfolders=true);
augmenter = imageDataAugmenter(RandXReflection=true);
augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);
生成器
filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];%

layersGenerator = [
    featureInputLayer(numLatentInputs)
    projectAndReshapeLayer(projectionSize)
    transposedConv2dLayer(filterSize,4*numFilters)
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")
    tanhLayer];
netG = dlnetwork(layersGenerator);
判别器
dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,Normalization="none")
    dropoutLayer(dropoutProb)
    convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(4,1)
    sigmoidLayer];
netD = dlnetwork(layersDiscriminator);
指定训练选项
numEpochs = 500;
miniBatchSize = 128;
learnRate = 0.00008;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
flipProb = 0.35;
validationFrequency = 100;
训练模型
augimds.MiniBatchSize = miniBatchSize;

mbq = minibatchqueue(augimds, ...
    MiniBatchSize=miniBatchSize, ...
    PartialMiniBatch="discard", ...
    MiniBatchFcn=@preprocessMiniBatch, ...
    MiniBatchFormat="SSCB");
trailingAvgG = [];
trailingAvgSqG = [];
trailingAvg = [];
trailingAvgSqD = [];
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,"single");
ZValidation = dlarray(ZValidation,"CB");
if canUseGPU
    ZValidation = gpuArray(ZValidation);
end

f = figure;
f.Position(3) = 2*f.Position(3);

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

C = colororder;
lineScoreG = animatedline(scoreAxes,Color=C(1,:));
lineScoreD = animatedline(scoreAxes,Color=C(2,:));
legend("Generator","Discriminator");
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Reset and shuffle datastore.
    shuffle(mbq);

    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;

        % Read mini-batch of data.
        X = next(mbq);

        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the format "CB" (channel, batch). If a GPU is
        % available, then convert latent inputs to gpuArray.
        Z = randn(numLatentInputs,miniBatchSize,"single");
        Z = dlarray(Z,"CB");

        if canUseGPU
            Z = gpuArray(Z);
        end

        % Evaluate the gradients of the loss with respect to the learnable
        % parameters, the generator state, and the network scores using
        % dlfeval and the modelLoss function.
        [L,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
            dlfeval(@modelLoss,netG,netD,X,Z,flipProb);
        netG.State = stateG;

        %%show data
        %"epoch"
        %epoch
        %"scoreG-D"
        %[scoreG,scoreD]

        % Update the discriminator network parameters.
        [netD,trailingAvg,trailingAvgSqD] = adamupdate(netD, gradientsD, ...
            trailingAvg, trailingAvgSqD, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);

        % Update the generator network parameters.
        [netG,trailingAvgG,trailingAvgSqG] = adamupdate(netG, gradientsG, ...
            trailingAvgG, trailingAvgSqG, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        % Every validationFrequency iterations, display batch of generated
        % images using the held-out generator input.
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            % Generate images using the held-out generator input.
            XGeneratedValidation = predict(netG,ZValidation);

            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(XGeneratedValidation));
            I = rescale(I);

            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end

        % Update the scores plot.
        subplot(1,2,2)
        scoreG = double(extractdata(scoreG));
        addpoints(lineScoreG,iteration,scoreG);

        scoreD = double(extractdata(scoreD));
        addpoints(lineScoreD,iteration,scoreD);

        % Update the title with training progress information.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))

        drawnow
    end
end
生成新图像  
numObservations = 4;
ZNew = randn(numLatentInputs,numObservations,"single");
ZNew = dlarray(ZNew,"CB");
if canUseGPU
    ZNew = gpuArray(ZNew);
end

XGeneratedNew = predict(netG,ZNew);

I = imtile(extractdata(XGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

生成器与判别器的设计

脚本gan中以及包含了生成器Generator和判别器Discriminator的结构设计,生成器利用装置卷积对特征进行上采样,最终生成了64*64*3的图像,而判别器则用卷积进行下采样,将输入提取至1*1的格式大小,利用sigmoid作为激活函数,判断输入图像的真假

如何自定义生成对抗网络?很简单,把握上采样和下采样的规模就行,利用MATLAB的DLtool(deep network designer)可以很好的观察到这一点,以刚刚的生成器为例,我们可以观察到,转置卷积后(步幅为2),输出的空间(S)长宽都翻倍,深度对应我们给定的filters数量,因此,我们想要生成特定大小的数据时,修改转置卷积的步幅、卷积核数量以及转置卷积层的数量就行,同时记得在添加的转置卷积层后连接新的BN层和ReLU激活函数。

比如我想生成128*128*3的图片,我只需要将刚刚示例中的其中一个转置卷积核的大小提高至7*7,同时步幅修改成4。或者,我直接添加一层步幅为2的转置卷积层。对于一些数据尺寸为非2倍数问题,如311*171*3,我们可以先生成312*172*3再resize一下,或者你提前将数据预处理成312*172.

注意:定义完网络结构后,要用dlnetwork()函数将layer参数转变成可训练的dlnetwork。

最近比较忙,先在这里停笔了,后面再慢慢补充-24-10-14

数据预处理

自定义模型训练

损失函数与梯度下降

优化器与参数更新

推荐阅读