主要内容

训练变分自动编码器(VAE)生成图像

这个例子展示了如何在MATLAB中创建一个可变自动编码器(VAE)来生成数字图像。VAE以MNIST数据集的风格生成手绘数字。

vie与常规自动编码器的不同之处在于,它们不使用编码-解码过程来重构输入。相反,他们在潜在空间上施加一个概率分布,并学习这个分布,以便解码器输出的分布与观测数据的分布相匹配。然后,他们从这个分布中抽取样本,生成新的数据。

在本例中,构建一个VAE网络,在MNIST数据集中对其进行训练,并生成与数据集中的图像非常相似的新图像。

加载数据

下载MNIST文件http://yann.lecun.com/exdb/mnist/并将MNIST数据集加载到工作空间[1]中。调用processimagesmnist.processLabelsMNIST附加到本示例中的辅助函数将文件中的数据加载到MATLAB数组中。

由于VAE将重建的数字与输入进行比较,而不是与分类标签进行比较,因此不需要在MNIST数据集中使用训练标签。

trainImagesFile =“train-images-idx3-ubyte.gz”;testImagesFile =“t10k-images-idx3-ubyte.gz”;testLabelsFile =“t10k-labels-idx1-ubyte.gz”;xtrain = processimagesmnist(TrainimagesFile);
读取MNIST图像数据…数据集中的图像数量:60000…
numTrainImages =大小(XTrain 4);XTest = processImagesMNIST (testImagesFile);
读取MNIST图像数据…数据集中的图像数量:10000…
欧美= processLabelsMNIST (testLabelsFile);
读取MNIST标签数据…数据集中的标签数:10000…

构建网络

自动编码器有两部分:编码器和解码器。编码器接收图像输入并输出压缩表示(编码),这是一个大小向量latentDim,在本例中等于20。解码器获取压缩的表示,解码它,并重新生成原始图像。

为了使计算在数值上更加稳定,通过使网络从方差的对数中学习,将可能值的范围从[0,1]增加到[-inf, 0]。定义两个大小的向量latent_dim:一为方法 μ 另一个是方差的对数 日志 σ 2 .然后用这两个向量来创建样本的分布。

使用二维卷积,然后使用一个完全连接的层从28 × 28 × 1 MNIST图像向下采样到潜在空间中的编码。然后,使用转置的2d卷积将1 × 1 × 20的编码放大到28 × 28 × 1的图像。

latentDim = 20;imageSize = [28 28 1];encoderLG = layerGraph([imageInputLayer(imageSize,“名字”'input_encoder'“归一化”'没有任何'32岁的)convolution2dLayer (3“填充”“相同”“步”2,“名字”“conv1”) reluLayer (“名字”“relu1”64年)convolution2dLayer(3日,“填充”“相同”“步”2,“名字”'conv2') reluLayer (“名字”“relu2”(2 * latentDim,“名字”“fc_encoder”)));decoderLG = layerGraph([imageInputLayer([1 1 latentDim]),“名字”“我”“归一化”'没有任何'64年)transposedConv2dLayer(7日,“种植”“相同”“步”7“名字”“transpose1”) reluLayer (“名字”“relu1”64年)transposedConv2dLayer(3日,“种植”“相同”“步”2,“名字”“transpose2”) reluLayer (“名字”“relu2”32岁的)transposedConv2dLayer (3“种植”“相同”“步”2,“名字”“transpose3”) reluLayer (“名字”“relu3”) transposedConv2dLayer(3、1“种植”“相同”“名字”'transposs4')));

要用自定义的训练循环来训练两个网络并使其能够自动区分,请将层图转换为dlnetwork对象。

encoderNet = dlnetwork (encoderLG);decoderNet = dlnetwork (decoderLG);

定义模型梯度函数

辅助函数modelGradients接收编码器和解码器dlnetwork对象和一小批输入数据X,并返回损失相对于网络中可学习参数的梯度。此辅助功能在此示例的末尾定义。

该函数分两步执行此过程:抽样和损失.采样步骤对均值和方差向量进行采样,以创建最终的编码并传递给解码器网络。但是,由于不可能通过随机抽样操作进行反向传播,因此必须使用reparameterization技巧.这个技巧将随机抽样操作移动到一个辅助变量 ε ,然后通过均值移位 μ 然后按标准差缩放 σ .这个想法是从 N μ σ 2 和抽样一样吗 μ + ε σ ,在那里 ε N 0 1 .下图生动地描述了这个想法。

损失步骤将采样步骤生成的编码通过解码器网络,确定损失,然后使用损失来计算梯度。ve中的损失,也称为证据下界(ELBO)损失,定义为两个单独的损失项之和:

ELBO 损失 重建 损失 + 吉隆坡 损失

重建的损失使用均方误差(MSE)测量解码器输出与原始输入的距离:

