-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjoint.py
executable file
·112 lines (93 loc) · 4.37 KB
/
joint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
from datasets.utils.validation import ValidationDataset
from torch.optim import SGD
from torchvision import transforms
from models.utils.continual_model import ContinualModel
from utils.args import add_management_args, add_experiment_args, ArgumentParser
from utils.status import progress_bar
def get_parser() -> ArgumentParser:
parser = ArgumentParser(description='Joint training: a strong, simple baseline.')
add_management_args(parser)
add_experiment_args(parser)
return parser
class Joint(ContinualModel):
NAME = 'joint'
COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
def __init__(self, backbone, loss, args, transform):
super(Joint, self).__init__(backbone, loss, args, transform)
self.old_data = []
self.old_labels = []
self.current_task = 0
def end_task(self, dataset):
if dataset.SETTING != 'domain-il':
self.old_data.append(dataset.train_loader.dataset.data)
self.old_labels.append(torch.tensor(dataset.train_loader.dataset.targets))
self.current_task += 1
# # for non-incremental joint training
if len(dataset.test_loaders) != dataset.N_TASKS:
return
# reinit network
self.net = dataset.get_backbone()
self.net.to(self.device)
self.net.train()
self.opt = SGD(self.net.parameters(), lr=self.args.lr)
# prepare dataloader
all_data, all_labels = None, None
for i in range(len(self.old_data)):
if all_data is None:
all_data = self.old_data[i]
all_labels = self.old_labels[i]
else:
all_data = np.concatenate([all_data, self.old_data[i]])
all_labels = np.concatenate([all_labels, self.old_labels[i]])
transform = dataset.TRANSFORM if dataset.TRANSFORM is not None else transforms.ToTensor()
temp_dataset = ValidationDataset(all_data, all_labels, transform=transform)
loader = torch.utils.data.DataLoader(temp_dataset, batch_size=self.args.batch_size, shuffle=True)
# train
for e in range(self.args.n_epochs):
for i, batch in enumerate(loader):
inputs, labels = batch
inputs, labels = inputs.to(self.device), labels.to(self.device)
self.opt.zero_grad()
outputs = self.net(inputs)
loss = self.loss(outputs, labels.long())
loss.backward()
self.opt.step()
progress_bar(i, len(loader), e, 'J', loss.item())
else:
self.old_data.append(dataset.train_loader)
# train
if len(dataset.test_loaders) != dataset.N_TASKS:
return
all_inputs = []
all_labels = []
for source in self.old_data:
for x, l, _ in source:
all_inputs.append(x)
all_labels.append(l)
all_inputs = torch.cat(all_inputs)
all_labels = torch.cat(all_labels)
bs = self.args.batch_size
scheduler = dataset.get_scheduler(self, self.args)
for e in range(self.args.n_epochs):
order = torch.randperm(len(all_inputs))
for i in range(int(math.ceil(len(all_inputs) / bs))):
inputs = all_inputs[order][i * bs: (i + 1) * bs]
labels = all_labels[order][i * bs: (i + 1) * bs]
inputs, labels = inputs.to(self.device), labels.to(self.device)
self.opt.zero_grad()
outputs = self.net(inputs)
loss = self.loss(outputs, labels.long())
loss.backward()
self.opt.step()
progress_bar(i, int(math.ceil(len(all_inputs) / bs)), e, 'J', loss.item())
if scheduler is not None:
scheduler.step()
def observe(self, inputs, labels, not_aug_inputs):
return 0