Main Content

Modulation Classification with Deep Learning

This example shows how to use a convolutional neural network (CNN) for modulation classification. You generate synthetic, channel-impaired waveforms. Using the generated waveforms as training data, you train a CNN for modulation classification. You then test the CNN with software-defined radio (SDR) hardware and over-the-air signals.

Predict Modulation Type Using CNN

The trained CNN in this example recognizes these eight digital and three analog modulation types:

  • Binary phase shift keying (BPSK)

  • Quadrature phase shift keying (QPSK)

  • 8-ary phase shift keying (8-PSK)

  • 16-ary quadrature amplitude modulation (16-QAM)

  • 64-ary quadrature amplitude modulation (64-QAM)

  • 4-ary pulse amplitude modulation (PAM4)

  • Gaussian frequency shift keying (GFSK)

  • Continuous phase frequency shift keying (CPFSK)

  • Broadcast FM (B-FM)

  • Double sideband amplitude modulation (DSB-AM)

  • Single sideband amplitude modulation (SSB-AM)

modulationTypes = categorical(["BPSK",“正交相移编码”,"8PSK",..."16QAM","64QAM","PAM4","GFSK","CPFSK",..."B-FM","DSB-AM","SSB-AM"]);

First, load the trained network. For details on network training, see theTraining a CNNsection.

loadtrainedModulationClassificationNetworktrainedNet
trainedNet = SeriesNetwork with properties: Layers: [28×1 nnet.cnn.layer.Layer] InputNames: {'Input Layer'} OutputNames: {'Output'}

The trained CNN takes 1024 channel-impaired samples and predicts the modulation type of each frame. Generate several PAM4 frames that are impaired with Rician multipath fading, center frequency and sampling time drift, and AWGN. Use following function to generate synthetic signals to test the CNN. Then use the CNN to predict the modulation type of the frames.

  • randi: Generate random bits

  • pammod(Communications Toolbox)PAM4-modulate the bits

  • rcosdesign(Signal Processing Toolbox): Design a square-root raised cosine pulse shaping filter

  • filter: Pulse shape the symbols

  • comm.RicianChannel(Communications Toolbox): Apply Rician multipath channel

  • comm.PhaseFrequencyOffset(Communications Toolbox): Apply phase and/or frequency shift due to clock offset

  • interp1: Apply timing drift due to clock offset

  • awgn(Communications Toolbox): Add AWGN

% Set the random number generator to a known state to be able to regenerate% the same frames every time the simulation is runrng(123456)% Random bitsd = randi([0 3], 1024, 1);% PAM4 modulationsyms = pammod(d,4);% Square-root raised cosine filterfilterCoeffs = rcosdesign(0.35,4,8); tx = filter(filterCoeffs,1,upsample(syms,8));% ChannelSNR = 30; maxOffset = 5; fc = 902e6; fs = 200e3; multipathChannel = comm.RicianChannel(...'SampleRate', fs,...'PathDelays', [0 1.8 3.4] / 200e3,...'AveragePathGains', [0 -2 -10],...'KFactor', 4,...'MaximumDopplerShift', 4); frequencyShifter = comm.PhaseFrequencyOffset(...'SampleRate', fs);% Apply an independent multipath channelreset(multipathChannel) outMultipathChan = multipathChannel(tx);% Determine clock offset factorclockOffset = (rand() * 2*maxOffset) - maxOffset; C = 1 + clockOffset / 1e6;% Add frequency offsetfrequencyShifter.FrequencyOffset = -(C-1)*fc; outFreqShifter = frequencyShifter(outMultipathChan);% Add sampling time driftt = (0:length(tx)-1)' / fs; newFs = fs * C; tp = (0:length(tx)-1)' / newFs; outTimeDrift = interp1(t, outFreqShifter, tp);% Add noiserx = awgn(outTimeDrift,SNR,0);% Frame generation for classificationunknownFrames = helperModClassGetNNFrames(rx);% Classification[prediction1,score1] = classify(trainedNet,unknownFrames);

