主要内容

训练一个暹罗网络来比较图像

这个例子展示了如何训练暹罗网络来识别手写字符的相似图像。

Siamese网络是一种深度学习网络,它使用两个或多个具有相同架构并共享相同参数和权重的相同子网。暹罗网络通常用于寻找两个可比较事物之间的关系的任务。Siamese网络的一些常见应用包括面部识别、签名验证[1]或释义识别[2]。Siamese网络在这些任务中表现良好,因为它们的共享权重意味着在训练过程中需要学习的参数更少,并且它们可以用相对较少的训练数据产生良好的结果。

Siamese网络在有大量类而每个类的观察值很少的情况下特别有用。在这种情况下,没有足够的数据来训练深度卷积神经网络将图像分类到这些类别中。相反,Siamese网络可以确定两个图像是否属于同一类。

这个例子使用Omniglot数据集[3]来训练Siamese网络来比较手写字符[4]的图像。Omniglot数据集包含50个字母的字符集,其中30个用于训练,20个用于测试。每个字母表都包含一些字符,从Ojibwe(加拿大土著苏拉语)的14个字符到Tifinagh的55个字符。最后,每个字符都有20个手写的评论。这个例子训练一个网络来识别两个手写的观察是否是同一字符的不同实例。

您还可以使用Siamese网络通过降维来识别相似的图像。使用示例请参见训练用于降维的暹罗网络

加载和预处理训练数据

下载并提取Omniglot训练数据集。

