-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMNISTMLP.py
executable file
·70 lines (54 loc) · 2.01 KB
/
MNISTMLP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from backbone import MammothBackbone, num_flat_features, xavier
class MNISTMLP(MammothBackbone):
"""
Network composed of two hidden layers, each containing 100 ReLU activations.
Designed for the MNIST dataset.
"""
def __init__(self, input_size: int, output_size: int) -> None:
"""
Instantiates the layers of the network.
:param input_size: the size of the input data
:param output_size: the size of the output
"""
super(MNISTMLP, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.fc1 = nn.Linear(self.input_size, 100)
self.fc2 = nn.Linear(100, 100)
self._features = nn.Sequential(
self.fc1,
nn.ReLU(),
self.fc2,
nn.ReLU(),
)
self.classifier = nn.Linear(100, self.output_size)
self.net = nn.Sequential(self._features, self.classifier)
self.num_classes = output_size
self.reset_parameters()
def reset_parameters(self) -> None:
"""
Calls the Xavier parameter initialization function.
"""
self.net.apply(xavier)
def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor:
"""
Compute a forward pass.
:param x: input tensor (batch_size, input_size)
:return: output tensor (output_size)
"""
x = x.view(-1, num_flat_features(x))
feats = self._features(x)
if returnt == 'features':
return feats
out = self.classifier(feats)
if returnt == 'out':
return out, None
elif returnt == 'all':
return (out, feats)
raise NotImplementedError("Unknown return type")