analo返回分类器预测,gous to hard decisions. The network correctly identifies the frames as PAM4 frames. For details on the generation of the modulated signals, seehelperModClassGetModulatorfunction.

prediction1
prediction1 =7×1 categoricalPAM4 PAM4 PAM4 PAM4 PAM4 PAM4 PAM4

The classifier also returns a vector of scores for each frame. The score corresponds to the probability that each frame has the predicted modulation type. Plot the scores.

helperModClassPlotScores(score1,modulationTypes)

Before we can use a CNN for modulation classification, or any other task, we first need to train the network with known (or labeled) data. The first part of this example shows how to use Communications Toolbox™ features, such as modulators, filters, and channel impairments, to generate synthetic training data. The second part focuses on defining, training, and testing the CNN for the task of modulation classification. The third part tests the network performance with over-the-air signals using software defined radio (SDR) platforms.

Waveform Generation for Training

Generate 10,000 frames for each modulation type, where 80% is used for training, 10% is used for validation and 10% is used for testing. We use training and validation frames during the network training phase. Final classification accuracy is obtained using test frames. Each frame is 1024 samples long and has a sample rate of 200 kHz. For digital modulation types, eight samples represent a symbol. The network makes each decision based on single frames rather than on multiple consecutive frames (as in video). Assume a center frequency of 902 MHz and 100 MHz for the digital and analog modulation types, respectively.

To run this example quickly, use the trained network and generate a small number of training frames. To train the network on your computer, choose the "Train network now" option (i.e. set trainNow to true).

trainNow =false;iftrainNow == true numFramesPerModType = 10000;elsenumFramesPerModType = 200;endpercentTrainingSamples = 80; percentValidationSamples = 10; percentTestSamples = 10; sps = 8;% Samples per symbolspf = 1024;% Samples per framesymbolsPerFrame = spf / sps; fs = 200e3;% Sample ratefc = [902e6 100e6];% Center frequencies

Create Channel Impairments

Pass each frame through a channel with

  • AWGN

  • Rician multipath fading

  • Clock offset, resulting in center frequency offset and sampling time drift

Because the network in this example makes decisions based on single frames, each frame must pass through an independent channel.

AWGN

The channel adds AWGN with an SNR of 30 dB. Implement the channel usingawgn(Communications Toolbox)function.

Rician Multipath

The channel passes the signals through a Rician multipath fading channel using thecomm.RicianChannel(Communications Toolbox)System object™. Assume a delay profile of [0 1.8 3.4] samples with corresponding average path gains of [0 -2 -10] dB. The K-factor is 4 and the maximum Doppler shift is 4 Hz, which is equivalent to a walking speed at 902 MHz. Implement the channel with the following settings.

Clock Offset

时钟偏移发生的错误internal clock sources of transmitters and receivers. Clock offset causes the center frequency, which is used to downconvert the signal to baseband, and the digital-to-analog converter sampling rate to differ from the ideal values. The channel simulator uses the clock offset factor C , expressed as C = 1 + Δ clock 10 6 , where Δ clock is the clock offset. For each frame, the channel generates a random Δ clock value from a uniformly distributed set of values in the range [ - max Δ clock max Δ clock ], where max Δ clock is the maximum clock offset. Clock offset is measured in parts per million (ppm). For this example, assume a maximum clock offset of 5 ppm.

maxDeltaOff = 5; deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff; C = 1 + (deltaOff/1e6);

Frequency Offset

Subject each frame to a frequency offset based on clock offset factor C and the center frequency. Implement the channel usingcomm.PhaseFrequencyOffset(Communications Toolbox).

Sampling Rate Offset

Subject each frame to a sampling rate offset based on clock offset factor C . Implement the channel using theinterp1function to resample the frame at the new rate of C × f s .