url =“https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip”;downloadFolder = tempdir;filename = fullfile(下载文件夹,“images_background.zip”);dataFolderTrain = fullfile(下载文件夹,“images_background”);如果~存在(dataFolderTrain“dir”) disp (“正在下载Omniglot训练数据(4.5 MB)…”) websave(文件名,url);解压缩(文件名,downloadFolder);结束disp (“训练数据已经下载。”
训练数据下载。

方法将训练数据加载为图像数据存储imageDatastore函数。通过从文件名中提取标签并设置标签财产。

imdsTrain = imageDatastore(dataFolderTrain,IncludeSubfolders = true,LabelSource =“没有”);files = imdsTrain.Files;Parts = split(files,filesep);标签= join(parts(:,(end-2):(end-1)),“-”);imdsTrain。标签=categorical(labels);

Omniglot训练数据集由来自30个字母的黑白手写字符组成,每个字符有20个观察值。图像大小为105 × 105 × 1,每个像素的值在01

显示随机选择的图像。

idx = randperm(numel(imdsTrain.Files),8);i = 1: nummel (idx) subplot(4,2,i) imshow(readimage(imdsTrain,idx(i))) title(imdsTrain. labels (idx(i)),Interpreter=“没有”);结束

创建一对相似和不同的图像

为了训练网络,必须将数据分组成相似或不相似的图像对。这里,相似图像是同一字符的不同手写实例,具有相同的标签,而不同字符的不相似图像具有不同的标签。这个函数getSiameseBatch(定义见万博1manbetx支持功能部分)创建相似或不同图像的随机配对,pairImage1pairImage2.该函数还返回标签pairLabel,它可以识别这对图像是相似还是不同。类似的图像对pairLabel = 1,而不同的配对则有pairLabel = 0

作为一个示例,创建一个包含五对图像的小型代表性集

batchSize = 10;[pairImage1,pairImage2,pairLabel] = getSiameseBatch(imdsTrain,batchSize);

显示生成的图像对。

i = 1:batchSize如果pairLabel(i) == 1 s =“相似”;其他的s =“不同”;结束次要情节(2、5、我)imshow ([pairImage1 (::,:, i) pairImage2(::,:,我)]);标题(年代)结束

在本例中,为训练循环的每次迭代创建180个成对图像的新一批。这确保了网络在大量的随机图像对上进行训练,这些图像对的相似和不相似的比例大致相等。

定义网络架构

Siamese网络架构如下图所示。

为了比较两个图像,每个图像都要经过两个相同的子网中的一个,这些子网共享权重。子网将每个105 × 105 × 1的图像转换为4096维特征向量。同一类的图像具有相似的4096维表示。每个子网络的输出特征向量通过减法组合,结果通过a传递fullyconnect单输出操作。一个乙状结肠操作将此值转换为之间的概率01表示网络对图像是否相似或不相似的预测。在训练过程中,利用网络预测与真实标签之间的二值交叉熵损失来更新网络。

在本例中,两个相同的子网定义为adlnetwork对象。最后一个fullyconnect乙状结肠操作是对子网输出进行的功能操作。

将子网创建为一系列层,这些层接受105 × 105 × 1的图像,并输出大小为4096的特征向量。

convolution2dLayer对象,使用窄正态分布初始化权重和偏差。

maxPooling2dLayer对象,将步幅设置为2

为了决赛fullyConnectedLayer对象,指定输出大小为4096,并使用窄正态分布初始化权重和偏差。

图层= [imageInputLayer([105 105 1],正常化=“没有”) convolution2dLayer (64, WeightsInitializer =“narrow-normal”BiasInitializer =“narrow-normal”) reluLayer maxPooling2dLayer(2,Stride=2) convolution2dLayer(7128,WeightsInitializer= .“narrow-normal”BiasInitializer =“narrow-normal”) reluLayer maxPooling2dLayer(2,Stride=2) convolution2dLayer(4,128,WeightsInitializer= .“narrow-normal”BiasInitializer =“narrow-normal”) reluLayer maxPooling2dLayer(2,Stride=2) convolution2dLayer(5256,WeightsInitializer= . 2“narrow-normal”BiasInitializer =“narrow-normal”) reluLayer fullyConnectedLayer(4096,WeightsInitializer= .“narrow-normal”BiasInitializer =“narrow-normal”));lgraph = layerGraph(layers);

为了使用自定义训练环路训练网络并启用自动微分,将层图转换为adlnetwork对象。

Net = dlnetwork(lgraph);

为决赛创建权重fullyconnect操作。通过从标准差为0.01的窄正态分布中随机抽样来初始化权重。

fcWeights = darray (0.01*randn(1,4096));fcBias = darray (0.01*randn(1,1));fcParams = struct(“FcWeights”fcWeights,“FcBias”, fcBias);

要使用网络,创建函数forwardSiamese(定义见万博1manbetx支持功能节),定义了两个子网和减法,fullyconnect,乙状结肠操作是组合的。这个函数forwardSiamese接受包含参数的网络结构fullyconnect操作,以及两个训练图像。的forwardSiamese函数输出关于两幅图像相似性的预测。

定义模型损失函数

创建函数modelLoss(定义见万博1manbetx支持功能节)。的modelLoss函数取暹罗子网,参数结构为fullyconnect操作,以及一小批输入数据X1X2他们的标签pairLabels.该函数返回损失值以及损失相对于网络可学习参数的梯度。

暹罗网络的目标是区分两种输入X1X2.网络的输出是介于01,其中值更接近0表示图像不相似的预测,并且值更接近1图像是相似的。损失由预测分数与真实标签值之间的二值交叉熵给出:

损失 - tlog y - 1 - t 日志 1 - y

真正的标签 t 是0还是1 y 是预测的标签。

指定培训方案

指定培训期间要使用的选项。训练10000次迭代。

numIterations = 10000;miniBatchSize = 180;

指定ADAM优化的选项:

  • 设置学习速率为0.00006

  • 设置梯度衰减因子为0.9梯度衰减因子的平方是0.99

learningRate = 6e-5;gradDecay = 0.9;gradDecaySq = 0.99;

在GPU上训练,如果有的话。使用GPU需要Parallel Computing Toolbox™和支持的GPU设备。万博1manbetx有关支持的设备的信息,请参见万博1manbetxGPU计算要求(并行计算工具箱).若要自动检测是否有可用的GPU,并将相关数据放置在GPU上,请设置的值executionEnvironment“汽车”.如果您没有GPU,或者不想使用GPU进行训练,请设置的值executionEnvironment“cpu”.为确保使用GPU进行训练,请设置executionEnvironment“图形”

executionEnvironment =“汽车”;

火车模型

初始化训练进度图。

图C = colororder;lineLossTrain = animatedline(Color=C(2,:));Ylim ([0 inf]) xlabel(“迭代”) ylabel (“损失”网格)

初始化ADAM求解器的参数。

trailingAvgSubnet = [];trailingAvgSqSubnet = [];trailingAvgParams = [];trailingAvgSqParams = [];

使用自定义训练循环训练模型。循环遍历训练数据并在每次迭代时更新网络参数。

对于每次迭代:

  • 提取一批图像对和标签使用getSiameseBatch节中定义的函数创建批量镜像对

  • 将数据转换为dlarray对象指定维度标签“SSCB”(空间、空间、通道、批处理)为图像数据和“CB”(通道,批次)用于标签。

  • 对于GPU训练,将数据转换为gpuArray对象。

  • 评估模型损失和梯度使用dlfevalmodelLoss函数。

  • 更新网络参数adamupdate函数。

Start = tic;%小批量循环。迭代= 1:numIterations提取小批量图像对和对标签[X1,X2,pairLabels] = getSiameseBatch(imdsTrain,miniBatchSize);将小批量数据转换为数组。指定维度标签%“SSCB”(空间,空间,通道,批处理)用于图像数据X1 = darray (X1,“SSCB”);X2 = darray (X2,“SSCB”);%如果在GPU上训练,则将数据转换为gpuArray。如果(executionEnvironment = =“汽车”&& canUseGPU) || executionEnvironment ==“图形”X1 = gpuArray(X1);X2 = gpuArray(X2);结束使用dlfeval和modelLoss评估模型损失和梯度在示例末尾列出的%函数。[loss,gradientsSubnet,gradientsParams] = dlfeval(@modelLoss,net,fcParams,X1,X2,pairLabels);%更新暹罗子网参数。[net,trailingAvgSubnet,trailingAvgSqSubnet] = adamupdate(net,gradientsSubnet,)trailingAvgSubnet trailingAvgSqSubnet,迭代,learningRate、gradDecay gradDecaySq);%更新fullyconnect参数。[fcParams,trailingAvgParams,trailingAvgSqParams] = adamupdate(fcParams,gradientsParams,trailingAvgParams trailingAvgSqParams,迭代,learningRate、gradDecay gradDecaySq);%更新训练损失进度图。D = duration(0,0,toc(start),Format=“hh: mm: ss”);lossValue = double(损失);addpoints (lineLossTrain迭代,lossValue);标题(”经过:“+ string(D)) drawnow结束

