Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

Cs181 Textbook

Download as pdf or txt
Download as pdf or txt
You are on page 1of 163

Undergraduate Fundamentals

of Machine Learning

The initial version of this textbook was created by William J. Deuschle for his senior thesis,
based on his notes of CS181 during the Spring of 2017. This textbook has since been maintained
by the CS181 course staff with bug fixes from many CS181 students.
Contents

1 Introduction to Machine Learning 1


1.1 What is Machine Learning? . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 1
1.2 What Will This Book Teach Me? . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 1
1.3 Our Machine Learning Framework . . . . . . . . . . . . . . . . . . . . . . . . . . . . 2
1.4 This Book’s Notation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 3
Mathematical and Statistical Notation . . . . . . . . . . . . . . . . . . . . . . 3
Textbook Specific Notation . . . . . . . . . . . . . . . . . . . . . . . . . . . . 3

2 Regression 4
2.1 Defining the Problem . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 4
2.2 Solution Options . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 4
2.2.1 K-Nearest-Neighbors . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
2.2.2 Neural Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
2.2.3 Random Forests . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
2.2.4 Gradient Boosted Trees . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
2.2.5 Turning to Linear Regression . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
2.3 Introduction to Linear Regression . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 6
2.4 Basic Setup . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 6
2.4.1 Merging of Bias . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 7
2.4.2 Visualization of Linear Regression . . . . . . . . . . . . . . . . . . . . . . . . 7
2.5 Finding the Best Fitting Line . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 8
2.5.1 Objective Functions and Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . 8
2.5.2 Least Squares Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 9
2.6 Linear Regression Algorithms . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 10
2.6.1 Optimal Weights via Matrix Differentiation . . . . . . . . . . . . . . . . . . . 10
2.6.2 Bayesian Solution: Maximum Likelihood Estimation . . . . . . . . . . . . . . 11
2.6.3 Alternate Interpretation: Linear Regression as Projection . . . . . . . . . . . 13
2.7 Model Flexibility . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 13
2.7.1 Basis Functions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 13
2.7.2 Regularization . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 16

ii
CONTENTS iii

2.7.3 Generalizing Regularization . . . . . . . . . . . . . . . . . . . . . . . . . . . . 19


2.7.4 Bayesian Regularization . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 20
2.8 Choosing Between Models . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 22
2.8.1 Bias-Variance Tradeoff and Decomposition . . . . . . . . . . . . . . . . . . . 22
2.8.2 Cross-Validation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 26
2.8.3 Making a Model Choice . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 26
2.8.4 Bayesian Model Averaging . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 27
2.9 Linear Regression Extras . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 27
2.9.1 Predictive Distribution . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 27
2.10 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 28

3 Classification 29
3.1 Defining the Problem . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29
3.2 Solution Options . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29
3.3 Discriminant Functions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 30
3.3.1 Basic Setup: Binary Linear Classification . . . . . . . . . . . . . . . . . . . . 30
3.3.2 Multiple Classes . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31
3.3.3 Basis Changes in Classification . . . . . . . . . . . . . . . . . . . . . . . . . . 31
3.4 Numerical Parameter Optimization and Gradient Descent . . . . . . . . . . . . . . . 33
3.4.1 Gradient Descent . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 35
3.4.2 Batch Gradient Descent versus Stochastic Gradient Descent . . . . . . . . . . 36
3.5 Objectives for Decision Boundaries . . . . . . . . . . . . . . . . . . . . . . . . . . . . 36
3.5.1 0/1 Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 36
3.5.2 Least Squares Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 37
3.5.3 Hinge Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 38
3.6 Probabilistic Methods . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 39
3.6.1 Probabilistic Discriminative Models . . . . . . . . . . . . . . . . . . . . . . . 39
Logistic Regression . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 40
Multi-Class Logistic Regression and Softmax . . . . . . . . . . . . . . . . . . 42
3.6.2 Probabilistic Generative Models . . . . . . . . . . . . . . . . . . . . . . . . . 44
Classification in the Generative Setting . . . . . . . . . . . . . . . . . . . . . 44
MLE Solution . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 45
Naive Bayes . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 47
3.7 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 49

4 Neural Networks 50
4.1 Motivation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 50
4.1.1 Comparison to Other Methods . . . . . . . . . . . . . . . . . . . . . . . . . . 51
4.1.2 Universal Function Approximation . . . . . . . . . . . . . . . . . . . . . . . . 51
4.2 Feed-Forward Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 52
iv CONTENTS

4.3 Neural Network Basics and Terminology . . . . . . . . . . . . . . . . . . . . . . . . . 52


4.3.1 Adaptive Basis Functions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 53
4.4 Network Training . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 56
4.4.1 Objective Function . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 56
4.4.2 Optimizing Parameters . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 56
4.4.3 Backpropagation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 57
4.4.4 Computing Derivatives Using Backpropagation . . . . . . . . . . . . . . . . . 57
4.5 Choosing a Network Structure . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 61
4.5.1 Cross Validation for Neural Networks . . . . . . . . . . . . . . . . . . . . . . 61
4.5.2 Preventing Overfitting . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 62
Regularization . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 62
Data Augmentation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 62
4.6 Specialized Forms of Neural Networks . . . . . . . . . . . . . . . . . . . . . . . . . . 63
4.6.1 Convolutional Neural Networks (CNNs) . . . . . . . . . . . . . . . . . . . . . 63
4.6.2 Recurrent Neural Networks (RNNs) . . . . . . . . . . . . . . . . . . . . . . . 63
4.6.3 Bayesian Neural Networks (BNNs) . . . . . . . . . . . . . . . . . . . . . . . . 64

5 Support Vector Machines 65


5.1 Motivation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 65
5.1.1 Max Margin Methods . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 65
5.1.2 Applications . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 66
5.2 Hard Margin Classifier for Linearly Separable Data . . . . . . . . . . . . . . . . . . . 67
5.2.1 Why the Hard Margin . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 67
5.2.2 Deriving our Optimization Problem . . . . . . . . . . . . . . . . . . . . . . . 67
5.2.3 What is a Support Vector . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 69
5.3 Soft Margin Classifier . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 70
5.3.1 Why the Soft Margin? . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 70
5.3.2 Updated Optimization Problem for Soft Margins . . . . . . . . . . . . . . . . 71
5.3.3 Soft Margin Support Vectors . . . . . . . . . . . . . . . . . . . . . . . . . . . 72
5.4 Conversion to Dual Form . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 72
5.4.1 Lagrange Multipliers . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 73
5.4.2 Deriving the Dual Formulation . . . . . . . . . . . . . . . . . . . . . . . . . . 74
5.4.3 Making Predictions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 75
5.4.4 Why is the Dual Formulation Helpful? . . . . . . . . . . . . . . . . . . . . . . 76
5.4.5 Kernel Composition . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 77

6 Clustering 78
6.1 Motivation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 78
6.1.1 Applications . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 79
6.2 K-Means Clustering . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 79
CONTENTS v

6.2.1 Lloyd’s Algorithm . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 79


6.2.2 Example of Lloyd’s . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 81
6.2.3 Number of Clusters . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 84
6.2.4 Initialization and K-Means++ . . . . . . . . . . . . . . . . . . . . . . . . . . 84
6.2.5 K-Medoids Alternative . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 86
6.3 Hierarchical Agglomerative Clustering . . . . . . . . . . . . . . . . . . . . . . . . . . 86
6.3.1 HAC Algorithm . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 87
6.3.2 Linkage Criterion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 89
Min-Linkage Criteria . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 89
Max-Linkage Criterion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 89
Average-Linkage Criterion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 90
Centroid-Linkage Criterion . . . . . . . . . . . . . . . . . . . . . . . . . . . . 90
Different Linkage Criteria Produce Different Clusterings . . . . . . . . . . . . 90
6.3.3 How HAC Differs from K-Means . . . . . . . . . . . . . . . . . . . . . . . . . 91

7 Dimensionality Reduction 92
7.1 Motivation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 92
7.2 Applications . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 93
7.3 Principal Component Analysis . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 93
7.3.1 Reconstruction Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 94
7.3.2 Minimizing Reconstruction Loss . . . . . . . . . . . . . . . . . . . . . . . . . 96
7.3.3 Multiple Principal Components . . . . . . . . . . . . . . . . . . . . . . . . . . 97
7.3.4 Identifying Directions of Maximal Variance in our Data . . . . . . . . . . . . 97
7.3.5 Choosing the Optimal Number of Principal Components . . . . . . . . . . . . 98
7.4 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 101

8 Graphical Models 102


8.1 Motivation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 102
8.2 Directed Graphical Models (Bayesian Networks) . . . . . . . . . . . . . . . . . . . . 102
8.2.1 Joint Probability Distributions . . . . . . . . . . . . . . . . . . . . . . . . . . 105
8.2.2 Generative Models . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 105
8.2.3 Generative Modeling vs. Discriminative Modeling . . . . . . . . . . . . . . . 107
8.2.4 Understanding Complexity . . . . . . . . . . . . . . . . . . . . . . . . . . . . 108
8.2.5 Independence and D-Separation . . . . . . . . . . . . . . . . . . . . . . . . . . 109
8.3 Example: Naive Bayes . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 112
8.4 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 112

9 Mixture Models 113


9.1 Motivation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 113
9.2 Applications . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 115
vi CONTENTS

9.3 Fitting a Model . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 115


9.3.1 Maximum Likelihood for Mixture Models . . . . . . . . . . . . . . . . . . . . 115
9.3.2 Complete-Data Log Likelihood . . . . . . . . . . . . . . . . . . . . . . . . . . 116
9.4 Expectation-Maximization (EM) . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 116
9.4.1 Expectation Step . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 117
9.4.2 Maximization Step . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 118
9.4.3 Full EM Algorithm . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 119
9.4.4 The Math of EM . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 119
The Evidence Lower Bound (ELBO) . . . . . . . . . . . . . . . . . . . . . . . 120
Optimization . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 121
Correctness . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 123
Equivalence to Prior Formulation . . . . . . . . . . . . . . . . . . . . . . . . . 123
9.4.5 Connection to K-Means Clustering . . . . . . . . . . . . . . . . . . . . . . . . 123
9.4.6 Dice Example: Mixture of Multinomials . . . . . . . . . . . . . . . . . . . . . 124
9.5 Gaussian Mixture Models (GMM) . . . . . . . . . . . . . . . . . . . . . . . . . . . . 126
9.6 Admixture Models: Latent Dirichlet Allocation (LDA) . . . . . . . . . . . . . . . . . 127
9.6.1 LDA for Topic Modeling . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 127
9.6.2 Applying EM to LDA . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 128
9.7 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 130

10 Hidden Markov Models 131


10.1 Motivation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 131
10.2 Applications . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 132
10.3 HMM Data, Model, and Parameterization . . . . . . . . . . . . . . . . . . . . . . . . 133
10.3.1 HMM Data . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 133
10.3.2 HMM Model Assumptions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 133
10.3.3 HMM Parameterization . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 134
10.4 Inference in HMMs . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 134
10.4.1 The Forward-Backward Algorithm . . . . . . . . . . . . . . . . . . . . . . . . 135
10.4.2 Using α’s and β’s for Training and Inference . . . . . . . . . . . . . . . . . . 138
p(Seq) . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 138
Prediction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 138
Smoothing . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 138
Transition . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 139
Filtering . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 139
Best path . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 139
10.5 Using EM to Train a HMM . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 140
10.5.1 E-Step . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 140
10.5.2 M-Step . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 141
10.6 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 141
CONTENTS vii

11 Markov Decision Processes 142


11.1 Formal Definition of an MDP . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 143
11.2 Finite Horizon Planning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 144
11.3 Infinite Horizon Planning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 144
11.3.1 Value iteration . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 145
Bellman Consistency Equation and Bellman Optimality . . . . . . . . . . . . 145
Bellman Operator . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 147
Value Iteration Algorithm . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 147
11.3.2 Policy Iteration . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 148
Policy Evaluation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 148

12 Reinforcement Learning 149


12.1 Motivation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 149
12.2 General Approaches to RL . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 149
12.3 Model-Free Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 150
12.3.1 SARSA and Q-Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 151
Convergence Conditions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 152
12.3.2 Deep Q-Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 152
12.3.3 Policy Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 153
12.4 Model-Based Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 154
12.5 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 155
viii CONTENTS
Chapter 1

Introduction to Machine Learning

1.1 What is Machine Learning?


There is a great deal of misunderstanding about what machine learning is, fueled by recent success
and at times sensationalist media coverage. While its applications have been and will continue
to be extraordinarily powerful under the right circumstances, it’s important to gain some sense
of where and why the tools presented in this book will be applicable. Broadly, machine learning
is the application of statistical, mathematical, and numerical techniques to derive some form of
knowledge from data. This ‘knowledge’ may afford us some sort of summarization, visualization,
grouping, or even predictive power over data sets.
With all that said, it’s important to emphasize the limitations of machine learning. It is not nor
will it ever be a replacement for critical thought and methodical, procedural work in data science.
Indeed, machine learning can be reasonably characterized a loose collection of disciplines and tools.
Where the lines begin that separate machine learning from statistics or mathematics or probability
theory or any other handful of fields that it draws on are not clear. So while this book is a synopsis
of the basics of machine learning, it might be better understood as a collection of tools that can be
applied to a specific subset of problems.

1.2 What Will This Book Teach Me?


The purpose of this book is to provide you the reader with the following: a framework with which
to approach problems that machine learning learning might help solve. You will hopefully come
away with a sense of the strengths and weaknesses of the tools presented within and, even more
importantly, gain an understanding of how these tools are situated among problems of interest. To
that end, we will aim to develop systems for thinking about the structure of the problems we work
on. That way, as we continue to add new tools and techniques to our repertoire, we will always
have a clear view of the context in which we can expect to use them. This will not only create
a nice categorization of the different practices in machine learning, it will also help motivate why
these techniques exist in the first place.
You will not be an expert in any individual ML concept after reading this text. Rather, you
should come away with three different levels of understanding. First, you should gain a general
contextual awareness of the different problem types that ML techniques may be used to solve.
Second, you should come away with a practical awareness of how different ML techniques operate.
This means that after you have successfully identified an appropriate ML technique for a given
problem, you will also know how that method actually accomplishes the goal at hand. If you only

1
2 CHAPTER 1. INTRODUCTION TO MACHINE LEARNING

come away with these two levels of understanding, you will be off to a good start. The third level
of understanding relates to having a derivational awareness of the algorithms and methods we will
make use of. This level of understanding is not strictly necessary to successfully interact with
existing machine learning capabilities, but it will be required if you desire to go further and deepen
existing knowledge. Thus, we will be presenting derivations, but it will be secondary to a high level
understanding of problem types and the practical intuition behind available solutions.

1.3 Our Machine Learning Framework


Let’s consider for the first time what we will call the Machine Learning Cube. The purpose of
this cube is to describe the domain of problems we will encounter, and it will be a useful way to
delineate the techniques we will apply to different types of problems. Understanding the different
facets of the cube will aid you in understanding machine learning as a whole, and can even give you
intuition about techniques that you have never encountered before. Let’s now describe the features
of the cube.
Our cube has three axes. On the first axis we will put the output domain. The domain of our
output can take on one of two forms: discrete or continuous. Discrete, or categorical data, is
data that can only fall into one of n specific classes. For example: male or female, integer values
between 1 and 10, or different states in the U.S. are all examples of categorical data. Continuous
data is that which falls on the real number line.
The second axis of the cube is reserved for the statistical nature of the machine learning tech-
nique in question. Specifically, it will fall into one of two broad categories: probabilistic or
non-probabilistic techniques. Probabilistic techniques are those for which we incorporate our
data using some form of statistical distribution or summary. In general, we are then able to dis-
card some or all of our data once we have finished tuning our probabilistic model. In contrast,
non-probabilistic techniques are those that use the data directly to perform some action. A very
common and general example of this is comparing how close a new data point is to other points
in your existing data set. Non-probabilistic techniques potentially make fewer assumptions, but
they do require that you keep around some or all of your data. They are also potentially slower
techniques at runtime because they may require touching all of the data in your dataset to perform
some action. These are a very broad set of guidelines for the distinction between probabilistic and
non-probabilistic techniques - you should expect to see some exceptions and even some techniques
that fit into both of these categories to some degree. Having a sense for their general benefits and
drawbacks is useful, and you will gain more intuition about the distinction as we begin to explore
different techniques.
The third and final axis of the cube describes the type of training we will use. There are
two major classes of machine learning techniques: supervised and unsupervised. In fact, these
two classes of techniques are so important to describing the field of machine learning that we will
roughly divide this textbook into two halves dedicated to techniques found within each of these
categories. A supervised technique is one for which we get to observe a data set of both the inputs
and the outputs ahead of time, to be used for training. For example, we might be given a data
set about weather conditions and crop production over the years. Then, we could train a machine
learning model that learns a relationship between the input data (weather) and output data (crop
production). The implication here is that given new input data, we will be able to predict the
unseen output data. An unsupervised technique is one for which we only get a data set of ‘inputs’
ahead of time. In fact, we don’t even need to consider these as inputs anymore: we can just consider
them to be a set of data points that we wish to summarize or describe. Unsupervised techniques
1.4. THIS BOOK’S NOTATION 3

revolve around clustering or otherwise describing our data.


We will see examples of all of these as we progress throughout the book, and you will gain an
intuition for where different types of data and techniques fall in our cube. Eventually, given just
the information in the cube for a new technique, you will have a solid idea of how that technique
operates.

1.4 This Book’s Notation


The machine learning community uses a number of different conventions, and learning to decipher
the different versions of those conventions is important to understanding work done in the field.
For this book, we will try to stick to a standard notation that we define here in part. In addition
to mathematical and statistical notation, we will also describe the conventions used in this book
for explaining and breaking up different concepts.

Mathematical and Statistical Notation


We will describe the dimensionality of variables when necessary, but generic variables will often be
sufficient when explaining new techniques. Boldface variables (x) represent vectors, capital boldface
characters (X) represent matrices, and standard typeface variables (x) describe scalars.
Statistical distributions will sometimes be described in terms of their probability density func-
2 2
tion (PDF), e.g. Y ∼ σ√12π e−(x−µ) /2σ . Alternatively, in the case of a well known probability
distribution, we will describe those in terms of their standard notation, e.g. Y ∼ N (µ, σ 2 )

Textbook Specific Notation


We have also introduced a few conventions to make consumption of this material easier. We have
boxes dedicated to definitions, explaining techniques in the context of the ML Framework Cube,
and for explaining common misconceptions:

Definition 1.4.1 (Definition Explanation): You will find definitions in these dark gray boxes.

Derivation 1.4.1 (Derivation Explanation): You will find derivations in these light gray boxes.

ML Framework Cube: ML Framework Cube


You will find ML Framework Cube explanations in these blue wrapped boxes.

⋆ You will find explanations for subtle or confusing concepts in these red wrapped boxes.
Chapter 2

Regression

A major component of machine learning, the one that most people associate with ML, is dedicated
to making predictions about a target given some inputs, such as predicting how much money an
individual will earn in their lifetime given their demographic information. In this chapter, w!e’re
going to focus on the case where our prediction is a continuous, real number. When the target is
a real number, we call this prediction procedure regression.

2.1 Defining the Problem


In order to understand regression, let’s start in a more natural place: what types of problems are
we trying to solve? What exactly does it mean to make a prediction that is a continuous, real
number ? To build some intuition, here are a few examples of problems that regression could be
used for:

1. Predicting a person’s height given the height of their parents.

2. Predicting the amount of time someone will take to pay back a loan given their credit history.

3. Predicting what time a package will arrive given current weather and traffic conditions.

Hopefully you are starting to see the pattern emerging here. Given some inputs, we need to
produce a prediction for a continuous output. That is exactly the purpose of regression. Notice
that regression isn’t any one technique in particular. It’s just a class of methods that helps us
achieve our overall goal of predicting a continuous output.

Definition 2.1.1 (Regression): A class of techniques that seeks to make predictions about un-
known continuous target variables given observed input variables.

2.2 Solution Options


Now that you understand the type of problems we are trying to solve with regression, we can
start to think at a high level of the different ways we might arrive at a solution. Here is a short,
nonexhaustive list of regression techniques with brief explanations:

4
2.2. SOLUTION OPTIONS 5

2.2.1 K-Nearest-Neighbors
K-Nearest-Neighbors is an extremely intuitive, non-parametric technique for regression or classifi-
cation. It works as follows in the regression case:
1. Identify the K points in our data set that are closest to the new data point. ‘Closest’ is some
measure of distance, usually Euclidean.
2. Average the value of interest for those K data points.
3. Return that averaged value of interest: it is the prediction for our new data point.

⋆ A non-parametric model simply means we don’t make any assumptions about the form of our data. We only need
to use the data itself to make predictions.

2.2.2 Neural Networks


A neural network works by scaling and combining input variables many times. Furthermore, at
each step of scaling and combining inputs, we typically apply some form of nonlinearity over our
data values. The proof is beyond the scope of this textbook, but neural networks are known to be
universal function approximators. This means that given enough scaling and combining steps, we
can approximate any function to an arbitrarily high degree of accuracy using a neural network.

2.2.3 Random Forests


Random forests are what’s known as an ensemble method. This means we average the results of
many smaller regressors known as decision trees to produce a prediction. These decision trees
individually operate by partitioning our original data set with respect to the value of predictive
interest using a subset of the features present in that data set. Each decision tree individually may
not be a reliable regressor; by combining many of them we achieve a model with better performance.

2.2.4 Gradient Boosted Trees


Gradient boosted trees are another technique built on decision trees. Assume we start with a set
of decision trees that together achieve a certain level of performance. Then, we iteratively add new
trees to improve performance on the hardest examples in our data set, and reweight our new set of
decision trees to produce the best level of performance possible.

2.2.5 Turning to Linear Regression


You’ve likely never even heard of some of these preceding techniques - that’s okay. The point is
that we have a number of existing methods that take different approaches to solving regression
problems. Each of these will have their own strengths and weaknesses that contribute to a decision
about whether or not you might use them for a given regression problem. We obviously do not
cover these methods in great detail here; what’s more important is the understanding that there
are a variety of ways to go about solving regression problems. From here on out, we will take a
deeper dive into linear regression. There are several reasons for which we are exploring this specific
technique in greater detail. The two main reasons are that it has been around for a long time
and thus is very well understood, and it also will introduce many concepts and terms that will be
utilized extensively in other machine learning topics.
6 CHAPTER 2. REGRESSION

2.3 Introduction to Linear Regression


In this chapter, we’re specifically going to focus on linear regression, which means that our goal
is to find some linear combination of the x1 , ..., xD input values that predict our target y.

Definition 2.3.1 (Linear Regression): Suppose we have an input x ∈ RD and a continuous


target y ∈ R. Linear regression determines weights wi ∈ R that combine the values of xi to
produce y:
y = w0 + w1 x1 + ... + wD xD (2.1)

⋆ Notice w0 in the expression above, which doesn’t have a corresponding x0 value. This is known as the bias term.
If you consider the definition of a line y = mx + b, the bias term is corresponds to the intercept b. It accounts for
data that has a non-zero mean.

Let’s illustrate how linear regression works using an example, considering the case of 10 year old
Sam. She is curious about how tall she will be when she grows up. She has a data set of parents’
heights and the final heights of their children. The inputs x are:
x1 = height of mother (cm)
x2 = height of father (cm)
Using linear regression, she determines the weights w to be:
w = [34, 0.39, 0.33]

Sam’s mother is 165 cm tall and her father is 185 cm tall. Using the results of the linear regression
solution, Sam solves for her expected height:
Sam’s height = 34 + 0.39(165) + 0.33(185) = 159.4 cm

ML Framework Cube: Linear Regression

Let’s inspect the categories linear regression falls into for our ML framework cube. First, as we’ve
already stated, linear regression deals with a continuous output domain. Second, our goal is
to make predictions on future data points, and to construct something capable of making those
predictions we first need a labeled data set of inputs and outputs. This makes linear regression a
supervised technique. Third and finally, linear regression is non-probabilistic. Note that there
also exist probabilistic interpretations of linear regression which we will discuss later in the chapter.
Domain Training Probabilistic
Continuous Supervised No

2.4 Basic Setup


The most basic form of linear regression is a simple weighted combination of the input variables x,
which you will often see written as:

f (x, w) = w0 + w1 x1 + ... + wD xD (2.2)


2.4. BASIC SETUP 7

Figure 2.1: Data set with clear trend.

2.4.1 Merging of Bias


We’re going to introduce a common notational trick here for making the bias term, w0 , easier to
handle. At the moment w0 is unwieldy because it is not being multiplied by an xi value. A simple
way to make our bias term easier to handle is to simply introduce another variable, x0 , that is
always 1 for every data point. For example, considering the case of Sam’s height from above, we
have the height of her parents, x:

x = (165, 185)

We now add a 1 in the first position of the data point to make it:

x’ = (1, 165, 185)

We do this for every point in our data set. This bias trick lets us write:

f (x, w) = w⊤ x = w0 x0 + w1 x1 + ... + wD xD (2.3)

This is more compact, easier to reason about, and makes properties of linear algebra nicer for the
calculations we will be performing.

2.4.2 Visualization of Linear Regression


Let’s try to build some intuition about how linear regression works. Our algorithm is given a
collection of data points: inputs x and corresponding targets y. Our goal is to find the best set of
weights w such that given a new data point x, we can accurately predict the true target value y.
This is visualizable in the simple case where x is a 1-dimensional input variable, as in Figure 2.1.
Our eyes are naturally able to detect a very clear trend in this data. If we were given a new
x∗ data point, how would be predict its target value y? We would first fit a line to our data, as in
Figure 2.2, and then find where on that line the new x∗ value sits.
That is the entirety of linear regression! It fits the ‘best’ line to our data, and then uses that
line to make predictions. In 1-D input space, this manifests itself as the simple problem seen above,
8 CHAPTER 2. REGRESSION

Figure 2.2: Data set with clear trend, best fitting line included.

where we need only find a single bias term w0 (which acts as the intercept of the line) and single
weight w1 (which acts as the slope of the line). However, the same principle applies to higher
dimensional data as well. We’re always fitting the hyperplane that best predicts the data.

⋆ Although our input data points x can take on multiple dimensions, our output data y is always a 1-dimensional
real number when dealing with regression problems.

Now that we have some intuition for what linear regression is, a natural question arises: how
do we find the optimal values for w? That is the remaining focus of this chapter.

2.5 Finding the Best Fitting Line

2.5.1 Objective Functions and Loss

Now that we’ve defined our model as a weighted combination of our input variables, we need some
way to choose our value of w. To do this, we need an objective function.

Definition 2.5.1 (Objective Function): A function that measures the ‘goodness’ of a model.
We can optimize this function to identify the best possible model for our data.

As the definition explains, the purpose of an objective function is to measure how good a specific
model is. We can therefore optimize this function to find a good model. Note that in the case of
linear regression, our ‘model’ is just a setting of our parameters w.
An objective function will sometimes be referred to as loss. Loss actually measures how bad a
model is, and then our goal is to minimize it. It is common to think in terms of loss when discussing
linear regression, and we incur loss when the hyperplane we fit is far away from our data.
So how do we compute the loss for a specific setting of w? To do this, we often use residuals.
2.5. FINDING THE BEST FITTING LINE 9

Definition 2.5.2 (Residual): The residual is the difference between the target (y) and predicted
(y(x, w)) value that a model produces:

residual = target − prediction = y − f (x, w) = y − w⊤ x

Commonly, loss is a function of the residuals produced by a model. For example, you can imagine
taking the absolute value of all of the residuals and adding those up to produce a measurement of
loss. This is sometimes referred to as L1 Loss. Or, you might square all of the residuals and then
add those up to produce loss, which is called L2 loss or least squares loss. You might also use some
combination of L1 and L2 loss. For the most part, these are the two most common forms of loss
you will see when discussing linear regression.
When minimized, these distinct measurements of loss will produce solutions for w that have
different properties. For example, L2 loss is not robust to outliers due to the fact that we are
squaring residuals. Furthermore, L2 loss will produce only a single solution while L1 loss can
potentially have many equivalent solutions. Finally, L1 loss produces unstable solutions, meaning
that for small changes in our data set, we may see large changes in our solution w.
Loss is a concept that we will come back to very frequently in the context of supervised machine
learning methods. Before exploring exactly how we use loss to fit a line, let’s consider least squares
loss in greater depth.

2.5.2 Least Squares Loss


As we mentioned above, there are different methods for computing loss. One of the most commonly
used measurements is known as least squares or L2 loss. Least squares, as it is often abbreviated,
says that the loss for a given data point is the square of the difference between the target and
predicted values:
Ln (w) = (y − w⊤ xn )2 (2.4)

⋆ The notation Ln (w) is used to indicate the loss incurred by a model w for a single data point (xn , y). L(w) indicates
the loss incurred for an entire data set by the model w. Be aware that this notation is sometimes inconsistent between
different sources.

There is a satisfying statistical interpretation for using this loss function which we will explain
later in this chapter, but for now it will suffice to discuss some of the properties of this loss function
that make it desirable.
First, notice that it will always take on positive values. This is convenient because we can focus
exclusively on minimizing our loss, and it also allows us to combine the loss incurred from different
data points without worrying about them cancelling out.
A more subtle but enormously important property of this loss function is that we know a lot
about how to efficiently optimize quadratic functions. This is not a textbook about optimization,
but some quick and dirty intuition that we will take advantage of throughout this book is that
we can easily and reliably take the derivative of quadratic functions because they are continuously
differentiable. We also know that optima of a quadrative function will be located at points where
the derivative of the function is equal to 0, as seen in Figure 2.3. In contrast, L1 loss is not
continuously differentiable over the entirety of its domain.
10 CHAPTER 2. REGRESSION

Figure 2.3: Quadratic function with clear optimum at x = 2, where the derivative of the function
is 0.

2.6 Linear Regression Algorithms


Now that we have our least squares loss function, we can finally begin to fit a line to our data. We
will walk through the derivation of how this is done, in the process revealing the algorithm used
for linear regression.

2.6.1 Optimal Weights via Matrix Differentiation


Let X denote a N × D dimension matrix (the “design matrix”) where the row n is x⊤ n , i.e., the
features corresponding to the nth example. Let Y denote a N × 1 vector corresponding to the
target values, where the nth entry corresponds to yn .
First, we can define the loss incurred by parameters w over our entire data set X as follows:
N
1X
L(w) = (yn − w⊤ xn )2 (2.5)
2
n=1

⋆ Note that we added a constant 12 to the beginning of our loss expression. This scales the loss, which will not change
our final result for the optimal parameters. It has the benefit of making our calculations cleaner once we’ve taken
the gradient of the loss.

We now want to solve for the values of w that minimize this expression.

Derivation 2.6.1 (Least Squares Optimal Weights Derivation): We find the optimal
weights w∗ as follows:

Start by taking the gradient of the loss with respect to our parameter w:
N
X
∇L(w) = (yn − w⊤ xn )(−xn )
n=1
2.6. LINEAR REGRESSION ALGORITHMS 11

Setting this gradient to 0, and multiplying both sides by -1, we have:


N
X N
X
yn xn − (w⊤ xn )xn = 0. (2.6)
n=1 n=1

Since (w⊤ xn )xn = (x⊤ ⊤ ⊤


n w)xn = xn (xn w), recognizing that for scalar a and vector v we have a = a
and av = va. Substituting, we want to solve for w such that
N
X N
X
yn xn − xn (x⊤
n w) = 0.
n=1 n=1

At this point, it is convenient to rewrite these summations as matrix operations, making use of
design matrix X (N × D) and target values y (N × 1). We have
N
X N
X
⊤ ⊤
X y= yn xn , X Xw = xn (x⊤
n w)
n=1 n=1

Substituting, we have

X⊤ y − X⊤ Xw = 0.

Solving for the optimal weight vector, we have

w∗ = (X⊤ X)−1 X⊤ y. (2.7)

For this to be well defined we need X to have full column rank (features are not colinear) so that
X⊤ X is positive definite and the inverse exists.

The quantity (X⊤ X)−1 X⊤ in Derivation 2.6.1 has a special name: the Moore-Penrose
pseudo inverse. You can think of it as a generalization of a matrix inversion operation to a
non-square matrix.

2.6.2 Bayesian Solution: Maximum Likelihood Estimation


We’ve thus far been discussing linear regression exclusively in terms of a loss function that helps
us fit a set of weights to our data. In particular, we have been working with least squares, which
has nice properties that make it a reasonable loss function.
In a very satisfying fashion, least squares also has a statistical foundation. In fact, you can
recover the least squares loss function purely from a statistical derivation that we present here.
Consider our data set D = {(xn , yn )}N m
n=1 , xn ∈ R , y ∈ R. Let’s imagine that our data was
generated according to the following process:

yn ∼ N (w⊤ xn , β −1 )

Which can be written equivalently as:

p(yn |xn , w, β) = N (w⊤ xn , β −1 ) (2.8)

The interpretation of this is that our target value y is generated according to a linear combination
12 CHAPTER 2. REGRESSION

of our inputs x, but there is also some noise in the data generating process described by the variance
parameter β −1 . It’s an acknowledgement that some noise exists naturally in our data.

⋆ It’s common to write variance as an inverse term, such as β −1 . The parameter β is then known as the precision,
which is sometimes easier to work with than the variance.

As before, we now ask the question: how do we solve for the optimal weights w? One approach
we can take is to maximize the probability of observing our target data y. This technique is known
as maximum likelihood estimation.

Derivation 2.6.2 (Maximum Likelihood Estimation for Bayesian Linear Regression):


The likelihood of our data set is given by:
N
Y
p(y|X, w, β) = N (w⊤ xn , β −1 )
n=1

We then take the logarithm of the likelihood, and since the logarithm is a strictly increasing,
continuous function, this will not change our optimal weights w:
N
X
ln p(y|X, w, β) = ln N (w⊤ xn , β −1 )
n=1

Using the density function of a univariate Gaussian:


N
X 1 ⊤ 2 −1
ln p(y|X, w, β) = ln p e−(yn −w xn ) /2β
2πβ −1
n=1
N
N N βX
= ln β − ln (2π) − (yn − w⊤ xn )2
2 2 2
n=1

Notice that this is a quadratic function in w, which means that we can solve for it by taking the
derivative with respect to w, setting that expression to 0, and solving for w:
N
∂ ln p(y|X, w, β) X
= −β (yn − w⊤ xn )(−xn )
∂w
n=1
N
X N
X
⇔ yn xn − (w⊤ xn )xn = 0.
n=1 n=1

Notice that this is exactly the same form as Equation 2.6. Solving for w as before, we have:

w∗ = (X⊤ X)−1 X⊤ y (2.9)

Notice that our final solution is exactly the same form as the solution in Equation 2.7, which
we solved for by minimizing the least squares loss! The takeaway here is that minimizing
a least squares loss function is equivalent to maximizing the probability under the
assumption of a linear model with Gaussian noise.
2.7. MODEL FLEXIBILITY 13

2.6.3 Alternate Interpretation: Linear Regression as Projection


Another common interpretation of linear regression is that of a projection of our targets, y, onto
the column space of our inputs X. This can be useful for building intuition.
We showed above that the quantity (X⊤ X)−1 X⊤ can be thought of as the pseudoinverse for
our inputs X. Let’s now consider the case where X is square and the pseudoinverse is equal to the
true inverse: X−1 = (X⊤ X)−1 X⊤ . We have for our optimal w∗ :

w∗ = (X⊤ X)−1 X⊤ y

which simplifies as

w∗ = X−1 y

We can recover our target values y by multiplying either side by X:

Xw∗ = XX−1 y
Xw∗ = y

We were able to recover our targets y exactly because X is an invertible tranformation. However,
in the general case where X is not invertible and we have to use the approximate pseudoinverse
(X⊤ X)−1 X⊤ , we instead recover ŷ:

Xw∗ = X(X⊤ X)−1 X⊤ y


Xw∗ = ŷ

where ŷ can be thought of as the closest projection of y onto the column space of X.
Furthermore, this motivates the intuition that w∗ is the set of coefficients that best transforms
our input space X into our target values y.

2.7 Model Flexibility


Occasionally, linear regression will fail to recover a good solution for a data set. While this may
be because our data doesn’t actually have predictive power, it might also just indicate that our
data is provided in a format unsuitable for linear regression. This section explores that problem, in
particular focusing on how we can manipulate the flexibility of our models to make them perform
better.

2.7.1 Basis Functions


There are some situations where linear regression is not the best choice of model for our input data
X. Because linear regression only scales and combines input variables, it is unable to apply more
complex transformations to our data such as a sin or square root function. In those situations where
we need to transform our input variable somehow prior to performing linear regression (which is
known as moving our data into a new basis), we can apply a basis function.

Definition 2.7.1 (Basis Function): Typically denoted by the symbol ϕ(·), a basis function is
a transformation applied to an input data point x to move our data into a different input basis,
which is another phrase for input domain.
14 CHAPTER 2. REGRESSION

Figure 2.4: Data with no basis function applied.

For example, consider our original data point:

x = (x(1) , x(2) )′

