主要内容

使用深度学习分类文本数据

这个例子展示了如何使用深度学习长短期记忆(LSTM)网络对文本数据进行分类。

文本数据自然是顺序的。一段文本是一组单词,它们之间可能存在依赖关系。为了学习和使用长期依赖关系对序列数据进行分类,使用LSTM神经网络。LSTM网络是一种循环神经网络(RNN),可以学习序列数据时间步长之间的长期依赖关系。

要向LSTM网络输入文本,首先要将文本数据转换为数字序列。您可以使用将文档映射到数字索引序列的word编码来实现这一点。为了获得更好的效果,还可以在网络中加入单词嵌入层。词嵌入将词汇表中的词映射到数字向量,而不是标量索引。这些嵌入捕获单词的语义细节,因此具有相似含义的单词具有相似的向量。他们还通过向量运算建立单词之间的关系。例如,关系"罗马之于意大利如同巴黎去法国用等式意大利来描述- - - - - -罗马+巴黎=法国。

在这个例子中,训练和使用LSTM网络有四个步骤:

  • 导入并预处理数据。

  • 使用单词编码将单词转换为数字序列。

  • 用词嵌入层创建并训练LSTM网络。

  • 使用训练好的LSTM网络对新的文本数据进行分类。

导入数据

导入工厂报表数据。此数据包含工厂事件的标记文本描述。若要将文本数据作为字符串导入,请指定文本类型“字符串”

文件名=“factoryReports.csv”;数据= readtable(文件名,“TextType”“字符串”);头(数据)
ans =8×5表类别描述紧急解决成本  _____________________________________________________________________ ____________________ ________ ____________________ _____ " 项目是偶尔陷入扫描仪卷。”“机械故障”“中等”“重新调整机”45“装配活塞发出巨大的嘎嘎声和砰砰声。”“机械故障”“中等”“重新调整机器”“开机时电源有故障。”"电子故障" "高" "完全更换" 16200 "装配器电容器烧坏"“电子故障”“高”“更换部件”“352”“混合器熔断器触发。”"电子故障" "低" "列入观察名单" 55 "爆裂管道中施工剂正在喷洒冷却剂""泄漏" "高" "更换部件" 371 "搅拌机保险丝烧断。"“电子故障”“信号低”“更换部件”“东西继续从传送带上掉下来。”"机械故障" "低" "调整机

类中的标签对事件进行分类类别列。要将数据划分为类,请将这些标签转换为类别。

数据。Category = categorical(data.Category);

使用直方图查看数据中类的分布。

