使用深度学习构建High Five计数器
这篇文章来自布莱恩·道格拉斯,YouTube内容创建者,用于控制系统和深度学习应用程序
大约十年来,我一直想实现一个愚蠢的想法,即测量一个人的手的加速度,以计算他们一天中击掌的次数。我不确定如何使用我熟悉的基于规则的算法开发方法来实现这一点,因此项目被搁置。那只是在我做决定的时候关于深入学习的MATLAB技术谈话视频系列我意识到深度学习对于解决这个问题是完美的!
本系列第4期视频的主题是迁移学习,结果证明这是我需要的关键概念,让我快速建立并运行击掌计数算法。在这篇博文中,我将介绍我编写的代码的细节,以及我用来让我的击掌计数器工作的工具。希望你们可以用这个作为起点来解决那些困难的分类问题这些问题你们也在过去10年里一直在研究。
本文分为以下几部分:
那么,让我们开始吧!
硬件概述
硬件设置非常简单。我有一个加速计,它连接到微控制器通过I2C总线。然后Arduino通过USB连接到我的电脑。
为了感知加速度,我用的是MPU-9250. 这是TDK InvenSense的一个9自由度惯性测量单元。我没有将芯片集成到我自己的定制电路设计中,而是使用了分接板暴露电源、接地和I2C通信引脚。我使用这种特殊芯片的唯一原因是因为我已经有一个放在周围了,但是任何加速度计都可以工作,只要它足够小,可以用手快速移动。
你可以看到,我的硬件设置是相当粗略的,用一个面包板和一些跳线,但我认为这是一种不错的,你不需要设置任何太花哨的东西,以使它工作。
用MATLAB读取加速度计
要通过Arduino读取MPU-9250的加速度,我使用Arduino万博1manbetx硬件的MATLAB支持包. 该软件包允许您与Arduino通信,而无需为其编译代码。另外,还有一个内置的mpu9250允许您使用单行命令读取传感器的功能。
连接Arduino、实例化MPU9250对象和读取加速计只需要三行代码。
数据预处理与标度图
如果您观看了关于深度学习的技术讲座系列的第四个视频,您就会知道我选择将三轴加速度数据转换为图像以利用GoogLeNet-经过训练以识别图像的网络。特别是,我使用了连续小波变换创建比例图.
标度图是一种适用于存在于多个标度的信号的时频表示。也就是说,信号频率低且变化缓慢,但偶尔会因高频瞬态而中断。事实证明,它们对于可视化缓慢移动的手内偶尔出现的高频五点加速数据非常有用。
下面的可折叠块是我用来制作上述图的MATLAB代码的一个清理版本。
全关%如果您的计算机无法实时运行此操作,请减少示例%对比例图部分进行评分或注释fs=50;% 50hz运行a = arduino('COM3', 'Uno', 'Libraries', 'I2C');%换上你的arduinoimu=mpu9250(a);缓冲区长度秒=2;%要存储在缓冲区中的数据秒数Accel = 0 (floor(buffer_length_sec * fs) + 1,3);%初始化缓冲区t=0:1/fs:(缓冲区长度秒(结束));%时间向量Subplot (2,1,1) plot_accel = plot(t, accel);%设置加速曲线图轴([0,缓冲区长度秒,-50,50]);子图(2,1,2)绘图_比例=图像(零(224,224,3));%设置比例图抽搐%启动计时器上次读取时间=0;i=0;%跑20秒While (toc <= 20) current_read_time = toc;If (current_read_time - last_read_time) >= 1/fs I = I + 1;Accel (1:end-1,:) = Accel (2:end,:);在FIFO缓冲区中移位值accel(end,:) = readAcceleration(imu);plot_accel(1)。YData = accel(:, 1);plot_accel(2)。YData = accel(:, 2);plot_accel(3)。YData = accel(:, 3);%仅每三次采样运行一次比例图,以节省计算时间if mod(i, 3) == 0 fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs,…“VoicesPerOctave”,12);Sig = accel(:, 1);[cfs, ~] = wt(fb, sig);cfs_abs = abs (cfs);Accel_i = imresize(cfs_abs/8, [224 224]);fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs,…“VoicesPerOctave”,12);Sig = accel(:, 2);[cfs, ~] = wt(fb, sig); cfs_abs = abs(cfs); accel_i(:, :, 2) = imresize(cfs_abs/8, [224 224]); fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ... 'VoicesPerOctave', 12); sig = accel(:, 3); [cfs, ~] = wt(fb, sig); cfs_abs = abs(cfs); accel_i(:, :, 3) = imresize(cfs_abs/8, [224 224]); if~(isempty(accel_i(accel_i>1))) accel_i(accel_i>1) = 1; end plot_scale.CData = accel_i; end last_read_time = current_read_time; end end
请注意,此代码使用一个名为cwtfilterbank要创建比例图,该比例图是小波工具箱. 如果您无法访问此工具箱,并且不想自己编写代码,请尝试尝试使用另一种类型的时频可视化。也许是光谱图或者你想出的其他算法。无论你选择什么,这里的想法是,我们试图创建一个图像,使high five图案的独特和可识别特征变得明显。我已经证明了比例图是可行的,但其他方法也可以。
创建培训数据
为了训练人际网络识别击掌,我们需要多个击掌的例子来说明击掌是什么样子,击掌不是什么样子。因为我们将从一个预先培训过的网络开始,所以我们不需要像从头开始培训网络那样多的培训示例。我不知道到底需要多少培训数据才能完全捕获所有可能的击掌游戏的解决方案空间,但是,我收集了100个击掌游戏和100个非击掌游戏的数据,这似乎效果很好。我怀疑我制作的视频可以少一些,但我想如果我真的在制作一个产品,我会用更多的例子。您可以处理标记的训练数据量,并查看其对结果的影响。
收集200张图像似乎需要很多工作,但我编写了一个脚本,一个接一个地循环浏览,并将图像保存在适当的文件夹中。我运行了以下脚本两次;一旦使用“high_five”标签,图像将保存到数据/高_五文件夹一旦有了'no_high_five'标签图像被保存到数据/无高5文件夹
%此脚本收集训练数据并将其放置在指定的位置%标签子文件夹。从中收集3秒钟的数据%传感器,但仅保留并保存最后2秒。%这给了用户一些缓冲时间来启动high five。%程序在图像之间暂停并提示用户继续。%注意,您需要将图形从MATLAB窗口移开,以便%您可以在响应等待提示后看到加速度。全关,全关%如果您的计算机无法实时运行,请降低采样率fs=50;% 50hz运行parentDir=pwd;dataDir='data';%%为正在生成的数据设置标签%标签= 'no_high_five';标签='high_five';a=arduino('COM3'、'Uno'、'Libraries'、'I2C');%换上你的arduinoimu=mpu9250(a);缓冲区长度秒=2;%要存储在缓冲区中的数据秒数Accel = 0 (floor(buffer_length_sec * fs) + 1,3);%初始化缓冲区t=0:1/fs:(缓冲区长度秒(结束));%时间向量Subplot (2,1,1) plot_accel = plot(t, accel);%设置加速曲线图轴([0 buffer_length_sec -50 50]);Subplot (2,1,2) plot_scale = image(0 (224, 224, 3));%设置比例图对于j=1:100%收集100张图片%提示用户准备好录制下一个击掌H=输入('准备就绪时按回车键:');抽搐%启动计时器上次读取时间=0;i=0;%跑3秒钟而(toc<=3)当前读取时间=toc;如果(当前读取时间-上次读取时间)>=1/fs i=i+1;加速度(1:end-1,:)=加速度(2:end,:);%在缓冲区中移动值accel(end,:) = readAcceleration(imu);plot_accel(1)。YData = accel(:, 1);plot_accel(2)。YData = accel(:, 2);plot_accel(3)。YData = accel(:, 3);%每三次采样运行一次比例图if mod(i, 3) == 0 fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs,…“VoicesPerOctave”,12);Sig = accel(:, 1);[cfs, ~] = wt(fb, sig);cfs_abs = abs (cfs);Accel_i = imresize(cfs_abs/8, [224 224]);fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs,…“VoicesPerOctave”,12);Sig = accel(:, 2);[cfs, ~] = wt(fb, sig); cfs_abs = abs(cfs); accel_i(:, :, 2) = imresize(cfs_abs/8, [224 224]); fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ... 'VoicesPerOctave', 12); sig = accel(:, 3); [cfs, ~] = wt(fb, sig); cfs_abs = abs(cfs); accel_i(:, :, 3) = imresize(cfs_abs/8, [224 224]); if~(isempty(accel_i(accel_i>1))) accel_i(accel_i>1) = 1; end plot_scale.CData = accel_i; end last_read_time = current_read_time; end end%保存图像到数据文件夹imageRoot=fullfile(parentDir、dataDir);imgLoc=fullfile(imageRoot,char(标签));imFileName=strcat(char(labels),“uu”,num2str(j),“.jpg”);imwrite(plot_scale.CData,fullfile(imgLoc,imFileName),“JPEG”);终止
运行脚本后,我手动检查了培训数据并删除了我认为会破坏培训的图像。这些图片是高五不在框架或图像中间,我知道我做了一个可怜的高五运动。在下面的gif中,我删除了high five image 49,因为它不在帧的中心。
迁移学习与谷歌网
我的所有培训数据都在相应的文件夹中,下一步就是建立网络。对于这一部分,我将跟随MATLAB示例基于小波分析和深度学习的时间序列分类,但是,我发现使用MATLAB脚本设置和训练网络更容易,而不是通过MATLAB脚本运行所有内容深度网络设计器应用程序。
我从预先训练好的GoogLeNet开始,利用这个网络在识别图像中物体方面的所有知识。GoogLeNet受过训练,能够识别图像中的鱼和热狗之类的东西——显然不是我要找的东西——但这正是迁移学习有用的地方。通过迁移学习,我可以保留大部分现有网络,并且只替换网络末端的两层,这两层将这些通用特性结合到我正在寻找的特定模式中。然后,当我对网络进行再培训时,几乎只需要对这两个层次进行培训,这就是为什么转移学习的培训速度要快得多。
如果你想确切地知道我是如何替换图层的,以及我使用了什么样的训练参数,我建议你跟随我使用的MATLAB示例,或者观看技术讲座。然而,这里还是一个很好的地方,可以让你尝试一些不同的东西。您可以尝试从一个不同的预训练网络开始,如SqueezeNet,也可以在GoogLeNet中替换更多层,或者更改训练参数。这里有很多选择,我认为偏离我所做的可以帮助你对所有这些变量如何影响结果产生一些直觉。
培训网络
网络已经准备就绪,在深度网络设计师应用程序中的培训非常简单。在data选项卡中,我通过选择保存击掌和不击掌图像集的文件夹导入训练数据。我还留出20%的图像用于训练过程中的验证。
然后在“培训”选项卡中,我设置了培训选项。在这里,我使用的选项与下面的MATLAB示例中使用的选项相同,但是,我再次鼓励您使用其中一些值,看看它们如何影响结果。
在我的单个CPU上,培训只花了4分钟多,验证准确率达到了97%左右。对几个小时的工作来说还不错!
测试高五计数器
现在我有了一个训练有素的网络,我使用了这个函数分类从深度学习工具箱中,在每次采样时传入比例图,并让网络返回标签。如果返回的标签是“high_five”,我将增加一个计数器。为了避免在加速度数据在整个缓冲区中划过时多次计算同一个high five,我添加了一个超时,该超时不会计算新的high five,除非它距离上一个high five至少2秒。
下面是我用来击掌的代码的清理版本。
全关%%更新到已培训网络的名称负载训练网络=训练网络;%如果您的计算机无法实时运行此操作,请减少示例%对比例图部分进行评分或注释fs=50;% 50hz运行a = arduino('COM3', 'Uno', 'Libraries', 'I2C');%换上你的arduinoimu=mpu9250(a);缓冲区长度秒=2;%要存储在缓冲区中的数据秒数Accel = 0 (floor(buffer_length_sec * fs) + 1,3);%初始化缓冲区t=0:1/fs:(缓冲区长度秒(结束));%时间向量%设置图h=数字;h、 位置=[100 900 700];p1=子批次(2,1,1);曲线图加速度=曲线图(t,加速度);绘图加速度(1)。线宽=3;绘图加速度(2)。线宽=3;绘图加速度(3)。线宽=3;p1.FontSize=20;p1.Title.String=‘加速度’;轴线([0T(末端)-5060]);xlabel(‘秒’);ylabel(“加速度,mpss”);网格化;label_string=文本(1.3,45,'无高五');label_string.解释器='none';label_string.FontSize=25;count_string=text(0.1,45,'高位五位计数器:');count_string.解释器='none';count_string.FontSize=15;val_string=text(0.65,45,'0');val_string.解释器='none';val_string.FontSize=15;p2=子批次(2,1,2);缩放加速度=图像(零(2242243));p2.Title.String='scalegram';p2.FontSize=20;末端=0;hfcount=0;抽搐%启动计时器上次读取时间=0;i=0;%运行high five计数器20秒而(toc<=20)当前读取时间=toc;如果(当前读取时间-上次读取时间)>=1/fs i=i+1;末端=末端+1;%读加速度Accel (1:end-1,:) = Accel (2:end,:);在FIFO缓冲区中移位值accel(end,:) = readAcceleration(imu);plot_accel(1)。YData = accel(:, 1);plot_accel(2)。YData = accel(:, 2);plot_accel(3)。YData = accel(:, 3);%仅每三次采样运行一次比例图,以节省计算时间如果mod(i,3)=0%比例图fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs,…“VoicesPerOctave”,12);Sig = accel(:, 1);[cfs, ~] = wt(fb, sig);cfs_abs = abs (cfs);Accel_i = imresize(cfs_abs/8, [224 224]);fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs,…“VoicesPerOctave”,12);Sig = accel(:, 2);[cfs, ~] = wt(fb, sig); cfs_abs = abs(cfs); accel_i(:, :, 2) = imresize(cfs_abs/8, [224 224]); fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ... 'VoicesPerOctave', 12); sig = accel(:, 3); [cfs, ~] = wt(fb, sig); cfs_abs = abs(cfs); accel_i(:, :, 3) = imresize(cfs_abs/8, [224 224]);%将像素饱和为1If ~(isempty(accel_i(accel_i>1))) accel_i(accel_i>1) = 1;scale_accel结束。CData = im2uint8 (accel_i);%分类比例图[YPred,probs]=分类(trainedNetwork,scale\u accel.CData);如果strcmp(string(YPred),'high_five')label_string.BackgroundColor=[1 0 0];label_string.string=“High Five!”;%仅当100个样本从上一个high five之后已过时才计算如果终端>100 hfcount=hfcount+1;val_string.string=字符串(hfcount);末端=0;end else标签_string.BackgroundColor=[1];label_string.string=“无高五”;结束
现在它开始行动了!
- 类别:
- 深度学习
评论
如需留言,请点击在这里登录到您的MathWorks帐户或创建新帐户。