TNT: An Interpretable Tree-Network-Tree Learning Framework using Knowledge Distillation
Abstract
:1. Introduction
- We improve the traditional decision tree and propose a novel James–Stein Decision Tree (JSDT) to provide better embedding representation of leaf nodes, which is more robust for the input data and applicable for DNNs.
- Inspired by recent advances on the differentiable models, we propose a distillable Gradient Boosted Decision Tree (dGBDT), which could learn the dark knowledge from DNNs and has interpretability for the test cases.
- To simultaneously improve the robustness and interpretability of the deep models, we explore potential pipelines, data flows, and structures on leveraging the tree models. Based on the analysis, we propose the TNT framework and verify it with extensive experiments.
2. Related Works
2.1. Deep Models in Black Box
2.2. Tree Models
2.3. Knowledge Distillation
3. Proposed Tree Models
3.1. James–Stein Decision Trees
Algorithm 1 Feature selection of James–Stein Decision Tree (JSDT). |
Input: Current node N, the feature sets , the number of leaf nodes , and the stop condition. Output: The best split feature .
|
3.2. Distillable Gradient Boosted Decision Trees
Algorithm 2 Training Distillable Gradient Boosted Decision Trees (dGBDT). |
Input: Training batches , number of trees M, dGBDT parameters . Output: The updated dGBDT parameters .
|
4. Proposed TNT Framework
4.1. Tree-Network-Tree Learning Framework
4.2. Further Exploration
5. Experiments
5.1. Datasets and Setup
5.2. Robustness and Performance
5.3. Interpretability
5.3.1. Partial Dependence Plots
5.3.2. Classification Activation Mapping
6. Conclusions
Author Contributions
Funding
Conflicts of Interest
References
- Pan, Y.; Mei, T.; Yao, T.; Li, H.; Rui, Y. Jointly modeling embedding and translation to bridge video and language. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Las Vegas, NV, USA, 27–30 June 2016; pp. 4594–4602. [Google Scholar]
- Huang, L.; Wang, W.; Chen, J.; Wei, X.Y. Attention on attention for image captioning. In Proceedings of the IEEE International Conference on Computer Vision, Seoul, Korea, 27 Octorber–2 November 2019; pp. 4634–4643. [Google Scholar]
- Lu, J.; Yang, J.; Batra, D.; Parikh, D. Hierarchical question-image co-attention for visual question answering. In Proceedings of the Advances in Neural Information Processing Systems, Barcelona, Spain, 5–10 December 2016; pp. 289–297. [Google Scholar]
- Guidotti, R.; Monreale, A.; Ruggieri, S.; Turini, F.; Giannotti, F.; Pedreschi, D. A survey of methods for explaining black box models. ACM Comput. Surv. (CSUR) 2018, 51, 1–42. [Google Scholar] [CrossRef] [Green Version]
- Molnar, C. Interpretable Machine Learning: A Guide for Making Black Box Models Explainable. Available online: https://christophm.github.io/interpretable-ml-book/ (accessed on 6 June 2018).
- Che, Z.; Purushotham, S.; Khemani, R.; Liu, Y. Interpretable deep models for ICU outcome prediction. AMIA Annu. Symp. Proc. 2016, 2016, 371. [Google Scholar] [PubMed]
- Ozbayoglu, A.M.; Gudelek, M.U.; Sezer, O.B. Deep learning for financial applications: A survey. Appl. Soft Comput. 2020, 93, 106384. [Google Scholar] [CrossRef]
- Zhang, C.; Bengio, S.; Hardt, M.; Recht, B.; Vinyals, O. Understanding deep learning requires rethinking generalization. arXiv 2016, arXiv:1611.03530. [Google Scholar]
- Kontschieder, P.; Fiterau, M.; Criminisi, A.; Rota Bulo, S. Deep neural decision forests. In Proceedings of the IEEE International Conference on Computer Vision, Santiago, Chile, 13–16 December 2015; pp. 1467–1475. [Google Scholar]
- Ioannou, Y.; Robertson, D.; Zikic, D.; Kontschieder, P.; Shotton, J.; Brown, M.; Criminisi, A. Decision forests, convolutional networks and the models in-between. arXiv 2016, arXiv:1603.01250. [Google Scholar]
- Frosst, N.; Hinton, G. Distilling a neural network into a soft decision tree. arXiv 2017, arXiv:1711.09784. [Google Scholar]
- Feng, J.; Xu, Y.X.; Jiang, Y.; Zhou, Z.H. Soft Gradient Boosting Machine. arXiv 2020, arXiv:2006.04059. [Google Scholar]
- Wang, X.; He, X.; Feng, F.; Nie, L.; Chua, T.S. Tem: Tree-enhanced embedding model for explainable recommendation. In Proceedings of the 2018 World Wide Web Conference, Lyon, France, 23–27 April 2018; pp. 1543–1552. [Google Scholar]
- Goodfellow, I.; Bengio, Y.; Courville, A. Deep Learning; MIT Press: Cambridge, MA, USA, 2016. [Google Scholar]
- Chen, T.; Guestrin, C. Xgboost: A scalable tree boosting system. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, New York, NY, USA, 13–17 August 2016; pp. 785–794. [Google Scholar]
- Zhao, B.; Xiao, X.; Zhang, W.; Zhang, B.; Gan, G.; Xia, S. Self-Paced Probabilistic Principal Component Analysis for Data with Outliers. In Proceedings of the 2020 IEEE International Conference on Acoustics, Speech and Signal Processing, Barcelona, Spain, 4–8 May 2020; pp. 3737–3741. [Google Scholar]
- Li, J.; Dai, T.; Tang, Q.; Xing, Y.; Xia, S.T. Cyclic annealing training convolutional neural networks for image classification with noisy labels. In Proceedings of the 2018 IEEE International Conference on Image Processing, Athens, Greece, 7–10 October 2018; pp. 21–25. [Google Scholar]
- Papernot, N.; McDaniel, P.; Goodfellow, I. Transferability in machine learning: From phenomena to black-box attacks using adversarial samples. arXiv 2016, arXiv:1605.07277. [Google Scholar]
- Chen, X.; Yan, X.; Zheng, F.; Jiang, Y.; Xia, S.; Zhao, Y.; Ji, R. One-Shot Adversarial Attacks on Visual Tracking With Dual Attention. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Virtual, Seattle, WA, USA, 13–19 June 2020; pp. 10176–10185. [Google Scholar]
- Yan, X.; Chen, X.; Jiang, Y.; Xia, S.; Zhao, Y.; Zheng, F. Hijacking Tracker: A Powerful Adversarial Attack on Visual Tracking. In Proceedings of the 2020 IEEE International Conference on Acoustics, Speech and Signal Processing, Barcelona, Spain, 4–8 May 2020; pp. 2897–2901. [Google Scholar]
- Zhang, H.; Yu, Y.; Jiao, J.; Xing, E.; El Ghaoui, L.; Jordan, M.I. Theoretically Principled Trade-off between Robustness and Accuracy. arXiv 2019, arXiv:1901.08573. [Google Scholar]
- Krizhevsky, A.; Sutskever, I.; Hinton, G.E. Imagenet classification with deep convolutional neural networks. In Proceedings of the Advances in Neural Information Processing Systems, Lake Tahoe, NV, USA, 3–6 December 2012; pp. 1097–1105. [Google Scholar]
- Zhou, B.; Khosla, A.; Lapedriza, A.; Oliva, A.; Torralba, A. Learning deep features for discriminative localization. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Las Vegas, NV, USA, 27–30 June 2016; pp. 2921–2929. [Google Scholar]
- Wang, W.; Zhu, M.; Wang, J.; Zeng, X.; Yang, Z. End-to-end encrypted traffic classification with one-dimensional convolution neural networks. In Proceedings of the 2017 IEEE International Conference on Intelligence and Security Informatics, Beijing, China, 22–24 July 2017; pp. 43–48. [Google Scholar]
- Bai, S.; Kolter, J.Z.; Koltun, V. An empirical evaluation of generic convolutional and recurrent networks for sequence modeling. arXiv 2018, arXiv:1803.01271. [Google Scholar]
- Breiman, L. Random forests. Mach. Learn. 2001, 45, 5–32. [Google Scholar] [CrossRef] [Green Version]
- Friedman, J.H. Greedy function approximation: A gradient boosting machine. Ann. Stat. 2001, 29, 1189–1232. [Google Scholar] [CrossRef]
- Chen, H.; Zhang, H.; Boning, D.; Hsieh, C.J. Robust Decision Trees Against Adversarial Examples. In Proceedings of the International Conference on Machine Learning, Long Beach, CA, USA, 10–15 June 2019; pp. 1122–1131. [Google Scholar]
- Bai, J.; Li, Y.; Li, J.; Jiang, Y.; Xia, S. Rectified Decision Trees: Exploring the Landscape of Interpretable and Effective Machine Learning. arXiv 2020, arXiv:2008.09413. [Google Scholar]
- Chen, H.; Zhang, H.; Si, S.; Li, Y.; Boning, D.; Hsieh, C.J. Robustness verification of tree-based models. In Proceedings of the Advances in Neural Information Processing Systems, Vancouver, BC, Canada, 8–14 December 2019; pp. 12317–12328. [Google Scholar]
- Ranzato, F.; Zanella, M. Robustness Verification of Decision Tree Ensembles. OVERLAY@ AI* IA 2019, 2509, 59–64. [Google Scholar]
- Cheng, H.T.; Koc, L.; Harmsen, J.; Shaked, T.; Chandra, T.; Aradhye, H.; Anderson, G.; Corrado, G.; Chai, W.; Ispir, M.; et al. Wide & deep learning for recommender systems. In Proceedings of the 1st workshop on Deep Learning for Recommender Systems, Boston, MA, USA, 15–19 September 2016; pp. 7–10. [Google Scholar]
- Irsoy, O.; Yıldız, O.T.; Alpaydın, E. Soft decision trees. In Proceedings of the 21st International Conference on Pattern Recognition, Tsukuba, Japan, 11–15 November 2012; pp. 1819–1822. [Google Scholar]
- Zhou, Z.H.; Feng, J. Deep forest: Towards an alternative to deep neural networks. In Proceedings of the 26th International Joint Conference on Artificial Intelligence, Vancouver, BC, Canada, 8–14 December 2019; pp. 3553–3559. [Google Scholar]
- Rota Bulo, S.; Kontschieder, P. Neural decision forests for semantic image labelling. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Columbus, OH, USA, 24–27 June 2014; pp. 81–88. [Google Scholar]
- Hinton, G.; Vinyals, O.; Dean, J. Distilling the knowledge in a neural network. arXiv 2015, arXiv:1503.02531. [Google Scholar]
- Li, J.; Xiang, X.; Dai, T.; Xia, S.T. Making Large Ensemble of Convolutional Neural Networks via Bootstrap Re-sampling. In Proceedings of the 2019 IEEE Visual Communications and Image Processing, Sydney, Australia, 1–4 December 2019; pp. 1–4. [Google Scholar]
- Li, J.; Li, Y.; Yang, J.; Guo, T.; Xia, S.T. UA-DRN: Unbiased Aggregation of Deep Neural Networks for Regression Ensemble. Aust. J. Intell. Inf. Process. Syst. 2019, 15, 86–93. [Google Scholar]
- Yim, J.; Joo, D.; Bae, J.; Kim, J. A gift from knowledge distillation: Fast optimization, network minimization and transfer learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Honolulu, HI, USA, 21–26 July 2017; pp. 4133–4141. [Google Scholar]
- Chen, Z.; Zhang, L.; Cao, Z.; Guo, J. Distilling the knowledge from handcrafted features for human activity recognition. IEEE Trans. Ind. Inform. 2018, 14, 4334–4342. [Google Scholar] [CrossRef]
- Shen, Z.; He, Z.; Xue, X. Meal: Multi-model ensemble via adversarial learning. In Proceedings of the AAAI Conference on Artificial Intelligence, Honolulu, HI, USA, 27 January–1 February 2019; Volume 33, pp. 4886–4893. [Google Scholar]
- Breiman, L.; Friedman, J.; Stone, C.J.; Olshen, R.A. Classification and Regression Trees; CRC Press: Boca Raton, FL, USA, 1984. [Google Scholar]
- Xiang, X.; Tang, Q.; Zhang, H.; Dai, T.; Li, J.; Xia, S. JSRT: James-Stein Regression Tree. arXiv 2020, arXiv:2010.09022. [Google Scholar]
- James, W.; Stein, C. Estimation with quadratic loss. In Proceedings of the 4th Berkeley Symposium on Mathematical Statistics and Probability, Berkeley, CA, USA, 20 June–30 July 1960; University of California Press: Berkeley, CA, USA, 1961; Volume 1. [Google Scholar]
- Efron, B.; Hastie, T. Computer Age Statistical Inference; Cambridge University Press: Cambridge, UK, 2016. [Google Scholar]
- Bock, M.E. Minimax estimators of the mean of a multivariate normal distribution. Ann. Stat. 1975, 3, 209–218. [Google Scholar] [CrossRef]
- Feldman, S.; Gupta, M.; Frigyik, B. Multi-task averaging. In Proceedings of the Advances in Neural Information Processing Systems, Stateline, NV, USA, 3–8 December 2012; pp. 1169–1177. [Google Scholar]
- Shi, T.; Agostinelli, F.; Staib, M.; Wipf, D.; Moscibroda, T. Improving survey aggregation with sparsely represented signals. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, San Francisco, CA, USA, 13–17 August 2016; pp. 1845–1854. [Google Scholar]
- Ke, G.; Meng, Q.; Finley, T.; Wang, T.; Chen, W.; Ma, W.; Ye, Q.; Liu, T.Y. Lightgbm: A highly efficient gradient boosting decision tree. In Proceedings of the Advances in Neural Information Processing Systems, Long Beach, CA, USA, 4–9 December 2017; pp. 3146–3154. [Google Scholar]
- Huang, J.; Li, G.; Yan, Z.; Luo, F.; Li, S. Joint learning of interpretation and distillation. arXiv 2020, arXiv:2005.11638. [Google Scholar]
- Ke, G.; Xu, Z.; Zhang, J.; Bian, J.; Liu, T.Y. DeepGBM: A deep learning framework distilled by GBDT for online prediction tasks. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, Anchorage, AK, USA, 4–8 August 2019; pp. 384–394. [Google Scholar]
- Fukui, S.; Yu, J.; Hashimoto, M. Distilling Knowledge for Non-Neural Networks. In Proceedings of the 2019 Asia-Pacific Signal and Information Processing Association Annual Summit and Conference, Lanzhou, China, 18–21 November 2019; pp. 1411–1416. [Google Scholar]
- Feng, F.; He, X.; Wang, X.; Luo, C.; Liu, Y.; Chua, T.S. Temporal relational ranking for stock prediction. ACM Trans. Inf. Syst. 2019, 37, 1–30. [Google Scholar] [CrossRef] [Green Version]
- Caicedo-Torres, W.; Gutierrez, J. ISeeU: Visually interpretable deep learning for mortality prediction inside the ICU. J. Biomed. Inform. 2019, 98, 103269. [Google Scholar] [CrossRef] [PubMed] [Green Version]
- He, X.; Yang, X.; Zhang, S.; Zhao, J.; Zhang, Y.; Xing, E.; Xie, P. Sample-Efficient Deep Learning for COVID-19 Diagnosis Based on CT Scans. medRxiv 2020. [Google Scholar] [CrossRef]
- Guo, H.; Tang, R.; Ye, Y.; Li, Z.; He, X. DeepFM: A factorization-machine based neural network for CTR prediction. In Proceedings of the 26th International Joint Conference on Artificial Intelligence, Melbourne, Australia, 19–25 August 2017; pp. 1725–1731. [Google Scholar]
T | N | T-N | N-T | T-N-T | N-T-N | T-N-T-N | Others | |
---|---|---|---|---|---|---|---|---|
Performance | x | √ | √ | √ | √ | √ | makes sense | redundant |
Robustness | √ | x | √ | x | √ | x | but is | and not |
Interpretability | √ | x | x | √ | √ | x | redundant | necessary |
Size | Task Description | Size | Task Description | ||
---|---|---|---|---|---|
Cancer | 569 × 30 | Risk Probability Prediction | NASDAQ | 1026 × 1245 | Relational Stock Ranking |
Criteo | 51.8 M × 39 | Click Rate Prediction | MIMIC-III | 38,425 × 22 | ICU Mortality Prediction |
Methods | Cancer (No Sparse) | Cancer (20% Sparse) | Cancer (40% Sparse) | ||||
---|---|---|---|---|---|---|---|
AUROC | AUPRC | AUROC | AUPRC | AUROC | AUPRC | ||
Tree Models | CART (single tree) | 0.9367 | 0.9529 | 0.9273 | 0.9449 | 0.9114 | 0.9424 |
JSDT (single tree) | 0.9449 | 0.9561 | 0.9341 | 0.9496 | 0.9185 | 0.9480 | |
Deep Models | DNNs (6-layer MLP) | 0.9665 | 0.9522 | 0.9394 | 0.9428 | 0.9288 | 0.9227 |
T & N Fusion | W & D (DAG pattern) | 0.9779 | 0.9496 | 0.9565 | 0.9423 | 0.9468 | 0.9312 |
CART-DNNs (T-N) | 0.9742 | 0.9463 | 0.9610 | 0.9428 | 0.9474 | 0.9357 | |
JSDT-DNNs (T-N) | 0.9784 | 0.9531 | 0.9629 | 0.9487 | 0.9523 | 0.9398 | |
DNNs-SDT (N-T) | 0.9620 | 0.9440 | 0.9381 | 0.9331 | 0.9223 | 0.9207 | |
Proposed TNT | CART-DNNs-SDT | 0.9674 | 0.9460 | 0.9602 | 0.9387 | 0.9436 | 0.9340 |
JSDT-DNNs-SDT | 0.9723 | 0.9471 | 0.9626 | 0.9406 | 0.9488 | 0.9389 |
Methods | Criteo | NASDAQ | MIMIC-III | ||||
---|---|---|---|---|---|---|---|
AUROC | LogLoss | MSE | MRR | AUROC | AUPRC | ||
Tree Models | GBDT (tree ensemble) | 0.7853 | 0.46425 | 6.04 × 10 | 2.95 × 10 | 0.7836 | 0.4371 |
sGBM (tree ensemble) | 0.7889 | 0.46267 | 5.72 × 10 | 3.27 × 10 | 0.7883 | 0.4420 | |
Deep Models | DFM/rLSTM/Conv | 0.8004 | 0.45039 | 3.88 × 10 | 4.13 × 10 | 0.8728 | 0.5327 |
T&N Fusion | W&D (DAG pattern) | 0.7970 | 0.45942 | 4.60 × 10 | 3.92 × 10 | 0.8783 | 0.5351 |
GBDT-[DNNs] (T-N) | 0.8136 | 0.44695 | 3.43 × 10 | 4.25 × 10 | 0.8949 | 0.5482 | |
JSDF-[DNNs] (T-N) | 0.8168 | 0.44237 | 3.27 × 10 | 4.43 × 10 | 0.9015 | 0.5503 | |
[DNNs]-sGBM (N-T) | 0.7958 | 0.46041 | 4.24 × 10 | 3.53 × 10 | 0.8689 | 0.5217 | |
Proposed TNT | GBDT-[DNNs]-sGBM | 0.8044 | 0.45733 | 3.78 × 10 | 4.18 × 10 | 0.8694 | 0.5410 |
GBDT-[DNNs]-dGBDT | 0.8079 | 0.44980 | 3.64 × 10 | 4.23 × 10 | 0.8916 | 0.5425 | |
JSDF-[DNNs]-dGBDT | 0.8095 | 0.44887 | 3.51 × 10 | 4.29 × 10 | 0.8988 | 0.5433 |
Publisher’s Note: MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations. |
© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).
Share and Cite
Li, J.; Li, Y.; Xiang, X.; Xia, S.-T.; Dong, S.; Cai, Y. TNT: An Interpretable Tree-Network-Tree Learning Framework using Knowledge Distillation. Entropy 2020, 22, 1203. https://doi.org/10.3390/e22111203
Li J, Li Y, Xiang X, Xia S-T, Dong S, Cai Y. TNT: An Interpretable Tree-Network-Tree Learning Framework using Knowledge Distillation. Entropy. 2020; 22(11):1203. https://doi.org/10.3390/e22111203
Chicago/Turabian StyleLi, Jiawei, Yiming Li, Xingchun Xiang, Shu-Tao Xia, Siyi Dong, and Yun Cai. 2020. "TNT: An Interpretable Tree-Network-Tree Learning Framework using Knowledge Distillation" Entropy 22, no. 11: 1203. https://doi.org/10.3390/e22111203