Combined Channel

Use thehelperModClassTestChannelobject to apply all three channel impairments to the frames.

channel = helperModClassTestChannel(...'SampleRate', fs,...'SNR', SNR,...'PathDelays', [0 1.8 3.4] / fs,...'AveragePathGains', [0 -2 -10],...'KFactor', 4,...'MaximumDopplerShift', 4,...'MaximumClockOffset', 5,...'CenterFrequency', 902e6)
channel = helperModClassTestChannel with properties: SNR: 30 CenterFrequency: 902000000 SampleRate: 200000 PathDelays: [0 9.0000e-06 1.7000e-05] AveragePathGains: [0 -2 -10] KFactor: 4 MaximumDopplerShift: 4 MaximumClockOffset: 5

You can view basic information about the channel using the info object function.

chInfo = info(channel)
chInfo =struct with fields:ChannelDelay: 6 MaximumFrequencyOffset: 4510 MaximumSampleRateOffset: 1

Waveform Generation

Create a loop that generates channel-impaired frames for each modulation type and stores the frames with their corresponding labels in MAT files. By saving the data into files, you eliminate the need to generate the data every time you run this example. You can also share the data more effectively.

Remove a random number of samples from the beginning of each frame to remove transients and to make sure that the frames have a random starting point with respect to the symbol boundaries.

% Set the random number generator to a known state to be able to regenerate% the same frames every time the simulation is runrng(1235) tic numModulationTypes = length(modulationTypes); channelInfo = info(channel); transDelay = 50; dataDirectory = fullfile(tempdir,"ModClassDataFiles"); disp("Data file directory is "+ dataDirectory)
Data file directory is C:\Users\kkearney\AppData\Local\Temp\ModClassDataFiles
fileNameRoot ="frame";% Check if data files existdataFilesExist = false;ifexist(dataDirectory,“dir”) files = dir(fullfile(dataDirectory,sprintf("%s*",fileNameRoot)));iflength(files) == numModulationTypes*numFramesPerModType dataFilesExist = true;endendif~dataFilesExist disp("Generating data and saving in data files...") [success,msg,msgID] = mkdir(dataDirectory);if~success error(msgID,msg)endformodType = 1:numModulationTypes elapsedTime = seconds(toc); elapsedTime.Format ='hh:mm:ss'; fprintf('%s - Generating %s frames\n',...elapsedTime, modulationTypes(modType)) label = modulationTypes(modType); numSymbols = (numFramesPerModType / sps); dataSrc = helperModClassGetSource(modulationTypes(modType), sps, 2*spf, fs); modulator = helperModClassGetModulator(modulationTypes(modType), sps, fs);ifcontains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'})%模拟调制types use a center frequency of 100 MHzchannel.CenterFrequency = 100e6;else% Digital modulation types use a center frequency of 902 MHzchannel.CenterFrequency = 902e6;endforp=1:numFramesPerModType% Generate random datax = dataSrc();% Modulatey = modulator(x);% Pass through independent channelsrxSamples = channel(y);% Remove transients from the beginning, trim to size, and normalizeframe = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps);% Save data filefileName = fullfile(dataDirectory,...sprintf("%s%s%03d",fileNameRoot,modulationTypes(modType),p)); save(fileName,"frame","label")endendelsedisp("Data files exist. Skip data generation.")end
Generating data and saving in data files...
00:00:00 - Generating BPSK frames 00:00:01 - Generating QPSK frames 00:00:03 - Generating 8PSK frames 00:00:04 - Generating 16QAM frames 00:00:06 - Generating 64QAM frames 00:00:07 - Generating PAM4 frames 00:00:09 - Generating GFSK frames 00:00:10 - Generating CPFSK frames 00:00:11 - Generating B-FM frames 00:00:26 - Generating DSB-AM frames 00:00:27 - Generating SSB-AM frames
% Plot the amplitude of the real and imaginary parts of the example frames% against the sample numberhelperModClassPlotTimeDomain(dataDirectory,modulationTypes,fs)