重建 损失 均方误差 译码器 输出 原始 图像

KL损失,即Kullback-Leibler散度,用来衡量两种概率分布之间的差异。在这种情况下,最小化KL损失意味着确保学习的均值和方差尽可能接近目标(正态)分布。一个潜在的尺寸 n ,则得到KL损失为

吉隆坡 损失 - 0 5 1 n 1 + 日志 σ 2 - μ 2 - σ 2

包含KL损失项的实际效果是将由于重构损失而学到的聚类紧密地围绕在潜在空间的中心,形成一个连续的样本空间。

指定培训选项

在可用的GPU上进行训练(需要并行计算工具箱™)。

executionEnvironment =“汽车”

设置网络的训练选项。当使用Adam优化器时,您需要初始化每个网络的拖尾平均梯度和空数组的拖尾平均梯度平方衰减率

numEpochs = 50;miniBatchSize = 512;lr = 1 e - 3;numIterations =地板(numTrainImages / miniBatchSize);迭代= 0;avgGradientsEncoder = [];avgGradientsSquaredEncoder = [];avgGradientsDecoder = [];avgGradientsSquaredDecoder = [];

火车模型

使用自定义训练循环训练模型。

对于epoch中的每个迭代:

  • 从训练集获取下一个迷你批处理。

  • 将迷你批处理转换为dlarray.对象,确保指定维度标签“SSCB”(spatial, spatial, channel, batch)。

  • 对于GPU训练,转换dlarray.到一个gpuArray对象。

  • 方法评估模型梯度dlfevalmodelGradients功能。

  • 更新网络可学习性和两个网络的平均梯度,使用adamupdate函数。

在每个epoch结束时,将测试集图像通过自动编码器,并显示该epoch的损耗和训练时间。