评估网络的准确性

下载并提取Omniglot测试数据集。

url =“https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip”;downloadFolder = tempdir;filename = fullfile(下载文件夹,“images_evaluation.zip”);dataFolderTest = fullfile(下载文件夹,“images_evaluation”);如果~存在(dataFolderTest“dir”) disp (“正在下载Omniglot测试数据(3.2 MB)…”) websave(文件名,url);解压缩(文件名,downloadFolder);结束disp (“测试数据已下载。”
测试数据已下载。

将测试数据加载为图像数据存储imageDatastore函数。通过从文件名中提取标签并设置标签财产。

imdsTest = imageDatastore(dataFolderTest,IncludeSubfolders = true,LabelSource =“没有”);files = imdsTest.Files;Parts = split(files,filesep);标签= join(parts(:,(end-2):(end-1)),“_”);imdsTest。标签=categorical(labels);

测试数据集包含20个不同于网络训练时使用的字母。测试数据集中总共有659个不同的类。

numClasses = nummel (unique(imdest . labels))
numClasses = 659

为了计算网络的准确性,创建一组随机的5个小批测试对。使用predictSiamese函数(定义在万博1manbetx支持功能来评估网络预测并计算小批的平均精度。

精度= 0 (1,5);accurybatchsize = 150;I = 1:5提取小批量图像对和对标签[X1,X2,pairLabelsAcc] = getSiameseBatch(imdsTest, accurybatchsize);将小批量数据转换为数组。指定维度标签%“SSCB”(空间,空间,通道,批处理)用于图像数据。X1 = darray (X1,“SSCB”);X2 = darray (X2,“SSCB”);%如果使用GPU,则将数据转换为gpuArray。如果(executionEnvironment = =“汽车”&& canUseGPU) || executionEnvironment ==“图形”X1 = gpuArray(X1);X2 = gpuArray(X2);结束使用训练过的网络评估预测Y = predictSiamese(net,fcParams,X1,X2);%将预测值转换为二进制0或1Y = collect (extractdata(Y));Y = round(Y);计算小批量的平均精度accuracy(i) = sum(Y == pairLabelsAcc)/ accurybatchsize;结束

计算所有小批量的精度

averageAccuracy =平均值(精度)*100
averageAccuracy = 89.0667

显示带有预测的图像测试集

为了直观地检查网络是否正确识别相似和不相似对,可以创建一小批图像对进行测试。使用predictSiamese函数获取每个测试对的预测。显示具有预测、概率分数和指示预测是否正确的标签的图像对。

testBatchSize = 10;[XTest1,XTest2,pairLabelsTest] = getSiameseBatch(imdsTest,testBatchSize);

将测试批数据转换为dlarray.指定维度标签“SSCB”(空间、空间、通道、批处理)用于图像数据。

XTest1 = array(XTest1,“SSCB”);XTest2 = darray (XTest2,“SSCB”);

如果使用GPU,则将数据转换为gpuArray

如果(executionEnvironment = =“汽车”&& canUseGPU) || executionEnvironment ==“图形”XTest1 = gpuArray(XTest1);XTest2 = gpuArray(XTest2);结束

计算预测概率。

YScore = predictSiamese(net,fcParams,XTest1,XTest2);YScore = collect (extractdata(YScore));

将预测值转换为二进制0或1。

YPred = round(YScore);

提取数据进行绘图。

XTest1 = extractdata(XTest1);XTest2 = extractdata(XTest2);

用预测的标签和预测的分数绘制图像。

F =图;tiledlayout(2、5);f.Position(3) = 2*f.Position(3);predLabels = categorical(YPred,[0 1],[0],]“不同”“相似”]);targetLabels = categorical(pairLabelsTest,[0 1],[0])“不同”“相似”]);i = 1:元素个数(pairLabelsTest) nexttile imshow ([XTest1 (::,:, i) XTest2(::,:,我)]);标题(目标:“+ string(targetLabels(i)) +换行符预测:“+ string(predLabels(i)) +换行符“分数:+ YScore (i))结束