图直方图(data.Category);包含(“类”) ylabel (“频率”)标题(“类分配”

下一步是将其划分为用于训练和验证的集。将数据划分为一个训练分区和一个用于验证和测试的预留分区。指定坚持百分比为20%。

CVP = cvpartition(数据。类别,“坚持”, 0.2);dataTrain = data(training(cvp),:);dataValidation = data(test(cvp),:);

从已分区的表中提取文本数据和标签。

textDataTrain = dataTrain.Description;textDataValidation = datavalid . description;YTrain = dataTrain.Category;YValidation = datavalid . category;

要检查是否正确导入了数据,请使用单词云可视化训练文本数据。

图wordcloud (textDataTrain);标题(“训练数据”

预处理文本数据

创建一个对文本数据进行标记和预处理的函数。这个函数preprocessText,在示例末尾列出,执行以下步骤:

  1. 使用标记化文本tokenizedDocument

  2. 使用将文本转换为小写较低的

  3. 删除标点符号erasePunctuation

对训练数据和验证数据进行预处理preprocessText函数。

documentsTrain = preprocessText(textDataTrain);documentsValidation = preprocessText(textDataValidation);

查看前几个预处理的培训文档。

documentsTrain (1:5)
ans = 5×1 tokenizedDocument: 9令牌:项目偶尔卡在扫描仪线轴10令牌:响亮的嘎嘎声和敲打的声音来自组装活塞10令牌:启动工厂时有切断电源5令牌:在组装机炸电容器4令牌:搅拌机绊倒保险丝

将文档转换为序列

要将文档输入LSTM网络,请使用单词编码将文档转换为数字索引序列。

要创建单词编码,请使用wordEncoding函数。

enc = worddencoding (documentsTrain);

下一个转换步骤是填充和截断文档,使它们都具有相同的长度。的trainingOptions函数提供自动填充和截断输入序列的选项。但是,这些选项不太适合字向量序列。相反,手动填充和截断序列。如果你left-pad并截断词向量序列,可以提高训练效果。

要填充和截断文档,首先选择一个目标长度,然后截断比它长的文档,左填充比它短的文档。为了获得最佳结果,目标长度应该很短,而不会丢弃大量数据。要找到合适的目标长度,请查看训练文档长度的直方图。

documentlength = documentlength (documentsTrain);直方图(documentlength)“文档长度”)包含(“长度”) ylabel (“文件数量”

大多数培训文档的标记少于10个。使用它作为截断和填充的目标长度。

将文档转换为数字索引序列doc2sequence.若要截断或左填充序列以使其长度为10,请设置“长度”选项10。

sequenceLength = 10;XTrain = doc2sequence(enc,documentsTrain,“长度”, sequenceLength);XTrain (1:5)
ans =5×1单元格数组{1×10 double} {1×10 double} {1×10 double} {1×10 double} {1×10 double}

使用相同的选项将验证文档转换为序列。

XValidation = doc2sequence(enc,documentsValidation,“长度”, sequenceLength);

创建并训练LSTM网络

定义LSTM网络体系结构。为了将序列数据输入到网络中,包括一个序列输入层,并将输入大小设置为1。接下来,包含一个维度为50的单词嵌入层,与单词编码的单词数量相同。接下来,包括一个LSTM层,并设置隐藏单元的数量为80。若要将LSTM层用于从序列到标签的分类问题,请将输出模式设置为“最后一次”.最后,添加一个与类数量相同大小的全连接层,一个softmax层和一个classification层。

inputSize = 1;embeddingDimension = 50;numHiddenUnits = 80;numWords = c. numWords;numClasses = nummel(类别(YTrain));层= [...sequenceInputLayer(inputSize) wordEmbeddingLayer(embeddingDimension,numWords) lstmLayer(numHiddenUnits,“OutputMode”“最后一次”fullyConnectedLayer(numClasses) softmaxLayer classificationLayer
2”字嵌入层字嵌入层50维423个唯一单词3”LSTM LSTM 80个隐藏单元4”全连接4全连接层5”Softmax Softmax 6”分类输出crossentropyex

指定培训项目

指定培训选项:

  • 使用亚当求解器训练。

  • 指定一个小批大小为16。

  • 对每个纪元的数据进行洗牌。

  • 通过设置监控培训进度“阴谋”选项“训练进步”

  • 属性指定验证数据“ValidationData”选择。

  • 属性抑制详细输出“详细”选项

默认情况下,trainNetwork使用GPU(如果有的话)。否则,占用CPU。若要手动指定执行环境,请使用“ExecutionEnvironment”的名称-值对参数trainingOptions.在CPU上进行训练比在GPU上进行训练花费的时间要长得多。使用GPU进行训练需要并行计算工具箱™和受支持的GPU设备。万博1manbetx有关受支持设备的信息,请参见万博1manbetxGPU计算要求(并行计算工具箱)

选项= trainingOptions(“亚当”...“MiniBatchSize”, 16岁,...“GradientThreshold”2,...“洗牌”“every-epoch”...“ValidationData”{XValidation, YValidation},...“阴谋”“训练进步”...“详细”、假);

对LSTM网络进行训练trainNetwork函数。

net = trainNetwork(XTrain,YTrain,图层,选项);

使用新数据进行预测

对三个新报告的事件类型进行分类。创建包含新报告的字符串数组。

reportsNew = [...“冷却剂在分拣机下面。”“分拣机在启动时炸断保险丝。”“有一些非常响亮的咔嗒声从组装。”];

将预处理步骤作为训练文档对文本数据进行预处理。

documentsNew = preprocessText(reportsNew);

使用将文本数据转换为序列doc2sequence与创建训练序列时的选项相同。

XNew = doc2sequence(enc,documentsNew,“长度”, sequenceLength);

利用训练好的LSTM网络对新序列进行分类。

labelsNew =分类(net,XNew)
labelsNew =3×1分类泄漏电子故障机械故障

预处理功能

这个函数preprocessText执行以下步骤:

  1. 使用标记化文本tokenizedDocument

  2. 使用将文本转换为小写较低的

  3. 删除标点符号erasePunctuation

函数documents = preprocessText(textData)标记文本。documents = tokenizedDocument(textData);%转换为小写字母。文档=较低(文档);删除标点符号。documents = eraspunctuation(文档);结束

另请参阅

|||(深度学习工具箱)|(深度学习工具箱)|(深度学习工具箱)||(深度学习工具箱)|

相关的话题