epoch = 1:numEpochs tic;i = 1:numIterations迭代=迭代+ 1;idx =(张)* miniBatchSize + 1:我* miniBatchSize;XBatch = XTrain (:,:,:, idx);XBatch = dlarray(单(XBatch),“SSCB”);如果(executionEnvironment = =“汽车”&& canUseGPU) || executionEnvironment ==“图形”XBatch = gpuArray (XBatch);结尾[infGrad, genGrad] = dlfeval(...@modelGradients, encoderNet, decoderNet, XBatch);[decoderNet。可学的,avgGradientsDecoder, avgGradientsSquaredDecoder] =...adamupdate (decoderNet。可学的,...genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr);[encoderNet。可学的,avgGradientsEncoder, avgGradientsSquaredEncoder] =...adamupdate (encoderNet。可学的,...infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr);结尾elapsedTime = toc;[z, zMean, zLogvar] =采样(encoderNet, XTest);xPred = sigmoid(forward(decoderNet, z));elbo = ELBOloss(XTest, xPred, zMean, zLogvar);disp (”时代:“+时代+"测试ELBO损耗= "+收集(extractdata (elbo)) +...".纪元所花费的时间= "+ elapsedTime +“s”结尾
时代:1试验eLbo损失= 28.0145。纪元= 28.0573S时代= 28.0573S时代:2试验eLBO损失= 24.8995。纪元= 8.797S时期所花费的时间:3试验eLBO损失= 23.2756。为期纪念时间= 8.8824S时代:4试验eLBO损失= 21.151。纪元= 8.5979S时期所花费的时间:5试验eLbo损失= 20.5335。纪元= 8.8472S时期的时间:6试验eLBO损失= 20.232。为时代占地时间= 8.5068S时代:7试验eLBO损失= 19.9988。时代纪元= 8.4356S时代:8试验eLBO损失= 19.8955。纪元= 8.4015S时代拍摄时间:9试验eLbo损失= 19.7991。纪元= 8.8089s时代= 8.8089s时代:10试验eLbo损失= 19.6773。 Time taken for epoch = 8.4269s Epoch : 11 Test ELBO loss = 19.5181. Time taken for epoch = 8.5771s Epoch : 12 Test ELBO loss = 19.4532. Time taken for epoch = 8.4227s Epoch : 13 Test ELBO loss = 19.3771. Time taken for epoch = 8.5807s Epoch : 14 Test ELBO loss = 19.2893. Time taken for epoch = 8.574s Epoch : 15 Test ELBO loss = 19.1641. Time taken for epoch = 8.6434s Epoch : 16 Test ELBO loss = 19.2175. Time taken for epoch = 8.8641s Epoch : 17 Test ELBO loss = 19.158. Time taken for epoch = 9.1083s Epoch : 18 Test ELBO loss = 19.085. Time taken for epoch = 8.6674s Epoch : 19 Test ELBO loss = 19.1169. Time taken for epoch = 8.6357s Epoch : 20 Test ELBO loss = 19.0791. Time taken for epoch = 8.5512s Epoch : 21 Test ELBO loss = 19.0395. Time taken for epoch = 8.4674s Epoch : 22 Test ELBO loss = 18.9556. Time taken for epoch = 8.3943s Epoch : 23 Test ELBO loss = 18.9469. Time taken for epoch = 10.2924s Epoch : 24 Test ELBO loss = 18.924. Time taken for epoch = 9.8302s Epoch : 25 Test ELBO loss = 18.9124. Time taken for epoch = 9.9603s Epoch : 26 Test ELBO loss = 18.9595. Time taken for epoch = 10.9887s Epoch : 27 Test ELBO loss = 18.9256. Time taken for epoch = 10.1402s Epoch : 28 Test ELBO loss = 18.8708. Time taken for epoch = 9.9109s Epoch : 29 Test ELBO loss = 18.8602. Time taken for epoch = 10.3075s Epoch : 30 Test ELBO loss = 18.8563. Time taken for epoch = 10.474s Epoch : 31 Test ELBO loss = 18.8127. Time taken for epoch = 9.8779s Epoch : 32 Test ELBO loss = 18.7989. Time taken for epoch = 9.6963s Epoch : 33 Test ELBO loss = 18.8. Time taken for epoch = 9.8848s Epoch : 34 Test ELBO loss = 18.8095. Time taken for epoch = 10.3168s Epoch : 35 Test ELBO loss = 18.7601. Time taken for epoch = 10.8058s Epoch : 36 Test ELBO loss = 18.7469. Time taken for epoch = 9.9365s Epoch : 37 Test ELBO loss = 18.7049. Time taken for epoch = 10.0343s Epoch : 38 Test ELBO loss = 18.7084. Time taken for epoch = 10.3214s Epoch : 39 Test ELBO loss = 18.6858. Time taken for epoch = 10.3985s Epoch : 40 Test ELBO loss = 18.7284. Time taken for epoch = 10.9685s Epoch : 41 Test ELBO loss = 18.6574. Time taken for epoch = 10.5241s Epoch : 42 Test ELBO loss = 18.6388. Time taken for epoch = 10.2392s Epoch : 43 Test ELBO loss = 18.7133. Time taken for epoch = 9.8177s Epoch : 44 Test ELBO loss = 18.6846. Time taken for epoch = 9.6858s Epoch : 45 Test ELBO loss = 18.6001. Time taken for epoch = 9.5588s Epoch : 46 Test ELBO loss = 18.5897. Time taken for epoch = 10.4554s Epoch : 47 Test ELBO loss = 18.6184. Time taken for epoch = 10.0317s Epoch : 48 Test ELBO loss = 18.6389. Time taken for epoch = 10.311s Epoch : 49 Test ELBO loss = 18.5918. Time taken for epoch = 10.4506s Epoch : 50 Test ELBO loss = 18.5081. Time taken for epoch = 9.9671s

可视化的结果

要可视化和解释结果,请使用助手可视化功能.这些辅助函数在本例的最后定义。

VisualizeReconstruction函数显示从每个类中随机选择的数字,并在通过自动编码器后显示其重构。

VisualizeLatentSpace函数取测试图像通过编码器网络后生成的均值和方差编码(各为20维),并对每幅图像的编码矩阵进行主成分分析(PCA)。然后,您可以可视化由两个第一个主成分表征的两个维度中的均值和方差定义的潜在空间。

生成函数初始化从正态分布中采样的新编码,并输出这些编码通过解码器网络时生成的图像。

visualizerconstruction (XTest, YTest, encoderNet, decoderNet)

visualizeLatentSpace (XTest欧美encoderNet)

生成(decoderNet latentDim)

下一步

变分自动编码器只是用于执行生成任务的众多可用模型之一。它们在图像较小且特征明确的数据集上工作得很好(如MNIST)。对于包含较大图像的更复杂的数据集,生成式对抗网络(gan)往往表现得更好,生成的图像噪声更小。有关如何实现gan生成64 × 64 RGB图像的示例,请参见训练生成对抗网络(GAN)

参考文献

  1. Lecun,Y.,C. Cortes,以及C. J. C.博览会。“手写数字的Mnist数据库。”http://yann.lecun.com/exdb/mnist/

辅助函数

模型梯度函数

modelGradients函数采用编码器和解码器dlnetwork对象和一小批输入数据X,并返回损失相对于网络中可学习参数的梯度。该函数执行三种操作:

  1. 的方法获取编码抽样功能上的小批图像,通过编码器网络。

  2. 通过将编码通过解码器网络并调用ELBOloss函数。

  3. 通过调用,计算损失相对于两个网络的可学习参数的梯度dlgradient函数。

函数[Infgrad,Gen​​grad] = MapeStriants(Encodernet,Decodernet,x)[z,zmean,zlogvar] =采样(eNcodernet,x);xPred = sigmoid(forward(decoderNet, z));丢失= elboloss(x,xpred,zmean,zlogvar);[GENGRAD,INFGRAD] = DLGRADIENT(丢失,DECODERNET.LEARNABLE,...encoderNet.Learnables);结尾

采样和损失功能

抽样函数从输入图像获取编码。最初,它通过编码器网络传递一小批图像,并分割大小输出(2 * latentDim) * miniBatchSize变成一个均值矩阵和一个方差矩阵,每个矩阵的大小latentDim * batchSize.然后,它使用这些矩阵来实现重新参数化技巧和计算编码。最后,它将此编码转换为dlarray.对象的SSCB格式。

函数[zSampled, zMean, zLogvar] = sampling(encoderNet, x)d =大小(压缩,1)/ 2;zMean =压缩(1:d:);zLogvar =压缩(1 + d:最终,);深圳=大小(zMean);ε= randn(深圳);σ= exp(。5* zLogvar); z = epsilon .* sigma + zMean; z = reshape(z, [1,1,sz]); zSampled = dlarray(z,“SSCB”);结尾

ELBOloss函数采用手段的编码和返回的差异抽样函数,并利用它们计算ELBO损失。

函数elbo = ELBOloss(x, xPred, zMean, zLogvar) square = 0.5*(xPred-x).^2;reconstructionLoss = sum(平方和,[1,2,3]);KL = -。5* sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1); elbo = mean(reconstructionLoss + KL);结尾