该网络能够比较测试图像以确定它们的相似性,即使这些图像都不在训练数据集中。

万博1manbetx支持功能

用于训练和预测的模型函数

这个函数forwardSiamese用于网络训练。该函数定义了子网和fullyconnect乙状结肠这些操作结合起来形成了完整的暹罗网络。forwardSiamese接受网络结构和两张训练图像,并输出关于两张图像相似度的预测。在本例中,函数forwardSiamese在这部分介绍了什么定义网络架构

函数Y = forwardSiamese(net,fcParams,X1,X2)% forwardSiamese接受网络和对训练图像,和%返回对相似(更接近)的概率的预测%为1)或不相似(接近于0)。训练时使用正向暹罗语。%通过双子网传递第一个映像Y1 = forward(net,X1);Y1 = sigmoid(Y1);%通过双子网传递第二个映像Y2 = forward(net,X2);Y2 = sigmoid(Y2)%减去特征向量Y = abs(Y1 - Y2);%通过完全连接操作传递结果Y = fulllyconnect (Y,fcParams.FcWeights,fcParams.FcBias);%转换为0到1之间的概率。Y = sigmoid(Y);结束

这个函数predictSiamese使用经过训练的网络对两幅图像的相似性进行预测。函数与函数相似forwardSiamese,前面定义的。然而,predictSiamese使用预测功能与网络,而不是向前函数,因为一些深度学习层在训练和预测过程中的行为不同。在本例中,函数predictSiamese在这部分介绍了什么评估网络的准确性

函数Y = predictSiamese(net,fcParams,X1,X2)% predictSiamese接受网络和图像对,并返回对相似概率的%预测(接近1)或%不相似(接近于0)。在预测时使用predictSiamese。%通过双子网传递第一个映像。Y1 = predict(net,X1);Y1 = sigmoid(Y1);%通过双子网传递第二个映像。Y2 = predict(net,X2);Y2 = sigmoid(Y2)%减去特征向量。Y = abs(Y1 - Y2);%通过完全连接操作传递结果。Y = fulllyconnect (Y,fcParams.FcWeights,fcParams.FcBias);%转换为0到1之间的概率。Y = sigmoid(Y);结束

模型损失函数

这个函数modelLoss暹罗人dlnetwork对象,一对小批量输入数据X1X2,以及表明它们是否相似或不同的标签。该函数返回预测值与真实值之间的二元交叉熵损失,以及该损失相对于网络中可学习参数的梯度。在本例中,函数modelLoss在这部分介绍了什么定义模型损失函数

函数[loss,gradientsSubnet,gradientsParams] = modelLoss(net,fcParams,X1,X2,pairLabels)%通过网络传递镜像对。Y = forwardSiamese(net,fcParams,X1,X2);%计算二元交叉熵损失。loss = binarycrossentropy(Y,pairLabels);计算损失相对于可学习网络的梯度%的参数。[gradientsSubnet,gradientsParams] = dlgradient(loss,net.Learnables,fcParams);结束