Figure contains 11 axes objects. Axes object 1 with title BPSK contains 2 objects of type line. Axes object 2 with title QPSK contains 2 objects of type line. Axes object 3 with title 8PSK contains 2 objects of type line. Axes object 4 with title 16QAM contains 2 objects of type line. Axes object 5 with title 64QAM contains 2 objects of type line. Axes object 6 with title PAM4 contains 2 objects of type line. Axes object 7 with title GFSK contains 2 objects of type line. Axes object 8 with title CPFSK contains 2 objects of type line. Axes object 9 with title B-FM contains 2 objects of type line. Axes object 10 with title DSB-AM contains 2 objects of type line. Axes object 11 with title SSB-AM contains 2 objects of type line.

% Plot the spectrogram of the example frameshelperModClassPlotSpectrogram(dataDirectory,modulationTypes,fs,sps)

Figure contains 11 axes objects. Axes object 1 with title BPSK contains an object of type image. Axes object 2 with title QPSK contains an object of type image. Axes object 3 with title 8PSK contains an object of type image. Axes object 4 with title 16QAM contains an object of type image. Axes object 5 with title 64QAM contains an object of type image. Axes object 6 with title PAM4 contains an object of type image. Axes object 7 with title GFSK contains an object of type image. Axes object 8 with title CPFSK contains an object of type image. Axes object 9 with title B-FM contains an object of type image. Axes object 10 with title DSB-AM contains an object of type image. Axes object 11 with title SSB-AM contains an object of type image.

Create a Datastore

Use asignalDatastoreobject to manage the files that contain the generated complex waveforms. Datastores are especially useful when each individual file fits in memory, but the entire collection does not necessarily fit.

frameDS = signalDatastore(dataDirectory,'SignalVariableNames',["frame","label"]);

Transform Complex Signals to Real Arrays

The deep learning network in this example expects real inputs while the received signal has complex baseband samples. Transform the complex signals into real valued 4-D arrays. The output frames have size 1-by-spf-by-2-by-N, where the first page (3rd dimension) is in-phase samples and the second page is quadrature samples. When the convolutional filters are of size 1-by-spf, this approach ensures that the information in the I and Q gets mixed even in the convolutional layers and makes better use of the phase information. SeehelperModClassIQAsPagesfor details.

frameDSTrans = transform(frameDS,@helperModClassIQAsPages);

Split into Training, Validation, and Test

Next divide the frames into training, validation, and test data. SeehelperModClassSplitDatafor details.

splitPercentages = [percentTrainingSamples,percentValidationSamples,percentTestSamples]; [trainDSTrans,validDSTrans,testDSTrans] = helperModClassSplitData(frameDSTrans,splitPercentages);
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).

Import Data into Memory

神经网络训练是迭代。在每一个工艺教育学院ration, the datastore reads data from files and transforms the data before updating the network coefficients. If the data fits into the memory of your computer, importing the data from the files into the memory enables faster training by eliminating this repeated read from file and transform process. Instead, the data is read from the files and transformed once. Training this network using data files on disk takes about 110 minutes while training using in-memory data takes about 50 min.

Import all the data in the files into memory. The files have two variables:frameandlabeland eachreadcall to the datastore returns a cell array, where the first element is theframeand the second element is thelabel. Use thetransform功能helperModClassReadFrameandhelperModClassReadLabelto read frames and labels. Usereadallwith"UseParallel"option set totrueto enable parallel processing of the transform functions, in case you haveParallel Computing Toolbox™ license. Sincereadallfunction, by default, concatenates the output of thereadfunction over the first dimension, return the frames in a cell array and manually concatenate over the 4th dimension.

