Unbalanced Data Loading For Multi-Task Learning in PyTorch (Blog)
Unbalanced Data Loading For Multi-Task Learning in PyTorch (Blog)
Working on multi-task learning (MTL) problems require a unique training setup, mainly
in terms of data handling, model architecture, and performance evaluation metrics.
In this post, I am reviewing the data handling part. Specifically, how to train a multi-task
Read more on Medium. Create a free account.
learning model on multiple datasets and how to handle tasks with a highly unbalanced
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 1/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
dataset.
1. Combining two (or more) datasets into a single PyTorch Dataset. This dataset will be
the input for a PyTorch DataLoader.
2. Modifying the batch preparation process to produce either one task in each batch or
alternatively mix samples from both tasks in each batch.
3. Handling the highly unbalanced datasets at the batch level by using a batch sampler
as part of the DataLoader.
I am only reviewing Dataset and DataLoader related code, ignoring other important
modules like the model, optimizer and metrics definition.
. . .
For simplicity, I am using a generic two dataset example. However, the number of
datasets and the type of data should not affect the main setup. We can even use several
instances of the same dataset, in case we have more than one set of labels for the same
set of samples. For example, a dataset of images with an object class and a spatial
location, or a face emotions dataset with facial emotion and age labeling per image.
A PyTorch Dataset class needs to implement the __getitem__() function. This function
handles samples fetching and preparation for a given index. When using two datasets, it
is then possible to have two different methods of creating samples. Hence, we can even
use a single dataset, get samples with different labels, and change the samples
processing scheme (the output samples should have the same shape since we stack them
as a batch tensor).
1 import torch
2 from torch.utils.data.dataset import ConcatDataset
We define two (binary) datasets, one with ten samples of ±1 (equally distributed), and
the second with 55 samples, 50 samples of the digit 5, and 5 samples of the digit -5.
These datasets are only for illustration. In real datasets, you should have both the
samples and the labels, you will probably read the data from a database or parse it from
data folders, but these simple datasets are enough to understand the main concepts.
Next, we need to define a DataLoader. We provide it with our concat_dataset and set the
loader
Read parameters,
more such as
on Medium. Create theaccount.
a free batch size, and whether or not to shuffle the samples.
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 3/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
1 batch_size = 8
2
3 # basic dataloader
4 dataloader = torch.utils.data.DataLoader(dataset=concat_dataset,
5 batch_size=batch_size,
6 shuffle=True)
7
8 for inputs in dataloader:
9 print(inputs)
Each batch is a tensor of 8 samples from our concat_dataset. The order is set randomly,
and samples are selected from the pool of samples.
Until now, everything was relatively straight forward. The datasets are combined into a
single one, and samples are randomly picked from both of the original datasets to
construct the mini-batch. Now let’s try to control and manipulate the samples in each
batch. We want to get samples from only one dataset in each mini-batch, switching
between them every other batch.
1 import torch
2 from torch.utils.data.sampler import RandomSampler
3
4
5 class BatchSchedulerSampler(torch.utils.data.sampler.Sampler):
6 """
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 4/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
9 def __init__(self, dataset, batch_size):
10 self.dataset = dataset
11 self.batch_size = batch_size
12 self.number_of_datasets = len(dataset.datasets)
13
14 def __len__(self):
15 return len(self.dataset) * self.number_of_datasets
16
17 def __iter__(self):
18 samplers_list = []
19 sampler_iterators = []
20 datasets_length = []
21 for dataset_idx in range(self.number_of_datasets):
22 cur_dataset = self.dataset.datasets[dataset_idx]
23 sampler = RandomSampler(cur_dataset)
24 samplers_list.append(sampler)
25 cur_sampler_iterator = sampler.__iter__()
26 sampler_iterators.append(cur_sampler_iterator)
27 datasets_length.append(len(cur_dataset))
28
29 push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
30 step = self.batch_size * self.number_of_datasets
31 samples_to_grab = self.batch_size
32 largest_dataset_index = torch.argmax(torch.as_tensor(datasets_length)).item()
33 # for this case we want to get all samples in dataset, this force us to resample
34 epoch_samples = datasets_length[largest_dataset_index] * self.number_of_datasets
35
36 final_samples_list = [] # this is a list of indexes from the combined dataset
37 for _ in range(0, epoch_samples, step):
38 for i in range(self.number_of_datasets):
39 cur_batch_sampler = sampler_iterators[i]
40 cur_samples = []
41 for _ in range(samples_to_grab):
42 try:
43 cur_sample_org = cur_batch_sampler.__next__()
44 cur_sample = cur_sample_org + push_index_val[i]
45 cur_samples.append(cur_sample)
46 except StopIteration:
47 if i == largest_dataset_index:
48 # largest dataset iterator is done we can break
49 samples_to_grab = len(cur_samples) # adjusting the samples_
50 # got to the end of iterator - extend final list and continu
51 break
Read
52 more on Medium. Create a free account.
else:
53 # restart the iterator - we want more samples until finishin
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 5/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
54 sampler_iterators[i] = samplers_list[i].__iter__()
55 cur_batch_sampler = sampler_iterators[i]
56 cur_sample_org = cur_batch_sampler.__next__()
57 cur_sample = cur_sample_org + push_index_val[i]
58 cur_samples.append(cur_sample)
59 final_samples_list.extend(cur_samples)
60
61 return iter(final_samples_list)
Now let’s run and print the samples using a new DataLoader, which gets our
BatchSchedulerSampler as an input sampler (shuffle can’t be set to True when working
with a sampler).
1 import torch
2 from multi_task_batch_scheduler import BatchSchedulerSampler
3
4 batch_size = 8
5
6 # dataloader with BatchSchedulerSampler
7 dataloader = torch.utils.data.DataLoader(dataset=concat_dataset,
8 sampler=BatchSchedulerSampler(dataset=concat_da
9 batch_size=batch_
10 batch_size=batch_size,
11 shuffle=False)
12
13 for inputs in dataloader:
14 print(inputs)
Themore
Read output now looks
on Medium. like
Create this:account.
a free
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 6/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
Hurray!!!
For each mini-batch we now get only one dataset samples.
We can play with this type of scheduling in order to downsample or upsample more
important tasks.
The remaining problem in our batches now comes from the second highly unbalanced
dataset. This is often the case in MTL, having a main task and a few other satellite sub-
tasks. Training the main task and sub-tasks together might lead to improve performance
and contribute to the generalization of the overall model. The problem is that samples of
the sub-tasks are often very sparse, having only a few positive (or negative) samples.
Let’s use our previous logic but also forcing a balanced batch with respect to the
distribution of samples in each task.
To handle the unbalanced issue, we need to replace the random sampler in the
BatchSchedulerSampler class with an ImbalancedDatasetSampler (I am using a great
implementation from this repository). This class handles the balancing of the dataset.
We can also mix and use RandomSampler for some tasks and ImbalancedDatasetSampler
for others.
1 import torch
2 from torch.utils.data import RandomSampler
3 from sampler import ImbalancedDatasetSampler
4
5
Read more on Medium. Create a free account.
6 class ExampleImbalancedDatasetSampler(ImbalancedDatasetSampler):
7 """
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 7/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
7
8 ImbalancedDatasetSampler is taken from https://github.com/ufoym/imbalanced-dataset-s
9 In order to be able to show the usage of ImbalancedDatasetSampler in this example I
10 to fit my datasets
11 """
12 def _get_label(self, dataset, idx):
13 return dataset.samples[idx].item()
14
15
16 class BalancedBatchSchedulerSampler(torch.utils.data.sampler.Sampler):
17 """
18 iterate over tasks and provide a balanced batch per task in each mini-batch
19 """
20 def __init__(self, dataset, batch_size):
21 self.dataset = dataset
22 self.batch_size = batch_size
23 self.number_of_datasets = len(dataset.datasets)
24
25 def __len__(self):
26 return len(self.dataset) * self.number_of_datasets
27
28 def __iter__(self):
29 samplers_list = []
30 sampler_iterators = []
31 datasets_length = []
32 for dataset_idx in range(self.number_of_datasets):
33 cur_dataset = self.dataset.datasets[dataset_idx]
34 if dataset_idx == 0:
35 # the first dataset is kept at RandomSampler
36 sampler = RandomSampler(cur_dataset)
37 else:
38 # the second unbalanced dataset is changed
39 sampler = ExampleImbalancedDatasetSampler(cur_dataset)
40 samplers_list.append(sampler)
41 cur_sampler_iterator = sampler.__iter__()
42 sampler_iterators.append(cur_sampler_iterator)
43 datasets_length.append(len(cur_dataset))
44
45 push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
46 step = self.batch_size * self.number_of_datasets
47 samples_to_grab = self.batch_size
48 largest_dataset_index = torch.argmax(torch.as_tensor(datasets_length)).item()
49 # for this case we want to get all samples in dataset, this force us to resample
Read
50 more on Medium. Create a free
epoch_samples account.
= datasets_length[largest_dataset_index] * self.number_of_datasets
51
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 8/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
1 import torch
2 from balanced_sampler import BalancedBatchSchedulerSampler
3
Read4 more on Medium.
batch_size = 8Create a free account.
5
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 9/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
The mini-batches of the unbalanced task are now much more balanced.
There is a lot of room to play with this setup even further. We can combine the tasks in a
balanced way, and by setting the samples_to_grab to 4, which is half of the batch size, we
can get a mixed mini-batch with 4 samples taken from each task. To produce a ratio of
1:2 toward a more important task, we can set samples_to_grab=2 for the first task and
samples_to_grab=6 for the second task.
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 10/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
Machine Learning Pytorch Multi Task Learning Data Handling Unbalanced Data
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 11/11