二元交叉熵损失函数

binarycrossentropy函数接受网络预测和对标签,并返回二元交叉熵损失值。

函数loss = binarycrossentropy(Y,pairLabels)获得预测的精度,以防止由于浮点数导致的错误%的精度。precision = underlyingType(Y);将小于浮点精度的值转换为eps。Y(Y < eps(precision)) = eps(precision);将1-eps和1-eps之间的值转换为1-eps。Y(Y > 1 - eps(精度))= 1 - eps(精度)%计算每对二进制交叉熵损失loss = -pairLabels.*log(Y) - (1 -pairLabels)。*log(1 - Y);对minibatch中所有对求和并规范化。loss = sum(loss)/numel(pairLabels);结束

创建批量镜像对

下面的函数根据标签创建相似或不相似的随机图像对。在本例中,函数getSiameseBatch在这部分介绍了什么创建一对相似和不同的图像。

获取暹罗批处理功能

getSiameseBatch返回随机选择的批处理或成对图像。平均而言,这个函数产生一组相似和不相似的对。

函数[X1,X2,pairLabels] = getSiameseBatch(imds,miniBatchSize) pairLabels = 0 (1,miniBatchSize);imgSize = size(readimage(imds,1));X1 = 0 ([imgSize 1 miniBatchSize],“单身”);X2 = 0 ([imgSize 1 miniBatchSize],“单身”);i = 1:miniBatchSize = rand(1);如果选择< 0.5 [pairIdx1,pairIdx2,pairLabels(i)] = getSimilarPair(imds.Labels);其他的[pairIdx1,pairIdx2,pairLabels(i)] = getDissimilarPair(imds.Labels);结束X1(:,:,:,i) = imds.readimage(pairIdx1);X2(:,:,:,i) = imds.readimage(pairIdx2);结束结束

得到相似对函数

getSimilarPair函数返回同一类中图像的随机索引对,并且相似的pair标签等于1。

函数[pairIdx1,pairIdx2,pairLabel] = getSimilarPair(classLabel)找到所有独特的类。class = unique(classLabel);%随机选择一个类,将被用来得到一个相似的配对。classChoice = randi(numel(classes));求所选类中所有观测值的指数。idxs = find(classLabel==classes(classChoice));%从所选的类别中随机选择两个不同的图像。pairIdxChoice = randperm(numel(idxs),2);pairIdx1 = idxs(pairIdxChoice(1));pairIdx2 = idxs(pairIdxChoice(2));pairLabel = 1;结束

得到不相似对函数

getDissimilarPair函数返回不同类别图像的随机索引对,并且不相似对标签等于0。

函数[pairIdx1,pairIdx2,label] = getDissimilarPair(classLabel)找到所有独特的类。class = unique(classLabel);随机选择两个不同的类别,这两个类别将被用来获得a%不相似对。classesChoice = randperm(nummel (classes),2);求第一次和第二次所有观测值的指数%的类。idxs1 = find(classLabel==classes(classesChoice(1)));idxs2 = find(classLabel==classes(classesChoice(2)));%从每个类别中随机选择一张图像。pairidx1 = randi(nummel (idxs1));pairIdx2Choice = randi(nummel (idxs2));pairIdx1 = idxs1(pairIdx1Choice);pairIdx2 = idxs2(pairIdx2);Label = 0;结束

参考文献

[1] Bromley, J., I. Guyon, Y. LeCun, E. Säckinger和R. Shah。使用“暹罗”时延神经网络的签名验证。第六届神经信息处理系统国际会议论文集(NIPS 1993), 1994, pp737-744。可以在使用“暹罗”时延神经网络的签名验证在NIPS学报网站上。

[2] Wenpeg, Y.和H . sch兹。”用于意译识别的卷积神经网络。《2015年美国ACL北美分会会议论文集》,2015年,第901-911页。可以在用于意译识别的卷积神经网络在ACL选集网站上

[3] Lake, b.m., Salakhutdinov, R.和Tenenbaum, j.b.。基于概率程序归纳法的人类概念学习。科学,350(6266),(2015)pp1332-1338。

[4] Koch, G., Zemel, R., and Salakhutdinov, R.(2015)。用于一次性图像识别的连体神经网络”。第32届国际机器学习会议论文集,37(2015)。可以在用于一次性图像识别的连体神经网络在ICML'15网站上。

另请参阅

||||

相关的话题