% Read the training and validation frames into the memorypctExists = parallelComputingLicenseExists(); trainFrames = transform(trainDSTrans, @helperModClassReadFrame); rxTrainFrames = readall(trainFrames,"UseParallel",pctExists); rxTrainFrames = cat(4, rxTrainFrames{:}); validFrames = transform(validDSTrans, @helperModClassReadFrame); rxValidFrames = readall(validFrames,"UseParallel",pctExists); rxValidFrames = cat(4, rxValidFrames{:});% Read the training and validation labels into the memorytrainLabels = transform(trainDSTrans, @helperModClassReadLabel); rxTrainLabels = readall(trainLabels,"UseParallel",pctExists); validLabels = transform(validDSTrans, @helperModClassReadLabel); rxValidLabels = readall(validLabels,"UseParallel",pctExists);

Train the CNN

This example uses a CNN that consists of six convolution layers and one fully connected layer. Each convolution layer except the last is followed by a batch normalization layer, rectified linear unit (ReLU) activation layer, and max pooling layer. In the last convolution layer, the max pooling layer is replaced with an average pooling layer. The output layer has softmax activation. For network design guidance, seeDeep Learning Tips and Tricks.

modClassNet = helperModClassCNN(modulationTypes,sps,spf);

Next configureTrainingOptionsSGDMto use an SGDM solver with a mini-batch size of 256. Set the maximum number of epochs to 12, since a larger number of epochs provides no further training advantage. By default, the'ExecutionEnvironment'property is set to'auto',那里的trainNetworkfunction uses a GPU if one is available or uses the CPU, if not. To use the GPU, you must have aParallel Computing Toolboxlicense. Set the initial learning rate to 2 x 10 - 2 . Reduce the learning rate by a factor of 10 every 9 epochs. Set“阴谋”to 'training-progress'to plot the training progress. On an NVIDIA® Titan Xp GPU, the network takes approximately 25 minutes to train.

maxEpochs = 12; miniBatchSize = 256; options = helperModClassTrainingOptions(maxEpochs,miniBatchSize,...numel(rxTrainLabels),rxValidFrames,rxValidLabels);

Either train the network or use the already trained network. By default, this example uses the trained network.

iftrainNow == true elapsedTime = seconds(toc); elapsedTime.Format ='hh:mm:ss'; fprintf('%s - Training the network\n', elapsedTime) trainedNet = trainNetwork(rxTrainFrames,rxTrainLabels,modClassNet,options);elseloadtrainedModulationClassificationNetworkend

As the plot of the training progress shows, the network converges in about 12 epochs to more than 95% accuracy.

Evaluate the trained network by obtaining the classification accuracy for the test frames. The results show that the network achieves about 94% accuracy for this group of waveforms.

elapsedTime = seconds(toc); elapsedTime.Format ='hh:mm:ss'; fprintf('%s - Classifying test frames\n', elapsedTime)
00:02:22 - Classifying test frames
% Read the test frames into the memorytestFrames = transform(testDSTrans, @helperModClassReadFrame); rxTestFrames = readall(testFrames,"UseParallel",pctExists); rxTestFrames = cat(4, rxTestFrames{:});% Read the test labels into the memorytestLabels = transform(testDSTrans, @helperModClassReadLabel); rxTestLabels = readall(testLabels,"UseParallel",pctExists); rxTestPred = classify(trainedNet,rxTestFrames); testAccuracy = mean(rxTestPred == rxTestLabels); disp("Test accuracy: "+ testAccuracy*100 +"%")
Test accuracy: 94.5455%

Plot the confusion matrix for the test frames. As the matrix shows, the network confuses 16-QAM and 64-QAM frames. This problem is expected since each frame carries only 128 symbols and 16-QAM is a subset of 64-QAM. The network also confuses QPSK and 8-PSK frames, since the constellations of these modulation types look similar once phase-rotated due to the fading channel and frequency offset.

figure cm = confusionchart(rxTestLabels, rxTestPred); cm.Title ='Confusion Matrix for Test Data'; cm.RowSummary ='row-normalized'; cm.Parent.Position = [cm.Parent.Position(1:2) 740 424];

