Leveraging Topological Guidance for Improved Knowledge Distillation, GRaM workshop @ ICML 2024
This repository is of our proposed method, TGD, where we use topological features in knowledge distillation (KD) to train and evaluate a light weight model for Cifar10 and CINIC10 dataset.
The CIFAR-10 data can be downloaded at link. You would have to unzip the cifar10.zip file inorder to access the data. The data follows the following file format:
Data
└── cifar10
├──test # Contains the Org testing images
├──test_PIc # Contains the Column wise PI testing images
├──test_PIr # Contains the Row wise PI testing images
├──train # Contains the Org training images
├──train_PIc # Contains the Column wise PI training images
├──train_PIr # Contains the Row wise PI training images
├──labels.txt # File containing the list of labels
├──test.csv
└──train_csv
If you want to generate the PI images for some specific data, you can use the following GenPI_image.ipynb file.
The GenPI_Images.ipynb File was used to generate the PI images. To generate the PI image we follow the following steps:
- Normalize the images in range [0,1]
- The code mentioned above will generate the row wise PI images, in order to generate the columns wise PI images you would need to uncomment the following line:
x_data.T
The following statement is mentioned in the gen_PI_image function.
Note: You would need to specify the data directory. Please verify the directory before excuting the GenPI_Images.ipynb file.
To train the model you will need to run the train.py script. This script will accept the following arguments:
- epochs: It is used to define the number of epochs. Default value is 200.
- dataset: It is used to define the dataset we want to train. It could be can be either cifar10 or cifar100 or cinic10. Default value is 'cifar10'
- batch_size: It is used to define the batch size. Default value is 128
- alpha: It is define the alpha parameter. It is one of the hyperparam that's used while calculating the loss. Default value is 0.95
- learning_rate: It is used to define the initial learning rate. Default value is 0.1
- momentum: It is used to define the SGD momentum. Default value is 0.9
- weight_decay: It is used to define the SGD weight decay (default: 1e-4).
- teacher1: It is used to define the model architecture for 1st teacher model
- teacher2: It is used to define the model architecture for 2nd teacher model
- student: It is used to define the model architecture for student model.
- teacher_checkpoint1: It's an optional argument. It is used to define the pretrained model checkpoint for 1st teacher model
- teacher_checkpoint2: It's an optional argument. It is used to define the pretrained model checkpoint for 2nd teacher model
- student_checkpoint: : It's an optional argument. It is used to define the pretrained model checkpoint for student model
- cuda: It is used to define the whether or not use cuda(train on GPU).
- dataset_dir: It is used to define the dataset directory
- trial: It is used to define the trial memo number
- sbj: It is used to define the sbj number
- seed: It is used to define the seed for given experiment
- save_weight: By default it is set to 0. If we wanna save the train weight of our model we should set it as 1.
Command to train the model:
python3 main_train.py --epochs 200 --alpha 0.99 --teacher1 wrn163 --teacher2 wrn163 --teacher_checkpoint1 ./models/wrn163_Teacher1.pth.tar --teacher_checkpoint2 ./models/wrn163_Teacher2.pth.tar --student wrn161 --cuda 1 --dataset cifar10 --batch-size 128 --trial T_wrn163_163_S_wrn161_TGD --seed 1234 --save_weight 0 --student_checkpoint ./models/wrn161_Student.pth.tar
To train the model you will need to run the main_eval.py script.
Note: The main_eval.py script also contains most of the arguments similar to the train.py script as mentioned above. The main_eval.py script only contains the 'dataset', 'batch_size', 'student', 'student_checkpoint', 'cuda', 'dataset_dir', 'trial' and 'seed' arguments.
Command to evaluate the model:
python3 main_eval.py --student wrn161 --batch-size 1 --cuda 1 --dataset cifar10 --trial eval_161 --seed 1234 --save_weight 0 --student_checkpoint ./models/T163_S161_TGD.pth.tar