We may choose our basis function ϕ(x) such that our transformed data point in its new basis is:
2
ϕ(x) = (x(1) , x(1) , x(2) , sin(x(2) ))′

Using a basis function is so common that we will sometimes describe our input data points as
ϕ = (ϕ(1) , ϕ(2) , ..., ϕ(D) )′ .

⋆ The notation x = (x(1) , x(2) )′ is a way to describe the dimensions of a single data point x. The term x(1) is the
first dimension of a data point x, while x1 is the first data point in a data set.

Basis functions are very general - they could specify that we just keep our input data the same.
As a result, it’s common to rewrite the least squares loss function from Equation 2.4 for linear
regression in terms of the basis function applied to our input data:

N
1X
L(w) = (yn − w⊤ ϕn )2 (2.10)
2
n=1

To motivate why we might need basis functions for performing linear regression, let’s consider
this graph of 1-dimensional inputs X along with their target outputs y, presented in Figure 2.4.
As we can see, we’re not going to be able to fit a good line to this data. The best we can hope
to do is something like that of Figure 2.5.
However, if we just apply a simple basis function to our data, in this case the square root

function, ϕ(x) = ( x1 )′ , we then have the red line in Figure 2.6. We now see that we can fit a very
good line to our data, thanks to basis functions.
2.7. MODEL FLEXIBILITY 15

Figure 2.5: Data with no basis function applied, attempt to fit a line.

Figure 2.6: Data with square root basis function applied.


16 CHAPTER 2. REGRESSION

Figure 2.7: Data set with a clear trend and Gaussian noise.

Still, the logical question remains: how can I choose the appropriate basis function? This toy
example had a very obviously good basis function, but in general with high-dimensional, messy
input data, how do we choose the basis function we need?
The answer is that this is not an easy problem to solve. Often, you may have some domain
specific knowledge that tells you to try a certain basis, such as if you’re working with chemical data
and know that an important equation involves a certain function of one of your inputs. However,
more often than not we won’t have this expert knowledge either. Later, in the chapter on neural
networks, we will discuss methods for discovering the best basis functions for our data automatically.

2.7.2 Regularization
When we introduced the idea of basis functions above, you might have wondered why we didn’t
just try adding many basis transformations to our input data to find a good transformation. For
example, we might use this large basis function on a D-dimensional data point z:
2 100 2 100 2 100 ′
ϕ(z) = (z (1) , z (1) , ..., z (1) , z (2) , z (2) , ..., z (2) , ..., z (D) , z (D) , ..., z (D) )
where you can see that we expand the dimensions of the data point to be 100 times its original
size.
Let’s say we have an input data point x that is 1-dimensional, and we apply the basis function
described above, so that after the transformation each data point is represented by 100 values. Say
we have 100 data points on which to perform linear regression, and because our transformed input
space has 100 values, we have 100 parameters to fit. In this case, with one parameter per data
point, it’s possible for us to fit our regression line perfectly to our data so that we have no loss!
But is this a desirable outcome? The answer is no, and we’ll provide a visual example to illustrate
that.
Imagine Figure 2.7 is our data set. There is a very clear trend in this data, and you would likely
draw a line that looks something like that of Figure 2.8 to fit it.
However, imagine we performed a large basis transformation like the one described above. If we
do that, it’s possible for us to fit our line perfectly, threading every data point, like that in Figure
2.9.
2.7. MODEL FLEXIBILITY 17

Figure 2.8: Natural fit for this data set.

Figure 2.9: Unnatural fit for this data set.


18 CHAPTER 2. REGRESSION

Let’s see how both of these would perform on new data points. With our first regression line,
if we have a new data point x = (10)′ , we would predict a target value of 14.1, which most people
would agree is a pretty good measurement. However, with the second regression line, we would
predict a value of 9.5, which most people would agree does not describe the general trend in the
data. So how can we handle this problem elegantly?
Examining our loss function, we see that right now we’re only penalizing predictions that are
not correct in training. However, what we ultimately care about is doing well on new data points,
not just our training set. This leads us to the idea of generalization.

Definition 2.7.2 (Generalization): Generalization is the ability of a model to perform well on


new data points outside of the training set.

A convoluted line that matches the noise of our training set exactly isn’t going to generalize
well to new data points that don’t look exactly like those found in our training set. If wish to avoid
recovering a convoluted line as our solution, we should also penalize the total size of our weights
w. The effect of this is to discourage many complex weight values that produce a messy regression
line. By penalizing large weights, we favor simple regression lines like the one in Figure 2.8 that
take advantage of only the most important basis functions.
The concept that we are introducing, penalizing large weights, is an example of what’s known as
regularization, and it’s one that we will see come up often in different machine learning methods.

Definition 2.7.3 (Regularization): Applying penalties to parameters of a model.

There is obviously a tradeoff between how aggressively we regularize our weights and how tightly
our solution fits to our data, and we will formalize this tradeoff in the next section. However, for
now, we will simply introduce a regularization parameter λ to our least squares loss function:

N
1X λ
L(w) = (yn − w⊤ ϕn )2 + w⊤ w (2.11)
2 2
n=1

The effect of λ is to penalize large weight parameters. The larger λ is, the more we will favor
simple solutions. In the limit limλ→∞ L(w), we will drive all weights to 0, while with a nonexistant
λ = 0 we will apply no regularization at all. Notice that we’re squaring our weight parameters -
this is known as L2 norm regularization or ridge regression. While L2 norm regularization is
very common, it is just one example of many ways we can perform regularization.
To build some intuition about the effect of this regularization parameter, examine Figure 2.10.
Notice how larger values of λ produce less complex lines, which is the result of applying more
regularization. This is very nice for the problem we started with - needing a way to choose which
basis functions we wanted to use. With regularization, we can select many basis functions, and then
allow regularization to ‘prune’ the ones that aren’t meaningful (by driving their weight parameters
to 0). While this doesn’t mean that we should use as many basis transformations as possible (there
will be computational overhead for doing this), it does allow us to create a much more flexible
linear regression model without creating a convoluted regression line.
2.7. MODEL FLEXIBILITY 19

Figure 2.10: Effect of different regularization parameter values on final regression solution.

2.7.3 Generalizing Regularization


We’ve thus far only discussed one form of regularization: ridge regression. Remember that under
ridge regression, the loss function takes the form:
N
1X λ
L(w) = (yn − w⊤ ϕn )2 + w⊤ w,
2 2
n=1

where the (λ/2)w⊤ w term is for the regularization. We can generalize our type of regularization
by writing it as:
N
1X λ h
L(w) = (yn − w⊤ ϕn )2 + w h
2 2
n=1

where h determines the type of regularization we are using and thus the form of the optimal solution
that we recover. For example, if h = 2 then we add λ/2 times the square of the L2 norm. The
three most commonly used forms of regularization are lasso, ridge, and elastic net.

Ridge Regression
This is the case of h = 2, which we’ve already discussed, but what type of solutions does it tend
to recover? Ridge regression prevents any individual weight from growing too large, providing us
with solutions that are generally moderate.

Lasso Regression
Lasso regression is the case of h = 1. Unlike ridge regression, lasso regression will drive some
parameters wi to zero if they aren’t informative for our final solution. Thus, lasso regression is
good if you wish to recover a sparse solution that will allow you to throw out some of your basis
functions. You can see the forms of ridge and lasso regression functions in Figure 2.11. If you think
about how Lasso is L1 Norm (absolute value) and Ridge is L2 Norm (squared distance), you can
think of those shapes as being the set of points (w1 , w2 ) for which the norm takes on a constant
20 CHAPTER 2. REGRESSION

Figure 2.11: Form of the ridge (blue) and lasso (red) regression functions.

value.

Elastic Net
Elastic net is a middle ground between ridge and lasso regression, which it achieves by using a linear
combination of the previous two regularization terms. Depending on how heavily each regulariza-
tion term is weighted, this can produce results on a spectrum between lasso and ridge regression.

2.7.4 Bayesian Regularization


We’ve seen regularization in the context of loss functions, where the goal is to penalize large weight
values. How does the concept of regularization apply to Bayesian linear regression?
The answer is that we can interpret regularizing our weight parameters as adding a prior
distribution over w. Note that this is a different conception of regularization than we saw in the
previous section. In the Bayesian framework, we are averaging over different models specified by
different values of w. Therefore, in this context regularization entails weighting models with smaller
values of w more heavily.

Derivation 2.7.1 (Bayesian Regularization Derivation): Because we wish to shrink our


weight values toward 0 (which is exactly what regularization does), we will select a Normal prior
with mean 0 and variance S−1
0 :

w ∼ N (0, S−1
0 )

Remember from Equation 2.8 that the distribution over our observed data is Normal as well, written
here in terms of our entire data set:

p(y|X, w, β) = N (Xw, β −1 I)
2.7. MODEL FLEXIBILITY 21

We want to combine the likelihood and the prior to recover the posterior distribution of w, which
follows directly from Bayes’ Theorem:

p(w|X, y, β) ∝ p(y|X, w, β) p(w)


| {z } | {z } | {z }
posterior likelihood prior

We now wish to find the value of w that maximizes the posterior distribution. We can maximize
the log of the posterior with respect to w, which simplifies the problem slightly:

ln p(w|X, y, β) = ln p(y|X, w, β) + ln p(w) + C1

where C1 is some constant free of w. Let’s handle ln p(y|X, w, β) first:


N
Y
ln p(y|X, w, β) = ln N (yn |w⊤ xn , β −1 )
n=1
N  
Y 1 β
= ln p exp − (yn − w⊤ xn )2
2πβ −1 2
n=1
N
βX 1
= C2 − (yn − w⊤ xn )2 + p
2 2πβ −1
n=1

where C2 collects more constant terms that don’t depend on w. Let’s now handle ln p(w):

ln p(w) = ln N (0, S−1


0 )

1 1 ⊤
= ln 1 exp − w S0 w
(|2πS−1
0 |) 2 2
1
= C3 − w⊤ S0 w
2
combining the terms for ln p(y|X, w, β) and ln p(w):
N
βX 1
ln p(w|X, y, β) = − (yn − w⊤ xn )2 − w⊤ S0 w + C
2 2
n=1
N
!
1 1X ⊤ 2 1 ⊤
=− (yn − w xn ) + w S0 w + C
β 2 2β
n=1

Since β > 0, notice that maximizing the posterior probability with respect to w is equivalent to
minimizing the sum of squared errors (yn − w⊤ xn )2 and the regularization term β1 w⊤ S0 w.

The interpretation of this in the case S0 = sI is that adding a prior over the distribution of our
weight parameters w and then maximizing the resulting posterior distribution is equivalent to ridge
regression with λ = βs , as the regularization term simplifies to βs w⊤ S0 w. Recall that both s and β
are precisions. This states that if the precision on the observations is small, then the regularization
term is relatively large, meaning our posterior leans more towards our prior. If the precision on our
prior is small, then the regularization is small, meaning our posterior leans more towards the data.
22 CHAPTER 2. REGRESSION

2.8 Choosing Between Models


2.8.1 Bias-Variance Tradeoff and Decomposition
Now that you know about regularization, you might have some intuition for why we need to find
a balance between complex and simple regression solutions. A complex solution, while it might fit
all of our training data, may not generalize well to future data points. On the other hand, a line
that is too simple might not vary enough to provide good predictions at all. This phenomenon is
not unique to linear regression- it’s actually a very fundamental concept in machine learning that’s
known as the bias-variance tradeoff.

Definition 2.8.1 (Bias-Variance Tradeoff ): When constructing machine learning models, we


have a choice somewhere on a spectrum between two extremes: fitting exactly to our training data
or not varying in response to our training data at all. The first extreme, fitting all of our training
data, is a situation of high variance, because our output changes heavily in reponse to our input
data (see the red line in Figure 2.10). At the other extreme, a solution that doesn’t change in
response to our training data at all is a situation of high bias (see the yellow line in Figure 2.10).
This means our model heavily favors a specific form regardless of the training data, so our target
outputs don’t fluctuate between distinct training sets.

Obviously a good solution will fall somewhere in between these two extremes of high variance
and high bias. Indeed, we have techniques like regularization to help us balance the two extremes
(improving generalization), and we have other techniques like cross-validation that help us deter-
mine when we have found a good balance (measuring generalization).

⋆ In case you are not familiar with the terms bias and variance, we provide their statistical definitions here:

bias(θ) = E[θ] − θ

variance(θ) = E[(θ − E[θ])2 ]

Before we discuss how to effectively mediate between these opposing forces of error in our models,
we will first show that the bias-variance tradeoff is not only conceptual but also has probabilistic
underpinnings. Specifically, any loss that we incur over our training set using a given model can
be described in terms of bias and variance, as we will demonstrate now.

Derivation 2.8.1 (Bias-Variance Decomposition): Let’s begin by asserting that we have a


model f (·) that makes a prediction of our target y given input data point x. We wish to break
down the squared error of f into terms involving bias and variance.

Start with the expected squared error (MSE), where the expectation is taken with respect to both
our data set D (variation in our modeling error comes from what data set we get), which is a
random variable of (x, y) pairs sample from a distribution F , and our conditional distribution y|x
(there may be additional error because the data are noisy):

MSE = ED,y|x [(y − fD (x))2 ]

where we use the notation fD to explicitly acknowledge the dependence of our fitted model f on
the dataset D. For reasons that will become clear in a few steps, add and subtract our target mean
2.8. CHOOSING BETWEEN MODELS 23

ȳ, which is the true conditional mean given by ȳ = Ey|x [y], inside of the squared term:

MSE = ED,y|x [(y − ȳ + ȳ − fD (x))2 ]

Group together the first two terms and the last two terms:

MSE = ED,y|x [((y − ȳ) + (ȳ − fD (x)))2 ]

Expanding this expression and using linearity of expectation:

MSE = ED,y|x [(y − ȳ)2 ] + ED,y|x [(ȳ − fD (x))2 ] + 2ED,y|x [(y − ȳ)(ȳ − fD (x))] (2.12)

Let’s examine the last term, 2E[(y − ȳ)(ȳ − fD (x))]. Notice that (ȳ − fD (x)) does not depend on the
conditional distribution y|x at all. Thus, we are able to move one of those expecations in, which
makes this term:

2ED,y|x [(y − ȳ)(ȳ − fD (x))] = 2ED [Ey|x [(y − ȳ)](ȳ − fD (x))]

And note that:

Ey|x [(y − ȳ)] = 0

Which eliminates this last term entirely:

2ED,y|x [(y − ȳ)(ȳ − fD (x))] = 2ED [0 · (ȳ − fD (x))] = 0

We can now write Equation 2.12 as:

MSE = ED,y|x [(y − ȳ)2 ] + ED,y|x [(ȳ − fD (x))2 ] (2.13)


= Ey|x [(y − ȳ)2 ] + ED [(ȳ − fD (x))2 ]

where we have removed expectations that do not apply (e.g. ȳ does not depend on the dataset D).
We now have two terms contributing to our squared error. We will put aside the first term
Ey|x [(y − ȳ)2 ], as this is unidentifiable noise in our data set. In other words, our data will randomly
deviate from the mean in ways we cannot predict. On the other hand, we can work with the
second term ED [(ȳ − fD (x))2 ] as it involves our model function f (·)

As before, for reasons that will become clear in a few steps, let’s add and subtract our prediction
mean f¯(·) = ED [fD (x)], which is the expectation of our model function taken with respect to our
random data set.

ED [(ȳ − fD (x))2 ] = ED [(ȳ − f¯(x) + f¯(x) − fD (x))2 ]

Expanding this squared term, we have:

ED [(ȳ − fD (x))2 ] = (ȳ − f¯(x))2 + ED [(f¯(x) − fD (x))2 ] + 2ED [(ȳ − f¯(x))(f¯(x) − fD (x))]

As before, the third term here is 0:

2ED [(ȳ − f¯(x))(f¯(x) − fD (x))] = 2(ȳ − f¯(x))ED [(f¯(x) − fD (x))] = 2(ȳ − f¯(x))(0) = 0
24 CHAPTER 2. REGRESSION

Leaving us with these two terms:

ED [(ȳ − fD (x))2 ] = (ȳ − f¯(x))2 + ED [(f¯(x) − fD (x))2 ]

Notice the form of these two terms. The first one, (ȳ − f¯(x))2 , is the squared bias of our model,
since it is the square of the average difference between our prediction and the true target value.
The second one, ED [(f¯(x) − fD (x))2 ], is the variance of our model, since it is the expected squared
difference between our model and its average value. Thus:

ED [(ȳ − fD (x))2 ] = bias(f (x))2 + variance(f (x))

Thus, our total squared error, plugging in to Equation 2.13 can be written as:

MSE = noise(x) + bias(f (x))2 + variance(f (x))


2.8. CHOOSING BETWEEN MODELS 25

Figure 2.12: Bias and variance both contribute to the overall error of our model.

The key takeaway of the bias-variance decomposition is that the controllable error in our model
is given by the squared bias and variance. Holding our error constant, to decrease bias requires
increasing the variance in our model, and vice-versa. In general, a graph of the source of error in
our model might look something like Figure 2.12.
For a moment, consider what happens on the far left side of this graph. Our variance is very
high, and our bias is very low. In effect, we’re fitting perfectly to all of the data in our data set.
This is exactly why we introduced the idea of regularization from before - we’re fitting a very
convoluted line that is able to pass through all of our data but which doesn’t generalize well to new
data points. There is a name for this: overfitting.

Definition 2.8.2 (Overfitting): A phenomenon where we construct a convoluted model that is


able to predict every point in our data set perfectly but which doesn’t generalize well to new data
points.

The opposite idea, underfitting, is what happens at the far right of the graph: we have high
bias and aren’t responding to the variation in our data set at all.

Definition 2.8.3 (Underfitting): A phenomenon where we construct a model that doesn’t re-
spond to variation in our data.

So you can hopefully now see that the bias-variance tradeoff is important to managing the
problem of overfitting and underfitting. Too much variance in our model and we’ll overfit to our
data set. Too much bias and we won’t account for the trends in our data set at all.
In general, we would like to find a sweet spot of moderate bias and variance that produces
minimal error. In the next section, we will explore how we find this sweet spot.
26 CHAPTER 2. REGRESSION

2.8.2 Cross-Validation
We’ve seen that in choosing a model, we incur error that can be described in terms of bias and
variance. We’ve also seen that we can regulate the source of error through regularization, where
heavier regularization increases the bias of our model. A natural question then is how do we know
how much regularization to apply to achieve a good balance of bias and variance?
Another way to look at this is that we’ve traded the question of finding the optimal number of
basis functions for finding the optimal value of the regularization parameter λ, which is often an
easier problem in most contexts.
One very general technique for finding the sweet spot of our regularization parameter, other hy-
perparameters, or even for choosing among entirely different models is known as cross-validation.

Definition 2.8.4 (Cross-Validation): A subsampling procedure used over a data set to tune
hyperparameters and avoid over-fitting. Some portion of a data set (10-20% is common) is set aside,
and training is performed on the remaining, larger portion of data. When training is complete,
the smaller portion of data left out of training is used for testing. The larger portion of data is
sometimes referred to as the training set, and the smaller portion is sometimes referred to as the
validation set.

Cross-validation is often performed more than once for a given setting of hyperparameters to
avoid a skewed set of validation data being selected by chance. In K-Folds cross-validation, you
1
perform cross-validation K times, allocating K of your data for the validation set at each iteration.
Let’s tie this back into finding a good regularization parameter. For a given value of λ, we will
incur a certain amount of error in our model. We can measure this error using cross-validation,
where we train our model on the training set and compute the final error using the validation set.
To find the optimal value for λ, we perform cross-validation using different values of λ, eventually
settling on the value that produces the lowest final error. This will effectively trade off bias and
variance, finding the value of λ that minimizes the total error.
You might wonder why we need to perform cross-validation at all - why can’t we train on the
entire data set and then compute the error over the entire data set as well?
The answer is again overfitting. If we train over the entire data set and then validate our results
on the exact same data set, we are likely to choose a regularization parameter that encourages our
model to conform to the exact variation in our data set instead of finding the generalizable trends.
By training on one set of data, and then validating on a completely different set of data, we
force our model to find good generalizations in our data set. This ultimately allows us to pick
the regularization term λ that finds the sweet spot between bias and variance, overfitting and
underfitting.

2.8.3 Making a Model Choice


Now that we’re aware of overfitting, underfitting, and how those concepts relate to the bias-variance
tradeoff, we still need to come back to the question of how we actually select a model. Intuitively,
we are trying to find the middle ground between bias and variance: picking a model that fits our
data but that is also general enough to perform well on yet unseen data. Furthermore, there is
no such thing as the ‘right’ model choice. Instead, there are only model options that are either
better or worse than others. To that end, it can be best to rely on the techniques presented above,
specifically cross-validation, to make your model selection. Then, although you will not be able to
make any sort of guarantee about your selection being the ‘best’ of all possible models, you can at
2.9. LINEAR REGRESSION EXTRAS 27

least have confidence your model achieved the best generalizability that could be proven through
cross-validation.

2.8.4 Bayesian Model Averaging


We can also handle model selection using a Bayesian approach. This means we account for our
uncertainty about the true model by averaging over the possible candidate models, weighting each
model by our prior certainty that it is the one producing our data. If we have M models indexed
by m = 1, ..., M , we can write the likelihood of observing our data set X as follows:
M
X
p(X) = p(X|m)p(m)
m=1

where p(m) is our prior certainty for a given model and p(X|m) is the likelihood of our data set
given that model. The elegance of this approach is that we don’t have to pick any particular model,
instead choosing to marginalize out our uncertainty.

2.9 Linear Regression Extras


With most of linear regression under our belt at this point, it’s useful to drill down on a few concepts
to come to a deeper understanding of how we can use them in the context of linear regression and
beyond.

2.9.1 Predictive Distribution


Remaining in the setting of Bayesian Linear Regression, we may wish to get a distribution over
our weights w instead of a point estimator for it using maximum likelihood. As we saw in Section
2.7.4, we can introduce a prior distribution over w, then together with our observed data, we can
produce a posterior distribution over w as desired.

Derivation 2.9.1 (Posterior Predictive Derivation): For the sake of simplicity and ease of
use, we will select our prior over w to be a Normal distribution with mean µ0 and variance S −1
0 :

p(w) = N (µ0 , S −1
0 )

Remembering that the observed data is normally distributed, and accounting for Normal-Normal
conjugacy, our posterior distribution will be Normal as well:

p(w|X, y, β) = N (µN , S −1
N )

where

S N = (S −1 ⊤
0 + βX X)
−1

µN = S N (S −1
0 µ0 + βXy)

We now have a posterior distribution over w. However, usually this distribution is not what we
care about. We’re actually interested in making a point prediction for the target y ∗ given a new
input x∗ . How do we go from a posterior distribution over w to this prediction?
28 CHAPTER 2. REGRESSION

The answer is using what’s known as the posterior predictive over y ∗ given by:
Z
∗ ∗
p(y |x , X, y) = p(y ∗ |x∗ , w)p(w|X, y)dw
Zw (2.14)
= N (y ∗ |w⊤ x∗ , β −1 )N (w|µN , S −1
N )dw
w

The idea here is to average the probability of y ∗ over all the possible setting of w,
weighting the probabilities by how likely each setting of w is according to its posterior
distribution.

2.10 Conclusion
In this chapter, we looked at a specific tool for handling regression problems known as linear regres-
sion. We’ve seen linear regression described in terms of loss functions, probabilistic expressions,
and geometric projections, which reflects the deep body of knowledge that we have around this
very common technique.
We’ve also discussed many concepts in this chapter that will prove useful in other areas of
machine learning, particularly for other supervised techniques: loss functions, regularization, bias
and variance, over and underfitting, posterior distributions, maximum likelihood estimation, and
cross-validation among others. Spending time to develop an understanding of these concepts now
will pay off going forward.
It may or may not be obvious at this point that we are missing a technique for a very large class
of problems: those where the solution is not just a continuous, real number. How do we handle
situations where we need to make a choice between different discrete options? This is the question
we will turn to in the next chapter.
Chapter 3

Classification

In the last chapter we explored ways of predicting a continuous, real-number target. In this chapter,
we’re going to think about a different problem- one where our target output is discrete-valued. This
type of problem, one where we make a prediction by choosing between finite class options, is known
as classification.

3.1 Defining the Problem


As we did when studying regression, let’s begin by thinking about the type of problems we are
trying to solve. Here are a few examples of classification tasks:

1. Predicting whether a given email is spam.

2. Predicting the type of object in an image.

3. Predicting whether a manufactured good is defective.

The point of classification is hopefully clear: we’re trying to identify the most appropriate class for
an input data point.

Definition 3.1.1 (Classification): A set of problems that seeks to make predictions about un-
observed target classes given observed input variables.

3.2 Solution Options


There are several different means by which we can solve classification problems. We’re going
to discuss three in this chapter: discriminant functions, probabilistic discriminative models (e.g.
logistic regression), and probabilistic generative models. Note that these are not the only methods
for performing classification tasks, but they are similar enough that it makes sense to present
and explore them together. Specifically, these techniques all use some linear combination of input
variables to produce a class prediction. For that reason, we will refer to these techniques as
generalized linear models.

29
30 CHAPTER 3. CLASSIFICATION

ML Framework Cube: Generalized Linear Models


Since we are using these techniques to perform classification, generalized linear models deal with
a discrete output domain. Second, as with linear regression, our goal is to make predictions on
future data points given an initial set of data to learn from. Thus, generalized linear models are
supervised techniques. Finally, depending on the type of generalized linear model, they can be
either probabilistic or non-probabilistic.
Domain Training Probabilistic
Discrete Supervised Yes / No

3.3 Discriminant Functions


Generalized linear models for classification come in several different flavors. The most straightfor-
ward method carries over very easily from linear regression: discriminant functions. As we will
see, with discriminant functions we are linearly separating the input space into sections belonging
to different target classes. We will explore this method first. One thing to keep in mind is that
it’s generally easiest to initially learn these techniques in the case where we have only two target
classes, but there is typically a generalization that allows us to handle the multi-class case as well.

As with linear regression, discriminant functions h(x, w) seek to find a weighted combination of
our input variables to make a prediction about the target class:
h(x, w) = w(0) x(0) + w(1) x(1) + ... + w(D) x(D) (3.1)
where we are using the bias trick of appending x(0) = 1 to all of our data points.

3.3.1 Basic Setup: Binary Linear Classification


