Deep Learning Data Synthesis For 5G Channel Estimation
Deep Learning Data Synthesis For 5G Channel Estimation
%
% This example shows how to train a convolutional neural network (CNN) for
% channel estimation using Deep Learning Toolbox(TM) and data generated
% with 5G Toolbox(TM). Using the trained CNN, you perform channel
% estimation in single-input single-output (SISO) mode, utilizing the
% physical downlink shared channel (PDSCH) demodulation reference signal
% (DM-RS).
%
% Copyright 2019-2020 The MathWorks, Inc.
%% Introduction
%
% The general approach to channel estimation is to insert known reference
% pilot symbols into the transmission and then interpolate the rest of the
% channel response by using these pilot symbols.
%
% <<../DeepLearningDataSynthesis5G_ChEstimationOverview.png>>
%
%
% For an example showing how to use this channel estimation approach, see
% <docid:5g_ug#mw_new-radio-pdsch-throughput NR PDSCH
% Throughput>.
%
% You can also use deep learning techniques to perform channel estimation.
% For example, by viewing the PDSCH resource grid as a 2-D image, you can
% turn the problem of channel estimation into an image processing problem,
% similar to denoising or super-resolution, where CNNs are effective.
%
% Using 5G Toolbox, you can customize and generate standard-compliant
% waveforms and channel models to use as training data. Using Deep Learning
% Toolbox, you can use this training data to train a channel estimation
% CNN. This example shows how to generate such training data and how to
% train a channel estimation CNN. The example also shows how to use the
% channel estimation CNN to process images that contain linearly
% interpolated received pilot symbols. The example concludes by visualizing
% the results of the neural network channel estimator in comparison to
% practical and perfect estimators.
%
% <<../DeepLearningDataSynthesis5G_ExampleOverview.png>>
%
%% Neural Network Training
%
% Neural network training consists of these steps:
%
% * Data generation
% * Splitting the generated data into training and validation sets
% * Defining the CNN architecture
% * Specifying the training options, optimizer, and learning rate
% * Training the network
%
% Due to the large number of signals and possible scenarios, training can
% take several minutes. By default, training is disabled, a pretrained
% model is used. You can enable training by setting |trainModel| to true.
trainModel = false;
%%
% If you have Parallel Computing Toolbox(TM) installed and a supported
% CUDA-enabled NVIDIA(R) GPU set up, the network training uses GPU
% acceleration by default. The <docid:nnet_ref#bu6sn4c trainNetwork>
% function allows you to override this default behaviour. For a list of
% supported GPUs, see <docid:distcomp_ug#mw_57e04559-0b60-42d5-ad55-e77ec5f5865f GPU
Support by Release>.
%
% Data generation is set to produce 256 training examples or training data
% sets. This amount of data is sufficient to train a functional channel
% estimation network on a CPU in a reasonable time. For comparison, the
% pretrained model is based on 16,384 training examples.
%
% Training data of the CNN model has a fixed size dimensionality, the
% network can only accept 612-by-14-by-1 grids, i.e. 612 subcarriers, 14
% OFDM symbols and 1 antenna. Therefore, the model can only operate on a
% fixed bandwidth allocation, cyclic prefix length, and a single receive
% antenna.
%
% The CNN treats the resource grids as 2-D images, hence each element of
% the grid must be a real number. In a channel estimation scenario, the
% resource grids have complex data. Therefore, the real and imaginary parts
% of these grids are input separately to the CNN. In this example, the
% training data is converted from a complex 612-by-14 matrix into a
% real-valued 612-by-14-by-2 matrix, where the third dimension denotes the
% real and imaginary components. Because you have to input the real and
% imaginary grids into the neural network separately when making
% predictions, the example converts the training data into 4-D arrays of
% the form 612-by-14-by-1-by-2N, where N is the number of training
% examples.
%
% To ensure that the CNN does not overfit the training data, the training
% data is split into validation and training sets. The validation data is
% used for monitoring the performance of the trained neural network at
% regular intervals, as defined by |valFrequency|, approximately 5 per
% epoch. Stop training when the validation loss stops improving. In this
% instance, the validation data size is the same as the size of a single
% mini-batch due to the small size of the data set.
%
% The returned channel estimation CNN is trained on various channel
% configurations based on different delay spreads, doppler shifts, and SNR
% ranges between 0 and 10 dB.
% Set the random seed for reproducibility (this has no effect if a GPU is
% used)
rng(42)
if trainModel
% Generate the training data
[trainData,trainLabels] = hGenerateTrainingData(256);
% Split real and imaginary grids into 2 image sets, then concatenate
trainData = cat(4,trainData(:,:,1,:),trainData(:,:,2,:));
trainLabels = cat(4,trainLabels(:,:,1,:),trainLabels(:,:,2,:));
trainData = trainData(:,:,:,batchSize+1:end);
trainLabels = trainLabels(:,:,:,batchSize+1:end);
else
% Load pretrained network if trainModel is set to false
load('trainedChannelEstimationNetwork.mat')
end
%%
% Inspect the composition and individual layers of the model. The model has
% 5 convolutional layers. The input layer expects matrices of size
% 612-by-14, where 612 is the number of subcarriers and 14 is the number of
% OFDM symbols. Each element is a real number, since the real and imaginary
% parts of the complex grids are input separately.
channelEstimationCNN.Layers
%%
% Load the predefined simulation parameters, including the PDSCH
% parameters and DM-RS configuration. The returned object |carrier| is a
% valid carrier configuration object and |pdsch| is a PDSCH configuration
% structure set for a SISO transmission.
[gnb,carrier,pdsch] = hDeepLearningChanEstSimParameters();
%%
% Create a TDL channel model and set channel parameters. To compare
% different channel responses of the estimators, you can change these
% parameters later.
channel = nrTDLChannel;
channel.Seed = 0;
channel.DelayProfile = 'TDL-A';
channel.DelaySpread = 3e-7;
channel.MaximumDopplerShift = 50;
waveformInfo = nrOFDMInfo(carrier);
channel.SampleRate = waveformInfo.SampleRate;
%%
% Get the maximum number of delayed samples by a channel multipath
% component. This number is calculated from the channel path with the
% largest delay and the implementation delay of the channel filter. This
% number is needed to flush the channel filter when obtaining the received
% signal.
chInfo = info(channel);
maxChDelay = ceil(max(chInfo.PathDelays*channel.SampleRate))+chInfo.ChannelFilterDelay;
rxGrid = nrOFDMDemodulate(carrier,rxWaveform);
% Pad the grid with zeros in case an incomplete slot has been demodulated
[K,L,R] = size(rxGrid);
if (L < carrier.SymbolsPerSlot)
rxGrid = cat(2,rxGrid,zeros(K,carrier.SymbolsPerSlot-L,R));
end
% Concatenate the real and imaginary grids along the batch dimension
nnInput = cat(4,real(interpChannelGrid),imag(interpChannelGrid));
%%
% Calculate the mean squared error (MSE) of each estimation method.
neural_mse = mean(abs(estChannelGridPerfect(:) - estChannelGridNN(:)).^2);
interp_mse = mean(abs(estChannelGridPerfect(:) - interpChannelGrid(:)).^2);
practical_mse = mean(abs(estChannelGridPerfect(:) - estChannelGrid(:)).^2);
%%
% Plot the individual channel estimations and the actual channel
% realization obtained from the channel filter taps. Both the practical
% estimator and the neural network estimator outperform linear
% interpolation.
plotChEstimates(interpChannelGrid,estChannelGrid,estChannelGridNN,estChannelGridPerfec
t,...
interp_mse,practical_mse,neural_mse);
%% References
% # van de Beek, Jan–Jaap, Ove Edfors, Magnus Sandell, Sarah Kate
% Wilson, and Per Ola Borjesson. “On Channel Estimation in OFDM
% Systems.” In 1995 IEEE 45th Vehicular Technology Conference.
% Countdown to the Wireless Twenty–First Century, 2:815–19,
% July 1995.
% # Ye, Hao, Geoffrey Ye Li, and Biing-Hwang Juang. “Power of Deep
% Learning for Channel Estimation and Signal Detection in OFDM
% Systems.” IEEE Wireless Communications Letters 7, no. 1 (February
% 2018): 114–17.
% # Soltani, Mehran, Vahid Pourahmadi, Ali Mirzaei, and Hamid Sheikhzadeh.
% “Deep Learning–Based Channel Estimation.” Preprint, submitted
October 13,
% 2018.
%% Local Functions
% Find the row and column coordinates for a given DMRS configuration
[rows,cols] = find(rxDMRSGrid ~= 0);
dmrsSubs = [rows,cols,ones(size(cols))];
[l_hest,k_hest] = meshgrid(1:size(hest,2),1:size(hest,1));
end
% Main loop for data generation, iterating over the number of examples
% specified in the function call. Each iteration of the loop produces a
% new channel realization with a random delay spread, doppler shift,
% and delay profile. Every perturbed version of the transmitted
% waveform with the DM-RS symbols is stored in trainData, and the
% perfect channel realization in trainLabels.
for i = 1:numExamples
% Release the channel to change nontunable properties
channel.release
% Pick a random delay profile, delay spread, and maximum doppler shift
channel.DelayProfile = string(delayProfiles(randi([1 numel(delayProfiles)])));
channel.DelaySpread = randi([1 300])*1e-9;
channel.MaximumDopplerShift = randi([5 400]);
% Send data through the channel model. Append zeros at the end of
% the transmitted waveform to flush channel content. These zeros
% take into account any delay introduced in the channel, such as
% multipath delay and implementation delay. This value depends on
% the sampling rate, delay profile, and delay spread
txWaveform=[txWaveform_original;zeros(maxChDelay,size(txWaveform_original,2))];
[rxWaveform,pathGains,sampleTimes] = channel(txWaveform);
% Linear interpolation
dmrsRx = rxGrid(dmrsIndices);
dmrsEsts = dmrsRx .* conj(dmrsSymbols);
f = scatteredInterpolant(dmrsSubs(:,2),dmrsSubs(:,1),dmrsEsts);
hest = f(l_hest,k_hest);
function
plotChEstimates(interpChannelGrid,estChannelGrid,estChannelGridNN,estChannelGridPerfec
t,...
interp_mse,practical_mse,neural_mse)
% Plot the different channel estimates and display the measured MSE
figure
subplot(1,4,1)
imagesc(abs(interpChannelGrid));
xlabel('OFDM Symbol');
ylabel('Subcarrier');
title({'Linear Interpolation', ['MSE: ', num2str(interp_mse)]});
subplot(1,4,2)
imagesc(abs(estChannelGrid));
xlabel('OFDM Symbol');
ylabel('Subcarrier');
title({'Practical Estimator', ['MSE: ', num2str(practical_mse)]});
subplot(1,4,3)
imagesc(abs(estChannelGridNN));
xlabel('OFDM Symbol');
ylabel('Subcarrier');
title({'Neural Network', ['MSE: ', num2str(neural_mse)]});
subplot(1,4,4)
imagesc(abs(estChannelGridPerfect));
xlabel('OFDM Symbol');
ylabel('Subcarrier');
title({'Actual Channel'});
end