Figure contains an object of type ConfusionMatrixChart. The chart of type ConfusionMatrixChart has title Confusion Matrix for Test Data.

Test with SDR

测试the performance of the trained network with over-the-air signals using thehelperModClassSDRTestfunction. To perform this test, you must have dedicated SDRs for transmission and reception. You can use two ADALM-PLUTO radios, or one ADALM-PLUTO radio for transmission and one USRP® radio for reception. You mustInstall Support Package for Analog Devices ADALM-PLUTO Radio(Communications Toolbox Support Package for Analog Devices ADALM-Pluto Radio). If you are using a USRP® radio, you must alsoInstall Communications Toolbox Support Package for USRP Radio(Communications Toolbox Support Package for USRP Radio). ThehelperModClassSDRTestfunction uses the same modulation functions as used for generating the training signals, and then transmits them using an ADALM-PLUTO radio. Instead of simulating the channel, capture the channel-impaired signals using the SDR that is configured for signal reception (ADALM-PLUTO or USRP® radio). Use the trained network with the sameclassifyfunction used previously to predict the modulation type. Running the next code segment produces a confusion matrix and prints out the test accuracy.

radioPlatform =“ADALM-PLUTO”;switchradioPlatformcase“ADALM-PLUTO”ifhelperIsPlutoSDRInstalled() == true radios = findPlutoRadio();iflength(radios) >= 2 helperModClassSDRTest(radios);elsedisp('Selected radios not found. Skipping over-the-air test.')endendcase{"USRP B2xx","USRP X3xx","USRP N2xx"}if(helperIsUSRPInstalled() == true) && (helperIsPlutoSDRInstalled() == true) txRadio = findPlutoRadio(); rxRadio = findsdru();switchradioPlatformcase"USRP B2xx"包含({rxRadio idx =。平台},{'B200','B210'});case"USRP X3xx"包含({rxRadio idx =。平台},{'X300','X310'});case"USRP N2xx"包含({rxRadio idx =。Platform},'N200/N210/USRP2');endrxRadio = rxRadio(idx);if(length(txRadio) >= 1) && (length(rxRadio) >= 1) helperModClassSDRTest(rxRadio);elsedisp('Selected radios not found. Skipping over-the-air test.')endendend

When using two stationary ADALM-PLUTO radios separated by about 2 feet, the network achieves 99% overall accuracy with the following confusion matrix. Results will vary based on experimental setup.

进一步的空洞ation

It is possible to optimize the hyperparameters parameters, such as number of filters, filter size, or optimize the network structure, such as adding more layers, using different activation layers, etc. to improve the accuracy.

Communication Toolbox provides many more modulation types and channel impairments. For more information seeModulation(Communications Toolbox)andPropagation and Channel Models(Communications Toolbox)sections. You can also add standard specific signals withLTE Toolbox,WLAN Toolbox, and5G Toolbox. You can also add radar signals withPhased Array System Toolbox.

helperModClassGetModulatorfunction provides the MATLAB® functions used to generate modulated signals. You can also explore the following functions and System objects for more details:

References

  1. O'Shea, T. J., J. Corgan, and T. C. Clancy. "Convolutional Radio Modulation Recognition Networks." Preprint, submitted June 10, 2016.https://arxiv.org/abs/1602.04105

  2. O'Shea, T. J., T. Roy, and T. C. Clancy. "Over-the-Air Deep Learning Based Radio Signal Classification." IEEE Journal of Selected Topics in Signal Processing. Vol. 12, Number 1, 2018, pp. 168–179.

  3. Liu, X., D. Yang, and A. E. Gamal. "Deep Neural Network Architectures for Modulation Classification." Preprint, submitted January 5, 2018.https://arxiv.org/abs/1712.00443v3

See Also

|

相关的话题