The simplest use case for a discriminant function is when we only have two classes that we are
trying to decide between. Let’s denote these two classes 1 and -1. Our discriminant function in
Equation 3.1 will then predict class 1 if h(x, w) ≥ 0 and class -1 if h(x, w) < 0:
(
1 if h(x, w) ≥ 0
−1 if h(x, w) < 0
Geometrically, the linear separation between these two classes then looks like that of Figure 3.1.
Notice the line where our prediction switches from class 1 to class -1. This is precisely where
h(x, w) = 0, and it is known as the decision boundary.

Definition 3.3.1 (Decision Boundary): The decision boundary is the line that divides the input
space into different target classes. It is learned from an initial data set, and then the target class
of new data points can be predicted based on where they fall relative to the decision boundary. At
the decision boundary, the discriminant function takes on a value of 0.

⋆ You will sometimes see the term decision surface in place of decision boundary, particularly if the input space is
larger than two dimensions.
3.3. DISCRIMINANT FUNCTIONS 31

Figure 3.1: Clear separation between classes.

3.3.2 Multiple Classes


Now consider the case that we have K > 2 classes C1 , C2 , ..., CK to choose between. One obvious
approach we might try is to use K different discriminant functions that each determine whether or
not a given input is in that class Ck . This is known as a one-versus-all approach, and it doesn’t work
properly because we end up with ambiguous regions as demonstrated in Figure 3.2. Intuitively,
several of the discriminator functions could claim that a data point is a part of their class, which
is an undesirable result.
Another obvious approach we might employ is to use K2 discriminant functions that each


determine whether a given point is more likely to be in class Cj or class Ck . This is known as a
one-versus-one approach, and it also doesn’t work because we again end up with ambiguous regions
as demonstrated in Figure 3.3.
Instead, we can avoid these ambiguities in the multi-class case by using K different linear
classifiers hk (x, wk ), and then assigning new data points to the class Ck for which hk (x, wk ) >
hj (x, wj ) for all j ̸= k. Then, similar to the two-class case, the decision boundaries are described
by the surface along which hk (x, wk ) = hj (x, wj ).
Now that we’ve explored the multi-class generalization, we can consider how to learn the weights
w that define the optimal discriminant functions. However, prior to solving for w, we need to discuss
how basis transformations apply to classification problems.

3.3.3 Basis Changes in Classification


We initially discussed basis changes in the context of linear regression, and they are equally impor-
tant for classification tasks. For example, consider the data set in Figure 3.4.
It’s obviously not possible for us to use a linear classifier to separate this data set. However, if
we apply a basis change by squaring one of the data points, we instead have Figure 3.5, which is
32 CHAPTER 3. CLASSIFICATION

Figure 3.2: Ambiguities arise from one-versus-all method.

Figure 3.3: Ambiguities arise from one-versus-one method.


3.4. NUMERICAL PARAMETER OPTIMIZATION AND GRADIENT DESCENT 33

Figure 3.4: Data set without any basis functions applied, not linearly separable.

now linearly separable by a plane between the two classes. Applying a generic basis change ϕ(·),
we can write our generalized linear model as:

hk (x, wk ) = w⊤ ⊤
k ϕ(x) = wk ϕ (3.2)

For the sake of simplicity in the rest of this chapter, we will leave out any basis changes in our
derivations, but you should recognize that they could be applied to any of our input data to make
the problems more tractable.

⋆ For an input matrix X, there is a matrix generalization of our basis transformed inputs: Φ = ϕ(X), where Φ is
known as the design matrix.

3.4 Numerical Parameter Optimization and Gradient Descent


Recall from the previous chapter on linear regression that when it comes to optimizing our model’s
weight parameters w, the goal is to minimize our loss function L(w) (also called the objective
function). We did this by taking the derivative of our objective function with respect to w, setting
that expression equal to 0, and solving for w. We were previously able to perform that procedure
with confidence because the least squares loss function was convex with respect to the weight
parameters, which meant it had a global solution we could solve for directly. Thus, the point of
minimization for the objective function would occur where ∇L(w) = 0.
Unfortunately, it’s not always the case that our objective function will be convex with respect
to our weight parameters. In fact, in the next section, we will consider an objective function that
is not convex, and as a result we will need a new way to optimize our parameters. Typically, when
facing a non-convex objective function, we will need to resort to a numerical procedure.
34 CHAPTER 3. CLASSIFICATION

Figure 3.5: Data set with basis functions applied, now linearly separable.
3.4. NUMERICAL PARAMETER OPTIMIZATION AND GRADIENT DESCENT 35

Figure 3.6: Step Size in Gradient Descent.

⋆ The terms numerical and analytical procedures come up very frequently in machine learning literature. An analytical
solution typically utilizes a closed form equation that accepts your model and input data and returns a solution in
the form of optimized model parameters. On the other hand, numerical solutions are those that require some
sort of iteration to move toward an ever better solution, eventually stopping once the solution is deemed ‘good
enough’. Analytical solutions are typically more desirable than numerical solutions due to computational efficiency
and performance guarantees, but they often are not possible for complex problems due to non-convexity.

Gradient descent is one such numerical optimization technique.

3.4.1 Gradient Descent

Definition 3.4.1 (Gradient Descent): Gradient descent is a numerical, iterative optimization


technique for finding the minimum of a function. It is often used to fit complex model parameters.

The high level idea behind gradient descent is as follows: to update our parameters, we take
a small step in the opposite direction of the gradient of our objective function with respect to the
weight parameters w(t) . Notationally, this looks like the following:

w(t+1) = w(t) − η∇L(w(t) ) (3.3)

where w(t) corresponds to the state of the parameters w at time t, L(w(t) ) is the gradient of our
objective function, and η > 0 is known as the learning rate. Note that the parameter values at
time t = 0 given by w(0) are often initialized randomly.

⋆ In general, we want a learning rate that is large enough so that we make progress toward reaching a better solution,
but not so large that we take a step that puts us in a worse place in the parameter space than we were at the previous
step. Notice in Figure 3.6 that an appropriately small step size improves our objective function, while a large step
size overshoots the update and leaves us in a worse position.
36 CHAPTER 3. CLASSIFICATION

Why take a step in the opposite direction of the gradient of the objective function? You can
think of the objective function as a hill, and the current state of our parameters w(t) is our position
on that hill. The gradient tells us the steepest direction of increase in the objective function (i.e.
it specifies the direction that will make our model worse). Since we want to minimize the objective
function, we choose to move away from the direction of the gradient, sending our model down the
hill towards an area of lower error. We typically cease optimization when our updates become
sufficiently small, indicating that we’ve reached a local minimum. Note that it’s a good idea to run
gradient descent multiple times to settle on a final value for w, ideally initializing w(0) to a different
starting value each time, because we are optimizing a function with multiple local minima.

3.4.2 Batch Gradient Descent versus Stochastic Gradient Descent


There are different means by which we can compute the gradient of our objective function at each
step. The first way, often called batch gradient descent, computes the gradient for our objective
function at each step using the entire data set. In contrast, the technique known as stochastic
gradient descent (also known as SGD) utilizes a subset of the data points at each step to compute
the gradient, sometimes just a single data point. Stochastic gradient descent is typically a more
popular technique for several reasons. First, the computation time is often significantly smaller as
you don’t need to pass over the entire data set at each iteration. Furthermore, it’s less likely that
you will get stuck in local minima while running SGD because a point in the parameter space that
is a local minima for the entire data set combined is much less likely to be a local minima for each
data point individually. Finally, SGD lends itself to being used for training online models (meaning
models built on data points that are arriving at regular intervals) as the entirety of the data does
not need to be present in order to train.

3.5 Objectives for Decision Boundaries


Now that we have a high-level understanding of what we’re trying to accomplish with discriminant
functions as well as a grasp on gradient descent, we can consider how to solve for the decision
boundaries that will dictate our classification decisions. Similar to linear regression, we first need
to establish an objective function to optimize. We begin with a very simple objective function
known as 0/1 loss.

3.5.1 0/1 Loss


Recall that a loss function penalizes mistakes made by our model. The idea behind the 0/1 loss
function is very simple: if our model misclassifies a point, we incur a loss of 1, and if our model
classifies it correctly, we incur no loss.
While this is a very intuitive loss function, it does not have a closed form solution like least
squares does, and it is non-convex so it is not easily optimized. Intuitively, because we incur a loss
of 0 or 1 for every prediction, we have no sense of ‘how good’ a given prediction was. For example,
one prediction could be very close to correct, while another could be way off, but they would both
receive an equivalent loss of 1. Formally, because this loss function is not differentiable, we cannot
get gradient information with which to optimize our model parameters w. We will find a way
around this in a moment when we discuss hinge loss, but before we get to that, let’s consider
using least squares loss as we did for linear regression.
3.5. OBJECTIVES FOR DECISION BOUNDARIES 37

3.5.2 Least Squares Loss


We are already familiar with the least squares loss function from linear regression, and we can
apply it again in this context to find the set of weights w that form the optimal decision boundary
between target classes.
We first need to introduce the idea of one-hot encoding, which simply means that the class
of a given data point is described by a vector with K options that has a 1 in the position that
corresponds to class Ck and 0s everywhere else (note that these classes aren’t usually 0-indexed).
For example, class C1 of 4 classes would be represented by the vector:
 
1
0
 
0 (3.4)
0

While class C2 would be represented by the vector:


 
0
1
 
0 (3.5)
0

and so on. Now that we have the idea of one-hot encoding, we can describe our target classes
for each data point in terms of a one-hot encoded vector, which can then be used in our training
process for least squares.

Each class Ck gets its own linear function with a different set of weights wk :

hk (x, wk ) = w⊤
kx

We can combine the set of weights for each class into a matrix W, which gives us our linear classifier:

h(x, W) = W⊤ x (3.6)

where each row in the transposed weight matrix W⊤ corresponds to the linear function of an
individual class, and matrix W is D × K. We can use the results derived in the last chapter to find
the solution for W that minimizes the least squares loss function. Assuming a data set of input
data points X and one-hot encoded target vectors Y (where every row is a single target vector, so
that Y is N × K), the optimal solution for W is given by:

W∗ = (X⊤ X)−1 X⊤ Y,

which we can then use in our discriminant function h(x, W∗ ) to make predictions on new data
points.
While least squares gives us an analytic solution for our discriminant function, it has significant
limitations when used for classification. For one, least squares penalizes data points that are ‘too
good’, meaning they fall too far on the correct side of the decision boundary. Furthermore, it is not
robust to outliers, meaning the decision boundary significantly changes with the addition of just a
few outlier data points, as seen in Figure 3.7.
We can help remedy the problems with least squares by using an alternative loss function for
determining our weight parameters.
38 CHAPTER 3. CLASSIFICATION

Figure 3.7: Outliers significantly impact our decision boundary.

Figure 3.8: Form of the ReLU function.

3.5.3 Hinge Loss


Recall that the problem with 0/1 loss was that we couldn’t use it optimize our model parameters
w. It didn’t produce a closed form solution like least squares loss, and it wasn’t differentiable so
we couldn’t take gradients.
The hinge loss function is a modification of the 0/1 loss function that both provides more fine-
grained information about the ‘goodness’ of a prediction and makes the loss function differentiable.
To understand the hinge loss function, it’s first necessary to introduce the rectified linear activation
unit, known as ReLU, seen in Figure 3.8.

ReLU(z) = max{0, z} (3.7)

We can use the form of this function to our advantage in constructing the hinge loss by recognizing
that we wish to incur error when we’re wrong (which corresponds to z > 0, the right side of
the graph that is continuously increasing), and we wish to incur 0 error if we are correct (which
corresponds to the left side of the graph where z < 0).
Remember from the previous section on least squares that in the two-class case, we classify a
data point x∗ as being from class 1 if h(x∗ , w) ≥ 0, and class -1 otherwise. We can combine this
3.6. PROBABILISTIC METHODS 39

logic with ReLU by recognizing that −h(x∗ , w)y ∗ ≥ 0 when there is a classification error, where
y ∗ is the true class of data point x∗ . This has exactly the properties we described above: we incur
error when we misclassify, and otherwise we do not incur error.

We can then write the entirety of the hinge loss function:


N
X
L(w) = ReLU(−h(xi , w)yi ) (3.8)
i=1
N
X
=− h(xi , w)yi (3.9)
yi ̸=ŷi
N
X
=− w⊤ xi yi (3.10)
yi ̸=ŷi

where ŷi is our class prediction and yi is the true class value. Notice that misclassified examples
contribute positive loss, as desired. We can take the gradient of this loss function, which will allow
us to optimize it using stochastic gradient descent. The gradient of the loss with respect to our
parameters w is as follows:
N
∂L(w) X
=− xi yi
∂w
yi ̸=ŷi

and then our update equation from time t to time t + 1 for a single misclassified example and with
learning rate η is given by:
∂L(w)
w(t+1) = w(t) − η = w(t) + ηxi yi
∂w
To sum up, the benefits of the hinge loss function are its differentiability (which allows us to
optimize our weight parameters), the fact that it doesn’t penalize any correctly classified data
points (unlike basic linear classification), and that it penalizes more heavily data points that are
more poorly misclassified.
Using hinge loss with discriminant functions to solve classification tasks (and applying stochastic
gradient descent to optimize the model parameters) is known as the perceptron algorithm. The
perceptron algorithm guarantees that if there is separability between all of our data points and we
run the algorithm for long enough, we will find a setting of parameters that perfectly separates our
data set. The proof for this is beyond the scope of this textbook.

3.6 Probabilistic Methods


Unsurprisingly, we can also cast the problem of classification into a probabilistic context, which
we now turn our attention to. Within this setting, we have a secondary choice to make between
two distinct probabilistic approaches: discriminative or generative. We will explore both of these
options.

3.6.1 Probabilistic Discriminative Models


Ultimately, our classification task can be summarized as follows: given a new data point x∗ , can
we accurately predict the target class y ∗ ?
40 CHAPTER 3. CLASSIFICATION

Given this problem statement, it makes sense that we might try to model p(y ∗ |x∗ ). In fact,
modeling this conditional distribution directly is what’s known as probabilistic discriminative
modeling.

Definition 3.6.1 (Probabilistic Discriminative Modeling): Probabilistic modeling is a clas-


sification technique whereby we choose to directly model the conditional class distribution in order
to make classification predictions.

This means that we will start with the functional form of the generalized linear model described
by Equation 3.2, convert this to a conditional distribution, and then optimize the parameters of
the conditional distribution directly using a maximum likelihood procedure. From here, we will be
able to make predictions on new data points x∗ . The key feature of this procedure, which is known
as discriminative training, is that it optimizes the parameters of a conditional distribution directly.
We describe a specific, common example of this type of procedure called logistic regression in
the next section.

Logistic Regression

One problem we need to face in our discriminative modeling paradigm is that the results of our
generalized linear model are not probabilities; they are simply real numbers. This is why in the
previous paragraph we mentioned needing to convert our generalized linear model to a conditional
distribution. That step boils down to somehow squashing the outputs of our generalized linear
model onto the real numbers between 0 and 1, which will then correspond to probabilities. To do
this, we will apply what is known as the logistic sigmoid function, σ(·).

Definition 3.6.2 (Logistic Sigmoid Function, σ(·)): The logistic sigmoid function is commonly
used to compress the real number line down to values between 0 and 1. It is defined functionally
as:
1
σ(z) =
1 + exp (−z)

As you can see in Figure 3.9 where the logistic sigmoid function is graphed, it squashes our output
domain between 0 and 1 as desired for a probability.

⋆ There is a more satisfying derivation for our use of the logistic sigmoid function in logistic regression, but under-
standing its squashing properties as motivation is sufficient for the purposes of this book.

Using the logistic sigmoid function, we now have a means of generating a probability that a new
data point x∗ is part of class y ∗ . Because we are currently operating in the two-class case, which
in this context will be denoted C1 and C2 , we’ll write the probability for each of these classes as:

p(y ∗ = C1 |x∗ ) = σ(w⊤ x∗ )


p(y ∗ = C2 |x∗ ) = 1 − p(y ∗ = C1 |x∗ )

Now that we have such functions, we can apply the maximum likelihood procedure to determine
the optimal parameters for our logistic regression model.
3.6. PROBABILISTIC METHODS 41

Figure 3.9: Logistic Sigmoid Function.

For a data set {xi , yi } where i = 1..N and yi ∈ {0, 1}, the likelihood for our setting of parameters
w can be written as:
N
ŷiyi {1 − ŷi }1−yi
Y
N
p({yi }i=1 |w) = (3.11)
i=1

where ŷi = p(yi = C1 |xi ) = σ(w⊤ xi ).


In general, we would like to maximize this probability to find the optimal setting of our param-
eters. This is exactly what we intend to do, but with two further simplifications. First, we’re going
to maximize the probability of the logarithm of the likelihood, as in Equation 3.12.

N
X
ln(p({yi }N
i=1 |w)) = {yi ln ŷi + (1 − yi ) ln (1 − yˆi )} (3.12)
i=1

As a monotonically increasing function, maximizing the logarithm of the likelihood (called the
log likelihood ) will result in the same optimal setting of parameters as if we had just optimized
the likelihood directly. Furthermore, using the log likelihood has the nice effect of turning what
is currently a product of terms from 1..N to a sum of terms from 1..N , which will make our
calculations nicer.
Second, we will turn our log likelihood into an error function by taking the negative of our log
likelihood expression. Now, instead of maximizing the log likelihood, we will be minimizing the
error function, which will again find us the same setting of parameters.

⋆ It’s worth rereading the above paragraph again to understand the pattern presented there, which we will see several
times throughout this book. Instead of maximizing a likelihood function directly, it is often easier to define an error
function using the negative log likelihood, which we can then minimize to find the optimal setting of parameters for
our model.

After taking the negative logarithm of the likelihood function defined by Equation 3.11, we
are left with the following term, known as the cross-entropy error function, which we will seek to
42 CHAPTER 3. CLASSIFICATION

minimize:
N
X
E(w) = − ln p({yi }|w) = − {yi ln ŷi + (1 − yi ) ln (1 − yˆi )} (3.13)
i=1

where as before ŷi = p(yi = C1 |xi ) = σ(w⊤ xi ). The cross-entropy error refers to the log likelhood of
the labels conditioned on the examples. When used with the specific form of the logistic regression,
this is also the logistic loss. Now, to solve for the optimal setting of parameters using a maximum
likelihood approach as we’ve done previously, we start by taking the gradient of the cross-entropy
error function with respect to w:
N
X
∇E(w) = (ŷi − yi )xi (3.14)
i=1

which we arrive at by recognizing that the derivative of the logistic sigmoid function can be written
in terms of itself as:
∂σ(z)
= σ(z)(1 − σ(z))
∂z
Let’s inspect the form of Equation 3.14 for a moment to understand its implications. First, it’s
a summation over all of our data points, as we would expect. Then, for each data point, we are
taking the difference between our predicted value ŷi and the actual value yi , and multiplying that
difference by the input vector xi .
While a closed form solution does not present itself here as it did in the case of linear regression
due to the nonlinearity of the logistic sigmoid function, we can still optimize the parameters w of
our model using an iterative procedure like gradient descent, where the objective function is defined
by Equation 3.13.

Multi-Class Logistic Regression and Softmax


As we saw when working with discriminant functions, we also need to account for multi-class
problems, which are practically speaking more common than the simple two-class scenario.
In the logistic regression setting (which is a form of discriminative modeling, not to be confused
with discriminant functions), we are now working with probabilities, which is why we introduced the
‘probability squashing’ sigmoidal function σ(·). Note that the sigmoidal function is also sometimes
known as the sigmoidal activation function.
Similarly, in the multi-class logistic regression setting, we would like to also have a probabil-
ity squashing function that generalizes beyond two classes. This generalization of the sigmoidal
function is known as softmax.

Definition 3.6.3 (Softmax): Softmax is the multi-class generalization of the sigmoidal activa-
tion function. It accepts a vector of activations (inputs) and returns a vector of probabilities
corresponding to those activations. It is defined as follows:

exp (zk )
softmaxk (z) = PK , for all k
i=1 exp (zi )

Multi-class logistic regression uses softmax over a vector of activations to select the most likely
target class for a new data point. It does this by applying softmax and then assigning the new data
point to the class with the highest probability.
3.6. PROBABILISTIC METHODS 43

Example 3.1 (Softmax Example): Consider an example that has three classes: C1 , C2 , C3 . Let’s
say we have an activation vector z for our new data point x that we wish to classify, given by:
 
4

z = W x = 1
7

where

zj = w⊤
j x

Then, using our definition of softmax, we have:


 
0.047
softmax(z) = 0.002
0.950

And therefore, we would assign our new data point x to class C3 , which has the largest activation.

As in the two-class logistic regression case, we now need to solve for the parameters W of our
model, also written as {wj }. Assume we have an observed data set {xi , yi } for i = 1..N where yi
are one-hot encoded target vectors. We begin this process by writing the likelihood for our data,
which is only slightly modified here to account for multiple classes:
N Y
K N Y
K
y
Y Y
p({yi }N
i=1 |W) = p(yi = Cj |xi )yij = ŷijij (3.15)
i=1 j=1 i=1 j=1

where ŷij = softmaxj (Wxi )

We can now take the negative logarithm to get the cross-entropy error function for the multi-
class classification problem:
N X
X K
E(W) = − ln p({yi }N
i=1 |W) =− yij ln ŷij (3.16)
i=1 j=1

As in the two-class case, we now take the gradient with respect to one of our weight parameter
vectors wj :
XN
∇wj E(W) = (ŷij − yij )xi (3.17)
i=1

which we arrived at by recognizing that the derivative of the softmax function with respect to the
input activations zj can be written in terms of itself:

∂softmaxk (z)
= softmaxk (z)(Ikj − softmaxj (z))
∂zj

where I is the identity matrix.


As in the two-class case, now that we have this gradient expression, we can use an iterative
procedure like gradient descent to optimize our parameters W.
44 CHAPTER 3. CLASSIFICATION

3.6.2 Probabilistic Generative Models


With the probabilistic discriminative modeling approach, we elected to directly model p(y ∗ |x∗ ).
However, there was an alternative option: we could have instead modeled the joint distribution of
the class y ∗ and the input data point x∗ together as p(y ∗ , x∗ ). This approach is what’s known as
probabilistic generative modeling because we actually model the process by which the data
was generated.
To model the data generating process in classification tasks generally acknowledges that a data
point is produced by first selecting a class y ∗ from a categorical class prior p(y ∗ ) and then generating
the data point x∗ itself from the class-conditional distribution p(x∗ |y ∗ ), the form of which is problem
specific. This generative approach is a particularly good idea if we want to create more data (by
sampling from the joint distribution) or if we have some sort of expert knowledge about how the
data was generated, which can make our model more powerful than the discriminative approach.

⋆ Notice that with probabilistic generative modeling, we choose a specific distribution for our class-conditional
densities instead of simply using a generalized linear model combined with a sigmoid/softmax function as we did
in the logistic regression setting. This highlights the difference between discriminative and generative modeling: in
the generative setting, we are modeling the production of the data itself instead of simply optimizing the parameters
of a more general model that predicts class membership directly.

Classification in the Generative Setting


Now that we’re situated in the generative setting, we turn our attention to the actual problem of
using our model to predict class membership of new data points x∗ .
To perform classification, we will pick the class Ck that maximizes the probability of x∗ being
from that class as defined by p(y ∗ = Ck |x∗ ). We can relate this conditional density to the joint
density p(y ∗ , x∗ ) through Bayes’ Rule:
p(y ∗ , x∗ ) p(x∗ |y ∗ = Ck )p(y ∗ = Ck )
p(y ∗ = Ck |x∗ ) = = ∝ p(x∗ |y ∗ = Ck )p(y ∗ = Ck )
p(x∗ ) p(x∗ )
where p(x∗ ) is a constant that can be ignored as it will be the same for every conditional probability
p(y ∗ = Ck |x∗ ).
Recall that the class prior p(y) will always be a categorical distribution (the multi-class general-
ization of the Bernoulli distribution), while the class-conditional distribution can be specified using
prior knowledge of the problem domain. Once we have specified this class conditional distribution,
we can solve for the parameters of both that model and the categorical distribution by optimizing
the likelihood function. Let’s now derive that likelihood function.

Derivation 3.6.1 (Probabilistic Generative Model Likelihood Function):


We’re going to derive the likelihood function for the parameters of our probabilistic generative
model in the two-class setting, allowing that the multi-class generalization will be a straightforward
exercise.
Let’s start by assuming a Gaussian conditional distribution for our data p(x|y = Ck ). Given a data
set {xi , yi } for i = 1..N , where yi = 1 corresponds to class C1 and yi = 0 corresponds to class C2 ,
we can construct our maximum likelihood solution. Let’s first specify our class priors:
p(C1 ) = π
p(C2 ) = 1 − π
3.6. PROBABILISTIC METHODS 45

For simplicity, we’ll assume a shared covariance matrix Σ between our two classes. Then, for data
points xi from class C1 , we have:

p(xi , C1 ) = p(C1 )p(xi |C1 ) = πN (xi |µ1 , Σ)

And for data points xi from class C2 , we have:

p(xi , C2 ) = p(C2 )p(xi |C2 ) = (1 − π)N (xi |µ2 , Σ)

Using these two densities, we can construct our likelihood function:


N 
Y yi  1−yi
p(π, µ1 , µ2 , Σ) = πN (xi |µ1 , Σ) (1 − π)N (xi |µ2 , Σ)
i=1

As usual, we will take the logarithm which is easier to work with:


N
X    
ln p(π, µ1 , µ2 , Σ) = yi ln πN (xi |µ1 , Σ) + (1 − yi ) ln (1 − π)N (xi |µ2 , Σ)
i=1

Now that we have specified the log-likelihood function for our model, we can go about optimizing
our model by maximizing this likelihood. One way to do this is with a straightforward maximimum
likelihood estimation approach. We will optimize our parameters π, µ1 , µ2 , and, Σ separately, using
the usual procedure of taking the derivative, setting equal to 0, and then solving for the parameter
of interest. We write down this MLE solution in the following section.

MLE Solution
Solving for π
Beginning with π, we’ll concern ourselves only with the terms that depend on π which are:
N
X
yi ln π + (1 − yi ) ln (1 − π)
i=1

Taking the derivative with respect to π, setting equal to 0, rearranging, we get:


N
1 X N1 N1
π= yi = =
N N N1 + N2
i=1

where N1 is the number of data points in our data set from class C1 , N2 is the number of data
points from class C2 , and N is just the total number of data points. This means that the maximum
likelihood solution for π is the fraction of points that are assigned to class C1 , a fairly intuitive
solution and one that will be commonly seen when working with maximum likelihood calculations.

Solving for µ
Let’s now perform the maximization for µ1 . Start by considering the terms from our log likelihood
involving µ1 :
N N
X 1X
yi ln N (xi |µ1 , Σ) = − yi (xi − µ1 )⊤ Σ−1 (xi − µ1 ) + c
2
i=1 i=1
46 CHAPTER 3. CLASSIFICATION

where c are constants not involving the µ1 term. Taking the derivative with respect to µ1 , setting
equal to 0, and rearranging:
N
1 X
µ1 = yi xi
N1
i=1

which is simply the average of all the data points xi assigned to class C1 , a very intuitive result.
By the same derivation, the maximum likelihood solution for µ2 is:
N
1 X
µ2 = (1 − yi )xi
N2
i=1

Solving for Σ

We can also the maximum likelihood solution for the shared covariance matrix Σ. Start by
considering the terms in our log likelihood expression involving Σ:
N N N
1X 1X 1X
− yi ln |Σ| − yi (xi − µ1 )⊤ Σ−1 (xi − µ1 ) − (1 − yi ) ln |Σ|
2 2 2
i=1 i=1 i=1
N
1 X
− (1 − yi )(xi − µ2 )⊤ Σ−1 (xi − µ2 )
2
i=1

We can use the following “matrix cookbook formulas” to help with taking the derivative with
respect to Σ. Also, we adopt convention Z−⊤ := (Z⊤ )−1 . The two helpful formulas are:

∂a⊤ Z−1 b
= −Z−⊤ ab⊤ Z−⊤
∂Z
∂ ln | det(Z)|
= Z−⊤ .
∂Z
Now, taking the derivative with respect to Σ and collecting terms, we have:
N N
1 1X 1X
− N Σ−⊤ + yi Σ−⊤ (xi − µ1 )(xi − µ1 )⊤ Σ−⊤ + (1 − yi )Σ−⊤ (xi − µ2 )(xi − µ2 )⊤ Σ−⊤
2 2 2
i=1 i=1

Setting this to zero and multiplying both sides by Σ⊤ from the left and right (with the effect of
−1 −1 ⊤
retaining Σ⊤ in only the first term, since Σ⊤ Σ−⊤ = Σ⊤ Σ⊤ = I and Σ−⊤ Σ⊤ = Σ⊤ Σ =
I), and multiplying by 2, we have
N 
X 
⊤ ⊤ ⊤
NΣ − yi (xi − µ1 )(xi − µ1 ) + (1 − yi )(xi − µ2 )(xi − µ2 ) = 0.
i=1

Rearranging to solve for Σ, and recognizing that covariance matrices are symmetric, and so
Σ⊤ = Σ, we have:
N  
1 X ⊤ ⊤
Σ= yi (xi − µ1 )(xi − µ1 ) + (1 − yi )(xi − µ2 )(xi − µ2 ) .
N
i=1
3.6. PROBABILISTIC METHODS 47

This has the intuitive interpretation that the maximum likelihood solution for the shared co-
variance matrix is the weighted average of the two individual covariance matrices. Note that
(xi − µ1 )(xi − µ1 )⊤ is a matrix (the outer product of the two vectors). Also, yi is a scalar, which
means that each term is a sum of matrices. For any point i, only one of two matrices inside will
contribute due to the use of yi and (1 − yi ).

It is relatively straightforward to extend these maximum likelihood derivations from their two-
class form to their more general, multi-class form.

Naive Bayes

There exists a further simplification to probabilistic generative modeling in the context of classifi-
cation known as Naive Bayes.

Definition 3.6.4 (Naive Bayes): Naive Bayes is a type of generative model for classification
tasks. It imposes the simplifying rule that for a given class Ck , we assume that each feature of
the data points x generated within that class are independent (hence the descriptor ‘naive’). This
means that the conditional distribution p(x|y = Ck ) can be written as:
D
Y
p(x|y = Ck ) = p(xi |y = Ck )
i=1

where D is the number of features in our data point x and Ck is the class. Note that Naive Bayes
does not specify the form of the model p(xi |y = Ck ), this decision is left up to us.

This is obviously not a realistic simplification for all scenarios, but it can make our calculations
easier and may actually hold true in certain cases. We can build more intuition for how Naive
Bayes works through an example.

Example 3.2 (Naive Bayes Example): Suppose you are given a biased two-sided coin and two
biased dice. The coin has probabilities as follows:

Heads : 30%
Tails : 70%

The dice have the numbers 1 through 6 on them, but they are biased differently. Die 1 has
probabilities as follows:

1 : 40%
2 : 20%
3 : 10%
4 : 10%
5 : 10%
6 : 10%
48 CHAPTER 3. CLASSIFICATION

Die 2 has probabilities as follows:

1 : 20%
2 : 20%
3 : 10%
4 : 30%
5 : 10%
6 : 10%

Your friend is tasked with doing the following. First, they flip the coin. If it lands Heads, they
select Die 1, otherwise they select Die 2. Then, they roll that die 10 times in a row, recording the
results of the die rolls. After they have completed this, you get to observe the aggregated results
from the die rolls. Using this information (and assuming you know the biases associated with the
coin and dice), you must then classify which die the rolls came from. Assume your friend went
through this procedure and produced the following counts:

1:3
2:1
3:2
4:2
5:1
6:1

Determine which die this roll count most likely came from.

Solution:
This problem is situated in the Naive Bayes framework: for a given class (dictated by the coin
flip), the outcomes within that class (each die roll) are independent. Making a classification in this
situation is as simple as computing the probability that the selected die produced the given roll
counts. Let’s start by computing the probability for Die 1:

p(Die 1) = p(Coin Flip = Heads) ∗ p(Roll Count = [3, 1, 2, 2, 1, 1])


∝ 0.3 ∗ (0.4)3 ∗ (0.2)1 ∗ (0.1)2 ∗ (0.1)2 ∗ (0.1)1 ∗ (0.1)1
∝ 3.84 ∗ 10−9

Notice that we don’t concern ourselves with the normalization constant for the probability of the
roll count - this will not differ between the choice of dice and we can thus ignore it for simplicity.
Now the probability for Die 2:

p(Die 2) = p(Coin Flip = Tails) ∗ p(Roll Count = [3, 1, 2, 2, 1, 1])


∝ 0.7 ∗ (0.2)3 ∗ (0.2)1 ∗ (0.1)2 ∗ (0.3)2 ∗ (0.1)1 ∗ (0.1)1
∝ 1.008 ∗ 10−8

Therefore, we would classify this roll count as having come from Die 2.
3.7. CONCLUSION 49

Note that this problem asked us only to make a classification prediction after we already
knew the parameters governing the coin flip and dice rolls. However, given a data set, we could
have also used a maximum likelihood procedure under the Naive Bayes assumption to estimate
the values of the parameters governing the probability of the coin flip and die rolls.

3.7 Conclusion
In this chapter, we looked at different objectives and techniques for solving classification problems,
including discriminant functions, probabilistic discriminative models, and probabilistic generative
models. In particular, we emphasized the distinction between two-class and multi-class problems
as well as the philosophical differences between generative and discriminative modeling.
We also covered several topics that we will make use of in subsequent chapters, including sigmoid
functions and softmax, maximum likelihood solutions, and further use of basis changes.
By now, you have a sound understanding of generative modeling and how it can be applied to
classification tasks. In the next chapter, we will explore how generative modeling is applied to a
still broader class of problems.
Chapter 4

Neural Networks

Despite how seemingly popular neural networks have become recently, they aren’t actually a novel
technique. The first neural networks were described in the early 1940s, and the only reason they
weren’t put into practice shortly thereafter was the fact that we didn’t yet have access to the large
amounts of storage and compute that complex neural network require. Over the last two decades,
and particularly with the advent of cloud computing, we now have more and more access to the
cheap processing power and memory required to make neural networks a viable option for model
building.
As we will come to see in this chapter, neural networks are an extraordinarily flexible class of
models used to solve a variety of different problem types. In fact, this flexibility is both what makes
them so widely applicable and yet so difficult to use properly. We will explore the applications,
underlying theory, and training schemes behind neural networks.

4.1 Motivation
For problems that fall into the category of regression or classification, we’ve already discussed the
utility of basis functions. Sometimes, a problem that is intractable with our raw input data will
be readily solvable with basis-transformed data. We often select these basis changes using expert
knowledge. For example, if we were working with a data set that related to chemical information,
and there were certain equations that a chemist told us to be important for the particular problem
we were trying to solve, we might include a variety of the transformations that are present in those
equations.
However, imagine now that we have a data set with no accompanying expert information. More
often than not, complex problem domains don’t come with a useful set of suggested transformations.
How do we find useful basis functions in these situations? This is exactly the strength of neural
networks - they identify the best basis for a data set!
Neural networks simultaneously solve for our model parameters and the best basis transforma-
tions. This makes them exceedingly flexible. Unfortunately, this flexibility is also the weakness of
neural nets: while it enables us to solve difficult problems, it also creates a host of other complica-
tions. Chief among these complications is the fact that neural networks require a lot of computation
to train. This is a result of the effective model space being so large - to explore it all takes time
and resources. Furthermore, this flexibilty can cause rather severe overfitting if we are not careful.
In summary, neural networks identify good basis transformations for our data, and the strengths
and weaknesses of neural networks stem from the same root cause: model flexibility. It will be our
goal then to appropriately harness these properties to create useful models.

50
4.1. MOTIVATION 51

4.1.1 Comparison to Other Methods

In the previous two chapters, we explored two broad problem types: classification and regression,
and it’s natural to wonder where neural networks fit in. The answer is that they are applicable to
both. The flexibility of neural networks even extends to the types of problems they can be made to
handle. Thus, the tasks that we’ve explored over the last two chapters, such as predicting heights
in the regression case or object category in the classification case, can be performed by neural
networks.
Given that neural networks are flexible enough to be used as models for either regression or
classification tasks, this means that every time you’re faced with a problem that falls into one of
these categories, you have a choice to make between the methods we’ve already covered or using
a neural network. Before we’ve explored the specifics of neural networks, how can we discern at a
high level when they will be a good choice for a specific problem?
One simple way to think about this is that if we never needed to use neural networks, we probably
wouldn’t. In other words, if a problem can be solved effectively by one of the techniques we’ve
already described for regression or classification (such as linear regression, discriminant functions,
etc.), we would prefer to use those. The reason is that neural networks are often more memory
and processor intensive than these other techniques, and they are much more complex to train and
debug.
The flip side of this is that hard problems are often too complex or too hard to engineer features
for to use a simple regression or classification technique. Indeed, even if you eventually think you
will need to use a neural network to solve a given problem, it makes sense to try a simple technique
first both to get a baseline of performance and because it may just happen to be good enough.
What is so special about neural networks that they can solve problems that the other techniques
we’ve explored may not be able to? And why are they so expensive to train? These questions will
be explored over the course of the chapter, and a good place to start is with the status of neural
networks as universal function approximators.

4.1.2 Universal Function Approximation

The flexibility of neural networks is a well-established phenomenon. In fact, neural networks are
what are known as universal function approximators. This means that with a large enough network,
it is possible to approximate any function. The proof of this is beyond the scope of this textbook,
but it provides some context for why flexibility is one of the key attributes of neural networks.
ML Framework Cube: Neural Networks
As universal function approximators, neural networks can operate over discrete or continuous out-
puts. We primarily use neural networks to solve regression or classification problems, which involve
training on data sets with example inputs and outputs, making this a supervised technique. Fi-
nally, while there exist probabilistic extensions for neural networks, they primarily operate in the
non-probabilistic setting.
Domain Training Probabilistic
Continuous/Discrete Supervised No
52 CHAPTER 4. NEURAL NETWORKS

Figure 4.1: Simple Neural Network.

4.2 Feed-Forward Networks


The feed-forward neural network is the most basic setup for a neural network. Most of the logic
behind neural networks can be explained using a feed-forward network, with additional bells and
whistles typically added to form more complex networks. We will explore this basic neural network
structure first.

4.3 Neural Network Basics and Terminology


Looking at Figure 4.1, we see that a feed-forward neural network is a series of connected layers that
transform an input data point x into an output data point y. Each layer is composed of nodes,
the small black circles. Each node in the input layer corresponds to one dimension of a single data
point x (meaning the first node is x1 , the second node x2 , etc.). The same is true of the nodes in
the output layer, which represent each dimension of y. For a binary classification problem there is
a single output node, representing the predicted probability of the positive class (with K outputs
for multi-classification). For a regression problem there may one or more nodes, depending on the
dimensions of the output. The nodes in the hidden layers correspond to activation functions.
Let’s zoom in on the first node in hidden layer 1, shown in Figure 4.2, to describe what happens
at each node as we transform an input. For now we adopt simple notation and don’t worry about
indexing by later. Looking at Figure 4.2, notice that every node in the input layer is connected to
this first node in hidden layer 1. Each of the lines is a connection, that has a weight wd associated
with it. We multiply all the nodes in the input layer by their corresponding weight, and then add
them all together to produce the activation a, i..e, the input at this first node (we’ve included the
bias term in the input vector):
a = x 1 w1 + x 2 w2 + x 3 w3 (4.1)
4.3. NEURAL NETWORK BASICS AND TERMINOLOGY 53

Figure 4.2: Zooming in on the inputs and the first node of the first layer.

⋆ Every node in every layer has distinct weights associated with it.

This gives us the activation for the first node in the first hidden layer. Once we’ve done this
for every node in the first hidden layer, we make a non-linear transform of these activations, and
then move on to computing the activations for the second hidden layer (which require the outputs
from the first hidden layer, as indicated by the network of connections). We keep pushing values
through the network in this manner until we have our complete output layer, at which point we
are finished.
We’ve skipped over some important details in this high-level overview, but with this general
information about what a neural network looks like and the terminology associated with it, we can
now dig into the details a little deeper.

4.3.1 Adaptive Basis Functions


As we mentioned in the introduction, the strength of neural networks is that we can learn an
effective basis for our problem domain at the same time as we train the parameters of our model.
In fact, learning this basis becomes just another part of our parameter training. Let’s make this
notion of learning a basis more concrete.
Thinking back to our chapter on linear regression, we were training a model that made predic-
tions using a functional form that looked like:
D
X
T
y(x, w) = wϕ = wd ϕ d
d=1
54 CHAPTER 4. NEURAL NETWORKS

where ϕ = ϕ(x), ϕ is the basis transformation function, and D is the dimensionality of the data
point.
Typically with this linear regression setup, we are training our model to optimize the parameters
w. With neural networks this is no different— we still train to learn those parameters. However,
the difference in the neural network setting is that the basis transformation function ϕ is no longer
fixed. Instead, the transformations are incorporated into the model parameters, and thus learned
at the same time.
This leads to a different functional form for neural networks. A neural network with M nodes
in its first hidden layer performs M linear combinations of an input data point x:

D
(1) (1) (1)
X
aj = wjd xd + wj0 ∀j ∈ 1..M (4.2)
d=1

Here, we use a(1) to denote the activation of a unit in layer 1 and notation w(1) denotes the
weights used to determine the activations in layer 1. We also make the bias explicit. We will still
(1)
use the bias trick in general, but we’ve left it out here to explicitly illustrate the bias term wj0 .
Other than this, equation 4.2 describes what we’ve already seen in Figure 4.2. The only difference
is that we index each node in the hidden layer (along with its weights) by j.
(1)
The M different values aj are the activations. We transform these activations with a non-linear
activation function h(·) to give:
(1)
zj = h(aj ) (4.3)

⋆ Note that we didn’t mention activation functions in the previous section only for the sake of simplicity. These
non-linearities are crucial to the performance of neural networks because they allow for modeling of outcomes that
vary non-linearly with their input variables.

(1)
These values zj correspond to the outputs of the hidden units, each of which is associated
with an activation function. Superscript (1) indicates they are the outputs of units in layer 1. A
typical activation function is the sigmoid function, but other common choices are the tanh function
and rectified linear unit (ReLU).
(1)
These output values zj , for units j ∈ {1, . . . , M } in layer 1, form the inputs to the next
layer. The activation of unit j ′ in layer 2 depends on the outputs from layer 1 and the weights
(2) (2) (2)
wj ′ 0 , wj ′ 1 , . . . , wj ′ M that define the linear sum at the input of unit j ′ :

M
(2) (2) (1) (2)
X
aj ′ = wj ′ m z j + wj ′ 0 (4.4)
j=1

We can connect many layers together in this way. They need not all have the same number of
nodes but we will adopt M for the number of nodes in each layer for convenience of exposition.
Eventually, we will reach the output layer, and each output is denoted yk , for k ∈ {1, . . . , K}. The
final activation function may be the sigmoid function, softmax function, or just linear (and thus no
transform).
We can now examine a more complete diagram of a feed-forward neural network, shown in
Figure 4.3. It may be helpful to reread the previous paragraphs and use the diagram to visualize
how a neural network transforms its inputs. This is a single hidden layer, or two-layer, network.
Here, we use z to denote the output values of the units in the hidden layer.
4.3. NEURAL NETWORK BASICS AND TERMINOLOGY 55

Figure 4.3: Feed-Forward Neural Network.

⋆ Different resources choose to count the number of layers in a neural net in different ways. We’ve elected to count
each layer of non-input nodes, thus the two-layer network in Figure 4.3. However, some resources will choose to count
every layer of nodes (three in this case) and still others count only the number of hidden layers (making this a one
layer network).

Combining Figure 4.3 and our preceeding functional description, we can describe the operation
performed by a two-layer neural network using a single functional transformation (with m to index
a unit in the hidden layer):
M
X D
X  
(2) (1) (1) (2)
yk (x, w) = σ wkm h wmd xd + wm0 + wk0 (4.5)
m=1 d=1

where we’ve elected to make the final activation function the sigmoid function σ(·), as is suitable
for binary classification. We use h to denote the non-linear activation function for a hidden unit.
Written like this, a neural network is simply a non-linear function that transforms an input x into
an output y that is controlled by our set of parameters w.
Furthermore, we see now why this basic variety of neural networks is a feed-forward neural
network. We’re simply feeding our input x forward through the network from the first layer to
the last layer. Assuming we have a fully trained network, we can make predictions on new input
data points by propagating them through the network to generate output predictions (“the forward
pass”).
We can also simplify this equation by utilizing the bias trick and appending an x0 = 1 value to
each of our data points such that:
XM X D 
(2) (1)
yk (x, w) = σ wkm h wmd xd
m=1 d=1

Finally, it’s worth considering that while a neural network is a series of linear combinations, it
is special because of the differentiable non-linearities applied at each of the hidden layers. Without
56 CHAPTER 4. NEURAL NETWORKS

these non-linearities, the successive application of different network weights would be equivalent to
a single large linear combination.

4.4 Network Training


Now that we understand the structure of a basic feed-forward neural network and how they can be
used to make predictions, we turn our attention to the training process.

4.4.1 Objective Function


To train our network, it’s first necessary to establish an objective function. Remember that neural
networks can be used to solve both regression and classification problems, which means that our
choice of objective will depend on the type of problem and the properties we desire.
For the case of linear regression, a common objective function is the least squares loss:

N  2
1X
L(w) = y(xn , w) − yn , (4.6)
2
n=1

where yn is the target value on example n. Sometimes we will have a regression problem with
multiple outputs, in which case the loss would also take the sum over these different target values.
For a binary classification problem, which we model through a single, sigmoid output activation
unit, then negated log-likelihood (or cross-entropy) is the typical loss function:

N 
X 
L(w) = − yn ln ŷn + (1 − yn )(ln (1 − ŷn ) (4.7)
n=1

For a multiclass classification problem, produced by a softmax function in the output activation
layer, we would use the negated log likelihood (cross entropy) loss:

N X
K  
X exp(ak (x, w))
L(w) = − ykn ln PK (4.8)
n=1 k=1 j=1 exp(aj (x, w))

⋆ Loss function and objective function all refer to the same concept: the function we optimize to train our model.

4.4.2 Optimizing Parameters


We want to find weight parameters w to minimize the objective function. The highly non-linear
nature of neural networks means that this objective function will be non-convex with respect to
our weight parameters. But still, we make use of stochastic gradient descent for optimization (refer
to the previous chapter for a refresher).
In order to use gradient descent, we first need to figure out how to compute the gradient of our
objective function with respect to our weights. That is the topic of the next section.
4.4. NETWORK TRAINING 57

4.4.3 Backpropagation
Considering how our feed-forward neural network works, by propagating activations through our
network to produce a final output, it’s not immediately clear how we can compute gradients for
the weights that lie in the middle of our networ. There is an elegant solution to this, which comes
from “sending errors backwards” through our network, in a process known as backpropagation.

Definition 4.4.1 (Backpropagation):


Backpropagation is the procedure by which we pass errors backwards through a feed-forward neural
network in order to compute gradients for the weight parameters of the network.

Backpropagation refers specifically to the portion of neural network training during which we
compute the derivative of the objective function with respect to the weight parameters. This is
done by propagating errors backwards through the network, hence the name.

⋆ Note that we still need to update the value of the weight parameters after computing their derivatives. This is
typically done using gradient descent or some variant of it.

We now explore the details of backpropagation in greater depth.

4.4.4 Computing Derivatives Using Backpropagation


(ℓ)
Recall that the activation aj for a node j in layer ℓ of a neural network can be described by the
equation:
M
(ℓ) (ℓ)
X
(ℓ−1)
aj = wjm zm , (4.9)
m=1

(ℓ−1) (ℓ−1)
where there are M incoming nodes, each with corresponding output values z1 , ..., zM , and
(ℓ) (ℓ)
with the weights in layer ℓ corresponding to node j denoted by wj1 , ..., wjM . This activation is
(ℓ)
transformed by an activation function h(·) to give unit output zj :

(ℓ) (ℓ)
zj = h(aj ). (4.10)

Computing these values as we flow through the network constitutes the forward pass through
our network.
We now wish to begin the process of computing derivatives of the objective function with respect
to our weights. For the sake of simplicity, we’ll assume that the current setting of our parameters
w generates a loss of L for a single data point, as though we were performing stochastic gradient
descent.
Let’s consider how we could compute the derivative of L with respect to an individual weight
(ℓ)
in our network, wjm (the mth weight for activation j in layer ℓ):

∂L
(ℓ)
. (4.11)
∂wjm

We first need to figure out what the dependence of L is on this weight. This weight contributes
(ℓ)
to the final result only via its contribution to the activation aj . This allows us to use the chain
58 CHAPTER 4. NEURAL NETWORKS

Figure 4.4: Gradient of the loss function in a neural network with respect to a weight. It depends
on the input value zm and the “error” corresponding to the activation value at the output end of
the weight.

rule to simplify Equation 4.11 as:

(ℓ)
∂L ∂L ∂aj
(ℓ)
= (ℓ)
· (ℓ)
. (4.12)
∂wjm ∂aj ∂wjm

The first part of this is the, typically non-linear, dependence of loss on activation. The second
part is the linear dependence of activation on weight. Using Equation 4.9, we have that:

(ℓ)
∂aj (ℓ−1)
(ℓ)
= zm ,
∂wjm

and just the value of the input from the previous layer. We now introduce the following notation
for the first term,
(ℓ) ∂L
δj = (ℓ)
, (4.13)
∂aj

(ℓ)
where δj values are referred to as errors. We rewrite Equation 4.12 as:

∂L (ℓ)
(ℓ−1)
(ℓ)
= δj zm . (4.14)
∂wjm

The implications of Equation 4.14 are significant for understanding backpropagation. The
derivative of the loss with respect to an arbitrary weight in the network can be calculated as the
(ℓ) (ℓ−1)
product of the error δj at the “output end of that weight” and the value zm at the “input end
of the weight.” We visualize this property in Figure 4.4 (dropping the layer subscripting).
To compute the derivatives, it suffices to compute the values of δj for each node, also saving
the output values zm during the forward pass through the network (to be multiplied by the values
of δj to get partials).
4.4. NETWORK TRAINING 59

⋆ We will only have “errors values” δj for the hidden and output units of our network. This is logical because there
is no notion of applying an error to our input data, which we have no control over.

We now consider how to compute these error values. For a unit in the output layer, indexing
it here by k, and assuming the output activation function is linear and adopting least squares loss
(i.e., regression), we have for the dependence of loss on the activation of this unit,
d 12 (ŷk − yk )2

(ℓ) ∂L ∂L
δk = (ℓ)
= = = ŷk − yk .
∂a ∂ ŷk dŷk
k
(ℓ)
Here, we use shorthand ak = ŷk , providing the kth dimension of the prediction of the model.
Although a regression problem, we’re imagining here that there are multiple regression targets (say,
the height, weight and blood pressure of an individual). Here, yk is the true target value for this
data point. Note that this is for OLS. The expression would be different for a classification problem
and negated log likelihood as the loss.
(ℓ)
To compute the error δj for a hidden unit j in a layer ℓ, we again make use of the chain rule,
and write:
M (ℓ+1)
(ℓ) ∂L X ∂L ∂am
δj = (ℓ)
= (ℓ+1) (ℓ)
, (4.15)
∂aj m=1 ∂am ∂aj
where the summation runs over all of the M nodes to which the node j in layer ℓ sends connections,
as seen in Figure 4.5. This expression recognizes that the activation value of this unit contributes
only via its contribution to the activation value of each unit to which it is connected in the next layer.
The first term in one of the products in the summation is the, typically non-linear, dependence
between loss and activation value of a unit in the next layer. The second term in one of the products
captures the relationship between this activation and the subsequent activation.
Now, we can simplify by noticing that:
∂L (ℓ+1)
(ℓ+1)
= δm {by definition} (4.16)
∂am
(ℓ+1) (ℓ)
∂am dh(aj ) (ℓ+1) (ℓ) (ℓ+1)
(ℓ)
= (ℓ)
· wmj = h′ (aj ) · wmj . {chain rule} (4.17)
∂aj daj
Substituting, and pulling forward the derivative of the activation function, we can rewrite the
expression for the error on a hidden unit j in layer ℓ as:
M
(ℓ) (ℓ) (ℓ+1)
X
δj = h′ (aj ) (ℓ+1)
wmj δm . (4.18)
m=1

This is very useful, and is the key insight in backpropagation. It means that the value of the
errors can be computed by “passing back” (backpropagating) the errors for nodes farther up in the
network! Since we know the values of δ for the final layer of output node, we can recursively apply
Equation 4.18 to compute the values of δ for all the nodes in the network.
Remember that all of these calculations were done for a single input data point that generated
the loss L. If we were using SGD with mini-batches, then we would perform same calculation for
each data point in mini-batch B, and average the gradients as follows:
∂L 1 X ∂Ln
(ℓ)
= (ℓ)
, (4.19)
∂wjm |B|
n∈B ∂wjm
60 CHAPTER 4. NEURAL NETWORKS

Figure 4.5: Summation over the nodes (blue) in layer ℓ + 1 to which node j in layer ℓ (gold) sends
connections (green). Note: read this as m and m′ .

where Ln is the loss on example n.


To solidify our understanding of the backpropagation algorithm, it can be useful to try a concrete
example.

Example 4.1 (Backpropagation Example): Imagine the case of a simple two layer neural net-
work as in Figure 4.3, with K outputs (we denote them ŷ1 through ŷK for a given data point). We
imagine this is a regression problem, but one with multiple dimensions to the output, and assume
OLS. For a given data point, the loss is computed as:
K
X 1
L= (ŷk − yk )2 ,
2
k=1

where we write yk for the kth dimension of the target value. For a unit in the hidden layer, with
activation value a, we make use of the sigmoid activation, with
1
z = σ(a) = ,
1 + exp (−a)
whose derivative is given by:
∂σ(a)
= σ(a)(1 − σ(a)).
∂a
For an input data point x, we forward propagate through the network to get the activations of the
hidden layer, and for each m in this layer we have:
D
(1)
X
a(1)
m = wmd xd ,
d=0
4.5. CHOOSING A NETWORK STRUCTURE 61

(1) (1)
given weights wm0 , . . . , wmD , and with output value from unit m as,
(1)
zm = σ(a(1)
m ).

We propagate these output values forward to get the outputs, and for each output unit k, we have:
M
(2)
X
(1)
ŷk = wkm zm ,
m=0

(2) (2)
where wk0 , . . . , wkM are the weights for unit k.
Now that we’ve propagated forward, we propagate our errors backwards! We start by computing
the errors for the output layer as follows:

(2) ∂L
δk = = ŷk − yk .
∂ ŷk
We then backpropagate these errors back to each hidden unit m in layer 1 as follows:
K
∂L X (2) (2)
(1)
δm = (1)
= h′ (a(1)
m ) wkm δk
∂am k=1
K
(2)
X
= σ(a(1) (1)
m )(1 − σ(am )) wkm (ŷk − yk )
k=1
K
(2)
X
(1) (1)
= zm (1 − zm ) wkm (ŷk − yk ).
k=1

And now that we have our errors for the hidden and output layers, we can compute the derivative
of the loss with respect to our weights as follows, for the dth weight on the mth unit in layer 1,
and the mth weight on the kth unit in layer 2:
∂L (1) ∂L (2)
(1)
(1)
= δm xd , (2)
= δk zm .
∂wmd ∂wkm

We then use these derivatives along with an optimization technique such as stochastic gradient
descent to improve the model weights.

4.5 Choosing a Network Structure


Now that we know the general form of a neural network and how the training process works, we
must step back and consider the question of how we actually arrive at an optimal network structure.
We’ll begin with an idea we’ve already seen before: cross validation.

4.5.1 Cross Validation for Neural Networks


We’ve previously discussed cross validation in the chapter on linear regression. We used it then
to compare the performance of different models, attempting to identify the best model while also
62 CHAPTER 4. NEURAL NETWORKS

Figure 4.6: Networks with different structures and numbers of internal nodes.

avoiding overfitting. We can use a similar process to identify a reasonable network structure.
First of all, the input and output parameters of a neural network are generally decided for us:
the dimensionality of our input data dictates the number of input units and the dimensionality of
the required output dictates the number of output units. For example, if we have an 8-by-8 pixel
image and need to predict whether it is a ‘0’ or a ‘1’, our input dimensions are fixed at 64 and our
output dimensions are fixed at 2. Depending on whether you wish to perform some sort of pre or
post-processing on the inputs/outputs of your network, this might not actually be the case, but
in general when choosing a network structure we don’t consider the first or last layer of nodes as
being a relevant knob that we can tune.
That leaves us to choose the structure of the hidden layers in our network. Unsurprisingly, the
more hidden layers we have and the more nodes we have in each of those layers, the more variation
we will produce in our results and the closer we will come to overfitting.
Thus, we can use cross validation in the same way we’ve done before: train our model with
differing numbers of internal units and structures (as in Figure 4.6) and then select the model that
performs best on the validation set.

⋆ There are other considerations at play beyond performance when choosing a network structure. For example, the
more internal units you have in your network, the more storage and compute time you will need to train them. If
either training time or response time after training a model is critical, you may need to consider consolidating your
network at the expense of some performace.

4.5.2 Preventing Overfitting


Besides keeping your neural network small, there are other means of preventing it from overfitting.

Regularization
You can also apply regularization to the weights in your network to help prevent overfitting. For
example, we could introduce a simple quadratic regularizer of the form λ2 wT w to our objective
function. There are other considerations to be made here, for example we would like our regularizer
to be invariant to scaling, meaning that multiplying our input data by a constant would produce
a proportionally equivalent network after training. The quadratic regularizer is not invariant to
scaling, but the basic concept of avoiding extreme weights is the same nonetheless.

Data Augmentation
We can use transformations to augment our data sets, which helps prevent overfitting. This tech-
nique is not specific to neural networks, but often the types of unstructured data for which we use
4.6. SPECIALIZED FORMS OF NEURAL NETWORKS 63

neural networks can benefit greatly from it.

Definition 4.5.1 (Data Augmentation): Data augmentation refers to the practice of increasing
the size and diversity of your training data by applying transformations to the initial data set.

For example, if we are working with image data, we might choose to rotate or reflect the image,
depending on the type of network we are trying to build and whether or not this would preserve
the integrity of the image. We might also change something like the brightness or density of the
image data. In this way, we can produce more and more varied training points, thus reducing the
likelihood of overfitting.

4.6 Specialized Forms of Neural Networks


Simple neural networks are useful for a general set of tasks, and as universal function approxima-
tors, they could be useful for any task. However, there are certain data types and use cases for
which we’ve developed more specialized forms of neural networks that perform even better in their
respective domains. We will take a high level view of these different flavors of neural networks.

4.6.1 Convolutional Neural Networks (CNNs)


Convolutional neural networks (abbreviated CNNs) are most often used for image data, but their
underlying principles apply in other domains as well.
To understand why a CNN is useful, consider this specific problem: you are trying to determine
whether or not there is a dog in an image. There are two general difficulties we have to deal with
in solving this problem. First, while dogs have a lot of similar features (ears, tails, paws, etc.), we
need some means of breaking an image down into smaller pieces that we can identify as being ears
or tails or paws. Second, what happens if we train on images of dogs that are all in the center of
the photo, and then we try to test our network on an image where the dog is in the upper left hand
corner? It’s going to fail miserably.
CNNs overcome these problems by extracting smaller local features from images via what’s
known as a sliding window. You can imagine this sliding window as a matrix kernel that moves
over every subsection of an image, producing a summary of those subsections that feed into the
next layer in our network. We do this over the entire image, and with several different sliding
windows. Without going into too many details, this solves the two general problems we had above:
our small sliding window can summarize a feature of interest (such as a dog ear) and it is also
location invariant, meaning that we can identify that dog ear anywhere in an image.

4.6.2 Recurrent Neural Networks (RNNs)


As with CNNs, recurrent neural networks (abbreviated RNNs) are used to more efficiently solve
a specific problem type. To motivate the structure of an RNN, we will turn again to a specific
example.
Imagine we were building a tool with the goal of predicting what word comes next in a newspaper
article. Obviously, the words that came before the word that we are currently trying to predict
are crucial to predicting the next word. Imagine we propagate the preceding ten words through
our network to predict which word we think will come next. It would also be useful if we could
send some of the information at each layer backwards through the network to help with the next
prediction - since we know the sequence of words matters. In this sense, our network is ‘stateful’
64 CHAPTER 4. NEURAL NETWORKS

Figure 4.7: Excellent diagram of the structure of a CNN. source:


https://www.mathworks.com/videos/introduction-to-deep-learning-what-are-convolutional-
neural-networks–1489512765771.html

Figure 4.8: Simple example of an RNN.

because it’s remembering what came before. We don’t have this ability with a feed-forward network,
which by design only propagates information forward through the network. RNNs add backward
passing of activations into their network structure to improve predictions on data where there is
some temporal dependence on what came previously.

4.6.3 Bayesian Neural Networks (BNNs)


Up until now, our training process has been one of maximum likelihood estimation, or a maximum
posterior approach if we utilize a regularizer that can be interpreted as introducing a prior.
A Bayesian neural network, or BNN, does exactly what you might imagine: it introduces
a distribution over the parameters of our model, which then requires marginalizing over those
distributions in order to make a prediction. The specifics of how exactly a BNN is constructed
are beyond the scope of this textbook, but the idea behind why we would utilize a BNN is the
same as the reason we utilize Bayesian techniques in other domains, particularly the use of prior
information to aid model performance.
Chapter 5

Support Vector Machines

In this chapter, we will explore what are known as a support vector machines, or SVMs for short.
SVMs are broadly useful for problems in classification and regression, and they are part of a
family of techniques known as margin methods. The defining goal of margin methods, and SVMs
specifically, is to put as much distance as possible between data points and decision boundaries.
We will dig deeper into what exactly this means over the course of the chapter. One of the most
appealing aspects of SVMs is that they can be solved as convex optimization problems, for which
we can find a global optimum with relative ease. We will explore the mathematical underpinnings
of SVMs, which can be slightly more challenging than our previous topics, as well as their typical
use cases.

5.1 Motivation
While SVMs can be used for classification or regression, we will reason about them in the classifi-
cation case as it is more straightforward.
The grand idea behind SVMs is that we should construct a linear hyperplane in our feature
space that maximally separates our classes, which means that the different classes should be as
far from that hyperplane as possible. The distance of our data from the hyperplane is known as
margin.

Definition 5.1.1 (Margin): Margin is the distance of the nearest data point from the separating
hyperplane of an SVM model, as seen in Figure 5.1. Larger margins often lead to more generalizable
models.

A larger margin tends to mean that our model will generalize better, since it provides more
wiggle room to correctly classify unseen data (think about new data being a perturbation on current
data).
This idea of the margin of a separator is quite intuitive. If you were presented with Figure 5.1
and were asked to separate the two classes, you would likely draw the line that keeps data points
as far from it as possible. SVMs and other margin-based methods will attempt to algorithmically
recreate this intuition.

5.1.1 Max Margin Methods


SVMs are a specific instance of a broader class of model known as max margin methods. Their
name describes them well: they deal with creating a maximum margin between training data and

65
66 CHAPTER 5. SUPPORT VECTOR MACHINES

Figure 5.1: Hyperplane with margin between different classes.

decision boundary, with the idea that this leads to model generalizability.
Other max margin methods are outside the scope of this textbook. These alternative methods
may differ from SVMs in a non-trivial way. For example, SVMs do not produce probabilities on
different classes, but rather decision rules for handling new data points. If you needed probabilities,
there are other max margin methods that can be used for the task.
ML Framework Cube: Support Vector Machines

SVMs are typically used in settings with discrete outputs. We need labeled training data to identify
the relevant hyperplane in an SVM model. Finally, SVMs operate in a non-probabilistic setting.
Domain Training Probabilistic
Discrete Supervised No

5.1.2 Applications

The theory behind SVMs has been around for quite some time (since 1963), and prior to the rise of
neural networks and other more computationally intensive techniques, SVMs were used extensively
for image recognition, object categorization, and other typical machine learning tasks.
SVMs are still widely used in practice, for example for classification problems known as anomaly
detection.
⋆ The purpose of anomaly detection is to identify unusual data points. For example, if we are manufacturing shoes,
we may wish to inspect and flag any shoe that seems atypical with respect to the rest of the shoes we produce.

Anomaly detection can be as simple as a binary classification problem where the data set is
comprised of anomalous and non-anomalous data points. As we will see, an SVM can be constructed
from this data set to identify future anomalous points very efficiently. SVMs extend beautifully to
settings where we want to use basis functions, and thus non-linear interactions on features. For
this reason, they continue to be competitive in many real-world situations where these kinds of
interactions are important to work with.
5.2. HARD MARGIN CLASSIFIER FOR LINEARLY SEPARABLE DATA 67

5.2 Hard Margin Classifier for Linearly Separable Data


We will learn the theory behind SVMs by starting with a simple two-class classification problem,
as we’ve seen several times in previous chapters. We will constrain the problem even further by
assuming, at least to get started, that the two classes are linearly separable, which is the basis of
the hard margin formulation for SVMs.

⋆ The expression ‘hard margin’ simply means that we don’t allow any data to be classified incorrectly. If it’s not
possible to find a hyperplane that perfectly separates the data based on class, then the hard margin classifier will
return no solution.

5.2.1 Why the Hard Margin


The hard margin constraint, which assumes that our data is linearly separable, is not actually
a requirement for constructing an SVM, but it simplifies the problem initially and makes our
derivations significantly easier. After we’ve established the hard margin formulation, we will extend
the technique to work in situations where our data is not linearly separable.

5.2.2 Deriving our Optimization Problem


Recall that our goal is to define a hyperplane that separates our data points and maintains the
maximum possible distance between the hyperplane and nearest data points on either side of it.
There are N examples x1 , ..., xN and there is a bias term w0 . Each example has a label y1 , ..., yN
which is either 1 or −1. To uncover this hyperplane, we start with a simple linear model for a
two-class classification problem:
h(x) = w⊤ x + w0 . (5.1)

This is the discriminant function and we classify a new example to class 1 or −1 according to
the sign produced by our trained model h(x). Later, we will also make this more general by using
a basis function, ϕ(x) to transform to a higher dimensional feature space.
By specifying our model this way, we have implicity defined a hyperplane separating our two
classes given by:
w⊤ x + w0 = 0 (5.2)

Furthermore, we have that w is orthogonal to the hyperplane, which we demonstrate now:

Derivation 5.2.1 (Hyperplane Orthogonal to w): Imagine two data points x1 and x2 on the
hyperplane defined by w⊤ x + w0 = 0. When we project their difference onto our model w, we find:

w⊤ (x1 − x2 ) = w⊤ x1 − w⊤ x2 = −w0 − (−w0 ) = 0 (5.3)

which means that w is orthogonal to our hyperplane. We can visualize this in Figure 5.2.

Remember that we’re trying to maximize the margin between our training data and the hyper-
plane. The fact that w is orthogonal to our hyperplane will help with this.
To determine the distance between a data point x and the hyperplane, which we denote d, we
need the distance in the direction of w between the point and the hyperplane. We denote xp to be
68 CHAPTER 5. SUPPORT VECTOR MACHINES

Figure 5.2: Our weight vector w is orthogonal to the separating hyperplane.

the projection of x onto the hyperplane, which allows us to decompose x as the following:
w
x = xp + d (5.4)
||w||2

which is the sum of the portion of the projection of x onto the hyperplane and the portion of x
that is parallel to w (and orthogonal to the hyperplane). From here we can solve for d:

Derivation 5.2.2 (Distance from Hyperplane Derivation): We start by left multiplying


Equation 5.4 with w⊤ ,

w⊤ w
w⊤ x = w⊤ xp + d .
||w||2

Simplifying (note that w⊤ xp = −w0 from Equation 5.3):

w⊤ x = −w0 + d||w||2

Rearranging:

w⊤ x + w0
d= .
||w||2

For each data point x, we now have the signed distance of that data point from the hyperplane.
For an example that is classified correctly, this signed distance d will be positive for class
yn = 1, and negative for class yn = −1. Given this, we can make the distance unsigned (and
always positive) for a correctly classified data point by multiplying by yn . Then, the margin for an
correctly classified data point (xn , yn ) is given by:

yn (w⊤ xn + w0 )
. (5.5)
||w||2

The margin for an entire data set is given by the margin to the closest point in the data set,
and
yn (w⊤ xn + w0 )
min . (5.6)
n ||w||2
5.2. HARD MARGIN CLASSIFIER FOR LINEARLY SEPARABLE DATA 69

Then, it is our goal to maximize this margin with respect to our model parameters w and w0 .
This is given by:
1 h i
max min yn (w⊤ xn + w0 ) (5.7)
w,w0 ||w||2 n

Here, we pull the 1/||w||2 term forward. Note carefully that w0 does not play a role in the
denominator ||w||2 .
This is a hard problem to optimize, but we can make it more tractable by recognizing some
important features of Equation 5.7. First, rescaling w → αw and w0 → αw0 , for any α > 0, has
no impact on the margin for any correctly classified data point xn . This is because the effect of α
cancels out in the numerator and denominator of Equation 5.5.
We can use this rescaling liberty to enforce

yn (w⊤ xn + w0 ) ≥ 1, for all n. (5.8)

This does not change the optimal margin because we can always scale up both w and w0 by
α > 0 to achieve yn (w⊤ xn + w0 ) ≥ 1, and without affecting the margin. Moreover, since the
problem is to maximize 1/||w||2 , an optimal solution will want ||w||2 to be as small as possible, and
thus at least one of these constraints (5.8) will be binding and equal to one in an optimal solution.
Thus our optimization problem now looks like:
1
max s.t. yn (w⊤ xn + w0 ) ≥ 1, for all n. (5.9)
w,w0 ||w||2

Here, we recognized that minn yn (w⊤ xn + w0 ) = 1 in an optimal solution with these new
constraints, and adopted this in the objective. This simplifies considerably, removing the “min n ”
part of the objective.
1
Notice that maximizing ||w|| 2
is equivalent to minimizing ||w||22 . We will also add a constant
term 12 for convenience, leaving the hard-margin formulation of the training problem:

1
min ||w||22 s.t. yn (w⊤ xn + w0 ) ≥ 1, for all n. (5.10)
w,w0 2

Note that Equation 5.10 is now a quadratic programming problem, which means we wish to
optimize a quadratic function subject to a set of linear constraints on our parameters. Arriving at
this form was the motivation for the preceding mathematic manipulations. We will discuss shortly
how we actually optimize this function.

5.2.3 What is a Support Vector


Up until now, we have discussed Support Vector Machines without identifying what a support
vector is. We now have enough information from the previous section to define them. Later, when
we work with a dual formulation, we will see the precise role that they play.
Say that an example is on the margin boundary if yn (w⊤ xn + w0 ) = 1. These are the points
that are closest to the hyperplane and have margin 1/||w||2 .

Definition 5.2.1 (Support Vector): A support vector in a hard-margin SVM formulation must
be a data point that is on the margin boundary of the optimal solution, with yn (w⊤ xn + w0 ) = 1
and margin 1/||w||2 .
70 CHAPTER 5. SUPPORT VECTOR MACHINES

Hard Margin SVM Example

Figure 5.3: Example of the resulting hyperplane for a hard margin SVM. The filled in data points
are support vectors in this example. A support vector for the hard-margin formulation must be on
the margin boundary, with a discriminant value of +1 or -1.

In the hard margin case we have constrained the closest data points to have discriminant value
w⊤ xn + w0 = 1 (−1 for a negative example). Figure 5.3 shows a hard margin SVM solution with
an illustration of corresponding support vectors.

⋆ After we have optimized an SVM in the hard margin case, we must have at least two support vectors with
discriminant value that is 1 or -1, and thus a margin of 1/||w||2 .

⋆ Not every example on the margin boundary needs to be a support vector.

5.3 Soft Margin Classifier


Thus far, we’ve been operating under the assumption that our data is linearly separable in feature
space, which afforded us several convenient guarantees in the derivations of the previous section.
For example, given that our data was linearly separable, we could guarantee that every data point
would be on the correct side of the hyperplane, which waswhat allowed us to enforce the constraint
that yn (w⊤ xn + w0 ) ≥ 1. We now seek to generalize the work of the previous section to situations
where our data is not separable.

5.3.1 Why the Soft Margin?


What if our data is not linearly separable in the feature space (even after applying a basis function)?
This is the likely case in real applications! Unfortunately, our current hard margin SVM formulation
would be useless with non-linearly separable data. That is why we need the soft margin SVM.
At a high level, the soft margin SVM allows for some of our data points to be closer to or even
on the incorrect side of the hyperplane. This is desirable if our data set is not linearly separable or
contains outliers, and it is also quite intuitive. Examining Figure 5.4, we see that we have a single
outlier data point. We can still create a good model by just allowing this single data point to be
5.3. SOFT MARGIN CLASSIFIER 71

Figure 5.4: An outlier can make the hard margin formulation impossible or unable to generalize
well.

close to the hyperplane (or in some cases, even on the wrong side of the hyperplane). That is what
the soft margin formulation will allow for.

5.3.2 Updated Optimization Problem for Soft Margins


To enable the soft margin formulation, we introduce what are known as slack variables denoted
ξn ≥ 0, which simply relax the constraint from Equation 5.9 that we imposed in the hard margin
formulation. Say that a data point is “inside the margin region” if its discriminant value is in the
range (−1, +1). There is a slack variable ξn ≥ 0 for every data point xn , and they take the following
values according to how we classify xn :

= 0
 if xn is correctly classified
ξn = ∈ (0, 1] if xn is correctly classified but inside the margin region (5.11)

>1 if xn is incorrectly classified

These slack variable penalize data points on the wrong side of the margin boundary, but they
don’t forbid us from allowing data points to be on the wrong side if this produces the best model. We
now reformulate the optimization problem as follows. This is the soft-margin training problem:

N
1 X
min ||w||22 + C ξn (5.12)
w,w0 2
n=1
s.t. yn (w⊤ xn + w0 ) ≥ 1 − ξn , for all n
ξn ≥ 0, for all n.

Here, C is a regularization parameter that determines how heavily we penalize violations of the
hard margin constraints. A large C penalizes violation of the hard margin constraints more heavily,
which means our model will follow the data closely and have small regularization. A small C won’t
heavily penalize having data points inside the margin region, relaxing the constraint and allowing
our model to somewhat disregard more of the data. This means more regularization.
72 CHAPTER 5. SUPPORT VECTOR MACHINES

Soft Margin SVM Example

Figure 5.5: Example of the resulting hyperplane for a soft margin SVM. The filled in data points
illustrate the support vectors in this example and must be either on the margin boundary or on
the “wrong side” of the margin boundary.

⋆ Unlike most regularization parameters we’ve seen thus far, C increases regularization as it gets smaller.

5.3.3 Soft Margin Support Vectors


Under the the hard margin formulation, the support vectors were some subset of the data points
with yn (w⊤ xn + w0 ) = 1, and so situated exactly on the margin boundary, and they points closest
to the hyperplane.
The support vectors in the soft margin case must be data points that are either on the margin
boundary (and thus yn (w⊤ xn + w0 ) = 1) or for which ξn ∈ (0, 1] and in the margin region, or that
are misclassified (when ξn > 1). We can visualize this in Figure 5.5.

⋆ Not every data point on the margin boundary, in the margin region, or that is misclassified needs to be a support
vector in the soft-margin formulation. But those that become support vectors must meet one of these criteria.

5.4 Conversion to Dual Form


Now that we have the formulation of the optimization problem for SVMs, we need to discuss how
we actually go about optimizing to produce a model solution. This will involve converting to a
dual form of our problem. We will do this in the hard margin case for technical simplicity, but our
solution can be easily modified to also apply to the soft margin formulation.
5.4. CONVERSION TO DUAL FORM 73

⋆ A dual form is an equivalent manner of representing an optimization problem, in this case the quadratic programming
problem we need to optimize. Dual forms can be easier to work with than their initial form (“the primal form.”)

The dual form will be useful because it will allow us to bring in a basis function into the SVM
formulation in a very elegant and computationally efficient way.

5.4.1 Lagrange Multipliers


Before we get into deriving the dual form, we need to be aware of a critical piece of math that will
enable us to solve our optimization problem: Langrange multipliers.
A Lagrange multiplier is used to find optima of a function subject to certain constraints. This
is exactly what we need to solve the optimization problem described by Equation 5.10.
The underlying theory behind Lagrange multipliers is not overly difficult to understand, but it
is beyond the scope of this textbook in its fully generality. We will simply offer the method by
which you can use them to solve the optimization problem of interest here.
Recall from the “Section 0 math” that we understood there how to use the Lagrangian method to
solve optimization problems with equality constraints, for example of the form min f (w)Ps.t. gn (w) =
0, for each constraint n. There, a suitable Lagrangian function was L(w, α) = f (w)+ n αn gn (w),
where αn ∈ R were the Lagrangian variables, and this problem could be solved via the partial
derivatives of L with respect to w and α, and setting them to zero.
Here, we have a slightly different problem to solve, which has the following form (recognizing
that we can write a constraint yn (w⊤ xn + w0 ) ≥ 1 as −yn (w⊤ xn + w0 ) + 1 ≤ 0, and thus as an
“≤ 0” constraint):
min f (w) s.t. gn (w) ≤ 0, all n (5.13)
w
To solve this, we make use of the same kind of Lagrangian function,
X
L(w, α) = f (w) + αn gn (w). (5.14)
n

But for the new “inequality form” of the constrained optimization problem, we also need to
introduce a new subproblem, which for any fixed w, solves

max L(w, α) s.t. αn ≥ 0 for all n (5.15)


α

With this, we consider the problem


" #
X
min max f (w) + αn gn (w) (5.16)
w α, α≥0
n

Now, if w violates one or more constraints in (5.13), then the subproblem (5.15) becomes
unbounded, with αn on the corresponding constraints driven arbitrarily large. Otherwise, if we
have gn (w) < 0 then we will have αn = 0, and we conclude αn gn (x) = 0 in all optimal solutions
to (5.16). Therefore, and assuming that problem (5.13) is feasible, we have L(w, α) = f (w) in an
optimal solution to (5.16). Thus, we establish that (5.16) is an equivalent formulation to (5.13).
Substituting into our problem (5.10), the Lagrangian formulation becomes
" #
1 ⊤ X
min max w w + αn (−yn (w⊤ xn + w0 ) + 1) (5.17)
w,w0 α, α≥0 2
n
74 CHAPTER 5. SUPPORT VECTOR MACHINES

Equivalently, it is convenient to write this as


" #
1 ⊤ X

min max w w − αn (yn (w xn + w0 ) − 1) , (5.18)
w,w0 α, α≥0 2
n

with Lagrangian function


1 X
L(w, α, w0 ) = w⊤ w − αn (yn (w⊤ xn + w0 ) − 1). (5.19)
2 n

5.4.2 Deriving the Dual Formulation


Using this Lagrangian function, we allow ourselves to switch from solving Equation 5.10 to instead
solving:  
min max L(w, α, w0 ) (5.20)
w,w0 α, α≥0

⋆ The ‘ min max’ in Equation 5.20 may be initially confusing. The way to read this is that for any choice of w, w0 ,
w,w0 α≥0
the inner “max” problem then finds values of α to try to “defeat” the outer minimization objective.

We now wish to convert the objective in Equation 5.20 to a dual objective. Under the sufficient
conditions of strong duality which hold for this problem because Equation 5.10 has a quadratic
objective and linear constraints (but whose explanation is beyond the scope of this textbook), we
can equivalently reformulate the optimization problem (5.20) as:
 
max min L(w, α, w0 ) . (5.21)
α, α≥0 w,w0

At this point, we can use first order optimality conditions to solve for w, i.e., the inner mini-
mization problem, for some choice of α values. Taking the gradient, setting them equal to 0, and
solving for w, we have:
N
X
∇L(w, α, w0 ) = w − αn yn xn = 0
n=1
N
X

⇔ w = αn yn xn . (5.22)
n=1

When we do the same thing for w0 , we find the following:


N
∂L(w, α, w0 ) X
=− α n yn = 0
∂w0
n=1
N
X
⇔ αn yn = 0. (5.23)
n=1
P
This is interesting. If n αn yn < 0, then L(·) is increasing with w0 without bound,
P and the
inner-minimization would choose w0 = −∞ and this choice of α cannot solve (5.21). If n αn yn >
0, then L(·) is decreasing with w0 without bound, and the inner-minimization
P would choose w0 =
+∞, and this choice of α cannot solve (5.21). We need n αn yn = 0 for an optimal solution
5.4. CONVERSION TO DUAL FORM 75

to (5.21). So, we don’t yet obtain the optimal value for w0 , but we do gain a new constraint on the
α-values that will need to hold in an optimal solution.
Now we substitute for w∗ into our Lagrangian function, and also assume (5.23), since this will
be adopted as a new constraint in solving the optimization problem. Given this, we obtain:
1 X X X
L(w, α, w0 ) = w⊤ w − w⊤ α n y n x n − w0 αn yn + αn
2 n n n
1 X
= − w⊤ w + αn
2 n
!⊤ !
X 1 X X
= αn − α n yn xn αn′ yn′ xn′ (5.24)
n
2 n ′ n
where the second equation follows from the first by using (5.22) and (5.23), and the third equation
follows by using (5.22).
This is now entirely formulated in terms of α, and provides the hard margin, dual formu-
lation:
X 1 XX
max αn − αn αn′ yn yn′ x⊤
n xn′ (5.25)
α
n
2 n n′
X
s.t. αn yn = 0, for all n
n
αn ≥ 0, for all n
P
Here, we add n αn yn = 0 as a constraint. This is another quadratic objective, subject to
linear constraints. This can be solved via SGD or another approach to solving convex optimization
problems. With a little more work we can use the optimal α values to make predictions.
Although the derivation is out of scope for this textbook, there is also a very similar dual form
for the soft-margin SVM training problem:
X 1 XX
max αn − αn αn′ yn yn′ x⊤
n xn′ (5.26)
α
n
2 n n ′
X
s.t. αn yn = 0, for all n
n
C ≥ αn ≥ 0, for all n
This puts an upper-bound on αn to prevent the dual from being unbounded in the case where
the hard-margin SVM problem is infeasible because the data cannot be separated. It is not yet
clear why any of this has been useful. We will see the value of the dual formulation when working
with basis functions.

⋆ By (5.22) we see how to find the weight vector w from a solution α. We didn’t yet explain how to find w0 . That
will be explained next.

5.4.3 Making Predictions


Substituting the optimal dual solution into the discriminant function, we have
N
X
h(x) = αn yn x⊤
n x + w0 . (5.27)
n=1
76 CHAPTER 5. SUPPORT VECTOR MACHINES

For data points with αn > 0, this is taking a weighted vote over examples in the training data
based on the size of the inner product x⊤ n x.
Since α are Lagrange multipliers, they are non-negative, and moreover, by reasoning about
the “max subproblem” in the min-max formulation (5.16), we know that they take on value zero
whenever yn h(xn ) > 1. The data points for which αn > 0 are known as support vectors, and
they will must be data points that are either on the margin boundary, inside the margin region, or
misclassified. For the hard-margin formulation they must be data points on the margin boundary.
This is a major takeaway for the usefulness of SVMs: once we’ve trained our model, we can
discard most of our data. We only need to keep the support vectors to make predictions. Soon we
also see the “kernel trick.” This also illustrates why we need to solve for the values of α: those
values dictate which data points are the support vectors for our model.

Solving for w0
We can solve for w0 by recognizing that yn (w⊤ xn + w0 ) = 1 for any data point on the margin
boundary. For the hard-margin formulation we can solve for w0 using any example for which
αn > 0. For the soft-margin formulation, it can be shown that the only points with αn = C are
those inside the margin region or misclassified, and so that any point with C > αn > 0 is on the
margin boundary. Any such point can be solved to solve for w0 .

5.4.4 Why is the Dual Formulation Helpful?


Beyond having a succinct formulation of the discriminant function (assuming the number of support
vectors is small), you might be wondering what exactly we gained by moving to the dual formulation
of this problem.
First, the complexity of the optimization problem we’re solving changed from one that is de-
pendent on the number of features D, i.e., the size of the weight vector, to one that is linearly
dependent on the number of data points N . Thus, the number of variables to optimize over is now
independent of the dimensionality of the feature space.
Second, in the dual formulation, we have the opportunity to take advantage of what’s known as
the kernel trick to map our data xn into higher dimensions without incurring performance costs.
This works as follows: notice that the only way in which a feature vector x is accessed during the
training process (5.26) and at prediction time (5.27) is through the inner product x⊤ z between
two data points. Suppose that we are using a basis function ϕ : RD → RM , so that this becomes
ϕ(x)⊤ ϕ(z). Now, define kernel function

K(x, z) = ϕ(x)⊤ ϕ(z). (5.28)

The idea of the kernel trick is that we might be able to compute K(·, ·) without actually working
in the basis function space, RM , but rather be able to compute the Kernel function directly through
algebra in the lower dimensional space, RD .
For example, it can be shown that the polynomial kernel

Kpoly (x, z) = (1 + x⊤ z)q , for integers q ≥ 2 (5.29)

corresponds to computing the inner product with a basis function that makes use of all terms up to
degree q. When q = 2, then it is all constant, linear, and quadratic terms. The polynomial kernel
function does this without needing to actually project the examples to the higher dimensional space.
Rather it takes the inner product in the lower-dimensional space, adds 1 to this scalar, and then
raises it to the power of q. The implicit basis is growing exponentially large in q!
5.4. CONVERSION TO DUAL FORM 77

⋆ The kernel trick can even be used to work in an infinite basis. This is the case with the Gaussian kernel. If that is
of interest, you should look into Taylor series basis expansions and the Gaussian kernel.

The importance of the kernel trick is that when computations can be done efficiently in the
initial space RD , then the training problem can be solved by computing the pairwise K(·, ·) values
for all pairs of training examples, and then using SGD to solve the dual, soft-margin training
problem (with N decision variables).

5.4.5 Kernel Composition


Now that we’ve seen the usefulness of the kernel trick for working in higher dimensional spaces
without incurring performance costs or memory overhead, it’s reasonable to wonder what sort of
valid kernels we can construct.
To be explicit, by ‘kernel’ we mean a function K(x, z) = ϕ(x)⊤ ϕ(z) that produces a scalar
product from two vectors, as determined by some basis function ϕ.
Although it is beyond the scope of this textbook, the condition for a kernel function K(·, ·)
to be valid is that the N × N matrix K, where element Kn,n′ = K(xn , xn′ ) should be positive
semidefinite for any choices of the data set {xn }. This is known as Mercer’s theorem.

⋆ The matrix K of elements K(xn , xn′ ) is known as the Gram matrix or Kernel matrix.

In practice, this provides for a logic for how different valid kernels can be composed (if they
maintain a p.s.d. Gram matrix!).
There exists a set of rules that preserve the validity of kernels through transformations. These
include such things as

• scalar multiplication c · K(·, ·)

• exponentiation exp{K(·, ·)}

• addition K1 (x, z) + K2 (x, z), and

• multiplication K1 (x, z) · K2 (x, z).

It is always possible to test the validity of a given kernel by demonstrating that its Gram matrix
K is positive semidefinite.
Chapter 6

Clustering

In this chapter, we will explore a technique known as clustering. This represents our first foray into
unsupervised machine learning techniques. Unlike the previous four chapters, where we explored
techniques that assumed a data set of inputs and targets, with the goal of eventually making
predictions over unseen data, our data set will no longer contain explicit targets. Instead, these
techniques are motivated by the goal of uncovering structure in our data. Identifying clusters of
similar data points is a useful and ubiquitous unsupervised technique.

6.1 Motivation
The reasons for using an unsupervised technique like clustering are broad. We often don’t have a
specific task in mind; rather, we are trying to uncover more information about a potentially opaque
data set. For clustering specifically, our unsupervised goal is to group data points that are similar.
There are many reasons why we might separate our data by similarity. For organizational
purposes, it’s convenient to have different classes of data. It can be easier for a human to sift
through data if it’s loosely categorized beforehand. It may be a preprocessing step for an inference
method; for example, by creating additional features for a supervised technique. It can help identify
which features make our data points most distinct from one another. It might even provide some
idea of how many distinct data types we have in our set.
This idea of data being ‘similar’ means that we need some measure of distance between our
data points. While there are a variety of clustering algorithms available, the importance of this
distance measurement is consistent between them.
Distance is meant to capture how ‘different’ two data points are from each other. Then, we can
use these distance measurements to determine which data points are similar, and thus should be
clustered together. A common distance measurement for two data points x and x′ is given by:
v
uD
uX
||x − x’||L2 = t (xd − x’d )2 (6.1)
d=1

where D is the dimensionality of our data. This is known as L2 or Euclidean distance, and you
can likely see the similarity to L2 regularization.
There are a variety of distance measurements available for data points living in a D-dimensional
Euclidean space, but for other types of data (such as data with discrete features), we would need
to select a different distance metric. Furthermore, the metrics we choose to use will have an impact
on the final results of our clustering.

78
6.2. K-MEANS CLUSTERING 79

ML Framework Cube: Clustering

In unsupervised learning, the domain refers to the domain of the hidden variable z (which is
analogous to y in supervised learning). In clustering, we have z’s that represent the discrete
clusters. Furthermore, the techniques that we explore in this chapter are fully unsupervised and
non-probabilistic.
Domain Training Probabilistic
Discrete Unsupervised No

6.1.1 Applications
Here are a few specific examples of use cases for clustering:

1. Determining the number of phenotypes in a population.

2. Organizing images into folders according to scene similarity.

3. Grouping financial data as a feature for anticipating extreme market events.

4. Identifying similar individuals based on DNA sequences.

As we mentioned above, there are different methods available for clustering. In this chapter, we
will explore two of the most common techniques: K-Means Clustering and Hierarchical Agglomer-
ative Clustering. We also touch on the flavors available within each of these larger techniques.

6.2 K-Means Clustering


The high level procedure behind K-Means Clustering (known informally as k-means) is as follows:

1. Initialize cluster centers by randomly selecting points in our data set.

2. Using a distance metric of your choosing, assign each data point to the closest cluster.

3. Update the cluster centers based on your assignments and distance metric (for example, when
using L2 distance, we update the cluster centers by averaging the data points assigned to each
cluster).

4. Repeat steps 1 and 2 until convergence.

In the case where we are using the L2 distance metric, this is known as Lloyd’s algorithm, which
we derive in the next section.

6.2.1 Lloyd’s Algorithm


Lloyd’s algorithm, named after Stuart P. Lloyd who first suggested the algorithm in 1957, optimizes
our cluster assignments via a technique known as coordinate descent, which we will learn more about
in later chapters.
80 CHAPTER 6. CLUSTERING

Derivation 6.2.1 (Lloyd’s Algorithm Derivation):

We begin by defining the objective used by Lloyd’s algorithm.

Objective
The loss function for our current assignment of data points to clusters is given by:
N X
C
C  N X
rnc ||xn − µc ||22

L(X, µ c=1
, r n=1 ) = (6.2)
n=1 c=1

where X is our N xD data set (N is the number of data points and D is the dimensionality of
 C
our data), µ c=1 is the CxD matrix of cluster centers (C is the number of clusters we chose),
 N
and r n=1 is our N xC matrix of responsibility vectors. These are one-hot encoded vectors (one
per data point), where the 1 is in the position of the cluster to which we assigned the nth data point.

We now define the algorithmic portion of Lloyd’s clustering procedure.

Algorithm
We first adjust our responsibility vectors to minimize each data point’s distance from its cluster
center. Formally: 
= 1 if c = arg min||xn − µc′ ||
rnc = c′ (6.3)
= 0 otherwise

After updating our responsibility vectors, we now wish to minimize our loss by updating our cluster
centers µc . The cluster centers which minimize our loss can be computed by taking the derivative
of our loss with respect to µc , setting equal to 0, and solving for our new cluster centers µc :
N
∂L X
= −2 rnc (xn − µc )
∂µc
n=1 (6.4)
PN
n=1 rnc xn
µc = P N
n=1 rnc

Intuitively, this is the average of all the data points xn assigned to the cluster center µc .

We then update our responsibility vectors based on the new cluster centers, update the cluster
centers again, and continue this cycle until we have converged on a stable set of cluster centers and
responsibility vectors.

Note that while Lloyd’s algorithm is guaranteed to converge, it is only guaranteed to converge
to a locally optimal solution. Finding the globally optimal set of assignments and cluster centers is
an NP-hard problem. As a result, a common strategy is to execute Lloyd’s algorithm several times
with different random initializations of cluster centers, selecting the assignment that minimizes
loss across the different trials. Furthermore, to avoid nonsensical solutions due to scale mismatch
between features (which would throw our Euclidean distance measurements off), it makes sense to
standardize our data in a preprocessing step. This is as easy as subtracting the mean and dividing
by the standard deviation across each feature.
6.2. K-MEANS CLUSTERING 81

6.2.2 Example of Lloyd’s

For some more clarity on exactly how Lloyd’s algorithm works, let’s walk through an example.

Example 6.1 (Lloyd’s Algorithm Example): We start with a data set of size N = 6. Each
data point is two-dimensional, with each feature taking on a value between -3 and 3. We also
have a ‘Red’ and ‘Green’ cluster. Here is a table and graph of our data points, labelled A through F:

Coordinates Dist. to Red Dist. to Green Cluster Assgn.


A (-3, -3) n/a n/a n/a
B (-1, -3) n/a n/a n/a
C (3, 0) n/a n/a n/a
D (-2, -1) n/a n/a n/a
E (0, 0) n/a n/a n/a
F (-1, -2) n/a n/a n/a

Let’s say we wish to have 2 cluster centers. We then randomly initialize those cluster centers by
selecting two data points. Let’s say we select B and F. We identify our cluster centers with a red
and green ‘X’ respectively:
82 CHAPTER 6. CLUSTERING

We now begin Lloyd’s algorithm by assigning each data point to its closest cluster center:

Coordinates Dist. to Red Dist. to Green Cluster Assgn.


A (-3, -3) n/a n/a Red
B (-1, -3) n/a n/a Red
C (3, 0) n/a n/a Green
D (-2, -1) n/a n/a Green
E (0, 0) n/a n/a Green
F (-1, -2) n/a n/a Green
6.2. K-MEANS CLUSTERING 83

We then update our cluster centers by averaging the data points assigned to each:

Coordinates Dist. to Red Dist. to Green Cluster Assgn.


A (-3, -3) 2.00 2.24 Red
B (-1, -3) 0.00 1.00 Red
C (3, 0) 5.00 4.47 Green
D (-2, -1) 2.24 1.41 Green
E (0, 0) 3.16 2.24 Green
F (-1, -2) 1.00 0.00 Green

We proceed like this, updating our cluster centers and assignments, until convergence. At conver-
gence, we’ve achieved these cluster centers and assignments:
84 CHAPTER 6. CLUSTERING

Coordinates Dist. to Red Dist. to Green Cluster Assgn.


A (-3, -3) 1.46 5.41 Red
B (-1, -3) 1.06 3.91 Red
C (3, 0) 5.26 1.50 Green
D (-2, -1) 1.27 3.64 Red
E (0, 0) 2.85 1.50 Green
F (-1, -2) 0.79 3.20 Red

Where our red cluster is at (-1.75, -2.25) and our green cluster is at (1.5 , 0). Note that for this
random initialization of cluster centers, we deterministically identified the locally optimal set of
assignments and cluster centers. For a specific initialization, running Lloyd’s algorithm will always
identify the same set of assignments and cluster centers. However, different initializations will
produce different results

6.2.3 Number of Clusters


You may have wondered about a crucial, omitted detail: how do we choose the proper number of
clusters for our data set? There doesn’t actually exist a ‘correct’ number of clusters. The fewer
clusters we have, the larger our loss will be, and as we add more clusters, our loss will get strictly
smaller. That being said, there is certainly a tradeoff to be made here.
Having a single cluster is obviously useless - we will group our entire data set into the same
cluster. Having N clusters is equally useless - each data point gets its own cluster.
One popular approach to identifying a good number of clusters is to perform K-Means with a
varying number of clusters, and then to plot the number of clusters against the loss. Typically,
that graph will look like Figure 6.1.
Notice that at x = ... clusters, there appears to be a slight bend in the decrease of our loss.
This is often called the knee, and it is common to choose the number of clusters to be where the
knee occurs.
Intuitively, the idea here is that up to a certain point, adding another cluster significantly
decreases the loss by more properly grouping our data points. However, eventually the benefit of
adding another cluster stops being quite so significant. At this point, we have identified a natural
number of groups for our data set.

6.2.4 Initialization and K-Means++


Up until now, we have assumed that we should randomly initialize our cluster centers and execute
Lloyd’s algorithm until convergence. We also suggested that since Lloyd’s algorithm only produces
a local minimum, it makes sense to perform several random initializations before settling on the
most optimal assignment we’ve identified.
While this is a viable way to perform K-Means, there are other ways of initializing our orig-
inal cluster centers that can help us find more optimal results without needing so many random
initializations. One of those techniques is known as K-Means++.
The idea behind K-Means++ is that our cluster centers will typically be spread out when we’ve
reached convergence. As a result, it might not make sense to initialize those cluster centers in an
entirely random manner. For example, Figure 6.2 would be a poor initialization.
6.2. K-MEANS CLUSTERING 85

Figure 6.1: Finding the Knee.

Figure 6.2: Bad Cluster Initialization.


86 CHAPTER 6. CLUSTERING

Figure 6.3: Good Cluster Initialization.

We would much rather start with a random initialization that looks like Figure 6.3.
We can use the hint that we want our cluster centers somewhat spread out to find a better
random initialization. This is where the initialization algorithm presented by K-Means++ comes
in.
For K-Means++, we choose the first cluster center by randomly selecting a point in our data
set, same as before. However, for all subsequent cluster center initializations, we select points in
our data set with probability proportional to the squared distance from their nearest cluster center.
The effect of this is that we end up with a set of initializations that are relatively far from one
another, as in Figure 6.3.

6.2.5 K-Medoids Alternative


Recall that in the cluster center update step (Equation 6.4) presented in the derivation of Lloyd’s
algorithm, we average the data points assigned to each cluster to compute the new cluster centers.
Note that in some cases, this averaging step doesn’t actually make sense (for example, if we have
categorical variables as part of our feature set). In these cases, we can use an alternative algorithm
known as K-Medoids. The idea behind K-Medoids is simple: instead of averaging the data points
assigned to that cluster, update the new cluster center to be the data point assigned to that cluster
which is most like the others.

6.3 Hierarchical Agglomerative Clustering


The motivating idea behind K-Means was that we could use a distance measurement to assign data
points to a fixed number of clusters, iteratively improving our assignments and cluster locations
6.3. HIERARCHICAL AGGLOMERATIVE CLUSTERING 87

Figure 6.4: Dendrogram Example.

until convergence.
Moving on to Hierarchical Agglomerative Clustering (also known as HAC - pronounced ‘hack’),
the motivating idea is instead to group data from the bottom up. This means every data point
starts as its own cluster, and then we merge clusters together based on a distance metric that
we define. This iterative merging allows us to construct a tree over our data set that describes
relationships between our data. These trees are known as dendrograms, with an example found in
Figure 6.4. Notice that the individual data points are the leaves of our tree, and the trunk is the
cluster that contains the entirety of our data set.
We now formally define the HAC algorithm, and in the process, explain how we construct such
a tree.

6.3.1 HAC Algorithm


1. Start with N clusters, one for each data point.
2. Measure the distance between clusters. This will require an inter-cluster distance measure-
ment that we will define shortly.
3. Merge the two ‘closest’ clusters together, reducing the number of clusters by 1. Record the
distance between these two merged clusters.
4. Repeat step 2 until we’re left with only a single cluster.
In the remainder of the chapter, we’ll describe this procedure in greater detail (including how
to measure the distance between clusters), explain the clustering information produced by the tree,
and discuss how HAC differs from K-Means. But first, to make this algorithm a little more clear,
let’s perform HAC one step at a time on a toy data set, constructing the dendrogram as we go.

Example 6.2 (HAC Algorithm Example): Let’s say we have a data set of five points A, B,
C, D, E that we wish to perform HAC on. These points will simply be scalar data that we can
represent on a number line. We start with 5 clusters and no connections at all:
88 CHAPTER 6. CLUSTERING

We find the closest two clusters to merge first. A and B are nearest (it’s actually tied with C and D,
but we can arbitrarily break these ties), so we start by merging them. Notice that we also annotate
the distance between them in the tree, which in this case is 1:

We now have four clusters: (A, B), C, D, and E. We again find the closest two clusters, which in
this case is C and D:

We now have three remaining clusters: (A, B), (C, D), and E. We proceed as before, identifying
the two closest clusters to be (A, B) and (C, D). Merging them:

Finally we are left with two clusters: (A, B, C, D) and E. The remaining two clusters are obviously
the closest together, so we merge them:
6.3. HIERARCHICAL AGGLOMERATIVE CLUSTERING 89

At this point there is only a single cluster. We have constructed our tree and are finished with
HAC.

Notice how the distance between two merged clusters manifests itself through the height of the
dendrogram where they merge (which is why we tracked those distances as we constructed the
tree). Notice also that we now have many layers of clustering: if we’re only only interested in
clusters whose elements are at least k units away from each other, we can ‘cut’ the dendrogram at
that height and examine all the clusters that exist below that cut point.
Finally, we need to handle the important detail of how to compute the distance between clusters.
In the preceding example, we designated the distance between two clusters to be the minimum
distance between any two data points in the clusters. This is what is known as the Min-Linkage
Criterion. However, there are certainly other ways we could have computed the distance between
clusters, and using a different distance measurement can produce different clustering results. We
now turn to these different methods and the properties of clusters they produce.

6.3.2 Linkage Criterion


Here are a few of the most common linkage criteria.

Min-Linkage Criteria
We’ve already seen the Min-Linkage Criterion in action from the previous example. Formally, the
criterion says that the distance dC,C ′ between each cluster pair C and C ′ is given by

dC,C ′ = min′ ||xk − xk′ || (6.5)


k,k

where xk are data points in cluster C and xk′ are data points in cluster C ′ . After computing these
pairwise distances, we choose to merge the two clusters that are closest together.

Max-Linkage Criterion
We could also imagine defining the distance dC,C ′ between two clusters as being the distance between
the two points that are farthest apart in each cluster. This is known as the Max-Linkage Criterion.
90 CHAPTER 6. CLUSTERING

Figure 6.5: Different Linkage Criteria.

The distance between two clusters is then given by:

dC,C ′ = max

||xk − xk′ || (6.6)
k,k

As with the Min-Linkage Criterion, after computing these pairwise distances, we choose to merge
the two clusters that are closest together.

⋆ Be careful not to confuse the linkage criterion with which clusters we choose to merge. We always merge the clusters
that have the smallest distance between them. How we compute that distance is given by the linkage criterion.

Average-Linkage Criterion
The Average-Linkage Criterion averages the pairwise distance between each point in each cluster.
Formally, this is given by:
K K′
1 XX
dC,C ′ = ||xk − xk′ || (6.7)
KK ′ ′ k=1 k =1

Centroid-Linkage Criterion
The Centroid-Linkage Criterion uses the distance between the centroid of each cluster (which is
the average of the data points in a cluster). Formally, this is given by:

K K ′
1 X 1 X
dC,C ′ = || xk − ′ xk′ || (6.8)
K K ′
k=1 k =1

Different Linkage Criteria Produce Different Clusterings


It’s important to note that the linkage criterion you choose to use will influence your final clustering
results. For example, the min-linkage criterion tends to produce ‘stringy’ clusters, while the max-
linkage criterion tends to produce more compact clusters. You can see the difference between the
results of these two linkage criteria in Figure 6.5.
6.3. HIERARCHICAL AGGLOMERATIVE CLUSTERING 91

⋆ You should convince yourself of the different flavors of linkage criteria. For example, when using the min-linkage
criterion, we get these ‘stringy’ results because we’re most inclined to extend existing clusters by grabbing whichever
data points are closest.

6.3.3 How HAC Differs from K-Means


Now that we’re aware of two distinct clustering techniques and their variants, we consider the
differences between the two methods.
First of all, there is a fundamental difference in determinism between HAC and K-Means. In
general, K-Means incurs a certain amount of randomness and needs to be run multiple times to
ensure a good result. On the other hand, once you’ve selected a linkage criterion for HAC, the
clusters you calculate are deterministic. You only need to run HAC a single time.
Another difference between HAC and K-Means comes from the assumptions we make. For K-
Means, we need to specify the number of clusters up front before running our algorithm, potentially
using something like the knee-method to decide on the number of clusters. On the other hand, you
don’t need to assume anything to run HAC, which simplifies its usage. However, the downside for
HAC is that when you wish to present your final clustering results, you need to decide on the max
distance between elements in each cluster (so that you can cut the dendrogram).
The fact that you need to make a decision about where to cut the dendrogram means that
running HAC once gives you several different clustering options. Furthermore, the dendrogram in
and of itself can be a useful tool for visualizing data. We don’t get the same interactivity from
K-Means clustering.

⋆ We often use dendrograms to visualize evolutionary lineage.


Chapter 7

Dimensionality Reduction

In previous chapters covering supervised learning techniques, we often used basis functions to
project our data into higher dimensions prior to applying an inference technique. This allowed us
to construct more expressive models, which ultimately produced better results. While it may seem
counterintuitive, in this chapter we’re going to focus on doing exactly the opposite: reducing the
dimensionality of our data through a technique known as Principal Component Analysis (PCA).
We will also explore why it is useful to reduce the dimensionality of some data sets.

7.1 Motivation

Real-world data is often very high dimensional, and it’s common that our data sets contain infor-
mation we are unfamiliar with because the dimensionality is too large for us to comb through all
the features the by hand.
In these situations, it can be very difficult to manipulate or utilize our data effectively. We
don’t have a sense for which features are ‘important’ and which ones are just noise. Fitting a
model to the data may be computationally expensive, and even if we were to fit some sort of model
to our data, it may be difficult to interpret why we obtain specific results. It’s also hard to gain
intuition about our data through visualization since humans struggle to think in more than three
dimensions. All of these are good reasons that we may wish to reduce the dimensionality of a data
set.
ML Framework Cube: Dimensionality Reduction

Dimensionality reduction operates primarily on continuous feature spaces, is fully unsupervised,


and is non-probabilistic for the techniques we explore in this chapter.
Domain Training Probabilistic
Continuous Unsupervised No

While dimensionality reduction is considered an unsupervised technique, it might also be


thought of as a tool used to make data more manageable prior to taking some other action. In fact,
it’s an important preprocessing step for a host of use cases.

92
7.2. APPLICATIONS 93

Bear Heights and Weights

Figure 7.1: Graph of bear heights and weights.

7.2 Applications
As described above, we need a tool like dimensionality reduction in situations where high-dimensional
data hinders us. Here are a few specific situations where we would use such a technique:

1. Presenting differences between complex molecules in two dimensions (via a graph).

2. Understanding the results of a credit trustworthiness algorithm.

3. Efficiently training a neural network to predict supermarket sales on a data set with many
input features.

4. Identifying which costly measurements are worth collecting when experimenting with new
chemicals.

With a few of these use cases in mind, we now turn to the math that underpins the dimension-
ality reduction technique known as Priniciple Component Analysis.

7.3 Principal Component Analysis


The main idea behind Principal Component Analysis (PCA) is that we can linearly project our
data set onto a subspace without losing too much information. For example, three-dimensional
data might primarily exist in a subspace that is actually a two-dimensional plane.
One way to think about this is to identify and preserve the features along which there is the
most variance. For example, imagine we had a data set comprised of the heights and weights of
individual bears. As an extreme case, let’s suppose all the bears were exactly the same height but
had a wide range of weights.
94 CHAPTER 7. DIMENSIONALITY REDUCTION

Bear Weights

Figure 7.2: Bear weights.

Converting Between Reduced and Original Data

Figure 7.3: Converting between the reduced data and original data.

To differentiate our data points, we obviously only need to report the weights of the bears. The
variance of the heights is 0, and the variance of the weights is some non-zero number. Intuitively,
the most interesting features from our data sets are those that vary the most.

⋆ In this simple example, the direction of maximal variance occurs exactly along the x1 axis, but in general it will
occur on a plane described by a combination of our input features.

The second way to think about PCA is that we are minimizing the error we incur when we
move from the lower-dimensional representation back to the original representation. This is known
as reconstruction loss. We can consider the meaning of this using our bear example.
Let’s say we project the data set from Figure 7.1 here down to a single dimension by recording
only the weights:
Then, to reconstruct our original graph, we need only to keep track of a slope and bias term
in the form of the familiar equation x2 = mx1 + b. In this case our slope is m = 0 and our bias
b = 3. Note that this storage overhead is constant (just remembering the slope and bias) regardless
of how big our data set gets. Thus we can go from our low-dimensional representation back to our
original data:
It will be our goal to determine a low-dimensional representation of our data that allows us
to return to our high-dimensional data while losing as little information as possible. We wish
to preserve everything salient about the data while discarding as much redundant information as
possible. We now turn to how this can be achieved.

7.3.1 Reconstruction Loss


We identified a key tenet of dimensionality reduction in the previous section: finding subspaces of
our data that preserve as much information as possible. Concretely, this means we want to convert
our original data point xn in D dimensions into a data point x′n in D′ dimensions where D′ < D.
7.3. PRINCIPAL COMPONENT ANALYSIS 95

Data Reconstruction

Figure 7.4: Far left: our original data. Middle: our reduced data in scalar form after the projection
given by x·w. Right: our reconstructed data points given by (x·w)w. Notice that our reconstructed
data points are not the same as our original data.

⋆ We’re going to assume that our data set has been mean-centered such that each feature in xn has mean 0. This
will not affect the application of the method (we can always convert back to the uncentered data by adding back the
mean of each feature), but will make our derivations more convenient to work with.

Let’s consider a simple case first: D′ = 1. This means that we’re projecting our D dimensional
data down onto just a single dimension, or in geometric terms, we’re projecting our data points
xn onto a line through the origin. We can define this line as the unit vector w ∈ RD×1 , and the
projection is given by the dot product x · w.

⋆ The unit vector w onto which we project our data is known as a principal component, from which PCA gets its
name.

This projection produces a scalar, and that scalar defines how far our projection x · w is from
the origin. We can convert this scalar back to D dimensional space by multiplying it with the
unit vector w. This means that (x · w)w is the result of projecting our data point x down into
one-dimension and then converting it to its coordinate location in D dimensions. We refer to these
as our projection vectors, and we can observe what this looks like geometrically in Figure 7.4.
The projection vectors we recover from the expression (x · w)w will be in D dimensions, but
they will obviously not be identical to the original D dimensional vectors (Figure 7.4 demonstrates
why that is the case). This difference between the original and projection vectors can be thought
of as error, since it is information lost from our original data. For a given data point xn and unit
vector w, we can measure this error through the expression:

||xn − (x · w)w||2 (7.1)

which is known as reconstruction loss because it measures the error incurred when reconstructing
our original data from its projection.

Definition 7.3.1 (Reconstruction Loss): Reconstruction loss is the difference (measured via a
distance metric such as Euclidean distance) between an original data set and its reconstruction from
a lower dimensional representation. It indicates how much information is lost during dimensionality
reduction.

Reconstruction loss is then a metric for evaluating how ‘good’ a subspace in D′ dimensions is
96 CHAPTER 7. DIMENSIONALITY REDUCTION

at representing our original data in D dimensions. The better it is, the less information we lose,
and the reconstruction loss is lower as a result.

7.3.2 Minimizing Reconstruction Loss


We now know that our goal is to find a good subspace to project onto, and we also know that
finding this good subspace is equivalent to minimizing the reconstruction loss it incurs. We can
now formalize this as an optimization problem.
First, we simplify the reconstruction loss for a single data point xn as follows:

||xn − (xn · w)w||2 = (xn − (xn · w)w)(xn − (xn · w)w)


= ||xn ||2 − 2(xn · w)2 + (xn · w)2 ||w||2
= ||xn ||2 − (xn · w)2

where ||w||2 = 1 because it is a unit vector. Note that we can define reconstruction loss over our
entire data set as follows:
N
1 X
RL(w) = ||xn ||2 − (xn · w)2 (7.2)
N
n=1
Recall that our goal is to minimize reconstruction loss over our data set by optimizing the subspace
defined by w. Let’s first rewrite Equation 7.2 as:
N N
1 X 1 X
RL(w) = ||xn ||2 − (xn · w)2
N M
n=1 n=1

where we can see that our optimization will depend only on maximizing the second term:
N
1 X
max (xn · w)2 (7.3)
w N
n=1

since it is thePonly one involving w. Recall that the sample mean of a data set is given by the
expression N1 N 2
n=1 xn , and note that Equation 7.3 is the sample mean of (x · w) . Using the
2
definition of variance for a random variable Z (which is given by V ar(Z) = E(Z ) − (E(Z))2 ), we
can rewrite Equation 7.3 as:
N
1 X 2
(xn · w)2 = V ar {xn · w}N N
  
n=1 + E {xn · w}n=1
N
n=1

Recall that we centered our data xn to have mean 0 such that the expression above simplifies to:
N
1 X
(xn · w)2 = V ar {xn · w}N
 
n=1 (7.4)
N
n=1

and therefore V ar {xn · w}N


 
n=1 is the term we wish to maximize. This means that minimizing
the reconstruction loss is equivalent to maximizing the variance of our projections
{xn · w}N
n=1 .
7.3. PRINCIPAL COMPONENT ANALYSIS 97

⋆ Note the intuitiveness of this result. We should like to find a subspace that maintains the spread in our data.

7.3.3 Multiple Principal Components


Up until now, we’ve been considering how we would project onto a single principal component
w ∈ RD×1 . This will reduce our data down to one dimension, just a scalar. In general, we will
wish to preserve more of our data than just a single dimension (in order to capture more of the
variance in our data and reduce reconstruction loss), which means that we will need to have multiple
principal components. For now we’ll assume that our principal components are orthogonal (we will
prove this later), which then allows us to describe the projection of our data xn onto this subspace
as the sum of the projections onto D′ orthogonal vectors:

D ′
X
(xn · wd′ )wd′ (7.5)
d′ =1

7.3.4 Identifying Directions of Maximal Variance in our Data


We now know from the previous section that we find our principal components (and thus the
subspace we will project onto) by identifying the directions of maximum variance in our projected
data set. We know from Equation 7.4 that the variance is equivalent to:

N
2 N 1 X
(xn · w)2
 
σw ≡ V ar {xn · w}n=1 =
N
n=1

Rewriting this in terms of matrix notation we have that:

2 1
σw = (Xw)T (Xw)
N
We can further simplify this:

2 1 T T
σw = w X Xw
N
2 XT X
σw = wT w
N
2
σw = wT Sw

XT X
where, since we assume the design matrix X is mean-centered, then S = N is the empirical
covariance matrix of our data set.
⋆ Notice that by convention we describe the empirical covariance of a data set with the term S instead of the usual
covariance term Σ.

2 = V ar {x · w}N
 
Our goal is to maximize the term σw n n=1 with respect to w. Furthermore, w is
a unit vector, so we must optimize subject to the constraint wT w = 1. Recalling the discussion of
98 CHAPTER 7. DIMENSIONALITY REDUCTION

Lagrange multipliers from Chapter 6 on Support Vector Machines, we incorporate this constraint
by reformulating our optimization problem as the Lagrangian equation:

L(w, λ) = wT Sw − λ(wT w − 1)

As usual, we proceed by taking the derivative with respect to each parameter:

dL(w, λ)
= 2Sw − 2λw
dw
dL(w, λ)
= wT w − 1

We can now set these equal to 0 and solve for the optimal values:

Sw = λw
wT w = 1

This result is very significant! As we knew already, we needed w to be a unit vector. However, we
also see that w is an eigenvector of the empirical covariance matrix w. Futhermore, the eigenvector
that will maximize our quantity of interest σw 2 = wT Sw will be the eigenvector with the largest

eigenvalue λ.
Linear algebra gives us many tools for finding eigenvectors, and as a result we can efficiently
identify our principal components. Note also that the eigenvectors of a symmetric matrix are
orthogonal, which proves our earlier assumption that our principal components are orthogonal.

⋆ Each eigenvector is a principal component. The larger the eigenvalue associated with that principal component,
the more variance there is along that principal component.

To recap, we’ve learned that the optimal principal components (meaning the vectors describing
our projection subspace) are the eigenvectors of the empirical covariance matrix of our data set.
The vector preserving the most variance in our data (and thus minimizing the reconstruction loss)
is given by the eigenvector with the largest eigenvalue, followed by the eigenvector with the next
largest eigenvalue, and so on. Furthermore, while it is somewhat outside the scope of this textbook,
we are guaranteed to have D distinct, orthogonal eigenvectors with eigenvalues ≥ 0. This is a result
of linear algebra that hinges on the fact that our empirical covariance matrix S is symmetric and
positive semi-definite.

7.3.5 Choosing the Optimal Number of Principal Components


We now know that the eigenvectors of the empirical covariance matrix S give us the principal
components that form our projection subspace. Note that the exact procedure for finding these
eigenvectors is a topic better suited for a book on linear algebra, but one approach is to use singular
value decomposition (SVD).
The eigenvalue decomposition of a square matrix is S = VLV⊤ , where the columns of V are
the eigenvectors of S, and L is a diagonal matrix whose entries are the corresponding eigenvalues
{λi }i . The SVD of a real-valued, square or rectangular matrix X is X = UZV⊤ with diagonal
7.3. PRINCIPAL COMPONENT ANALYSIS 99

Four Dimensional Iris Data Set in Three Dimensions

Figure 7.5: PCA applied to Fisher’s Iris data set, which is originally in four dimensions. We reduce
it to three dimensions for visualization purposes and label the different flower types. This example
is taken from the sklearn documentation.

matrix Z, and orthogonal matrices U and V (i.e., U⊤ U = I and V⊤ V = I). Matrix V is the right
singular matrix, and its columns the right singular vectors. The entries in Z are the corresponding
singular values. The column vectors in V also the eigenvectors of X⊤ X. To see this, we have
X⊤ X = VZU⊤ UZV⊤ = VZ2 V⊤ , where we substitute the SVD for X. We also see that the
eigenvalues λi = zi2 , so that they are the squared singular values.
For a mean-centered design matrix X, so that (1/N )X⊤ X = cov(X), we can compute the SVD
on the design matrix X, and then read off the eigenvectors of the covariance as the columns of
the right singular matrix and the√eigenvalues as λi = zi2 /N . Alternatively, you can first divide the
mean-centered design matrix by N before taking its SVD, in which case the square of the singular
values corresponds to the eigenvalues.
Because the principal components are orthogonal, the projections they produce will be entirely
uncorrelated. This means we can project our original data onto each component individually and
then sum those projections to create our lower dimensional data points. Note that it doesn’t make
sense that we would use every one of our D principal components to define our projection subspace,
since that wouldn’t lead to a reduction in the dimensionality of our data at all (the D orthogonal
principal components span the entire D dimensional space of our original data set). We now need
to decide how many principal components we will choose to include, and therefore what subspace
we will be projecting onto.
The ‘right’ number of principal components to use depends on our goals. For example, if we
simply wish to visualize our data, then we would project onto a 2D or 3D space. Therefore, we
would choose the first 2 or 3 principal components, and project our original data onto the subspace
defined by those vectors. This might look something like Figure 7.5.
However, it’s more complicated to choose the optimal number of principal components when our
goal is not simply visualization. We’re now left with the task of trading off how much dimensionality
100 CHAPTER 7. DIMENSIONALITY REDUCTION

Reconstruction Loss vs. Number of Principal Components

Figure 7.6: Reconstruction loss versus the number of principal components. Notice the similarity
to the ‘elbow’ method of K-Means clustering.

reduction we wish to achieve with how much information we want to preserve in our data.
One way to do this is similar to the informal ‘elbow’ method described for K-Means clustering.
We graph our reconstruction loss against the number of principal components used, as seen in
Figure 7.6. The idea is to add principal components to our subspace one at a time, calculating
the reconstruction loss as we go. The first few principal components will greatly reduce the recon-
struction loss, before eventually leveling off. We can identify the ‘elbow’ where the reduction in
loss starts to diminish, and choose to use that number of principal components.
Another way to do this is to consider how much variance we wish to preserve in our data. Each
principal component is associated with an eigenvalue λd that indicates what proportion of the
variance that principal component is responsible for in our data set. Then the fraction of variance
retained from our data set if we choose to keep D′ principal components is given by:
PD′
′ λd′
retained variance = PdD=1 (7.6)
d=1 λd

For different applications, there may be different levels of acceptable variance retention, which can
help us decide how many principal components to keep.
Finally, once we have selected our principal components, we have also defined the subspace onto
which we will be projecting our original data. And although this subspace is defined by the basis
given by our principal components, these principal components are not a unique description of that
subspace. We could choose to use any basis after we’ve identified our subspace through the principal
components. The importance of this idea is simply that although our principal components are
unique, they are not the only basis we could use to define the same projection subspace.
7.4. CONCLUSION 101

7.4 Conclusion
Principal component analysis is useful for visualization purposes, removing redundant informa-
tion, or making a data set more computationally manageable. PCA is also a good tool for data
exploration, particularly when digging into an unfamiliar data set for the first time.
It is not essential that we know by heart the exact derivation for arriving at the principal
components of a data set. The same can be said of the linear algebra machinery needed to compute
principal components. However, it is important to have an intuitve grasp over how variance in
our data set relates to principal components, as well as an understanding of how subspaces in our
data can provide compact representations of that data set. These are critical concepts for working
effectively with real data, and they will motivate related techniques in machine learning.
Chapter 8

Graphical Models

Mathematics, statistics, physics, and other academic fields have useful notational systems. As a
hybrid of these and other disciplines, machine learning borrows from many of the existing systems.
Notational abstractions are important to enable consistent and efficient communication of ideas, for
both teaching and knowledge creation purposes. Much of machine learning revolves around mod-
eling data processes, and then performing inference over those models to generate useful insights.
In this chapter, we will be introducing a notational system known as the directed graphical model
(DGM) that will help us reason about a broad class of models.

8.1 Motivation
Up until this point, we’ve defined notation on the fly, relied on common statistical concepts, and
used diagrams to convey meaning about the problem setup for different techniques. We’ve built
up enough working knowledge and intuition at this point to switch to a more general abstraction
for defining models: directed graphical models (DGMS). DGMs will allow us to both consolidate
our notation and convey information about arbitrary problem formulations. An example of what
a DGM looks like for a linear regression problem setup is given in Figure 8.1, and over the course
of the chapter, we’ll explain how to interpret the symbols in this diagram.
We need graphical models for a couple of reasons. First, and most importantly, a graphical
model unambiguously conveys a problem setup. This is useful both to share models between
people (communication) and to keep all the information in a specific model clear in your own head
(consolidation). Once we understand the meaning of the symbols in a DGM, it will be far easier to
examine one of them than it will be to read several sentences describing the type of model we’re
imagining for a specific problem. Another reason we use DGMs is that they help us reason about
independence properties between different parts of our model. For simple problem setups this may
be easy to keep track of in our heads, but as we introduce more complicated models it will be useful
to reason about independence properties simply by examining the DGM describing that model.
Ultimately, directed graphical models are a tool to boost efficiency and clarity. We’ll examine
the core components of a DGM, as well as some of their properties regarding independence and
model complexity. The machinery we develop here will be used heavily in the coming chapters.

8.2 Directed Graphical Models (Bayesian Networks)


There are a few fundamental components to all of the models we’ve examined so far. In fact, we can
model just about everything we’ve discussed to this point using just random variables, deterministic

102
8.2. DIRECTED GRAPHICAL MODELS (BAYESIAN NETWORKS) 103

Figure 8.1: Linear regression model expressed with a DGM.

Figure 8.2: Random variables are denoted with an open circle, and it is shaded if the variable is
observed.

parameters, and arrows to indicate the relationships between them. Let’s consider linear regression
as a simple but comprehensive example. We have a random variable yn , the object of predictive
interest, which depend on deterministic parameters in the form of data xn and weights w. This
results in the DGM given by Figure 8.1. There are four primary pieces of notation that the linear
regression setup gives rise to, and these four components form the backbone of every DGM we
would like to construct.
First, we the random variable yn represented by an open circle, shown in Figure 8.2. If we
observe a random variable of a given model, then it is shaded. If the random variable is unobserved
(sometimes called latent), it is unshaded.
Second, we have deterministic parameters represented by a tight, small dot, shown in Figure
8.3.
Third, we have arrows that indicate the dependence relationship between different random
variables and parameters, shown in Figure 8.4. Note that an arrow from X into Y means that Y
depends on X.
And finally, we have plate notation to indicate that we have repeated sets of variables in our
model setup, shown in Figure 8.5.

Figure 8.3: Deterministic parameters are denoted with a tight, small dot.
104 CHAPTER 8. GRAPHICAL MODELS

Figure 8.4: Arrows indicate dependence relationships.

Figure 8.5: Plates indicate repeated sets of variables. Often there will be a number in one of the
corners (N in this case) indicating how many times that variable is repeated.
8.2. DIRECTED GRAPHICAL MODELS (BAYESIAN NETWORKS) 105

Figure 8.6: DGM for the joint distribution give by Equation 8.1.

With these four constructs, we can describe complex model setups diagrammatically. We can
have an arbitrary number of components and dependence structure. DGMs can be useful as a
reference while working on a problem, and they also make it easy to iterate on an existing model
setup.

8.2.1 Joint Probability Distributions


We’ll now consider how DGMs simplify the task of reasoning about a joint probability distribution
over multiple random variables. For any joint probability distribution, regardless of our knowledge
about the dependence structure in the model, it’s always valid to write a generic joint probability
distribution as follows:

p(A, B, C)

where in this setup, we are interested in the joint distribution between three random variables
A, B, C. However, this doesn’t tell us anything about the structure of the problem at hand. We
would like to know where there is independence and use that to simplify our model. For example, if
we knew that B and C were independent and we also knew the conditional distribution of A|B, C
then we would much rather setup our joint probability equation as:

p(A, B, C) = p(B)p(C)p(A|B, C) (8.1)

DGMs assist in this process of identifying the appropriate factorization, as their structure allows
us to read off valid factorizations directly. For example, the joint distribution given by Equation
8.1 can be read from Figure 8.6.
We can translate between a DGM and a factorized joint distribution by interpreting the arrows
as dependencies. If a random variable has no dependencies (as neither B nor C do in this example),
they can be written on their own as marginal probabilities p(B) and p(C). Since arrows indicate
dependence of the random variable at the tip on the random variable at the tail, the factorized joint
probability distribution of the dependent random variable is written as a conditional probability,
i.e. P (A|B, C) in 8.6. In this way, we can move back and forth between DGMs and factorized joint
distributions with ease.

8.2.2 Generative Models


DGMs also show us the process by which data comes into existence (sometimes referred to as the
data generating story or generative process). Thus, a DGM conveys the information required to
106 CHAPTER 8. GRAPHICAL MODELS

Figure 8.7: Example of a generative model.

identify how the data are generated, and if we have the proper tools, how we could generate new
data ourselves.

Definition 8.2.1 (Generative Models): A generative model describes the entire process by
which data comes into existence. It enables the creation of new data by sampling from the gener-
ative model, but generative models are not required to make predictions or perform other kinds of
inference.

⋆ Note that we can create graphical models for both generative and discriminative models. Discriminative models
will only model the conditional distribution p(Y |Z), while generative models will model the full joint distribution
p(Z, Y ).

Let’s consider a simple example to see how this works in practice. Consider the flow of in-
formation present in Figure 8.7. First, there is some realization of the random variable Z. Then
conditioned on that value of Z, there is some realization of the random variable Y . The equivalent
joint factorization for this DGM is given by p(Z)p(Y |Z). More intuitively, the data are created by
first sampling from Z’s distribution, and then based on the sampled value Z = z, sampling from
the conditional distribution p(Y |Z = z).
As a concrete example, let Z be a random variable representing dog breed and Y be a random
variable representing snout length of a dog of a given breed, which is conditional on the breed of
dog. Notice that we have not specified the distributional form of either Z or Y ; we only have the
story of how they relate to each other (i.e. the dependence relation).
This story also shows us that if we had some model for Z and Y |Z, we could generate data
points ourselves. Our procedure would simply be to sample from Z and then to sample from Y
conditioned on that value of Z. We could perform this process as many times as we like to generate
new data. This is in contrast to sampling directly from the joint p(Z, Y ), which is difficult if we do
not know the exact form of the joint distribution (which we often do not) or if the joint distribution
is difficult to sample from directly.
The technique of sampling from distributions in the order indicated by their DGM is known as
ancestral sampling, and it is a major benefit of generative models.

Definition 8.2.2 (Ancestral Sampling): Ancestral sampling is a technique used to generate


data from an arbitrary DGM. It works by sampling from the random variables in topological order,
meaning that we first sample from all random variables without any dependence, then from the
8.2. DIRECTED GRAPHICAL MODELS (BAYESIAN NETWORKS) 107

Figure 8.8: Notice we have to sample these in topological order: A, D, B, E, C.

random variables that depend on those intially sampled random variables, and so on until all the
random variables have been sampled. This is demonstrated in Figure 8.8.

8.2.3 Generative Modeling vs. Discriminative Modeling


In the previous section, we described generative modeling, which describes how data come into
existence. Another type of modeling is discriminative modeling, and we have already seen several
examples.
A discriminative model directly models the dependence of a random variable on input random
variables (i.e. it models the conditional p(Y |Z), skipping p(Y, Z)). For example, if we wish to
predict what value a response variable Y will take on, we can consider input values to be fixed
parameters without an underlying distribution, and then our model is simply tasked with predicting
the response variable based on those parameters. This is in contrast to a generative model, which
models how all the values were generated by assigning a distribution to each of them.
To make the distinction more concrete, consider the DGM describing linear regression from
Figure 8.1. Notice that the data points xn are not random variables but instead fixed parameters
of the DGM. To be a generative model, the DGM would need to have the form shown in Figure
8.9.
The difference between these model setups is significant. The generative model for linear re-
gression would allow us to generate new input data points, but it is also significantly more complex
to handle: instead of only predicting a single response variable based on the input values, we also
model the generation of the input data points xn . This may be difficult for both conceptual and
computational reasons. This doesn’t mean we’ll never try to create generative models, but if our
goal is only making predictions about the response variable yn , it is typically unnecessary and less
effective to use a generative model. Another example of discriminative modeling include logistic
regression, and oftentimes SVMs and neural networks are set up as discriminative models as well.
In essence, the distinction between generative and discriminative models lies in whether the
model tries to describe how the data is realized or if the model only performs a specific inference
task without modeling the entirety of the data generating process. Neither type is better; they are
different techniques in your toolkit and should be used depending on your modeling and inferential
108 CHAPTER 8. GRAPHICAL MODELS

Figure 8.9: Linear regression DGM, modeling xn as a random variable.

needs.

8.2.4 Understanding Complexity

We’ve already motivated one of the primary uses of DGMs as being the ability to convert a joint
distribution into a factorization. At the heart of that task was recognizing and exploiting indepen-
dence properties in a model over multiple random variables. Another benefit of this process is that
it allows us to easily reason about the size (also called ‘complexity’) of a joint factorization over
discrete random variables. In other words, it allows us to determine how many parameters we will
have to learn to describe the factorization for a given DGM.
Let us consider an example to make this concept clear. Suppose we have four categorical random
variables A, B, C, D which take on 2, 4, 8, and 16 values respectively. If we were to assume full
dependence between each of these random variables, then a joint distribution table over all of these
random variables would require (2 ∗ 4 ∗ 8 ∗ 16) − 1 = 1023 total parameters (where each parameter
corresponds to the probability of a specific permutation of the values A, B, C, D).

⋆ Notice that the number of parameters we need is (2 ∗ 4 ∗ 8 ∗ 16) − 1 and not (2 ∗ 4 ∗ 8 ∗ 16). This is because if we
know the first (2 ∗ 4 ∗ 8 ∗ 16) − 1 parameters, the probability of the final combination of values is fixed since joint
probabilities sum up to 1.

However, if we knew that some of these random variables were conditionally independent, then
the number of parameters would change. For example, consider the joint distribution given by
p(A, B, C, D) = p(A)p(B|A)p(C|A)p(D|A). This would imply that conditioned on A, each of
B, C, D were conditionally independent. This can also be shown by the DGM in Figure 8.10.
In this case, a table of parameters to describe this joint distribution would only require 2 ∗ ((4 −
1) + (8 − 1) + (16 − 1)) = 50 parameters, which is significantly less. In general, the more conditional
independence we can identify between random variables, the easier they are to model and compute.
8.2. DIRECTED GRAPHICAL MODELS (BAYESIAN NETWORKS) 109

Figure 8.10: Conditioned on A, we have that B, C, and D are independent.

Figure 8.11: The three random variable relationships that will tell us about independence relation-
ships.

8.2.5 Independence and D-Separation


Another major benefit of DGMs is the ability to visually reason about the independence properties
of a given model. The form of a graphical model determines which variables are independent under
specific observation assumptions. With just three basic cases of relationships between three random
variables, we can reason about any arbitrarily complex model. The three cases are shown in Figure
8.11. For each of these cases, there is a notion of information (dependence) either flowing from one
random variable to the other or being ‘blocked’ (implying independence) either by an observation
or the absence of an observation. Specifically, case 1 and case 2 (left and middle DGMs) have
information between nodes A and C ‘blocked’ by observing node B, and case 3 has information
between nodes A and C ‘blocked’ by not observing node B.
To understand what it means to block information flow by observing a random variable, consider
the first example random variable structure, shown in Figure 8.12.
We can factorize this as follows:

p(A, B, C) = p(B)p(A|B)p(C|B)

We know that for this case, A and C are not independent (note that we have not observed B).
Therefore, we say that information flows from A to C. However, once we’ve observed B, we know
that A and B are conditionally independent, which is shown in Figure 8.13.
We now say that the flow of information from A to C is ‘blocked’ by the observation of B.
Intuitively, if we observe A but not B, then we have some information about the value of B, which

Figure 8.12: First structure, unobserved.


110 CHAPTER 8. GRAPHICAL MODELS

Figure 8.13: First structure, observed.

Figure 8.14: Second structure, unobserved.

also gives some information about the value of C. The same applies in the other direction: observing
C but not B has implications on the distribution of B and A.
In the second random variable structure, shown in Figure 8.14, we again consider the unobserved
case first.
This allows us to write the joint distribution as:

p(A, B, C) = p(A)p(B|A)p(C|B)

Again, A and C are dependent if we have not observed B. Information is flowing from A to C
through B. However, once we’ve observed B, then A and C are again conditionally independent,
shown in Figure 8.15.
The flow of information from A to C is ‘blocked’ by the observation of B. Intuitively, if we
observed A but not B, we have some information about what B might be and therefore what C
might be as well. The same applies in the other direction: observing C but not B.
Notice that these first two cases behave in the same manner: observing a random variable
in between two other random variable ‘blocks’ information from flowing between the two outer
random variables. In the third and final case the opposite is true. Not observing data in this case
will ‘block’ information, and we will explain this shift through an idea known as ‘explaining away’.
We have the third and final random variable structure, shown in Figure 8.16. We consider the
unobserved case first.
In this setup, we say that information from A to C is being ‘blocked’ by the unobserved variable
B. Thus A and C are currently conditionally independent. However, once we’ve observed B, as
shown in Figure 8.17, the information flow changes.
Now, information is flowing between A and C through the observed variable B, making A and
C conditionally dependent. This phenomenon, where the observation of the random variable in the
middle creates conditional dependence is known as explaining away. The idea relies on knowledge

Figure 8.15: Second structure, observed.


8.2. DIRECTED GRAPHICAL MODELS (BAYESIAN NETWORKS) 111

Figure 8.16: Third structure, unobserved.

Figure 8.17: Third structure, observed.

of the value for B giving information about how much A or C may have contributed to B adopting
that value.
Consider the following example: let the random variables A correspond to whether or not it
rained on a certain day, B correspond to the lawn being wet, and C correspond to the sprinkler
being on. Let’s say we observe B: the lawn is wet. Then, if we observe variable A: it has not rained
today, we would infer that variable C has the value: the sprinkler has been on, because that’s the
only way for the lawn to be wet. This is the phenomenon of explaining away. Observing B unblocks
the flow of information between A and C because we can now use an observation to ‘explain’ how
B got its value, and therefore determine what the other unobserved value might have been.
Notice that we’ve only described three simple cases relating dependence relationships between
random variables in a DGM, but with just these three cases, we can determine the dependence
structure of any arbitrarily complicated DGM. We just have to consider how information flows
from node to node. If information gets ‘blocked’ at any point in our DGM network because of an
observation (or lack thereof), then we gain some knowledge about independence within our model.
Consider the dependence between random variables A and F in Figure 8.18. Initially, before any
observations are made, we can see that A and F are dependent (information flows from A through
B). However, after observing B, the nodes A and F become independent (because information
blocked at both the observed B and the unobserved D). Finally, after observing D, dependence is
restored between A and F because information flows from A through D.

Figure 8.18: Notice how the independence between A and F depends on observations made within
the network.
112 CHAPTER 8. GRAPHICAL MODELS

Figure 8.19: DGM for Naive Bayes problem setup.

⋆ In some other resources, you’ll come upon the idea of ‘D-separation’ or ‘D-connection’. D-separation is simply
applying the principles outlined above to determine if two nodes are independent, or D-separated. By contrast,
two-nodes that are D-connected are dependent on one another.

8.3 Example: Naive Bayes


Recall from the chapter on classification the Naive Bayes model. As a quick recap, Naive Bayes
makes the assumption that for a single observation coming from a specific class (for example, our
classes could be different dog breeds), the data associated with that observation are independent
(for example, the data could be fur color, snout length, weight, etc.).
We can set up the Naive Bayes problem specification using the DGM form, as we see in Figure
8.19.
Even if we hadn’t heard of Naive Bayes before, we would understand this model and the
implications of the model, simply by an examination of the DGM that describes it. We can
directly read off the factorization described above that is the whole point of Naive Bayes:

p(yn , xn1 , ..., xnJ ) = p(yn )p(xn1 |yn ) · · · p(xnJ |yn )

Writing this factorization is facilitated directly by our DGM, even if we’ve never heard of Naive
Bayes before. It provides a common language for us to move fluidly between detailed probability
factorizations and general modeling intuition.

8.4 Conclusion
Directed graphical models are indispensible for model visualization and effective communication of
modeling ideas. With an understanding of what DGMs represent, it’s much easier to analyze more
complex probabilistic models. In many ways, this chapter is preparation for where we head next.
The topics of the following chapters will rely heavily on DGMs to explain their structure, use cases,
and interesting variants.
Chapter 9

Mixture Models

The real world often generates observable data that falls into a combination of unseen categories.
For example, at a specific moment in time I could record sound waves on a busy street that come
from a combination of cars, pedestrians, and animals. If I were to try to model my data points, it
would be helpful if I could group them by source, even though I didn’t observe where each sound
wave came from individually.
In this chapter we explore what are known as mixture models. Their purpose is to handle data
generated by a combination of unobserved categories. We would like to discover the properties of
these individual categories and determine how they mix together to produce the data we observe.
We consider the statistical ideas underpinning mixture models, as well as how they can be used in
practice.

9.1 Motivation
Mixture models are used to model data involving latent variables.

Definition 9.1.1 (Latent Variable): A latent variable is a piece of data that is not observed,
but that influences the observed data. We often wish to create models that capture the behavior
of our latent variables.

We are sometimes unable to observe all the data present in a given system. For example,
if we measure the snout length of different animals but only get to see the snout measurements
themselves, the latent variable would be the type of animal we are measuring for each data point.
For most data generating processes, we will only have access to a portion of the data and the rest
will be hidden from us. However, if we can find some way to also model the latent variables, our
model will potentially be much richer, and we will also be able to probe it with more interesting
questions. To build some intuition about latent variable models, we present a simple directed
graphical model with a latent variable zn in Figure 9.1.
One common means of modeling data involving latent variables, and the topic of this chapter,
is known as a mixture model.

Definition 9.1.2 (Mixture Model): A mixture model captures the behavior of data coming from
a combination of different distributions.

At a high level, a mixture model operates under the assumption that our data is generated by
first sampling a discrete class, and then sampling a data point from within that category according

113
114 CHAPTER 9. MIXTURE MODELS

Figure 9.1: Directed graphical model with a latent variable z.

to the distribution for that category. For the example of animal snouts, we would first sample a
species of animal, and then based on the distribution of snout lengths in that species, we would
sample an observation to get a complete data point.
Probabilistically, sampling a class (which is our latent variable, since we don’t actually observe
it) happens according to a Categorical distribution, and we typically refer to the latent variable as
z. Thus:
p(z = Ck ; θ) = θk
where Ck is class k, and θ is the parameter to the Categorical distribution that specifies the
probability of drawing each class. We write the latent variable in bold z because we will typically
consider it to be one-hot encoded (of dimension K, for K classes). Then, once we have a class, we
have a distribution for the observed data point coming from that class:
p(x|z = Ck ; w)

⋆ The distribution given by p(x|z = Ck ; w) is known as the class-conditional distribution.

This distribution depends on the type of data we are observing, and is parameterized by an
arbitrary parameter w whose form depends on what is chosen as the class-conditional distribution.
For the case of snout lengths, and many other examples, this conditional distribution is often
modeled using a Gaussian distribution, in which case our model is known as a Gaussian Mixture
Model. We will discuss Gaussian Mixture Models in more detail later in the chapter.
If we can effectively model the distribution of our observed data points and the latent variables
responsible for producing the data, we will be able to ask interesting questions of our model. For
example, upon observing a new data point x′ we will be able to produce a probability that it came
from a specific class z′ = Ck using Bayes’ rule and our model parameters:
p(x′ |z′ = Ck ; w)p(z′ = Ck ; θ)
p(z′ = Ck |x′ ) = P ′ ′ ′
k′ p(x |z = Ck′ ; w)p(z = Ck′ ; θ)
9.2. APPLICATIONS 115

Furthermore, after modeling the generative process, we will be able to generate new data points by
sampling from our categorical class distribution, and then from the class-conditional distribution
for that category:

z ∼ Cat(θ)
x ∼ p(x|z = Ck ; w)

Finally, it will also be possible for us to get a sense of the cardinality of z (meaning the number of
classes our data falls into), even if that was not something we were aware of a priori.
ML Framework Cube: Mixture Models
The classes of data z in a mixture model will typically be discrete. Notice also that this is an
unsupervised technique: while we have a data set X of observations, our goal is not to make
predictions. Rather, we are trying to model the generative process of this data by accounting for
the latent variables that generated the data points. Finally, this is a probabilistic model both for
the latent variables and for our observed data.
Domain Training Probabilistic
Discrete Unsupervised Yes

9.2 Applications
Since much of the data we observe in our world has some sort of unobserved category associated
with it, there are a wide variety of applications for mixture models. Here are just a few:

1. Handwriting image recognition. The categories are given by the characters (letters, numbers,
etc.) and the class-conditional is a distribution over what each of those characters might look
like.

2. Noise classification. The categories are given by the source of a noise (e.g. we could have
different animal noises), and the class-conditional is a distribution over what the sound waves
for each animal noise look like.

3. Vehicle prices. The categories are given by the brand of vehicle (we could alternatively
categorize by size, safety, year, etc.), and the class-conditional is a distribution over the price
of each brand.

9.3 Fitting a Model


We’ve defined the general form of a mixture model: we have a distribution p(z; θ) over our classes
and a distribution p(x|z = Ck ; w) as our class-conditional distribution. A natural approach would
be to compute the maximum likelihood values for our parameters θ and w. Let’s consider how we
might go about this for a mixture model.

9.3.1 Maximum Likelihood for Mixture Models


Our goal is to maximize the likelihood of our observed data. Because we don’t actually observe the
latent variables zn which determine the class of each observed data point xn , we can simply sum
116 CHAPTER 9. MIXTURE MODELS

over the possible classes for each of our N data points as follows:
N X
Y K
p(X; θ, w) = p(xn , zn,k ; θ, w)
n=1 k=1
P
This uses p(xn ; θ, w) = k p(xn , zn,k ; θ, w) (marginalizing out over the latent class). Taking
the logarithm to get our log-likelihood as usual:
N
X K
X 
log p(X; θ, w) = log p(xn , zn,k ; θ, w) (9.1)
n=1 k=1

It may not be immediately obvious, but under this setup, the maximum likelihood calculation for
our parameters θ and w is now intractable. The summation over the K classes of our latent variable
zn , which is required because we don’t actually observe those classes, is inside of the logarithm,
which prevents us from arriving at an analytical solution (it may be helpful to try to solve this
yourself, you’ll realize that consolidating a summation inside of a logarithm is not possible). You
could still try to use gradient descent, but the problem is non-convex and we’ll see a much more
elegant approach. The rest of this chapter will deal with how we can optimize our mixture model
in the face of this challenge.

9.3.2 Complete-Data Log Likelihood


We have a problem with computing the MLE for our model parameters. If we only knew which
classes our data points came from, i.e., if we had zn for each example n, then we would be able to
calculate log p(x, z) with relative ease because we would no longer require a summation inside the
logarithm:
N
X
log p(X, Z) = log p(xn , zn ; θ, w) (9.2)
n=1
XN
= log[p(xn |zn ; w)p(zn ; θ)] (9.3)
n=1
XN
= log p(xn |zn ; w) + log p(zn ; θ) (9.4)
n=1

Notice that because we’ve now observed zn , we don’t have to marginalize over its possible
values. This motivates an interesting approach that takes advantage of our ability to work with
p(x, z) if we only knew z.
The expression p(x, z) is known as the complete-data likelihood because it assumes that we have
both our observation x and the class z that x came from. Our ability to efficiently calculate the
complete-data log likelihood log p(x, z) is the crucial piece of the algorithm we will present to op-
timize our mixture model parameters. This algorithm is known as Expectation-Maximization,
or EM for short.

9.4 Expectation-Maximization (EM)


The motivation for the EM algorithm, as presented in the previous section, is that we do not have
a closed form optimization for our mixture model parameters due to the summation inside of the
9.4. EXPECTATION-MAXIMIZATION (EM) 117

logarithm. This summation was required because we didn’t observe a crucial piece of data, the
class z, and therefore we had to sum over its values.
EM uses an iterative approach to optimize our model parameters. It proposes a soft value for
z using an expectation calculation (we can think about this as giving a distribution on zn for each
n), and then based on that proposed value, it maximizes the expected complete-data log likelihood
with respect to the model parameters θ and w via a standard MLE procedure.
Notice that EM is composed of two distinct steps: an “E step” that finds the expected value of
the latent class variables given the current set of parameters, and an “M step” that improves the
model parameters by maximizing expected complete-data log likelihood given these soft assignments
to class variables. These two steps give the algorithm its name, and more generally, this type of
approach is also referred to as coordinate ascent. The idea behind coordinate ascent is that we
can replace a hard problem (maximizing the log likelihood for our mixture model directly) with
two easier problems, namely the E- and M-step. We alternate between the two easier problems,
executing each of them until we reach a point of convergence or decide that we’ve done enough.
We may also restart because EM will provide a local but not global optimum.
We’ll walk through the details of each of these two steps and then tie them together with the
complete algorithm.

⋆ K-Means, an algorithm we discussed in the context of clustering, is also a form of coordinate ascent. K-Means is
sometimes referred to as a “maximization-maximization” algorithm because we iteratively maximize our assignments
(by assigning each data point to just a single cluster) and then update our cluster centers to maximize their likelihood
with respect to the new assignments. That is, it does a “max” in place of the E-etep, making a hard rather than soft
assignment.

9.4.1 Expectation Step


The purpose of the E-step is to find expected values of the latent variables zn for each example
given the current parameter values. Let’s consider what this looks like with a concrete example.
Let’s say our data points xn can come from one of three classes. Then, we can represent the
latent variable zn associated with each data point using a one-hot encoded vector. For example, if
zn came from class C1 , we would denote this:
 
1
zn = 0

0

As we’ve already explained, we don’t know the true value of this latent variable. Instead, we will
compute its conditional expectation based on the current setting of our model parameters and our
observed data xn . We denote the expectation of our latent variables as qn , and we calculate them
as follows:
   
p(zn = C1 |xn ; θ, w) p(xn |zn = C1 ; w)p(zn = C1 ; θ)
qn = E[zn |xn ] = p(zn = C2 |xn ; θ, w) ∝ p(xn |zn = C2 ; w)p(zn = C2 ; θ)
p(zn = C3 |xn ; θ, w) p(xn |zn = C3 ; w)p(zn = C3 ; θ)

The expectation of a 1-hot encoded vector is equivalent to a distribution on the values that the
latent variable might take on. Notice that we can switch from proportionality in our qn values to
118 CHAPTER 9. MIXTURE MODELS

actual probabilities by simply dividing each unnormalized value by the sum of all the unnormalized
values. Then, our qn values will look something like the following, where a larger number indicates
a stronger belief that the data point xn came from that class:
 
0.8
qn = 0.1

0.1
There are two important things to note about the expectation step. First, the model parameters θ
and w are held fixed. We’re computing the expectation of our latent variables based on the current
setting of those model parameters. Those parameters are randomly initialized if this is our first
time running the expectation step.
Second, we have a value of qn for every data point xn in our data set. As a result, qn are
sometimes called “local parameters,” since there is one assigned to each data point. This is in
contrast to our model parameters θ and w, which are “global parameters.” The size of the global
model parameters doesn’t fluctuate based on the size of our data set.
After performing the E-step, we now have an expectation for our latent variables, given by qn .
In the maximization step, which we describe next, we use these qn values to improve our global
parameters.

9.4.2 Maximization Step


After the expectation step, we have a qn ∈ [0, 1]K (and summing to one) associated with each data
point xn , which describes our belief that the data point came from each class Ck . Now that we
have these expected ‘class assignments’, it’s possible for us to maximize the expected complete-data
likelihood with respect to our model parameters θ and w.
Recall that optimizing our parameters using the complete-data log likelihood is tractable be-
cause we avoid summing over the classes inside of the logarithm. Although we do not have the
actual complete-data, since we don’t know the true zn values, we now have a distribution over these
latent variables (given by qn ).
Notice that our qn values are ‘soft’ assignments - meaning that unlike the zn values, which are
one-hot encodings of assignments to a class, the qn values have a probability that a data point xn
came from each class. Recall the expression for complete-data log likelihood:
N
X
log p(X, Z) = log p(xn |zn ; w) + log p(zn ; θ) (9.5)
n=1
We work instead with the expected complete-data log likelihood, using qn to provide the dis-
tribution on zn for each example n:
N
X 
Ezn |xn [log p(X, Z)] = Ezn |xn log p(xn |zn ; w) + log p(zn ; θ) (9.6)
n=1
N
X  
= Ezn |xn log p(xn |zn ; w) + log p(zn ; θ) (9.7)
n=1
XN X K

= p(zn = Ck |xn ) log p(xn |zn = Ck ; w) + log p(zn = Ck ; θ) (9.8)
n=1 k=1
N X
X K

= qn,k log p(xn |zn = Ck ; w) + log p(zn = Ck ; θ) (9.9)
n=1 k=1
9.4. EXPECTATION-MAXIMIZATION (EM) 119

Notice the crucial difference between this summation and that of Equation 9.1: the summation
over the classes is now outside of the logarithm! Recall that using the log-likelihood directly was
intractable precisely because the summation over the classes was inside of the logarithm. This
maximization became possible by taking the expectation over our latent variables (using the values
we computed in the E-step), which moved the summation over the classes outside of the logarithm.
We can now complete the M-step by maximizing Equation 9.6 with respect to our model pa-
rameters θ and w. This has an analytical solution. We take the derivative with respect to the
parameter of interest, set to 0, solve, and update the parameter with the result.

9.4.3 Full EM Algorithm


Now that we have a grasp on the purpose of the EM algorithm, as well as an understanding of
the expectation and maximization steps individually, we are ready to put everything together to
describe the entire EM algorithm.

1. Begin by initializing our model parameters w and θ, which we can do at random. Since
the EM algorithm is performed over a number of iterative steps, we will denote these initial
parameter values w(0) and θ (0) . We will increment those values as the algorithm proceeds.

2. E-step: compute the values of qn based on the current setting of our model parameters.

p(zn = C1 |xn ; θ (i) , w(i) ) p(xn |zn = C1 ; w(i) )p(zn = C1 ; θ (i) )


   

qn = E[zn |xn ] =  .. ..
∝
   
. . 
(i) (i) (i) (i)
p(zn = CK |xn ; θ , w ) p(xn |zn = CK ; w )p(zn = CK ; θ )

3. M-step: compute the values of w and θ that maximize our expected complete-data log like-
lihood for the current setting of the values of qn :

w(i+1) , θ (i+1) ∈ arg max EZ|X [log p(X, Z; w, θ)] (9.10)


w,θ

4. Return to step 2, repeating this cycle until our likelihood converges. Note that the likelihood
is guaranteed to (weakly) increase at each step using this procedure.

It is also typical to re-start the procedure because we are guaranteed a local but not global
optimum.

9.4.4 The Math of EM


The above sections supply an intuitive sense of why EM should work in scenarios where the latent
variables follow categorical distributions. We now supply a more rigorous proof that iteratively
computing the quantities derived above will increase the data log likelihood in a more general
setting.
120 CHAPTER 9. MIXTURE MODELS

The Evidence Lower Bound (ELBO)


Despite our inability to directly optimize the observed data log likelihood, it remains the quantity
of interest. Note the following equivalence:
N
"K # N
"K #
X X X X
log p(xn , zn = Ck ; θ, w) = log p(xn | zn = Ck ; θ, w)p(zn = Ck | θ, w) (9.11)
n=1 k=1 n=1 k=1
N
X
= log Ezn ∼p(zn |θ) [p(xn | zn , w)] (9.12)
n=1

We adopt this latter expression, since it also holds in cases where zn is not discrete.
As touched on earlier, the main issues with directly optimizing this expression are as follows:
1. If we were to take a gradient, the distribution over which we’re taking an expectation involves
one of the parameters, θ, and hence computing the gradient with respect to θ is difficult.
2. There’s an expectation (previously a sum) inside the logarithm.
To solve this, we ultimately derive a lower bound on the observed data log likelihood which
proves to be computationally tractable to optimize. First, let q(zn ) denote another probability
distribution on zn . Then:
N
"K #
X X
log p(xn , zn = Ck ; θ, w)
n=1 k=1
N K
" #
X X
= log p(xn | zn = Ck ; θ, w)p(zn = Ck | θ, w)
n=1 k=1
N K
" #
X X q(zn = Ck )
= log p(xn | zn = Ck ; θ, w)p(zn = Ck | θ, w) ·
q(zn = Ck )
n=1 k=1
N K
" #
X p(xn | zn = Ck ; θ, w)p(zn = Ck | θ, w)
X
= log · q(zn = Ck )
q(zn = Ck )
n=1 k=1
N  
X p(xn | zn ; θ, w)p(zn | θ, w)
= log Ezn ∼q(zn ) (9.13)
q(zn )
n=1

The above derivation again restricts to discrete zn , but the equivalence between the expressions
in (9.12) and (9.13) is in fact more general, holding whenever p is absolutely continuous with respect
to the chosen q.
In any case, we’ve now fixed the first of the two issues. By introducing the distribution q, the
expectation is no longer over some distribution depending on the parameter θ.
To fix the second issue, we must somehow pass the log into the expectation. This is accomplished
using Jensen’s inequality, at the cost of turning the equality into a lower bound:
N  
X p(xn | zn ; θ, w)p(zn | θ, w)
log Ezn ∼q(zn ) (9.14)
q(zn )
n=1
N   
X p(xn | zn ; θ, w)p(zn | θ, w)
≥ Ezn ∼q(zn ) log (9.15)
q(zn )
n=1
9.4. EXPECTATION-MAXIMIZATION (EM) 121

In summary, the two issues with directly optimizing the observed data log likelihood were
resolved as follows:

• Introducing auxiliary variables q(zn ) allow us to reparametrize the expectation to be over q


rather than p(zn | θ).

• Jensen’s inequality allows the log to pass through the expectation.

However, this came at the cost of converting our objective into a lower bound of the original
quantity, as well as introducing the new parameter that is the distribution q.
It turns out that the iterative process given in the section above amounts to alternating between
optimizing the parameters w, θ and the distribution q.

Optimization
Before delving into the optimization process, we first establish two identities involving the ELBO.
N   
X p(xn | zn ; θ, w)p(zn | θ, w)
ELBO(w, q) = Ezn ∼q(zn ) log (9.16)
q(zn )
n=1

First,
N
X 
ELBO(w, q) = Ezn ∼q(zn ) [log (p(xn | zn ; θ, w)p(zn | θ, w))] − Ezn ∼q(zn ) [log q(zn )]
n=1
XN N
X
= Ezn ∼q(zn ) [log (p(xn | zn ; θ, w)p(zn | θ, w))] − Ezn ∼q(zn ) [log q(zn )]
n=1 n=1
XN N
X
= Ezn ∼q(zn ) [log p(xn , zn | θ, w)] − Ezn ∼q(zn ) [log q(zn )] (9.17)
n=1 n=1

The first term in (9.17) is the expected complete data log likelihood, and the second is the
entropy of the distribution q.
Second, using Bayes’ rule:
N   
X p(xn | θ, w)p(zn | xn ; θ, w)
ELBO(w, q) = Ezn ∼q(zn ) log
q(zn )
n=1
N   
X p(zn | xn ; θ, w)
= log p(xn | θ, w) + Ezn ∼q(zn ) log
q(zn )
n=1
N N  
X X q(zn )
= log p(xn | θ, w) − Ezn ∼q(zn ) log (9.18)
p(zn | xn ; θ, w)
n=1 n=1

The first term is exactly the observed data log likelihood, and the second is known as the KL diver-
gence between q(zn ) and p(zn | xn ; θ, w). In particular, the KL-divergence between distributions
P and Q is defined as  
P (x)
DKL (P ∥Q) = Ex∼P (x) log (9.19)
Q(x)
122 CHAPTER 9. MIXTURE MODELS

The key property of KL-divergence is that it is nonnegative, that is DKL (P ∥Q) ≥ 0 always,
with equality if and only if P = Q:
   
Q(x) Q(x)
−DKL (P ∥Q) = Ex∼P (x) log ≤ log Ex∼P (x)
P (x) P (x)
X Q(x)
≤ log · P (x)
P (x)
≤0

where we have equality when the equality condition of Jensen’s is met, meaning Q(x) = cP (x),
which forces Q(x) = P (x) since both are distributions. The proof above is only for discrete variables,
but can again be extended to the continuous case.
Using this, we rewrite (9.18) as

N
X N
X
ELBO(w, q) = log p(xn | θ, w) − DKL (q(zn )∥p(zn | xn ; θ, w)) (9.20)
n=1 n=1

We now describe the E and M steps of the procedure. Begin with some random assignment of
θ, w, and q.

1. The E-step: Fixing the parameters θ, w, choose the distributions q that maximizes ELBO(w, q).
Observe that only the second term of (9.20) depends on the distributions
PNq. Hence maximizing
the ELBO is equivalent to minimizes the sum of KL divergences n=1 DKL (q(zn )∥p(zn |
xn ; θ, w)). This immediately yields that

q(zn ) = p(zn | xn ; θ, w).

2. The M-step: Fixing the distributions q, maximize the ELBO with respect to the parameters
θ, w.
Observe that only the first term of (9.17) depends on the parameters, and hence this is
equivalent to maximizing the expected complete data log likelihood,

N
X
Ezn ∼q(zn ) [log p(xn , zn | θ, w)] . (9.21)
n=1

An equivalent formulation is then as follows. Define the auxiliary function Q(w, θ | wold , θ old )
as
N
X
old old
Q(w, θ | w ,θ )= Ezn ∼p(zn |xn ;wold ,θold ) [log p(xn , zn | w, θ)] (9.22)
n=1

Then as before, begin with some random assignment of w, θ, q.

1. The E-step: compute the auxiliary function Q(w, θ; wold , θ old ). Note that this essentially
just requires computing the posterior distribution p(zn | xn ; wold , θ old ), as before.

2. The M-step: maximize Q(w, θ; wold , θ old ) with respect to w and θ.


9.4. EXPECTATION-MAXIMIZATION (EM) 123

Why are these equivalent? We’ve dropped the qs by noting that at each step, they will just be
p(zn | xn ; wold , θ old ).
It is important to note the significance of expression (9.20). Since the first term is exactly the
observed data log likelihood, this indicates that the gap between the ELBO and the observed data
log likelihood is exactly the KL divergence between q(zn ) and p(zn | xn ; θ, w). If this gap goes to
zero, then we will have succeeded in maximizing the observed data log likelihood. This is why it is
important to continually maximize over the distributions q - choosing the best q makes the bound
as tight as possible every round.

Correctness

We will only show that EM increases the observed log likelihood at every iteration. We do this
using the second of the two formulations.

N
X N
X
log p(xn | w, θ) − log p(xn | wold , θ old )
n=1 n=1
N
X p(xn | w, θ)
= log
n=1
p(xn | wold , θ old )
N PK
X p(xn , zn | w, θ)
= log k=1
n=1
p(xn | zn ; wold , θ old )
N K
X X p(xn , zn | w, θ)
= log p(zn | wold , θ old )
p(x | z ; w old , θ old )p(z | wold , θ old )
n=1 k=1 n n n
N  
X p(xn , zn | w, θ)
= log Ezn ∼p(zn |wold ,θold )
n=1
p(xn , zn | wold , θ old )
N  
X p(xn , zn | w, θ)
≥ Ezn ∼p(zn |wold ,θold ) log
n=1
p(xn , zn | wold , θ old )
N
X N
X h i
= Ezn ∼p(zn |wold ,θold ) [log p(xn , zn | w, θ)] − Ezn ∼p(zn |wold ,θold ) log p(xn , zn | wold , θ old )
n=1 n=1
old old old old
= Q(w, θ | w , θ) − Q(w ,θ |w , θ)

We are maximizing Q at each step, and thus this difference is strictly positive.

Equivalence to Prior Formulation

It is not hard to see that the formulation presented above, centered around maximizing the ELBO,
is equivalent to the one presented in section 9.4.3. Note that qn,k is exactly p(zn = Ck | xn , w, θ),
and thus the quantity being computed in the M -step (as in (9.10)) is exactly the same as in (9.21).

9.4.5 Connection to K-Means Clustering


At this point, it’s worth considering the similarity between the EM algorithm and another coordi-
nate ascent algorithm that we considered in the context of clustering: K-Means.
124 CHAPTER 9. MIXTURE MODELS

Recall that K-Means proceeds according to a similar iterative algorithm: we first make hard
assignments of data points to existing cluster centers, and then we update the cluster centers based
on the most recent data point assignments.
In fact, the main differences between K-Means clustering and the EM algorithm are that:

1. In the EM setting, we make soft cluster assignments through our qn values, rather than
definitively assigning each data point to only one cluster.

2. The EM algorithm is able to take advantage of flexible, class-conditional distributions to


capture the behavior of the observed data, whereas K-Means clustering relies only on distance
measurements to make assignments and update cluster centers.

In the context of a mixture-of-Gaussian model, which we get to later in the chapter, we can
confirm that K-means is equal to the limiting case of EM where the variance of each class-conditional
Gaussian goes to 0, the prior probability of each class is uniform, and the distributions are spherical.

9.4.6 Dice Example: Mixture of Multinomials


Consider the following example scenario: we have two biased dice (with 6 faces) and one biased
coin (with 2 sides). Data is generated as follows: first, the biased coin is flipped. Suppose it lands
heads. Then dice 1 is rolled c = 10 times. This gives the first example, i.e., x1 would correspond
to the number of times of rolling each of a 1, 2, . . . , 6. Then we repeat, flipping the biased coin.
Suppose it lands tails. Then Dice 2 is rolled 10 times. We record the result of the dice rolls as x2 .
We keep repeating, obtaining additional examples.
For example, our observations for the first 10 rolls may look like: 1, 5, 3, 4, 2, 2, 3, 1, 6, 2 and we’d
record as our first example
 
2
3
 
2
x1 = 
1

 
1
1

We’re going to try to infer the parameters of each of the dice based on these observations. Let’s
consider how this scenario fits into our idea of a mixture model. First, the latent variable zn has a
natural interpretation as being which dice was rolled for the nth observed data point xn . We can
represent zn using a one-hot vector, so that if the nth data point came from Dice 1, we’d denote
that:
 
1
zn =
0

We denote the probability vector associated with the biased coin as θ ∈ [0, 1]2 , summing to
1, with θ1 being the probability of the biased coin landing heads and θ2 being the probability of
the biased coin landing tails. Furthermore, we need parameters to describe the behavior of the
biased dice. We use π1 , π 2 ∈ [0, 1]6 , summing to 1, where each 6-dimensional vector describes the
probability that the respective dice lands on each face.
9.4. EXPECTATION-MAXIMIZATION (EM) 125

For a given dice, this defines a multinomial distribution. For c trials, and counts x1 , . . . , x6 for
each of 6 faces on a 6-sided dice, and probabilities π, this is
c!
p(x; π) = π x1 · . . . · π6x6 (9.23)
x1 ! · . . . · x6 ! 1
For our purposes, let p(xn |zn = Ck ; π 1 , π 2 ) denote the multinomial distribution on observation
xnj when latent vector zn = Ck .
The model parameters are w = {θ, π 1 , π 2 }. We can optimize the model parameters using EM.
We start by initializing the parameters θ (0) , π (0) .
In the E-step, we compute the soft assignment values, qn . For dice k, this given by

qnk =p(zn = Ck |xn ; θ (i) , π (i) ) (9.24)


p(xn |zn = Ck ; π (i) )p(zn = Ck ; θ (i) )
= PK (i)
(9.25)
(i)
k=1 p(xn |zn = Ck ; π )p(zn = Ck ; θ )
(i)
p(xn |zn = Ck ; π (i) )θk
=P (9.26)
K (i) (i)
k=1 p(xn |zn = Ck ; π )θk

We could also use the “product trick” to write a single expression


QK  znk
p(x |z = C ; π (i) )θ (i)
k=1 n n k k
p(zn |xn ; θ (i) , π (i) ) = PK (i)
(i)
k=1 p(xn |zn = Ck ; π )θk
QK (i) znk
QK (i) znk
k=1 p(xn |zn = Ck ; π ) k=1 (θk )
=
(i) (i)
PK
k=1 p(xn |zn = Ck ; π )θk

The vector qn is defined as:

p(zn = C1 |xn ; θ (i) , π (i) )


 
qn = (9.27)
p(zn = C2 |xn ; θ (i) , π (i) )

After computing the values of qn , we are ready to perform the M-step. Recall that we are
maximizing the expected complete-data log likelihood, which takes the form:
N
X 
EZ|X [log p(X, Z)] = Eqn log p(zn ; θ (i+1) , π (i+1) ) + log p(xn |zn ; θ (i+1) , π (i+1) ) (9.28)
n=1
N
X  
(i+1) (i+1) (i+1) (i+1)
= Ezn |xn log p(zn ; θ ,π ) + log p(xn |zn ; θ ,π ) (9.29)
n=1

We can then substitute in for the multinomial expression and simplify, and dropping constants
we have that we’re looking for parameters that solve
(N 2 N X
2
)
 
(i+1) xn ,1 xn ,6
XX X
arg max qn,k log θk + qn,k log πk,1 · . . . · πk,6
θ (i+1) ,π (i+1) n=1 k=1 n=1 k=1
 
XN X
2 N X
2 X
6 
(i+1)
X
= arg max qn,k log θk + qn,k xn,j log(πk,j ) (9.30)
(i+1)
θ ,π (i+1) 
n=1 k=1 n=1 k=1 j=1

126 CHAPTER 9. MIXTURE MODELS

To maximize the expected complete-data log likelihood, it’s necessary to introduce Lagrange
P (i+1) P (i+1)
multipliers to enforce the constraints k θk = 1 and j πk,j = 1, for each k. After doing this,
and solving, we recover the following update equations for the model parameters:
PN
(i+1) qn,k
θk ← n=1
N
PN
(i+1) n=1 qn,k xn
πk ← P ,
c N n=1 qn,k
where c = 10 in out example.
We now have everything we need to perform EM for this setup. After initializing our parameters
w(0) , we perform the E-step by evaluating 9.27. After calculating our values of qn in the E-step, we
update our parameters w = {θ, π 1 , π 2 } in the M-step by maximizing 9.30 with respect to θ, π 1 , π 2 .
We perform these two steps iteratively, until convergence of our parameters. We may also do a
restart.

9.5 Gaussian Mixture Models (GMM)


Our previous example was a simple but somewhat restricted application of the EM algorithm to
solving a latent variable problem. We now turn to a more practical example, used widely in different
contexts, called a Gaussian Mixture Model (GMM). As you might expect, a GMM consists of a
combination of multiple Gaussian distributions. Among other things, it is useful for modeling
scenarios where the observed data is continuous.
Let’s go over a more rigorous formulation of the GMM setup. First, we have observed continous
data xn ∈ Rm and latent variables zn which indicate which Gaussian ‘cluster’ our observed data
point was drawn from. In other words:
p(xn |zn = Ck ) = N (xn ; µk , Σk )
where µk , Σk are the mean and covariance parameters respectively for the k th cluster center.
The data generation process works as follows: we first sample a cluster center from a Categorical
distribution parameterized by θ ∈ RK . Then, based on the sampled cluster center, we sample a
data point xn ∈ Rm , which is the only piece of data that we actually observe. As usual for a mixture
model, it is our goal to use the observed data to determine the cluster means and covariances, as
well as the parameters of the Categorical distribution that selects the cluster centers.
Fortunately, this problem setup is perfectly suited for EM. We can apply the same machinery
we’ve discussed throughout the chapter and used in the previous example.
1. First, we randomly initialize our parameters θ, {µk , Σk }K
k=1 .

2. [E-Step] Calculate the posterior distribution over zn given by qn :


 
p(zn = C1 |xn ; θ1 , µ1 , Σ1 )
qn = E[zn |xn ] = 
 .. 
. 
p(zn = CK |xn ; θK , µK , ΣK )
 
θ1 N (xn ; µ1 , Σ1 )
∝
 .. 
. 
θK N (xn ; µK , ΣK )
9.6. ADMIXTURE MODELS: LATENT DIRICHLET ALLOCATION (LDA) 127

This is the current expectation for our latent variables zn given our data xn and the current
setting of our model parameters θ, {µk , Σk }K
k=1 .

3. [M-Step] Using our values of qn , calculate the expected complete-data log likelihood, and
then use that term to optimize our model parameters:

N
X 
Eqn [log p(X, Z)] = Eqn ln(p(xn , zn ; θ, {µk , Σk }K
k=1 ))
n=1
N X
X K
= qn,k ln θk + qn,k ln N (xn ; µk , Σk )
n=1 k=1

We can then use this expected complete-data log likelihood to optimize our model parameters
θ, {µ , Σ }K by computing the MLE as usual. Using a Lagrange multiplier to enforce
PK k k k=1
k=1 θk = 1, we recover the update equations:

PN
(i+1) n=1 qn,k
θk ←
N
PN
(i+1) n=1 qn,k xn
µk ← PN
n=1 qn,k
PN (i+1) (i+1) T
(i+1) n=1 qn,k (xn −µ )(xn − µk )
Σk ← PN k
n=1 qn,k

4. Return to step 2. Repeat until convergence.

Finally, it’s worth comparing EM and K-Means clustering as applied to GMMs. First, as dis-
cussed previously, EM uses soft assignments of data points to clusters rather than hard assignments.
Second, the standard K-Means algorithm does not estimate the covariance of each cluster. How-
ever, if we enforce as a part of our GMM setup that the covariance matrices of all the clusters are
given by ϵI, then as ϵ → 0, EM and K-Means will in fact produce the same results.

9.6 Admixture Models: Latent Dirichlet Allocation (LDA)


With a grasp on mixture models, it is not too difficult to understand admixture models. In a
sentence: an admixture model is a mixture of mixture models. Latent Dirichlet Allocation (LDA)
is a common form of admixture models, and it is sometimes also referred to as topic modeling, for
reasons that will become apparent shortly. Describing LDA using an example will hopefully make
the idea of an admixture model more concrete.

9.6.1 LDA for Topic Modeling


Consider the following data generating process for a set of text documents. We have a Dirichlet
distribution θ ∼ Dir(α) over the possible topics a document can take on.
128 CHAPTER 9. MIXTURE MODELS

⋆ If you haven’t seen the Dirichlet before, it is a distribution over an n-dimensional vector whose components sum to
1. For example, a sample from a dirichlet distribution in 3-dimensions could produce a sample that is the vector
 
0.2
0.5
0.3
.

We sample from that Dirichlet distribution to determine the mixture of topics θ n in our docu-
ment Dn :

θ n ∼ Dir(α)

Then, for each possible topic, we sample from a Dirichlet distribution to determine the mixture of
words ϕk in that topic:

ϕk ∼ Dir(β)

Then, for each word wn,j in the document Dn , we first sample from a Categorical parameterized
by the topic mixture θ n to determine which topic that word will come from:

zn,j ∼ Cat(θ n )

Then, now that we have a topic given by zn,j for this word wn,j , we sample from a Categorical
parameterized by that topic’s mixture over words given by ϕzn,j :

wn,j ∼ Cat(ϕzn,j )

Notice the mixture of mixtures at play here: we have a mixture model over the topics to produce
each document in our corpus, and then for every word in a given document, we have a mixture
over the topics to generate each individual word.
The indexing is particularly confusing because there are several layers of mixtures here, but to
clarify: n ∈ 1..N indexes each document Dn in our corpus, k ∈ 1..K indexes each possible topic,
and j ∈ 1..J indexes each word wn,j in document Dn , and e ∈ 1..E indexes each word in our
dictionary (note that wn,j ∈ RE ).
θ n specifies the distribution over topics in document Dn , and α is the hyperparameter for the
distribution that produces θ n . Similarly, ϕk specifies the distribution over words for the k th topic,
and β is the hyperparameter for the distribution that produces ϕk .

9.6.2 Applying EM to LDA


Now that the problem setup and notation are taken care of, let’s consider how we can apply EM
to optimize the parameters θ n (the mixture over topics in a document) and ϕk (the mixture over
words for a topic). Note that we can simplify the problem slightly by considering θ n and ϕk to be
deterministic parameters for optimization (rather than random variables parameterized by α and
β). Then, EM proceeds as follows:

1. First, we randomly initialize our parameters {θ n }N K


n=1 , {ϕk }k=1 .
9.6. ADMIXTURE MODELS: LATENT DIRICHLET ALLOCATION (LDA) 129

2. [E-Step] Fix the topic distribution of the document given by θ n and the word distribution
under a topic given by ϕk . Calculate the posterior distribution qn,j = p(zn,j |wn,j ), and note
that this is the distribution over the possible topics of a word:
 
p(zn,j = C1 |wn,j ; θ n , ϕ1 )
qn,j = E[zn,j |wn,j ] = 
 .. 
. 
p(zn,j = CK |wn,j ; θ n , ϕK )
 
p(wn,j |zn,j = C1 ; ϕ1 )p(zn,j = C1 ; θ n )
∝
 .. 
. 
p(wn,j |zn,j = CK ; ϕK )p(zn,j = CK ; θ n )
 
ϕ1,wn,j · θn,1
=
 .. 
. 
ϕK,wn,j · θn,K

3. [M-Step] Using our values of qn , calculate the expected complete-data log likelihood (which
marginalizes over the unknown hidden variables zn,j ), and then use that expression to optimize
our model parameters θ n and ϕk :

N X
X J 
Eqn [log p(W, Z)] = Eqn ln(p(wn,j , zn,j ; {θ n }N K
n=1 , {ϕk }k=1
n=1 j=1
N X
X J X
K
= qn,j,k ln θn,k + qn,j,k ln ϕk,wn,j
n=1 j=1 k=1

We can then use this expected complete-data log likelihood to optimize our model parameters
{θ n }N , {ϕk }K
k=1 by computing the MLE as usual. Using Lagrange multipliers to enforce
Pn=1
∀n k=1 θn,k = 1 and ∀k E
K P
e=1 k,e = 1 (where e indexes each word in our dictionary), we
ϕ
recover the update equations:
PJ
(i+1) j=1 qn,j,k
θn,k ←
J
PN PJ
(i+1) n=1 j=1 qn,j,k wn,j,d
ϕk,d ← PN PJ
n=1 j=1 qn,j,k

4. Return to step 2. Repeat until convergence.

The largest headache for applying the EM algorithm to LDA is keeping all of the indices in
order, and this is the result of working with a mixture of mixtures. Once the bookkeeping is sorted
out, the actual updates are straightforward.
130 CHAPTER 9. MIXTURE MODELS

9.7 Conclusion
Mixture models are one common way of handling data that we believe is generated through a
combination of unobserved, latent variables. We’ve seen that training these models directly is
intractable (due to the marginalization over the latent variables), and so we turned to a coordinate
ascent based algorithm known as Expectation-Maximization to get around this difficulty. We then
explored a couple of common mixture models, including a multinomial mixture, Gaussian Mixture
Model, and an admixture model known as Latent Dirichlet Allocation. Mixture models are a subset
of a broader range of models known as latent variable models, and the examples seen in this chapter
are just a taste of the many different mixture models available to us. Furthermore, EM is just a
single algorithm for optimizing these models. A good grasp on the fundamentals of mixture models
and the EM algorithm will be useful background for expanding to more complicated, expressive
latent variable models.
Chapter 10

Hidden Markov Models

Many of the techniques we’ve considered so far in this book have been motivated by the types of data
we could expect to work with. For example, the supervised learning techniques (forms of regression,
neural networks, support vector machines, etc.) were motivated by the fact that we had labelled
training data. We ventured into clustering to group unlabelled data and discussed dimensionality
reduction to handle overly high-dimensional data. In the previous chapter, we examined techniques
for managing incomplete data with latent variable models. In this chapter we turn to a technique
for handling temporal data.

10.1 Motivation
One major type of data we have not yet paid explicit attention to is time series data. Most of the
information we record comes with some sort of a timestamp. For example, any time we take an
action online, there is a high probability that the database storing the data also tracks it with a
timestamp. Physical sensors in the real world always record timestamps because it would be very
difficult to make sense of their information if it is not indexed by time. When we undergo medical
exams, the results are recorded along with a timestamp. It’s almost inconceivable at this point
that we would record information without also keeping track of when that data was generated, or
at the very least when we saved that data.
For these reasons, its interesting to develop models that are specialized to temporal data.
Certainly, time encodes a lot of information that we take for granted about the physical and digital
worlds. For example, if the sensors on a plane record the position of the plane at a specific point
in time, we would expect the surrounding data points to be relatively similar, or at least move in
a consistent direction. In a more general sense, we expect that time constrains other attributes of
the data in specific ways.
In this chapter, we will focus on one such model known as a Hidden Markov Model or
HMM. At a high level, the goal of an HMM is to model the state of an entity over time, with the
caveat that we never actually observe the state itself. Instead, we observe a data point xt at each
time step (often called an ‘emission’ or ‘observation’) that depends on the state st . For example,
we could model the position of a robot over time given a noisy estimation of the robot’s current
position at each time step. Furthermore, we will assume that one state st transitions to the next
state st+1 according to a probabilistic model. Graphically, an HMM looks like Figure 10.1, which
encodes the relationships between emissions and hidden states. Here, there are n time steps in
total.
We will probe HMMs in more detail over the course of the chapter, but for now let’s consider

131
132 CHAPTER 10. HIDDEN MARKOV MODELS

Figure 10.1: The directed graphical model for an HMM.

their high-level properties.


ML Framework Cube: Hidden Markov Models
First, HMMs handle discrete states, and for the purpose of this text, we will only consider discrete-
valued emissions as well. Second, since a HMM does not have access to the states, these are
hidden (or latent) and in this sense we can view training as “unsupervised.” Finally, we assume a
probabilistic relationship between hidden states and observed emissions as well as the transitions
between hidden states.
Domain Training Probabilistic
Discrete Unsupervised Yes

⋆ Models that treat continuous state variables are commonly referred to as dynamical systems.

10.2 Applications
Unsurprisingly, there are many applications for models like HMMs that explicitly account for time
and unobserved states, especially those that relate to the physical world. Examples include:

1. The position of a robot arm when its movements may be non-deterministic and sensor readings
are noisy. [State = robot arm position; observation = sensor reading]

2. Speech recognition. [State = phoneme; observation = sound]

3. Analyzing sequences that occur in the natural world, such as DNA [State = codon, a genetic
code in a DNA molecule; observation= one of the four bases, i.e., A, C, T, or G]
10.3. HMM DATA, MODEL, AND PARAMETERIZATION 133

10.3 HMM Data, Model, and Parameterization


As explained above, HMMs model the state of an entity over time given some noisy observations,
as shown in Figure 10.1.

10.3.1 HMM Data


The data for a HMM consists of the sequence of one-hot encoded emissions x1 , ..., xn , for n total
time steps. There corresponding states, s1 , ..., sn , are latent and unobserved.
Each state corresponds to one of K possible options, i.e., with one-hot coding, st ∈ {0, 1}K .
Each emission corresponds to one of M possible options, with one-hot coding, xt ∈ {0, 1}M .

⋆ In general the observed emissions don’t have to be discrete, but for the sake of being explicit, we present the discrete
interpretation here.

A data set has N data points, meaning N sequences where each sequence is composed of n
emissions (in general they can be of different lengths, we assume same length for simplicity). To
summarize:
• A data set consists of N sequences.
• Each sequence is composed of n observed emissions x1 , ..., xn .
• Each emission xt takes on one of M possible values.
• Each hidden state st take on one of K possible values.

10.3.2 HMM Model Assumptions


In modeling the joint distribution over hidden states and observed emissions
p(s1 , ..., sn , x1 , ..., xn ) = p(s1 , ..., sn )p(x1 , ..., xn |s1 , ..., sn ), (10.1)
the HMM model makes the following two assumptions:
1. State st+1 depends only on the previous state st :
p(st+1 |s1 , ..., st , x1 , ..., xt ) = p(st+1 |st )
This is the Markov Property, and it means that given knowledge of the state at the previous
time step, we can ignore all other earlier states and emissions. Here, we also assume the
transition is stationary, so that the transition model doesn’t depend on time.
2. At each time step t, the observed emission xt depends only on the current state st :
p(xt |s1 , ..., st , x1 , ..., xt−1 ) = p(xt |st )

⋆ The Markovian assumption for transitions, as well as the fact that we don’t observe the true states, gives rise to
the Hidden Markov Model name.

These two assumptions allow us to factorize the large joint distribution given by Equation 10.1
as follows:
n−1
Y n
Y
p(s1 , ..., sn )p(x1 , ..., xn |s1 , ..., sn ) = p(s1 ) p(st+1 |st ) p(xt |st ) (10.2)
t=1 t=1
This factorization will prove important for making HMM training and inference tractable.
134 CHAPTER 10. HIDDEN MARKOV MODELS

10.3.3 HMM Parameterization


Now that we understand the form of the data as well as the modeling assumptions made by a HMM,
we can specify the model parameterization explicitly. Referencing the factorized joint distribution
from Equation 10.2, we will need three distinct sets of parameters.

1. Parameters
P for the prior over the initial hidden state p(s1 ). This will be denoted θ ∈ [0, 1]K ,
with k θk = 1, such that:

p(s1 = k) = θk .

2. Parameters for the transition probabilities between states p(st+1 |st ). This will be denoted
T ∈ [0, 1]K×K , with j Ti,j = 1 for each i, such that:
P

p(st+1 = j|st = i) = Ti,j ,

where Ti,j is the probability of transitioning from state i to state j.

3. Parameters for the conditional probability of the emission, p(xt |st ), given the state. This will
K×M
P
be denoted π ∈ [0, 1] , with m πk,m = 1 for each k, such that:

p(xt = m|s = k) = πk,m .

For each state, there is a distribution on possible emissions.

In sum, we have three sets of parameters θ ∈ [0, 1]K , T ∈ [0, 1]K×K , and π ∈ [0, 1]K×M that we
need to learn from our data set. Then, using a trained model, we will be able to perform several
types of inference over our hidden states, as detailed next.

10.4 Inference in HMMs


We will see how to estimate the parameters of a HMM (“training”) in Section 10.5. For now we
want to understand how to do efficient inference in HMMs. Given parameters θ, T, and π, and
given a sequence of observations, x1 , ..., xn (or x1 , ..., xt ) there are various queries we may like to
perform:

• “p(seq)” p(x1 , . . . , xn ) (what is the distribution on sequences of emissions?)

• Prediction p(xt+1 |x1 , . . . , xt ) (what is the prediction of the next emission given what is known
so far?)

• Smoothing p(st |x1 , . . . , xn ), t ≤ n (after the fact, what do we predict for some earlier state?)

• Transition p(st , st+1 |x1 , . . . , xn ), t + 1 ≤ n (after the fact, what do we predict for the joint
distribution on some pair of temporally adjacent states?)

• Filtering p(st |x1 , . . . , xt ) (what is the prediction, in real-time, of the current state?)

• Best path max p(s1 , . . . , sn |x1 , . . . , xn ) (after the fact, what is the most likely sequence of
states?)
10.4. INFERENCE IN HMMS 135

For just one example, let’s consider smoothing p(st |x1 , ..., xn ). To compute this would require
marginalizing over all the unobserved states other than st , as follows:

p(st |x1 , ..., xn ) ∝ p(st , x1 , ..., xn )


X
= p(s1 , . . . , sn , x1 , . . . , xn )
s1 ,...,st−1 ,st+1 ,...,sn

X n−1
Y n
Y
= p(s1 ) p(st+1 |st ) p(xt |st ).
s1 ,...,st−1 ,st+1 ,...,sn t=1 t=1

Without making use of variable elimination, this requires summing over all possible states other
than t, which is very costly. Moreover, suppose we then query this for another state. We’d need
to sum again over all the states except for this new state, which duplicates a lot of work. Rather
than performing these summations over and over again, we can instead “memoize” (or reuse) these
kinds of summations using the Forward-Backward algorithm. This algorithm also makes uses of
variable elimination to improve the efficiency of inference.

10.4.1 The Forward-Backward Algorithm

The Forward-Backward algorithm uses variable elimination methods to compute two sets of quanti-
ties, that we refer to as the “alpha” and “beta” values. The algorithm makes elegant use of dynamic
programming (breaking down optimization problem into sub-problems, solving each sub-problem a
single time, and storing the solutions). It can be viewed as a preliminary inference step such that
the alpha and beta values can then be used for all inference tasks of interest as well as within EM
for training a HMM model.
The Forward-Backward algorithm is also an example of a message-passing scheme, which means
we can conceptualize it as passing around compact messages along edges of the graphical model that
corresponds to a HMM. The algorithm passes messages forwards and backwards through ‘time’,
meaning up and down the chain shown in the graphical model representation in Figure 10.1. The
forward messages are defined at each state as αt (st ), while the backward messages are defined at
each state as βt (st ). The overarching idea is to factor the joint distribution

p(x1 , . . . , xn , st ) = p(x1 , . . . , xt , st )p(xt+1 , . . . , xn |st )

because the factored terms can be efficiently computed. Let’s define these α and β terms explicitly.
The αt ’s represent the joint probability of all our observed emissions from time 1, ..., t as well
as the state at time t:

αt (st ) = p(x1 , ..., xt , st ) (10.3)

Graphically, the αt ’s are capturing the portion of the HMM shown in Figure 10.2.
We can factorize this joint probability using what we know about the conditional independence
136 CHAPTER 10. HIDDEN MARKOV MODELS

Figure 10.2: αt ’s capture the joint probability for the boxed portion; shown for α3 (s3 )

properties of HMMs as follows:

αt (st ) = p(x1 , ..., xt , st )


= p(xt |x1 , . . . , xt−1 , st )p(x1 , . . . , xt−1 , st )
X
= p(xt |st ) p(x1 , . . . , xt−1 , st−1 , st ) (10.4)
st−1
X
= p(xt |st ) p(st |x1 , . . . , xt−1 , st−1 )p(x1 , . . . , xt−1 , st−1 ) (10.5)
st−1
X
= p(xt |st ) p(st |st−1 )αt−1 (st−1 ) (10.6)
st−1

The first term in Equation (10.4) follows from the Markov property, and for the second term
we’ve expressed this joint probability by explicitly introducing st−1 and marginalizing out over this
variable. Equation 10.6 follows from the Markov property, and by substituting for the definition of
the alpha value.
Notice that our expression for αt (st ) includes the expression for αt−1 (st−1 ), which is the α from
the previous time step. This means we can define our messages recursively. After we’ve computed
the α’s at one time step, we pass them forwards along the chain and use them in the computation
of alpha values for the next time step. In other words, we compute the α values for period 1, then
pass that message along to compute the α values in period 2, and so forth until we reach the end
of the chain and have all the α’s in hand.

⋆ These α values are used both for inference and training a HMM via EM (in the E-step)

At this point, we’ve handled the forward messages, which send information from the beginning
to the end of the chain. In the backward portion, we also send information from the end of the
chain back to the beginning. In this backward message pass, we will compute our β values. The
βt ’s represent the joint probability over all the observed emissions from time t + 1, ..., n conditioned
on the state at time t:
βt (st ) = p(xt+1 , ..., xn |st ) (10.7)

Graphically, this means that the βt ’s are capturing the portion of the HMM shown in Figure 10.3.
We can factorize Equation 10.7 in a similar way to how we factorized the distribution described
10.4. INFERENCE IN HMMS 137

by the α’s:

βt (st ) = p(xt+1 , ..., xn |st )


X
= p(xt+1 , ..., xn , st+1 |st ) (10.8)
st+1
X
= p(st+1 |st )p(xt+1 |st , st+1 )p(xt+2 , ..., xn |xt+1 , st , st+1 ) (10.9)
st+1
X
= p(st+1 |st )p(xt+1 |st+1 )p(xt+2 , ..., xn |st+1 ) (10.10)
st+1
X
= p(st+1 |st )p(xt+1 |st+1 )βt+1 (st+1 ). (10.11)
st+1

Here, Equation (10.8) introduces st+1 and marginalizes out over this variable. Equation (10.9)
is the product rule (recall that n ≥ t + 1, so the third conditional probability starts at xt+2 ).
Equation (10.10) makes use of the Markov property in the last two terms. Equation (10.11)
substitutes in the expression for βt+1 from Equation (10.7).
As we saw with our calculation of the α’s, we can calculate β recursively. This recursive
definition enables us to propagate messages backward and compute the beta values efficiently in
one pass. In this case, we start at the end of the chain (t = n), and compute our β’s for each state
by passing messages back toward the front.
To summarize, the Forward-Backward algorithm calculates the α and β values as follows:
( P
p(xt |st ) st−1 p(st |st−1 )αt−1 (st−1 ) 1<t≤n
αt (st ) =
p(x1 |s1 )p(s1 ) otherwise
(P
st+1 p(st+1 |st )p(xt+1 |st+1 )βt+1 (st+1 ) 1 ≤ t < n
βt (st ) =
1 otherwise

⋆ Notice that the base case for the β’s is n. This is a quirk of our indexing, and it ensures we have a defined sn when
we pass messages back to calculate sn−1 , sn−2 , . . . .

Figure 10.3: βt ’s capture the joint probability for the boxed portion of the HMM. Shown here for
β2 (s2 )
138 CHAPTER 10. HIDDEN MARKOV MODELS

10.4.2 Using α’s and β’s for Training and Inference


Now that we know how to compute these α and β values, let’s see how to use them for inference.
Consider the product of the α and β value at a specific time t:

αt (st )βt (st ) = p(x1 , ..., xt , st )p(xt+1 , ..., xn |st ) = p(x1 , ..., xn , st ).

This is the joint distribution over all emissions and the state at time t. Using this as a building
block, this can support many kinds of inference.

p(Seq)

For example, we might like to evaluate the joint distribution over the emissions.
X X
p(x1 , ..., xn ) = p(x1 , ..., xn , st ) = αt (st )βt (st ) (10.12)
st st

where we can sum over the possible state values. This calculation be defined for any state st .

Prediction

Another common task is to predict the value of the next emission given the previous emissions.

p(xt+1 |x1 , ..., xt )

To compute this we can sum over state st and the next state st+1 as follows:

p(xt+1 |x1 , ..., xt ) ∝ p(x1 , ..., xt , xt+1 )


XX
= p(x1 , ..., xt , xt+1 , st , st+1 ) (10.13)
st st+1
XX
= p(x1 , ..., xt , st )p(st+1 |x1 , ..., xt , st )p(xt+1 |x1 , ..., xt , st , st+1 ) (10.14)
st st+1
XX
= αt (st )p(st+1 |st )p(xt+1 |st+1 ). (10.15)
st st+1

Here, Equation (10.13) follows by introducing states st and st+1 and marginalizing out over
them. Equation (10.14) follows from the product rule, and Equation (10.15) by using the Markov
property in two places and substituting for αt (st ).

Smoothing

Smoothing is the problem of predicting the state at time t given all the observed emissions. We can
think about this as updating the beliefs that we would have had in real-time, given emissions up
to and including t, given all observed evidence up to period n. Hence the phrasing “smoothing.”
For this, we have
p(st |x1 , ..., xn ) ∝ p(x1 , ..., xn , st ) = αt (st )βt (st ). (10.16)
10.4. INFERENCE IN HMMS 139

Transition
Finally, we may wish to understand the joint distribution on states st and st+1 given all the observed
evidence.

p(st , st+1 |x1 , ..., xn ) ∝ p(st , st+1 , x1 , ..., xn )


= p(x1 , ..., xt , st )p(st+1 |x1 , ..., xt , st )p(xt+1 |x1 , ..., xt , st , st+1 )
p(xt+2 , ..., xn |x1 , ..., xt+1 , st , st+1 ) (10.17)
= αt (st )p(st+1 |st )p(xt+1 |st+1 )βt+1 (st+1 ). (10.18)

Here, Equation (10.17) follows from the product rule, and (10.18) by substituting for αt (st ),
applying the Markov property three times, and substituting for βt+1 (st+1 ).

Filtering
For filtering, we have

p(st |x1 , ..., xt ) ∝ p(st , x1 , ..., xt ) = αt (st ). (10.19)

Best path
For the best path problem, we want to find the sequence of states that is most likely to give rise to
the observed emissions. We solve

arg max p(s|x) = arg max p(s1 , . . . , sn , x1 , . . . , xn ).


s1 ,...,sn s1 ,...,sn

This sometimes referred to as the “decoding” (or explanation) problem. For this, we can define
the following function:

γt (st ) = max p(s1 , . . . , st , x1 , . . . , xt ). (10.20)


s1 ,...,st−1

This is the likelihood of x1 , . . . , xt , if the current state is st , and under the best explanation so
far (we maximized over s1 , . . . , st−1 ). Recall that the recurrence for alpha is as follows:
 P
p(xt |st ) st−1 p(st |st−1 )αt−1 (st−1 ) if 1 < t ≤ n
∀st : αt (st ) =
p(x1 |s1 )p(s1 ) otherwise

Analogously, the recurrence for this γ-value can be shown to be (see if you can derive this for
yourself):

p(xt |st ) maxst−1 p(st |st−1 )γt−1 (st−1 ) if 1 < t ≤ n
∀st : γt (st ) = (10.21)
p(x1 |s1 )p(s1 ) otherwise

To be able to find the optimal sequence, we also store, for each st , the best choice of st−1 :

zt∗ (st ) = arg max [p(st |st−1 )γt−1 (st−1 )] (10.22)


st−1

This recursive procedure is known as the Viterbi algorithm and provides an efficient way to
infer the “best path” through states given a sequence of observations.
140 CHAPTER 10. HIDDEN MARKOV MODELS

10.5 Using EM to Train a HMM


Let’s now turn to using EM to train a HMM. Recall the motivation for the Expectation-Maximization
algorithm from the previous chapter: we had parameters we wished to optimize, but the presence
of unobserved variables made direct optimization of those parameters intractable. We’re faced with
a similar problem in the context of HMMs.
Given a data set of observed emissions {xi }N i
i=1 where each data point x represents a sequence
i i
(x1 , ..., xn ), our goal is to estimate the parameters θ, T, and π (refer back to 10.3.3 if you forget
the meanings of these parameters).
With knowledge of the hidden states, this would be a relatively straightforward problem of
MLE for the parameters. However, the states are latent variables, and for this reason we use the
EM algorithm.
This amounts to computing the probability on hidden states in the E-step, and then based on
these probabilities, we update our parameters by maximizing the expected complete-data likelihood
in the M-step. As usual, we perform these E and M steps iteratively until convergence. We will
make use of the Forward-Backward algorithm for the inference on probabilities of the latent states
in the E-Step.
We initialize the parameters in EM arbitrarily.

10.5.1 E-Step
For the E-Step, we take the parameters θ, T, π as fixed. For each data point xi , we run Forward-
Backward with these parameters to get the α and β values for this data point.
The hidden variables are the states s1 , ..., sn . For each data point xi = (xi1 , ..., xin ), we are
interested in computing the predicted probabilities qi ∈ [0, 1]n×K , for K possible hidden state
values. The qi represent the predicted probability of each hidden state value for each time period.
In particular, we have
i
qt,k = p(sit = k | xi1 , ..., xin ). (10.23)

This is the probability that state sit takes on value k given the data point xi . This is the
smoothing operation described in the previous section and we can use Equation 10.16 to compute
our qi values.
Ordinarily, we’d be done with the E-step after computing the marginal probability of each latent
variable. But in this case we will also want to estimate the transition probabilities between hidden
states, i.e., parameter matrix T. For this, we also need to calculate the joint distribution between
temporally-adjacent pairs of latent variables (e.g. st , st+1 ). For data point xi , and for periods t
and t + 1, we denote this as Qit,t+1 ∈ [0, 1]K×K , where the entries in this matrix sum to 1. This
represents the distribution on pairs of states in periods t and t + 1 for this data point. We write
Qi to denote the corresponding values for all pairs of time periods.
To see how the entries are calculated, we can use Qit,t+1,k,ℓ to denote the transition from state
k at time step t to state ℓ at time step t + 1,

Qit,t+1,k,ℓ = p(sit = k, sit+1 = ℓ | xi1 , ..., xin ). (10.24)

This is exactly the transition inference problem that we described in the previous section.
Because of this, we can directly use our α and β values in the transition operation, as given by
Equation (10.18).
With our qi and Qi values for each data point, we are ready to move on to the maximization
step.
10.6. CONCLUSION 141

10.5.2 M-Step
We now solve for the expected complete-data log likelihood problem, making use of the qi and Qi
values from the E-step.
Given knowledge of states, the complete-data likelihood for one data point with observations x
and states s is
n−1
Y n
Y
p(x, s) = p(s1 ; θ) p(st+1 |st ; T) p(xt |st ; π).
t=1 t=1
With one-hot encoding of xt and st , and taking the log, the expression becomes
K
X n−1
XX K X
K n X
X K M
X
ln[p(x, s)] = s1k ln θk + st,k st+1,ℓ ln Tk,ℓ + st,k xt,m ln(πk,m ).
k=1 t=1 k=1 ℓ=1 t=1 k=1 m=1
To see this, recall that productorials become summations when we take logarithms. The un-
bolded symbols represent single entries of the discrete probability distributions, and we create
additional summation indices k, ℓ over K possible hidden state values for st , st+1 (K is the di-
mension of the one-hot encoding of s). We also create an additional summation index m over M
possible emission values for each xt .
From this, we would be able to solve for the MLE for the parameters for the complete-data log
likelihood. Now, the states are latent variables, and we need to work with the expected complete-data
log likelihood, which for a single data point xi , is
K
X n−1
XX K X
K n X
X K M
X
Esi [ln(p(xi , si ))] = i
q1k ln θk + Qit,t+1,k,ℓ ln Tk,l + i
qt,k xit,m ln πk,m .
k=1 t=1 k=1 ℓ=1 t=1 k=1 m=1
(10.25)
Applying the appropriate Lagrange multipliers, and maximizing with respect to each of the
parameters of interest, we can obtain the following update equations for each of the parameters
(for N data points):
PN i
i=1 q1,k
θk = , for all states k (10.26)
N
PN Pn−1 i
i=1 t=1 Qt,t+1,k,l
Tk,l = PN Pn−1 i , for all states k, ℓ (10.27)
i=1 t=1 qt,k
PN Pn i i
i=1 t=1 qt,k xt,m
πk,m = PN Pn , for all states k, observations m (10.28)
i
i=1 t=1 qt,k
After updating our parameter matrices θ, T, and π, we switch back to the E-step, continuing
in this way until convergence. As with other uses of EM, it provides only a local optimum and it
can be useful to try a few different random initializations.

10.6 Conclusion
The Hidden Markov Model is a type of latent variable model motivated by the combination of time
series and discrete observations and states. We relied on the Expectation-Maximization algorithm
to train a HMM, and developed the Forward-Backward algorithm to make both inference and
training (the E-step) computationally efficient. Many of the ideas developed in this chapter will
offer good intuition for how to develop learning and inference methods for dynamical systems and
other time series models.
Chapter 11

Markov Decision Processes

In the previous chapter we learned about Hidden Markov Models (HMMs) in which we modeled
our underlying environment as being a Markov chain where each state was hidden, but produced
certain observations that we could use to infer the state. The process of learning in this setting
included finding the distribution over initial hidden states, the transition probabilities between
states, and the probabilities of each state producing each measurement. The movement of the
underlying Markov chain from one state to the next was a completely autonomous process and
questions of interest focused mainly on inference.
Markov Decision Processes (MDPs) introduce somewhat of a paradigm shift. Similar to HMMs,
in the MDP setting we model our underlying environment as a Markov chain. However, our
environment doesn’t transition from state to state autonomously like in an HMM, but rather
requires that an agent in the environment takes some action. Thus, the probability of transitioning
to some state at time t + 1 depends on both the state and the action taken at time t. Furthermore,
after taking some action the agent receives a reward which you can think of as characterizing the
“goodness” of performing a certain action in a certain state. Thus, our overall objective in the
MDP setting is different. Rather than trying to perform inference, we rather would like to find a
sequence of actions that maximizes the agent’s total reward.
While MDPs can have hidden states – these are called partially-observable MDPs (POMDPs)
– for the purposes of this course we will only consider MDPs where the agent’s state is known. If,
like in HMMs, we do not have full knowledge of our environment and thus do not know transition
probabilities between states or the rewards associated with each state and action we can try to
find the optimal sequence of actions using reinforcement learning —a subfield of machine learning,
distinct from supervised and unsupervised learning, that is characterized by problems in which an
agent seeks to explore its environment, and simultaneously use that knowledge to perform actions
that exploit its environment and obtain rewards. This chapter discusses how we find the optimal
set of actions given that we do have full knowledge of the environment.

ML Framework Cube: Markov Decision Processes

Domain Training Probabilistic


Discrete or Continuous ? Yes

Note that the question mark under the “training” category means that learning in MDPs is neither
supervised nor unsupervised but rather falls under the umbrella of reinforcement learning - an
entirely separate branch of ML.

142
11.1. FORMAL DEFINITION OF AN MDP 143

11.1 Formal Definition of an MDP


In an MDP, we model our underlying environment as containing any one of a set of states S. At
each state, an agent can take any action from the set A which gives them a reward according to
the function r : S × A → [0, 1], which we will often see written as r(s, a). For the purposes of this
class, we assume S and A are discrete and finite, i.e. contain finitely many different elements, and
also assume that rewards are deterministic, i.e. r(s, a) is fixed for any s and a.
Given a specified state and action (state-action pair ), the environment will transition into a
new state according to a transition function p : S × A → ∆S which we will often seen written as
p(s′ |s, a). That is, from our current state-action pair (s, a), for all s′ ∈ S, the environment will
transition to the new state s′ with probability p(s′ |s, a). If we fix the action taken, our model
behaves like a Markov chain.
We let the action we take at each state be given by a policy π : S → A which we will often
see written as π(a|s). Our goal is to find some optimal policy π ∗ that maximizes total rewards.
Interacting with the environment according to some policy produces a dataset

D = {s0 , a0 , r0 , s1 , a1 , r1 , . . . }

where the subscripts indicate timesteps.

Figure 11.1: A graphical representation of an MDP.

Note that since our state transitions are modeled as a Markov chain, the Markov assumption
holds:

pt (st+1 |s1 , · · · st , a1 , · · · at ) = pt (st+1 |st , at )

In other words, the probability that the environment transitions to state st+1 at time t only depends
on the state the agent was in and the action the agent took at time t. Furthermore, we can assume
that transition probabilities are stationary:

pt (st+1 |st , at ) = p(st+1 |st , at )

In other words, the probability of transitioning from one state to another does not depend on the
current timestep (note how we dropped the subscript t in the expression on the RHS).
We call the process of finding the optimal policy in an MDP given full knowledge of our envi-
ronment planning (when we don’t have prior knowledge about our environment but know that it
behaves according to some MDP with unknown parameters we turn to Reinforcement Learning -
see Chapter 12). Our approach to planning changes depending on whether there is a limit to the
number of timesteps the agent may act for, which we call a finite horizon, or if they may act forever
(or effectively forever), which we call an infinite horizon.
144 CHAPTER 11. MARKOV DECISION PROCESSES

11.2 Finite Horizon Planning


In a finite horizon MDP, our objective function is as follows:
" T #
X
max Es∼p rt |π
π
t=0

Note that the only source of randomness in the system is the transition function.

We find the optimal policy for a finite horizon MDP using dynamic programming. To formalize
this, let’s define the optimal value function V(t)∗ (s) to be the highest value achievable in state s

(i.e. acting under the optimal policy π ∗ ) with t timesteps remaining for the agent to act. Then, we
know

V(1) (s) = max[r(s, a)]
a

since with only one timestep left to act, the best the agent can do is take the action that maximizes
their immediate reward. We can define the optimal value function for t > 1 recursively as follows:
X

V(t+1) (s) = max[r(s, a) + p(s′ |s, a)V(t)

(s′ )]
a
s′ ∈S

In other words, with more than one timestep to go we take the action that maximizes not only
our immediate reward but also our expected future reward. This formulation makes use of the
principle of optimality or the Bellman consistency equation which states that an optimal policy
consists of taking an optimal first action and then following the optimal policy from the successive
state (these concepts will be discussed further in section 11.3.1). Consequently, we have a different
optimal policy at each timestep where

π(1) (s) = arg max[r(s, a)]
a
X

π(t+1) (s) = arg max[r(s, a) + p(s′ |s, a)V(t)

(s′ )]
a
s′ ∈S

The computational complexity of this approach is O(|S|2 |A|T ) since for each state and each
timestep we find the value-maximizing action which involves calculating the expected future reward
requiring a summation over the value function for all states.

11.3 Infinite Horizon Planning


If the agent is able to act forever, a dynamic programming solution is no longer feasible as there is
no base case (there will never only be 1 timestep remaining). Furthermore, we can’t use the same
objective function that we did when working with finite horizon MDPs since our expected reward
could be infinite! Thus, we modify the objective function as follows:
"∞ #
X
max Es∼p γ t rt
π
t=0

where γ ∈ [0, 1] is called the discount factor. Multiplying the reward at time t in the infinite sum
by γ t makes the sum resolve to a finite number, since our rewards are bounded in [0, 1]. As γ → 1,
11.3. INFINITE HORIZON PLANNING 145

rewards further in the future have more of an impact on the expected total reward and thus the
optimal policy will be more “patient”, acting in a way to achieve high rewards in the future rather
than just maximizing short-term rewards. Conversely, as γ → 0, rewards further in the future will
have less of an impact on the expected total reward and thus the optimal policy will be less patient,
preferring to maximize short-term rewards.
The effective time horizon of an infinite horizon MDP is the number of timesteps after which
t
γ becomes so small that rewards after time t are negligible. Using the formula for the sum of a
1
geometric series we find that this is approximately 1−γ . This yields our first approach to planning
1
in an infinite horizon MDP which is to convert it into a finite horizon MDP where T = 1−γ .
However, note that this fraction could be arbitrarily large and that the time complexity of the
dynamic programming solution is linear with respect to T . This implores us to consider alternate
solutions.

11.3.1 Value iteration


Define the value function of an infinite time horizon MDP under the policy π as the expected
discounted value from policy π starting at state s:
"∞ #
X
V π (s) = Es∼p γ t r(st , π(st )|s0 = s, π
t=0


Then, policy π is weakly better than policy π ′ if V π (s) ≥ V π (s) ∀s. The optimal policy π ∗ is that
which is weakly better than all other policies.

Consequently, the optimal value function is the value function following policy π ∗ :

V ∗ (s) = max V π (s)


π

Bellman Consistency Equation and Bellman Optimality


The value iteration algorithm, which we will describe soon, is based on an alternate expression
for the value function. The Bellman consistency equation gives us an alternate expression for
V π (s):

V π (s) = r(s, π(s)) + γEs′ ∼p [V π (s)]

Proof: We begin by expanding our initial expression for V π (s):


"∞ #
X
π t
V (s) = Es∼p γ r(st , π(st )|s0 = s, π
t=0
= r(s, π(s)) + γEs∼p [r(s1 , a1 ) + γr(s2 , a2 ) + · · · |s0 = s, π]

We now apply Adam’s law (dropping the subscripts on the expectations to avoid notational clutter):

= r(s, π(s)) + γE[E[r(s1 , a1 ) + γr(s2 , a2 ) + · · · |s0 = s, s1 = s′ , π]]

By the Markov property:

= r(s, π(s)) + γE[E[r(s1 , a1 ) + γr(s2 , a2 ) + · · · |s1 = s′ , π]]


146 CHAPTER 11. MARKOV DECISION PROCESSES

By definition, this is equal to r(s, π(s)) + γEs′ ∼p [V π (s′ )]. □

The Bellman optimality conditions are a set of two theorems that tell us important proper-
ties about the optimal value function V ∗ .

Theorem 1: V ∗ satisfies the following Bellman consistency equations:


V ∗ (s) = max[r(s, a) + γEs′ ∼p [V ∗ (s′ )]] ∀s
a

This theorem also tells us that if π̂(s) = arg maxa [r(s, a) + γEs′ ∼p [V ∗ (s′ )]] then π̂ is the optimal
policy.
Proof: Let π̂(s) = arg maxa [r(s, a) + γEs′ ∼p [V ∗ (s′ )]]. To prove the claim it suffices to show that
V ∗ (s) ≤ V π̂ (s) ∀s. By the Bellman consistency equation we have:
V ∗ (s) = r(s, π ∗ (s)) + γEs′ ∼p [V ∗ (s′ )]
≤ max[r(s, a) + γEs′ ∼p [V ∗ (s′ )]]
a
= r(s, π̂(s)) + γEs′ ∼p [V ∗ (s′ )]
Proceeding recursively:
≤ r(s, π̂(s)) + γEs′ ∼p [r(s′ , π̂(s′ )) + γEs′′ ∼p [V ∗ (s′′ )]]
···
≤ Es,s′ ,···∼p [r(s, π̂(s)) + γr(s′ , π̂(s′ ) + · · · |π̂]
= V π̂ (s)

Theorem 2: For any value function V , if V (s) = maxa [r(s, a) + γEs′ ∼p [V (s′ )]] ∀s then V = V ∗ .
Proof: First, we define the maximal component distance between V ad V ∗ as follows:
||V − V ∗ ||∞ = max |V (s) − V ∗ (s)|
s

For V which satisfies the Bellman optimality equations, if we could show that ||V − V ∗ ||∞ ≤
γ||V − V ∗ ||∞ then the proof would be complete since it would imply:
||V − V ∗ ||∞ ≤ γ||V − V ∗ ||∞ ≤ γ 2 ||V − V ∗ ||∞ ≤ · · · ≤ lim γ k ||V − V ∗ ||∞ = 0
k→∞

Thus we have:
|V (s) − V ∗ (s)| = |maxa [r(s, a) + γEs′ ∼p [V (s′ )]] − max[r(s, a) + γEs′ ∼p [V ∗ (s′ )]]|
a
≤ max |r(s, a) + γEs′ ∼p [V (s′ )] − r(s, a) − γEs′ ∼p [V ∗ (s′ )]|
a
= γ max |Es′ ∼p [V (s′ )] − Es′ ∼p [V ∗ (s′ )]|
a
≤ max Es′ ∼p [|V (s′ ) − V ∗ (s′ )]
a
≤ γ max

|V (s′ ) − V ∗ (s′ )|
a,s

= γ max

|V (s′ ) − V ∗ (s′ )|
s
= γ||V − V ∗ ||∞
11.3. INFINITE HORIZON PLANNING 147

They key takeaways from the Bellman optimality conditions are twofold. Theorem 1 gives us
a nice expression for the optimal value function that we will soon show allows us to iteratively
improve an arbitrary value function towards optimality. It also tells us what the corresponding
optimal policy is. Theorem 2 tells us that the optimal value function is unique so if we find some
value function that satisfies the Bellman consistency equations given in Theorem 1 then it must be
optimal.

Bellman Operator
In the previous section we mentioned that the formulation of the optimal value function we get
from Theorem 1 will allow us to improve some arbitrary value function iteratively. In this section
we will show why this is the case. Define the Bellman operator B : R|S| → R|S| to be a function
that takes as input a value function and outputs another value function as follows:

B(V (s)) := max[r(s, a) + γEs∼p [V (s′ )]]


a

By the Bellman optimality conditions we know that B(V ∗ ) = V ∗ and that this identity property
doesn’t hold for any other value function. The optimal value function V ∗ is thus the unique fix-point
of the function B which means that passing V ∗ into B always returns V ∗ .

Furthermore, we know that B is a contraction mapping which means that ||B(V ) − B(V ′ )||∞ ≤
γ||V − V ′ ||∞ .
Proof:

|B(V (s)) − B(V ′ (s))| = | max[r(s, a) + γEs′ ∼p [V (s′ )] − max[r(s, a) + γEs′ ∼p [V ′ (s′ )]]]
a a
≤ max |r(s, a) + γEs′ ∼p [V (s′ ) − r(s, a) − γEs′ ∼p V ′ (s′ )]|
a
= γ max |Es′ ∼p [V (s′ ) − V ′ (s′ )]|
a
≤ γ max[Es′ ∼p [|V (s′ ) − V ′ (s′ )|]]
a
≤ γ max |V (s) − V ′ (s)|
a,s

= γ||V − V ′ ||∞

Value Iteration Algorithm


From the previous section we learned that applying the Bellman operator to an arbitrary value
function repeatedly will get it closer and closer to the optimal value function. Once we see that
applying the Bellman operator no longer changes the input value function, we know we have found
the optimal value function. This is the basis for the value iteration algorithm which works as
follows:
1. Initialize V (s) = 0 ∀s
2. Repeat until convergence:P
V ′ (s) ← maxa [r(s, a) + γ s′ ∈S p(s′ |s, a)V (s′ )] ∀s
V (s) ← V ′ (s) ∀s
148 CHAPTER 11. MARKOV DECISION PROCESSES

It is a theorem that value iteration converges to V ∗ asymptotically. We can also extract the
optimal policy from the resulting value function asymptotically. Each iteration of value iteration
takes O(|S|2 |A|) time.

11.3.2 Policy Iteration


An alternate approach to finding the optimal policy in an infinite horizon MDP uses the same
contraction properties of the Bellman operator but iterates on the policy directly rather than on
the value function. The algorithm is as follows:
1. Initialize some arbitrary policy π
2. Repeat until convergence:
(a) Policy evaluation: Calculate V π (s) ∀s
(b) Policy improvement:
π ′ (s) ← arg maxa [r(s, a) + γ s′ ∈S p(s′ |s, a)V π (s′ )] ∀s
P
π(s) ← π ′ (s) ∀s
Here, convergence means that the improvement step does not modify the policy from the previous
iteration. It is a theorem that policy iteration converges to the optimal policy in a finite number
of steps and also yields the optimal value function.

Policy Evaluation
In the previous section, we saw that one of the steps in the policy iteration algorithm was to
calculate the value function for each state following a policy π. There are two ways to do this:
• Exact policy evaluation:
We know that our value function must satisfy the Bellman consistency equations:
X
V π (s) = r(s, π(s)) + γ p(s′ |s, π(s))V π (s′ ) ∀s
s∈S

This is a system of |S| linear equations that has a unique solution. Rearranging terms,
representing functions as matrices and vectors, and replacing sums with matrix multiplication
we get that:
Vπ = (I − γPπ )−1 Rπ
where Vπ is an |S| × 1 vector, I is the identity matrix, Pπ is an |S| × |S| matrix where
Pπs,s′ = p(s′ |s, π(s))), and Rπ is an |S| × 1 vector where Rπs = r(s, π(s)). This subroutine
runs in time O(|S|3 ).
• Iterative policy evaluation:
Rather than solving the system of equations described earlier exactly, we can instead do it
iteratively. We perform the following two steps:
h i
1. Initialize V0 such that ||V0 ||∞ ∈ 0, 1−γ
1

2. Repeat until convergence: Vt+1 ← R + γPπ Vt


 2
|S|
. This algorithm runs in time Õ 1−γ .
Policy iteration requires fewer iterations than value iteration to solve for the optimal policy π ∗ , but
requires more time per iteration: O(|S|2 |A| + |S|3 ) if we use exact policy evaluation.
Chapter 12

Reinforcement Learning

12.1 Motivation
In the last chapter, we discussed the Markov Decision Process (MDP): a framework that models a
learner’s environment as a vector of states, actions, rewards, and transition probabilities. Given this
model, we can solve for an optimal (reward-maximizing) policy using either value iteration or policy
iteration. Sometimes, however, the learner doesn’t have prior knowledge of the rewards they will
get from each state or the probability distribution over states they could end up in after taking some
action from their current state. Is it still possible to learn a policy that will maximize rewards? In
this chapter, we will learn about Reinforcement Learning (RL) - a machine learning technique
that addresses this problem.
ML Framework Cube: Reinforcement Learning

Domain Training Probabilistic


Discrete or Continuous ? Yes

Note that we will only cover RL for discrete state spaces in this chapter. The question mark under
the “training” category means that RL is neither supervised nor unsupervised.

12.2 General Approaches to RL


Imagine that you are vacationing in a town that has 20 restaurants. Suppose that if you had been
able to try the food at every restaurant you would have had a definitive ranking of which was
your favorite, second favorite, third favorite and so on. Your vacation is 10 days long and your
budget allows you to eat at one restaurant per day. Your goal on your vacation is to have the best
experience, which in this case means to eat at the restaurants you like the most. The problem is
that you don’t have enough time to figure out which restaurants you like the most since you don’t
have enough time to eat at each one. In this situation, you have two key considerations to balance:

1. Time spent exploring. If you spend the right amount of time exploring new restaurants,
then there is a good chance that you will find one that you like a lot. However, spending too
much time exploring will result in an average experience overall since some restaurants will
be good and some will be bad.

149
150 CHAPTER 12. REINFORCEMENT LEARNING

2. Time spent exploiting. If you spend the right amount of time exploiting your knowledge of
what restaurants you’ve liked the most from your exploration you will have a good experience.
However, if you spend too long exploiting (and consequently less time exploring) you risk
missing out on eating at all the restaurants that you could have liked more.

This is called the exploration vs exploitation tradeoff. The more time you spend exploring
the less time you can spend exploiting, and vice versa. In order to have the best vacation experi-
ence, you need to find a balance between time spent exploring different restaurants and time spent
exploiting what you have learned by eating at that restaurants that you liked the most thus far.
A reinforcement learner’s task is similar to the vacation restaurant task described above. They
must balance time spent exploring the environment by taking actions and observing rewards and
penalties, and time spent maximizing rewards based on what they know.

We now return to the problem described in the previous section: how do we learn a policy that
will maximize rewards in an MDP where we know nothing about which rewards we will get from
each state and which state we will end up in after taking an action? One approach is to try and
learn the MDP. In other words, we can create a learner that will explore its environment for some
amount of time in order to learn the rewards associated with each state and the transition prob-
abilities associated with each action. Then, they can use an MDP planning algorithm like policy
iteration or value iteration to come up with an optimal policy and maximize their rewards. This is
called model-based learning. The main advantage of model-based learning is that it can inexpen-
sively incorporate changes in the reward structure or transition function of the environment into
the model. However, model-based learning is computationally expensive, and an incorrect model
can yield a sub-optimal policy.

Conversely, model-free learning is a family of RL techniques where a learner tries to learn


the optimal policy directly without modelling the environment. While this sort of learning is less
adaptive to changes in the environment, it is far more computationally efficient than model-based
learning.

12.3 Model-Free Learning


While there are many methods within the family of model-free learning, we will focus on three
value-based methods - SARSA (section 12.3.1), Q-learning (section 12.3.1) and Deep Q-learning
(section 12.3.2) - and one policy-based method - Policy Learning (section 12.3.3).

We will begin by discussing value-based methods. In this family of RL algorithms, the learner tries
to calculate the expected reward they will receive from a state s upon taking action a. Formally,
they are trying to learn a function that maps (s, a) to some value representing the expected reward.
This is called the Q-function and is defined as follows for some policy π:
X
Qπ (s, a) = r(s, a) + γ p(s′ |s, a)V π (s′ ) (12.1)
s′

In words, the approximate expected reward (Q value) for taking action s from state a is the actual
reward received from the environment by doing so in the current iteration plus the expectation
taken over all reachable states of the highest value achievable starting at that state times the
12.3. MODEL-FREE LEARNING 151

discount factor γ. The Q-function following the optimal policy is defined analogously:
X
Q∗ (s, a) = r(s, a) + γ p(s′ |s, a)V ∗ (s′ ) (12.2)
s′

Note, that V ∗ (s′ ) = maxa′ Q∗ (s′ , a′ ) since the highest value achievable from state s′ following policy
∗ is the Q value of taking the optimal action from s′ . Substituting this in, we get the following
Bellman Equation:
X′
Q∗ (s, a) = r(s, a) + γ p(s′ |s, a) max

Q∗ (s′ , a) (12.3)
a
s

Note that we can’t directly calculate the term γ ′s p(s′ |s, a) max′a Q∗ (s′ , a) since we don’t know
P
p(s′ |s, a). We will discuss how this is addressed by the two algorithms we will cover in the value-
based family.

12.3.1 SARSA and Q-Learning


At a high level value-based model-free reinforcement learners work by initializing the values of
Q(s, a) for all states s and all actions a in an s x a matrix. They then repeat the following two
steps until satisfactory performance is achieved:

1. Act based on Q values

2. Use s (current state), a (action), r (reward), s′ (next state), a′ (action taken from next state)
in order to update the approximation of Q(s, a)

We will refer to (s, a, r, s′ , a′ ) as an experience. Let π(s) we the action that a learner takes from
state s. One strategy for acting that attempts to balance exploration and exploitation is called
ϵ-greedy and defined as follows:
(
argmaxa Q(s, a) with probability 1 - ϵ
π(s) = (12.4)
random with probability ϵ

Here, ϵ is some number ∈ [0, 1] which controls how likely the learner is to choose a random action
as opposed to the currently known optimal action. Varying the value of ϵ changes the balance
between exploration and exploitation.
Once the learner has had an experience, they can begin to learn Q∗ . We will now describe two

Figure 12.1: The process of model-free learning

algorithms which perform this update differently for every new experience:
152 CHAPTER 12. REINFORCEMENT LEARNING

1. SARSA: Q(s, a) ← αt (s, a)[r + γQ(s′ , a′ ) − Q(s, a)] where αt (s, a) is the learning rate, a
parameter which controls how much the observation affects Q(s, a).
The expression r + γQ(s′ , a′ ) is a 1-step estimate of Q(s, a). The expression in the square
brackets above is called the temporal difference (TD) error and represents the difference
between the previous estimate of Q(s, a) and the new one. Since the action a′ that we use
for our update is the one that was recommended by the policy π (recall that a′ = π(s′ )),
SARSA is an on-policy algorithm. This means that if there was no epsilon greedy action
selection, the reinforcement learner would always act according to the policy π and SARSA
would converge to V π .

2. Q-learning: Q(s, a) ← Q(s, a) + αt (s, a)[r + γ max′a (s′ , a′ ) − Q(s, a)]


Similar to SARSA, the expression in the square brackets is the TD error. Note, that in
Q-learning the update depends on the reward-maximizing action max′a (s′ , a′ ) corresponding
to the Bellman equation Q∗ rather than the policy-recommended action. This makes it an
off-policy algorithm. In other words, Q-learning is computing V ∗ which may or may not
correspond to actions recommended by the policy π.

This procedure of learning Q by updating Q(s, a) by the TD error is called TD updating.

Convergence Conditions
Let αt (s, a) = 0 for all (s, a) that are not visited at time t. It is a theorem that Q-learning converges
to Q∗ (and hence π converges to π ∗ ) as t → ∞ as long as the following two conditions are met:


P
t αt (s, a) = ∞ ∀s, a
The sum of the learning rate over infinitely many time steps must diverge. In order for this
to happen, each state-action pair (s, a) must be visited infinitely often. Thus, we see the
importance of an ϵ-greedy learner which forces the agent to probablistically take random
actions in order to explore more of the state space.

• 2
P
t α(s, a) < ∞ ∀s, a
The sum of the square of the learning rate over infinitely many time steps must converge.
Note that in order for αt (s, a)2 < ∞, the learning rate must be iteratively reduced. For
1
example, we could set αt (s, a) = Nt (s,a) where Nt (s, a) is the number of times the learner
took action a from state s.

SARSA converges to Q∗ if the above two conditions are met and behavior is greedy in the limit
(ie. ϵ → 0 as t → ∞). One common choice is ϵt (s) = Nt1(s) where Nt (s) is the number of times the
learner visited state s. The notation ϵt (s) implies that a separate ϵ value is maintained for every
state rather than just maintaining one value of ϵ that controls learner behavior across all states
(though this is also an option).

Citation: Chapter 6 of Reinforcement Learning: An Introduction by Sutton and Barto.

12.3.2 Deep Q-Networks


Sometimes, we have scenarios where there are too many state-action pairs to learn via traditional
TD updates in a reasonable amount of time. For example, consider the game of chess where there
are mind bogglingly many possible configurations of the board. In cases like these, we can use a
universal function approximator, like a convolutional neural network, to estimate the Q-function.
12.3. MODEL-FREE LEARNING 153

The neural network takes the learner’s current state as input and outputs the approximated Q-
values for taking each possible action from that state by training parameters w. The learner’s next
action is the maximum over all outputs of the neural network.

Figure 12.2: Q-learning vs Deep Q-learning.(Image citation: https://www.mlq.ai/deep-


reinforcement-learning-q-learning/)

While there are several specific variants depending on the problem that is trying to be solved, the
general loss function that the neural network tries to minimize at iteration i is as follows:

L(wi ) = E[(r + γ max



Q(s′ , a′ ; wi−1 ) − Q(s, a; wi ))2 ] (12.5)
a

Here, (s, a, r, s′ , a′ ) and γ are defined the same as in regular Q-learning. wi are the parameters
that the neural net is training during the current iteration of the RL algorithm, and wi−1 are
the optimal parameters from the previous iteration of the algorithm. The TD error is the term
r + γ maxa′ Q(s′ , a′ ; wi−1 ), and the squared term inside the expectation is the TD target. Since
directly optimizing the loss in equation 12.5 is difficult, gradient descent, specifically stochastic
gradient descent is used (in order to avoid having to calculate the entire expected value term in
12.5).

The Atari deep Q network that solved the game of brickbreaker used a technique called experi-
ence replay to make updates to the network more stable. The experience (s, a, r, s′ ) was put into a
replay buffer which was sampled from to perform minibatch gradient descent to minimize the loss
function.

12.3.3 Policy Learning


There are a few situations in which the RL techniques described thus far don’t work well or work
relatively less well than alternatives:

1. The Q function is far more complex than the policy being learned.

2. The action space is continuous.


154 CHAPTER 12. REINFORCEMENT LEARNING

3. Expert knowledge needs to be introduced.


This presents the need for another model-free method called policy learning in which we adopt
a differentiable policy πθ (a|s) parametrized by θ, and try to learn θ directly without going through
a Q function first.
We update θ iteratively according to the following equation:
θ ← θ + αt ∇θ J(θ)
where J(θ) = E[r(h)], r(h) = t rt , and h = (s, a, r, s′ , a′ , r′ , s′′ , aQ
′′ ...). The gradient term approx-
P
π
imates the performance V . We define the likelihood µθ (h) = t πθ (at |st )p(st+1 |st , a). We can
expand ∇θ (θ) as follows:
Z Z
∇θ J(θ) = ∇θ E[r(n)] = ∇θ µθ (h)r(h)δh = r(h)∇θ µθ (h)δh (12.6)
h h

We seem to have run into an issue here. The term ∇θ µθ (h) depends on the transition probability
p(∗|s, a) by definition, but point of policy learning was to avoid having to learn these transition
probabilities. However, it turns out we can circumvent this issue by applying the useful identity
1
∇ θ µθ = µ θ ∇θ µθ = µθ ∇θ ln µθ (12.7)
µθ
Thus, we can rewrite the previous equation as follows:
Z " #
X X
r(h)µθ (h)∇θ ln πθ (at |st ) + ln p(st+1 |st , a) δh (12.8)
h t t
" #
X
= E r(h) ∇θ ln πθ (at |st ) (12.9)
t
" #
X
=E ∇θ ln πθ (at |st )rt (12.10)
t

Equation 12.10 only involves πθ (at |st ) which we are trying to learn, and rt which we observe from
the environment. We can now perform SGD to find the optimal θ. Note policy learning is on-policy.

12.4 Model-Based Learning

Figure 12.3: The process of model-based learning.

Recall that in model-based learning, the learner tries to learn the parameters of the MDP underlying
their environment and then uses planning to come up with a policy. A model-based learner repeats
the following three steps until satisfactory performance is achieved:
12.5. CONCLUSION 155

1. Act according to some policy for M steps.

2. Learn/update a model of the underlying MDP based on the experiences gained from step 1.

3. Use the model from step 2 to plan for the next M steps.

In many implementations, “learning” a model means coming up with a maximum likelihood esti-
mate of the parameters of the underlying MDP: the reward function and transition probabilities
(we will not cover the specifics of this process in this section). We can now plan according to this
model and follow the recommended policy for the next M steps. One issue with this basic model
is that it doesn’t do anything to handle time spent exploring vs exploiting. While there are many
ways to address this in practice, we will now present three common approaches:

• Incorporate exploration directly into the policy, for example, by using an ϵ-greedy strategy
as discussed in 12.3.1.

• Optimism under uncertainty: Let N (s, a) be the number of times we visited the state-
action pair (s, a). Each “visit” is a time that the agent was in state s and chose to perform
action a. When we are learning, we will assume that if N (s, a) is small, then the next
state will have higher reward than predicted by the maximum likelihood model. Since the
model thinks taking visiting lesser-known state-action pairs will lead to high reward (we are
optimistic when we are uncertain), we’ll have a tendency to explore these lesser-known areas.

• Posterior sampling/Thompson sampling: We maintain a posterior distribution such as


a Dirichlet over the transition function p(∗|s, a) ∀s, a which we update using the experiences
from the previous M steps of acting rather than coming up with a maximum likelihood point
estimate. We then sample from this posterior to get a model, which we proceed to plan with.
The idea is that when we are less certain about the outcome of the transition function, there’s
a greater likelihood of the sampling leading us to explore. When we are more certain about
what the best outcome is, we’re more likely to just stick with that best outcome.

12.5 Conclusion
To recap, reinforcement learning is a machine learning technique that allows a learner to find an
optimal policy in an environment modeled by an MDP where the transition probabilities and reward
function are unknown. The learner gains information by interacting with the environment and
receiving rewards and uses this information to update the model they have of the underlying MDP
(model-based) or their beliefs about the value of taking particular actions from particular states
(value-based model-free). Q-learning and temporal difference updating are key topics in model-
free learning that underlie many techniques in the area including SARSA, Q-learning and deep
Q-learning. Contemporary RL models such as the one used by AlphaGo often use a combination
of supervised and reinforcement learning methods which can be endlessly customized to meet the
needs of a problem.

You might also like