可视化功能

VisualizeReconstruction函数为MNIST数据集的每一位数字随机选择两幅图像,将它们通过VAE,并将重建图像与原始输入图像并排绘制。请注意,要绘制包含在dlarray.对象,您需要首先使用extractdata收集功能。

函数visualizerconstruction (XTest,YTest, encoderNet, decoderNet) f = figure;图(f)标题("地面真实图像与重建图像的对比"i = 1:2c=0:9 idx = iRandomIdxOfClass(YTest,c);X = XTest (:,:,:, idx);[z, ~, ~] =采样(encoderNet, X);XPred = sigmoid(forward(decoderNet, z));X =收集(extractdata (X));XPred =收集(extractdata (XPred));比较= [X, ones(size(X,1),1), XPred];次要情节(4、5、(张)* 10 + c + 1), imshow(比较,[]),结尾结尾结尾函数idx = iRandomIdxOfClass(T,c);idx =找到(idx);idx = idx (randi(元素个数(idx), 1));结尾

VisualizeLatentSpace函数可视化由形成编码器网络的输出的平均值和方差矩阵定义的潜像,并定位由每个数字的潜在空间表示形成的群集。

函数首先从中提取均值和方差矩阵dlarray.对象。由于不能用通道/批处理维度(C和B)置换矩阵,因此函数调用剥离丁片在转置矩阵之前。然后,对两个矩阵进行主成分分析(PCA)。为了使潜在空间在二维空间中形象化,该函数保留了前两个主分量,并将它们绘制在一起。最后,该函数为数字类着色,以便您可以观察集群。

函数visualizeLatentSpace(XTest, YTest, encoderNet) [~, zMean, zLogvar] =采样(encoderNet, XTest);zMean = stripdims (zMean)”;zMean =收集(extractdata (zMean));zLogvar = stripdims (zLogvar)”;zLogvar =收集(extractdata (zLogvar));[~, scoreMean] = pca (zMean);[~, scoreLogvar] = pca (zLogvar);c = parula (10);f1 =图;图(f1)标题(“潜在的空间”) ah = subplot(1,2,1);散射(scoreMean (:, 2), scoreMean (: 1), [], c(双重(欧美):));啊。YDir =“反向”;轴平等的包含(“Z_m_u(2)”) ylabel (“Z_m_u(1)”) cb = colorbar;cb。蜱虫= 0 (1/9):1;cb。TickLabels =字符串(0:9);啊=情节(1、2、2);散射(scoreLogvar (:, 2), scoreLogvar (: 1), [], c(双重(欧美):));啊。YDir =“反向”;包含(“Z_v_a_r(2)”) ylabel (“Z_v_a_r(1)”) cb = colorbar;cb。蜱虫= 0 (1/9):1;cb。TickLabels =字符串(0:9);轴平等的结尾

生成功能测试VAE的生成能力。它初始化一个dlarray.包含25个随机生成的编码的对象通过解码网络传递它们,并绘制输出。

函数randomNoise = dlarray(randn(1,1,latentDim,25),“SSCB”);generatedImage = sigmoid(predict(decoderNet, randomNoise)));generatedImage = extractdata (generatedImage);f3 =图;图(f3) imshow (imtile (generatedImage“ThumbnailSize”,[100,100]))标题(“生成的数字样本”) drawnow结尾

另请参阅

||||||

相关的话题