Cs181 Textbook
Cs181 Textbook
Cs181 Textbook
of Machine Learning
The initial version of this textbook was created by William J. Deuschle for his senior thesis,
based on his notes of CS181 during the Spring of 2017. This textbook has since been maintained
by the CS181 course staff with bug fixes from many CS181 students.
Contents
2 Regression 4
2.1 Defining the Problem . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 4
2.2 Solution Options . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 4
2.2.1 K-Nearest-Neighbors . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
2.2.2 Neural Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
2.2.3 Random Forests . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
2.2.4 Gradient Boosted Trees . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
2.2.5 Turning to Linear Regression . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
2.3 Introduction to Linear Regression . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 6
2.4 Basic Setup . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 6
2.4.1 Merging of Bias . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 7
2.4.2 Visualization of Linear Regression . . . . . . . . . . . . . . . . . . . . . . . . 7
2.5 Finding the Best Fitting Line . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 8
2.5.1 Objective Functions and Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . 8
2.5.2 Least Squares Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 9
2.6 Linear Regression Algorithms . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 10
2.6.1 Optimal Weights via Matrix Differentiation . . . . . . . . . . . . . . . . . . . 10
2.6.2 Bayesian Solution: Maximum Likelihood Estimation . . . . . . . . . . . . . . 11
2.6.3 Alternate Interpretation: Linear Regression as Projection . . . . . . . . . . . 13
2.7 Model Flexibility . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 13
2.7.1 Basis Functions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 13
2.7.2 Regularization . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 16
ii
CONTENTS iii
3 Classification 29
3.1 Defining the Problem . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29
3.2 Solution Options . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29
3.3 Discriminant Functions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 30
3.3.1 Basic Setup: Binary Linear Classification . . . . . . . . . . . . . . . . . . . . 30
3.3.2 Multiple Classes . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31
3.3.3 Basis Changes in Classification . . . . . . . . . . . . . . . . . . . . . . . . . . 31
3.4 Numerical Parameter Optimization and Gradient Descent . . . . . . . . . . . . . . . 33
3.4.1 Gradient Descent . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 35
3.4.2 Batch Gradient Descent versus Stochastic Gradient Descent . . . . . . . . . . 36
3.5 Objectives for Decision Boundaries . . . . . . . . . . . . . . . . . . . . . . . . . . . . 36
3.5.1 0/1 Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 36
3.5.2 Least Squares Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 37
3.5.3 Hinge Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 38
3.6 Probabilistic Methods . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 39
3.6.1 Probabilistic Discriminative Models . . . . . . . . . . . . . . . . . . . . . . . 39
Logistic Regression . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 40
Multi-Class Logistic Regression and Softmax . . . . . . . . . . . . . . . . . . 42
3.6.2 Probabilistic Generative Models . . . . . . . . . . . . . . . . . . . . . . . . . 44
Classification in the Generative Setting . . . . . . . . . . . . . . . . . . . . . 44
MLE Solution . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 45
Naive Bayes . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 47
3.7 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 49
4 Neural Networks 50
4.1 Motivation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 50
4.1.1 Comparison to Other Methods . . . . . . . . . . . . . . . . . . . . . . . . . . 51
4.1.2 Universal Function Approximation . . . . . . . . . . . . . . . . . . . . . . . . 51
4.2 Feed-Forward Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 52
iv CONTENTS
6 Clustering 78
6.1 Motivation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 78
6.1.1 Applications . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 79
6.2 K-Means Clustering . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 79
CONTENTS v
7 Dimensionality Reduction 92
7.1 Motivation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 92
7.2 Applications . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 93
7.3 Principal Component Analysis . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 93
7.3.1 Reconstruction Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 94
7.3.2 Minimizing Reconstruction Loss . . . . . . . . . . . . . . . . . . . . . . . . . 96
7.3.3 Multiple Principal Components . . . . . . . . . . . . . . . . . . . . . . . . . . 97
7.3.4 Identifying Directions of Maximal Variance in our Data . . . . . . . . . . . . 97
7.3.5 Choosing the Optimal Number of Principal Components . . . . . . . . . . . . 98
7.4 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 101
1
2 CHAPTER 1. INTRODUCTION TO MACHINE LEARNING
come away with these two levels of understanding, you will be off to a good start. The third level
of understanding relates to having a derivational awareness of the algorithms and methods we will
make use of. This level of understanding is not strictly necessary to successfully interact with
existing machine learning capabilities, but it will be required if you desire to go further and deepen
existing knowledge. Thus, we will be presenting derivations, but it will be secondary to a high level
understanding of problem types and the practical intuition behind available solutions.
Definition 1.4.1 (Definition Explanation): You will find definitions in these dark gray boxes.
Derivation 1.4.1 (Derivation Explanation): You will find derivations in these light gray boxes.
⋆ You will find explanations for subtle or confusing concepts in these red wrapped boxes.
Chapter 2
Regression
A major component of machine learning, the one that most people associate with ML, is dedicated
to making predictions about a target given some inputs, such as predicting how much money an
individual will earn in their lifetime given their demographic information. In this chapter, w!e’re
going to focus on the case where our prediction is a continuous, real number. When the target is
a real number, we call this prediction procedure regression.
2. Predicting the amount of time someone will take to pay back a loan given their credit history.
3. Predicting what time a package will arrive given current weather and traffic conditions.
Hopefully you are starting to see the pattern emerging here. Given some inputs, we need to
produce a prediction for a continuous output. That is exactly the purpose of regression. Notice
that regression isn’t any one technique in particular. It’s just a class of methods that helps us
achieve our overall goal of predicting a continuous output.
Definition 2.1.1 (Regression): A class of techniques that seeks to make predictions about un-
known continuous target variables given observed input variables.
4
2.2. SOLUTION OPTIONS 5
2.2.1 K-Nearest-Neighbors
K-Nearest-Neighbors is an extremely intuitive, non-parametric technique for regression or classifi-
cation. It works as follows in the regression case:
1. Identify the K points in our data set that are closest to the new data point. ‘Closest’ is some
measure of distance, usually Euclidean.
2. Average the value of interest for those K data points.
3. Return that averaged value of interest: it is the prediction for our new data point.
⋆ A non-parametric model simply means we don’t make any assumptions about the form of our data. We only need
to use the data itself to make predictions.
⋆ Notice w0 in the expression above, which doesn’t have a corresponding x0 value. This is known as the bias term.
If you consider the definition of a line y = mx + b, the bias term is corresponds to the intercept b. It accounts for
data that has a non-zero mean.
Let’s illustrate how linear regression works using an example, considering the case of 10 year old
Sam. She is curious about how tall she will be when she grows up. She has a data set of parents’
heights and the final heights of their children. The inputs x are:
x1 = height of mother (cm)
x2 = height of father (cm)
Using linear regression, she determines the weights w to be:
w = [34, 0.39, 0.33]
Sam’s mother is 165 cm tall and her father is 185 cm tall. Using the results of the linear regression
solution, Sam solves for her expected height:
Sam’s height = 34 + 0.39(165) + 0.33(185) = 159.4 cm
Let’s inspect the categories linear regression falls into for our ML framework cube. First, as we’ve
already stated, linear regression deals with a continuous output domain. Second, our goal is
to make predictions on future data points, and to construct something capable of making those
predictions we first need a labeled data set of inputs and outputs. This makes linear regression a
supervised technique. Third and finally, linear regression is non-probabilistic. Note that there
also exist probabilistic interpretations of linear regression which we will discuss later in the chapter.
Domain Training Probabilistic
Continuous Supervised No
x = (165, 185)
We now add a 1 in the first position of the data point to make it:
We do this for every point in our data set. This bias trick lets us write:
This is more compact, easier to reason about, and makes properties of linear algebra nicer for the
calculations we will be performing.
Figure 2.2: Data set with clear trend, best fitting line included.
where we need only find a single bias term w0 (which acts as the intercept of the line) and single
weight w1 (which acts as the slope of the line). However, the same principle applies to higher
dimensional data as well. We’re always fitting the hyperplane that best predicts the data.
⋆ Although our input data points x can take on multiple dimensions, our output data y is always a 1-dimensional
real number when dealing with regression problems.
Now that we have some intuition for what linear regression is, a natural question arises: how
do we find the optimal values for w? That is the remaining focus of this chapter.
Now that we’ve defined our model as a weighted combination of our input variables, we need some
way to choose our value of w. To do this, we need an objective function.
Definition 2.5.1 (Objective Function): A function that measures the ‘goodness’ of a model.
We can optimize this function to identify the best possible model for our data.
As the definition explains, the purpose of an objective function is to measure how good a specific
model is. We can therefore optimize this function to find a good model. Note that in the case of
linear regression, our ‘model’ is just a setting of our parameters w.
An objective function will sometimes be referred to as loss. Loss actually measures how bad a
model is, and then our goal is to minimize it. It is common to think in terms of loss when discussing
linear regression, and we incur loss when the hyperplane we fit is far away from our data.
So how do we compute the loss for a specific setting of w? To do this, we often use residuals.
2.5. FINDING THE BEST FITTING LINE 9
Definition 2.5.2 (Residual): The residual is the difference between the target (y) and predicted
(y(x, w)) value that a model produces:
Commonly, loss is a function of the residuals produced by a model. For example, you can imagine
taking the absolute value of all of the residuals and adding those up to produce a measurement of
loss. This is sometimes referred to as L1 Loss. Or, you might square all of the residuals and then
add those up to produce loss, which is called L2 loss or least squares loss. You might also use some
combination of L1 and L2 loss. For the most part, these are the two most common forms of loss
you will see when discussing linear regression.
When minimized, these distinct measurements of loss will produce solutions for w that have
different properties. For example, L2 loss is not robust to outliers due to the fact that we are
squaring residuals. Furthermore, L2 loss will produce only a single solution while L1 loss can
potentially have many equivalent solutions. Finally, L1 loss produces unstable solutions, meaning
that for small changes in our data set, we may see large changes in our solution w.
Loss is a concept that we will come back to very frequently in the context of supervised machine
learning methods. Before exploring exactly how we use loss to fit a line, let’s consider least squares
loss in greater depth.
⋆ The notation Ln (w) is used to indicate the loss incurred by a model w for a single data point (xn , y). L(w) indicates
the loss incurred for an entire data set by the model w. Be aware that this notation is sometimes inconsistent between
different sources.
There is a satisfying statistical interpretation for using this loss function which we will explain
later in this chapter, but for now it will suffice to discuss some of the properties of this loss function
that make it desirable.
First, notice that it will always take on positive values. This is convenient because we can focus
exclusively on minimizing our loss, and it also allows us to combine the loss incurred from different
data points without worrying about them cancelling out.
A more subtle but enormously important property of this loss function is that we know a lot
about how to efficiently optimize quadratic functions. This is not a textbook about optimization,
but some quick and dirty intuition that we will take advantage of throughout this book is that
we can easily and reliably take the derivative of quadratic functions because they are continuously
differentiable. We also know that optima of a quadrative function will be located at points where
the derivative of the function is equal to 0, as seen in Figure 2.3. In contrast, L1 loss is not
continuously differentiable over the entirety of its domain.
10 CHAPTER 2. REGRESSION
Figure 2.3: Quadratic function with clear optimum at x = 2, where the derivative of the function
is 0.
⋆ Note that we added a constant 12 to the beginning of our loss expression. This scales the loss, which will not change
our final result for the optimal parameters. It has the benefit of making our calculations cleaner once we’ve taken
the gradient of the loss.
We now want to solve for the values of w that minimize this expression.
Derivation 2.6.1 (Least Squares Optimal Weights Derivation): We find the optimal
weights w∗ as follows:
Start by taking the gradient of the loss with respect to our parameter w:
N
X
∇L(w) = (yn − w⊤ xn )(−xn )
n=1
2.6. LINEAR REGRESSION ALGORITHMS 11
At this point, it is convenient to rewrite these summations as matrix operations, making use of
design matrix X (N × D) and target values y (N × 1). We have
N
X N
X
⊤ ⊤
X y= yn xn , X Xw = xn (x⊤
n w)
n=1 n=1
Substituting, we have
X⊤ y − X⊤ Xw = 0.
For this to be well defined we need X to have full column rank (features are not colinear) so that
X⊤ X is positive definite and the inverse exists.
The quantity (X⊤ X)−1 X⊤ in Derivation 2.6.1 has a special name: the Moore-Penrose
pseudo inverse. You can think of it as a generalization of a matrix inversion operation to a
non-square matrix.
yn ∼ N (w⊤ xn , β −1 )
The interpretation of this is that our target value y is generated according to a linear combination
12 CHAPTER 2. REGRESSION
of our inputs x, but there is also some noise in the data generating process described by the variance
parameter β −1 . It’s an acknowledgement that some noise exists naturally in our data.
⋆ It’s common to write variance as an inverse term, such as β −1 . The parameter β is then known as the precision,
which is sometimes easier to work with than the variance.
As before, we now ask the question: how do we solve for the optimal weights w? One approach
we can take is to maximize the probability of observing our target data y. This technique is known
as maximum likelihood estimation.
We then take the logarithm of the likelihood, and since the logarithm is a strictly increasing,
continuous function, this will not change our optimal weights w:
N
X
ln p(y|X, w, β) = ln N (w⊤ xn , β −1 )
n=1
Notice that this is a quadratic function in w, which means that we can solve for it by taking the
derivative with respect to w, setting that expression to 0, and solving for w:
N
∂ ln p(y|X, w, β) X
= −β (yn − w⊤ xn )(−xn )
∂w
n=1
N
X N
X
⇔ yn xn − (w⊤ xn )xn = 0.
n=1 n=1
Notice that this is exactly the same form as Equation 2.6. Solving for w as before, we have:
Notice that our final solution is exactly the same form as the solution in Equation 2.7, which
we solved for by minimizing the least squares loss! The takeaway here is that minimizing
a least squares loss function is equivalent to maximizing the probability under the
assumption of a linear model with Gaussian noise.
2.7. MODEL FLEXIBILITY 13
w∗ = (X⊤ X)−1 X⊤ y
which simplifies as
w∗ = X−1 y
Xw∗ = XX−1 y
Xw∗ = y
We were able to recover our targets y exactly because X is an invertible tranformation. However,
in the general case where X is not invertible and we have to use the approximate pseudoinverse
(X⊤ X)−1 X⊤ , we instead recover ŷ:
where ŷ can be thought of as the closest projection of y onto the column space of X.
Furthermore, this motivates the intuition that w∗ is the set of coefficients that best transforms
our input space X into our target values y.
Definition 2.7.1 (Basis Function): Typically denoted by the symbol ϕ(·), a basis function is
a transformation applied to an input data point x to move our data into a different input basis,
which is another phrase for input domain.
14 CHAPTER 2. REGRESSION
x = (x(1) , x(2) )′
We may choose our basis function ϕ(x) such that our transformed data point in its new basis is:
2
ϕ(x) = (x(1) , x(1) , x(2) , sin(x(2) ))′
Using a basis function is so common that we will sometimes describe our input data points as
ϕ = (ϕ(1) , ϕ(2) , ..., ϕ(D) )′ .
⋆ The notation x = (x(1) , x(2) )′ is a way to describe the dimensions of a single data point x. The term x(1) is the
first dimension of a data point x, while x1 is the first data point in a data set.
Basis functions are very general - they could specify that we just keep our input data the same.
As a result, it’s common to rewrite the least squares loss function from Equation 2.4 for linear
regression in terms of the basis function applied to our input data:
N
1X
L(w) = (yn − w⊤ ϕn )2 (2.10)
2
n=1
To motivate why we might need basis functions for performing linear regression, let’s consider
this graph of 1-dimensional inputs X along with their target outputs y, presented in Figure 2.4.
As we can see, we’re not going to be able to fit a good line to this data. The best we can hope
to do is something like that of Figure 2.5.
However, if we just apply a simple basis function to our data, in this case the square root
√
function, ϕ(x) = ( x1 )′ , we then have the red line in Figure 2.6. We now see that we can fit a very
good line to our data, thanks to basis functions.
2.7. MODEL FLEXIBILITY 15
Figure 2.5: Data with no basis function applied, attempt to fit a line.
Figure 2.7: Data set with a clear trend and Gaussian noise.
Still, the logical question remains: how can I choose the appropriate basis function? This toy
example had a very obviously good basis function, but in general with high-dimensional, messy
input data, how do we choose the basis function we need?
The answer is that this is not an easy problem to solve. Often, you may have some domain
specific knowledge that tells you to try a certain basis, such as if you’re working with chemical data
and know that an important equation involves a certain function of one of your inputs. However,
more often than not we won’t have this expert knowledge either. Later, in the chapter on neural
networks, we will discuss methods for discovering the best basis functions for our data automatically.
2.7.2 Regularization
When we introduced the idea of basis functions above, you might have wondered why we didn’t
just try adding many basis transformations to our input data to find a good transformation. For
example, we might use this large basis function on a D-dimensional data point z:
2 100 2 100 2 100 ′
ϕ(z) = (z (1) , z (1) , ..., z (1) , z (2) , z (2) , ..., z (2) , ..., z (D) , z (D) , ..., z (D) )
where you can see that we expand the dimensions of the data point to be 100 times its original
size.
Let’s say we have an input data point x that is 1-dimensional, and we apply the basis function
described above, so that after the transformation each data point is represented by 100 values. Say
we have 100 data points on which to perform linear regression, and because our transformed input
space has 100 values, we have 100 parameters to fit. In this case, with one parameter per data
point, it’s possible for us to fit our regression line perfectly to our data so that we have no loss!
But is this a desirable outcome? The answer is no, and we’ll provide a visual example to illustrate
that.
Imagine Figure 2.7 is our data set. There is a very clear trend in this data, and you would likely
draw a line that looks something like that of Figure 2.8 to fit it.
However, imagine we performed a large basis transformation like the one described above. If we
do that, it’s possible for us to fit our line perfectly, threading every data point, like that in Figure
2.9.
2.7. MODEL FLEXIBILITY 17
Let’s see how both of these would perform on new data points. With our first regression line,
if we have a new data point x = (10)′ , we would predict a target value of 14.1, which most people
would agree is a pretty good measurement. However, with the second regression line, we would
predict a value of 9.5, which most people would agree does not describe the general trend in the
data. So how can we handle this problem elegantly?
Examining our loss function, we see that right now we’re only penalizing predictions that are
not correct in training. However, what we ultimately care about is doing well on new data points,
not just our training set. This leads us to the idea of generalization.
A convoluted line that matches the noise of our training set exactly isn’t going to generalize
well to new data points that don’t look exactly like those found in our training set. If wish to avoid
recovering a convoluted line as our solution, we should also penalize the total size of our weights
w. The effect of this is to discourage many complex weight values that produce a messy regression
line. By penalizing large weights, we favor simple regression lines like the one in Figure 2.8 that
take advantage of only the most important basis functions.
The concept that we are introducing, penalizing large weights, is an example of what’s known as
regularization, and it’s one that we will see come up often in different machine learning methods.
There is obviously a tradeoff between how aggressively we regularize our weights and how tightly
our solution fits to our data, and we will formalize this tradeoff in the next section. However, for
now, we will simply introduce a regularization parameter λ to our least squares loss function:
N
1X λ
L(w) = (yn − w⊤ ϕn )2 + w⊤ w (2.11)
2 2
n=1
The effect of λ is to penalize large weight parameters. The larger λ is, the more we will favor
simple solutions. In the limit limλ→∞ L(w), we will drive all weights to 0, while with a nonexistant
λ = 0 we will apply no regularization at all. Notice that we’re squaring our weight parameters -
this is known as L2 norm regularization or ridge regression. While L2 norm regularization is
very common, it is just one example of many ways we can perform regularization.
To build some intuition about the effect of this regularization parameter, examine Figure 2.10.
Notice how larger values of λ produce less complex lines, which is the result of applying more
regularization. This is very nice for the problem we started with - needing a way to choose which
basis functions we wanted to use. With regularization, we can select many basis functions, and then
allow regularization to ‘prune’ the ones that aren’t meaningful (by driving their weight parameters
to 0). While this doesn’t mean that we should use as many basis transformations as possible (there
will be computational overhead for doing this), it does allow us to create a much more flexible
linear regression model without creating a convoluted regression line.
2.7. MODEL FLEXIBILITY 19
Figure 2.10: Effect of different regularization parameter values on final regression solution.
where the (λ/2)w⊤ w term is for the regularization. We can generalize our type of regularization
by writing it as:
N
1X λ h
L(w) = (yn − w⊤ ϕn )2 + w h
2 2
n=1
where h determines the type of regularization we are using and thus the form of the optimal solution
that we recover. For example, if h = 2 then we add λ/2 times the square of the L2 norm. The
three most commonly used forms of regularization are lasso, ridge, and elastic net.
Ridge Regression
This is the case of h = 2, which we’ve already discussed, but what type of solutions does it tend
to recover? Ridge regression prevents any individual weight from growing too large, providing us
with solutions that are generally moderate.
Lasso Regression
Lasso regression is the case of h = 1. Unlike ridge regression, lasso regression will drive some
parameters wi to zero if they aren’t informative for our final solution. Thus, lasso regression is
good if you wish to recover a sparse solution that will allow you to throw out some of your basis
functions. You can see the forms of ridge and lasso regression functions in Figure 2.11. If you think
about how Lasso is L1 Norm (absolute value) and Ridge is L2 Norm (squared distance), you can
think of those shapes as being the set of points (w1 , w2 ) for which the norm takes on a constant
20 CHAPTER 2. REGRESSION
Figure 2.11: Form of the ridge (blue) and lasso (red) regression functions.
value.
Elastic Net
Elastic net is a middle ground between ridge and lasso regression, which it achieves by using a linear
combination of the previous two regularization terms. Depending on how heavily each regulariza-
tion term is weighted, this can produce results on a spectrum between lasso and ridge regression.
w ∼ N (0, S−1
0 )
Remember from Equation 2.8 that the distribution over our observed data is Normal as well, written
here in terms of our entire data set:
p(y|X, w, β) = N (Xw, β −1 I)
2.7. MODEL FLEXIBILITY 21
We want to combine the likelihood and the prior to recover the posterior distribution of w, which
follows directly from Bayes’ Theorem:
We now wish to find the value of w that maximizes the posterior distribution. We can maximize
the log of the posterior with respect to w, which simplifies the problem slightly:
where C2 collects more constant terms that don’t depend on w. Let’s now handle ln p(w):
Since β > 0, notice that maximizing the posterior probability with respect to w is equivalent to
minimizing the sum of squared errors (yn − w⊤ xn )2 and the regularization term β1 w⊤ S0 w.
The interpretation of this in the case S0 = sI is that adding a prior over the distribution of our
weight parameters w and then maximizing the resulting posterior distribution is equivalent to ridge
regression with λ = βs , as the regularization term simplifies to βs w⊤ S0 w. Recall that both s and β
are precisions. This states that if the precision on the observations is small, then the regularization
term is relatively large, meaning our posterior leans more towards our prior. If the precision on our
prior is small, then the regularization is small, meaning our posterior leans more towards the data.
22 CHAPTER 2. REGRESSION
Obviously a good solution will fall somewhere in between these two extremes of high variance
and high bias. Indeed, we have techniques like regularization to help us balance the two extremes
(improving generalization), and we have other techniques like cross-validation that help us deter-
mine when we have found a good balance (measuring generalization).
⋆ In case you are not familiar with the terms bias and variance, we provide their statistical definitions here:
bias(θ) = E[θ] − θ
Before we discuss how to effectively mediate between these opposing forces of error in our models,
we will first show that the bias-variance tradeoff is not only conceptual but also has probabilistic
underpinnings. Specifically, any loss that we incur over our training set using a given model can
be described in terms of bias and variance, as we will demonstrate now.
Start with the expected squared error (MSE), where the expectation is taken with respect to both
our data set D (variation in our modeling error comes from what data set we get), which is a
random variable of (x, y) pairs sample from a distribution F , and our conditional distribution y|x
(there may be additional error because the data are noisy):
where we use the notation fD to explicitly acknowledge the dependence of our fitted model f on
the dataset D. For reasons that will become clear in a few steps, add and subtract our target mean
2.8. CHOOSING BETWEEN MODELS 23
ȳ, which is the true conditional mean given by ȳ = Ey|x [y], inside of the squared term:
Group together the first two terms and the last two terms:
MSE = ED,y|x [(y − ȳ)2 ] + ED,y|x [(ȳ − fD (x))2 ] + 2ED,y|x [(y − ȳ)(ȳ − fD (x))] (2.12)
Let’s examine the last term, 2E[(y − ȳ)(ȳ − fD (x))]. Notice that (ȳ − fD (x)) does not depend on the
conditional distribution y|x at all. Thus, we are able to move one of those expecations in, which
makes this term:
where we have removed expectations that do not apply (e.g. ȳ does not depend on the dataset D).
We now have two terms contributing to our squared error. We will put aside the first term
Ey|x [(y − ȳ)2 ], as this is unidentifiable noise in our data set. In other words, our data will randomly
deviate from the mean in ways we cannot predict. On the other hand, we can work with the
second term ED [(ȳ − fD (x))2 ] as it involves our model function f (·)
As before, for reasons that will become clear in a few steps, let’s add and subtract our prediction
mean f¯(·) = ED [fD (x)], which is the expectation of our model function taken with respect to our
random data set.
ED [(ȳ − fD (x))2 ] = (ȳ − f¯(x))2 + ED [(f¯(x) − fD (x))2 ] + 2ED [(ȳ − f¯(x))(f¯(x) − fD (x))]
2ED [(ȳ − f¯(x))(f¯(x) − fD (x))] = 2(ȳ − f¯(x))ED [(f¯(x) − fD (x))] = 2(ȳ − f¯(x))(0) = 0
24 CHAPTER 2. REGRESSION
Notice the form of these two terms. The first one, (ȳ − f¯(x))2 , is the squared bias of our model,
since it is the square of the average difference between our prediction and the true target value.
The second one, ED [(f¯(x) − fD (x))2 ], is the variance of our model, since it is the expected squared
difference between our model and its average value. Thus:
Thus, our total squared error, plugging in to Equation 2.13 can be written as:
Figure 2.12: Bias and variance both contribute to the overall error of our model.
The key takeaway of the bias-variance decomposition is that the controllable error in our model
is given by the squared bias and variance. Holding our error constant, to decrease bias requires
increasing the variance in our model, and vice-versa. In general, a graph of the source of error in
our model might look something like Figure 2.12.
For a moment, consider what happens on the far left side of this graph. Our variance is very
high, and our bias is very low. In effect, we’re fitting perfectly to all of the data in our data set.
This is exactly why we introduced the idea of regularization from before - we’re fitting a very
convoluted line that is able to pass through all of our data but which doesn’t generalize well to new
data points. There is a name for this: overfitting.
The opposite idea, underfitting, is what happens at the far right of the graph: we have high
bias and aren’t responding to the variation in our data set at all.
Definition 2.8.3 (Underfitting): A phenomenon where we construct a model that doesn’t re-
spond to variation in our data.
So you can hopefully now see that the bias-variance tradeoff is important to managing the
problem of overfitting and underfitting. Too much variance in our model and we’ll overfit to our
data set. Too much bias and we won’t account for the trends in our data set at all.
In general, we would like to find a sweet spot of moderate bias and variance that produces
minimal error. In the next section, we will explore how we find this sweet spot.
26 CHAPTER 2. REGRESSION
2.8.2 Cross-Validation
We’ve seen that in choosing a model, we incur error that can be described in terms of bias and
variance. We’ve also seen that we can regulate the source of error through regularization, where
heavier regularization increases the bias of our model. A natural question then is how do we know
how much regularization to apply to achieve a good balance of bias and variance?
Another way to look at this is that we’ve traded the question of finding the optimal number of
basis functions for finding the optimal value of the regularization parameter λ, which is often an
easier problem in most contexts.
One very general technique for finding the sweet spot of our regularization parameter, other hy-
perparameters, or even for choosing among entirely different models is known as cross-validation.
Definition 2.8.4 (Cross-Validation): A subsampling procedure used over a data set to tune
hyperparameters and avoid over-fitting. Some portion of a data set (10-20% is common) is set aside,
and training is performed on the remaining, larger portion of data. When training is complete,
the smaller portion of data left out of training is used for testing. The larger portion of data is
sometimes referred to as the training set, and the smaller portion is sometimes referred to as the
validation set.
Cross-validation is often performed more than once for a given setting of hyperparameters to
avoid a skewed set of validation data being selected by chance. In K-Folds cross-validation, you
1
perform cross-validation K times, allocating K of your data for the validation set at each iteration.
Let’s tie this back into finding a good regularization parameter. For a given value of λ, we will
incur a certain amount of error in our model. We can measure this error using cross-validation,
where we train our model on the training set and compute the final error using the validation set.
To find the optimal value for λ, we perform cross-validation using different values of λ, eventually
settling on the value that produces the lowest final error. This will effectively trade off bias and
variance, finding the value of λ that minimizes the total error.
You might wonder why we need to perform cross-validation at all - why can’t we train on the
entire data set and then compute the error over the entire data set as well?
The answer is again overfitting. If we train over the entire data set and then validate our results
on the exact same data set, we are likely to choose a regularization parameter that encourages our
model to conform to the exact variation in our data set instead of finding the generalizable trends.
By training on one set of data, and then validating on a completely different set of data, we
force our model to find good generalizations in our data set. This ultimately allows us to pick
the regularization term λ that finds the sweet spot between bias and variance, overfitting and
underfitting.
least have confidence your model achieved the best generalizability that could be proven through
cross-validation.
where p(m) is our prior certainty for a given model and p(X|m) is the likelihood of our data set
given that model. The elegance of this approach is that we don’t have to pick any particular model,
instead choosing to marginalize out our uncertainty.
Derivation 2.9.1 (Posterior Predictive Derivation): For the sake of simplicity and ease of
use, we will select our prior over w to be a Normal distribution with mean µ0 and variance S −1
0 :
p(w) = N (µ0 , S −1
0 )
Remembering that the observed data is normally distributed, and accounting for Normal-Normal
conjugacy, our posterior distribution will be Normal as well:
p(w|X, y, β) = N (µN , S −1
N )
where
S N = (S −1 ⊤
0 + βX X)
−1
µN = S N (S −1
0 µ0 + βXy)
We now have a posterior distribution over w. However, usually this distribution is not what we
care about. We’re actually interested in making a point prediction for the target y ∗ given a new
input x∗ . How do we go from a posterior distribution over w to this prediction?
28 CHAPTER 2. REGRESSION
The answer is using what’s known as the posterior predictive over y ∗ given by:
Z
∗ ∗
p(y |x , X, y) = p(y ∗ |x∗ , w)p(w|X, y)dw
Zw (2.14)
= N (y ∗ |w⊤ x∗ , β −1 )N (w|µN , S −1
N )dw
w
The idea here is to average the probability of y ∗ over all the possible setting of w,
weighting the probabilities by how likely each setting of w is according to its posterior
distribution.
2.10 Conclusion
In this chapter, we looked at a specific tool for handling regression problems known as linear regres-
sion. We’ve seen linear regression described in terms of loss functions, probabilistic expressions,
and geometric projections, which reflects the deep body of knowledge that we have around this
very common technique.
We’ve also discussed many concepts in this chapter that will prove useful in other areas of
machine learning, particularly for other supervised techniques: loss functions, regularization, bias
and variance, over and underfitting, posterior distributions, maximum likelihood estimation, and
cross-validation among others. Spending time to develop an understanding of these concepts now
will pay off going forward.
It may or may not be obvious at this point that we are missing a technique for a very large class
of problems: those where the solution is not just a continuous, real number. How do we handle
situations where we need to make a choice between different discrete options? This is the question
we will turn to in the next chapter.
Chapter 3
Classification
In the last chapter we explored ways of predicting a continuous, real-number target. In this chapter,
we’re going to think about a different problem- one where our target output is discrete-valued. This
type of problem, one where we make a prediction by choosing between finite class options, is known
as classification.
The point of classification is hopefully clear: we’re trying to identify the most appropriate class for
an input data point.
Definition 3.1.1 (Classification): A set of problems that seeks to make predictions about un-
observed target classes given observed input variables.
29
30 CHAPTER 3. CLASSIFICATION
As with linear regression, discriminant functions h(x, w) seek to find a weighted combination of
our input variables to make a prediction about the target class:
h(x, w) = w(0) x(0) + w(1) x(1) + ... + w(D) x(D) (3.1)
where we are using the bias trick of appending x(0) = 1 to all of our data points.
Definition 3.3.1 (Decision Boundary): The decision boundary is the line that divides the input
space into different target classes. It is learned from an initial data set, and then the target class
of new data points can be predicted based on where they fall relative to the decision boundary. At
the decision boundary, the discriminant function takes on a value of 0.
⋆ You will sometimes see the term decision surface in place of decision boundary, particularly if the input space is
larger than two dimensions.
3.3. DISCRIMINANT FUNCTIONS 31
determine whether a given point is more likely to be in class Cj or class Ck . This is known as a
one-versus-one approach, and it also doesn’t work because we again end up with ambiguous regions
as demonstrated in Figure 3.3.
Instead, we can avoid these ambiguities in the multi-class case by using K different linear
classifiers hk (x, wk ), and then assigning new data points to the class Ck for which hk (x, wk ) >
hj (x, wj ) for all j ̸= k. Then, similar to the two-class case, the decision boundaries are described
by the surface along which hk (x, wk ) = hj (x, wj ).
Now that we’ve explored the multi-class generalization, we can consider how to learn the weights
w that define the optimal discriminant functions. However, prior to solving for w, we need to discuss
how basis transformations apply to classification problems.
Figure 3.4: Data set without any basis functions applied, not linearly separable.
now linearly separable by a plane between the two classes. Applying a generic basis change ϕ(·),
we can write our generalized linear model as:
hk (x, wk ) = w⊤ ⊤
k ϕ(x) = wk ϕ (3.2)
For the sake of simplicity in the rest of this chapter, we will leave out any basis changes in our
derivations, but you should recognize that they could be applied to any of our input data to make
the problems more tractable.
⋆ For an input matrix X, there is a matrix generalization of our basis transformed inputs: Φ = ϕ(X), where Φ is
known as the design matrix.
Figure 3.5: Data set with basis functions applied, now linearly separable.
3.4. NUMERICAL PARAMETER OPTIMIZATION AND GRADIENT DESCENT 35
⋆ The terms numerical and analytical procedures come up very frequently in machine learning literature. An analytical
solution typically utilizes a closed form equation that accepts your model and input data and returns a solution in
the form of optimized model parameters. On the other hand, numerical solutions are those that require some
sort of iteration to move toward an ever better solution, eventually stopping once the solution is deemed ‘good
enough’. Analytical solutions are typically more desirable than numerical solutions due to computational efficiency
and performance guarantees, but they often are not possible for complex problems due to non-convexity.
The high level idea behind gradient descent is as follows: to update our parameters, we take
a small step in the opposite direction of the gradient of our objective function with respect to the
weight parameters w(t) . Notationally, this looks like the following:
where w(t) corresponds to the state of the parameters w at time t, L(w(t) ) is the gradient of our
objective function, and η > 0 is known as the learning rate. Note that the parameter values at
time t = 0 given by w(0) are often initialized randomly.
⋆ In general, we want a learning rate that is large enough so that we make progress toward reaching a better solution,
but not so large that we take a step that puts us in a worse place in the parameter space than we were at the previous
step. Notice in Figure 3.6 that an appropriately small step size improves our objective function, while a large step
size overshoots the update and leaves us in a worse position.
36 CHAPTER 3. CLASSIFICATION
Why take a step in the opposite direction of the gradient of the objective function? You can
think of the objective function as a hill, and the current state of our parameters w(t) is our position
on that hill. The gradient tells us the steepest direction of increase in the objective function (i.e.
it specifies the direction that will make our model worse). Since we want to minimize the objective
function, we choose to move away from the direction of the gradient, sending our model down the
hill towards an area of lower error. We typically cease optimization when our updates become
sufficiently small, indicating that we’ve reached a local minimum. Note that it’s a good idea to run
gradient descent multiple times to settle on a final value for w, ideally initializing w(0) to a different
starting value each time, because we are optimizing a function with multiple local minima.
and so on. Now that we have the idea of one-hot encoding, we can describe our target classes
for each data point in terms of a one-hot encoded vector, which can then be used in our training
process for least squares.
Each class Ck gets its own linear function with a different set of weights wk :
hk (x, wk ) = w⊤
kx
We can combine the set of weights for each class into a matrix W, which gives us our linear classifier:
h(x, W) = W⊤ x (3.6)
where each row in the transposed weight matrix W⊤ corresponds to the linear function of an
individual class, and matrix W is D × K. We can use the results derived in the last chapter to find
the solution for W that minimizes the least squares loss function. Assuming a data set of input
data points X and one-hot encoded target vectors Y (where every row is a single target vector, so
that Y is N × K), the optimal solution for W is given by:
W∗ = (X⊤ X)−1 X⊤ Y,
which we can then use in our discriminant function h(x, W∗ ) to make predictions on new data
points.
While least squares gives us an analytic solution for our discriminant function, it has significant
limitations when used for classification. For one, least squares penalizes data points that are ‘too
good’, meaning they fall too far on the correct side of the decision boundary. Furthermore, it is not
robust to outliers, meaning the decision boundary significantly changes with the addition of just a
few outlier data points, as seen in Figure 3.7.
We can help remedy the problems with least squares by using an alternative loss function for
determining our weight parameters.
38 CHAPTER 3. CLASSIFICATION
We can use the form of this function to our advantage in constructing the hinge loss by recognizing
that we wish to incur error when we’re wrong (which corresponds to z > 0, the right side of
the graph that is continuously increasing), and we wish to incur 0 error if we are correct (which
corresponds to the left side of the graph where z < 0).
Remember from the previous section on least squares that in the two-class case, we classify a
data point x∗ as being from class 1 if h(x∗ , w) ≥ 0, and class -1 otherwise. We can combine this
3.6. PROBABILISTIC METHODS 39
logic with ReLU by recognizing that −h(x∗ , w)y ∗ ≥ 0 when there is a classification error, where
y ∗ is the true class of data point x∗ . This has exactly the properties we described above: we incur
error when we misclassify, and otherwise we do not incur error.
where ŷi is our class prediction and yi is the true class value. Notice that misclassified examples
contribute positive loss, as desired. We can take the gradient of this loss function, which will allow
us to optimize it using stochastic gradient descent. The gradient of the loss with respect to our
parameters w is as follows:
N
∂L(w) X
=− xi yi
∂w
yi ̸=ŷi
and then our update equation from time t to time t + 1 for a single misclassified example and with
learning rate η is given by:
∂L(w)
w(t+1) = w(t) − η = w(t) + ηxi yi
∂w
To sum up, the benefits of the hinge loss function are its differentiability (which allows us to
optimize our weight parameters), the fact that it doesn’t penalize any correctly classified data
points (unlike basic linear classification), and that it penalizes more heavily data points that are
more poorly misclassified.
Using hinge loss with discriminant functions to solve classification tasks (and applying stochastic
gradient descent to optimize the model parameters) is known as the perceptron algorithm. The
perceptron algorithm guarantees that if there is separability between all of our data points and we
run the algorithm for long enough, we will find a setting of parameters that perfectly separates our
data set. The proof for this is beyond the scope of this textbook.
Given this problem statement, it makes sense that we might try to model p(y ∗ |x∗ ). In fact,
modeling this conditional distribution directly is what’s known as probabilistic discriminative
modeling.
This means that we will start with the functional form of the generalized linear model described
by Equation 3.2, convert this to a conditional distribution, and then optimize the parameters of
the conditional distribution directly using a maximum likelihood procedure. From here, we will be
able to make predictions on new data points x∗ . The key feature of this procedure, which is known
as discriminative training, is that it optimizes the parameters of a conditional distribution directly.
We describe a specific, common example of this type of procedure called logistic regression in
the next section.
Logistic Regression
One problem we need to face in our discriminative modeling paradigm is that the results of our
generalized linear model are not probabilities; they are simply real numbers. This is why in the
previous paragraph we mentioned needing to convert our generalized linear model to a conditional
distribution. That step boils down to somehow squashing the outputs of our generalized linear
model onto the real numbers between 0 and 1, which will then correspond to probabilities. To do
this, we will apply what is known as the logistic sigmoid function, σ(·).
Definition 3.6.2 (Logistic Sigmoid Function, σ(·)): The logistic sigmoid function is commonly
used to compress the real number line down to values between 0 and 1. It is defined functionally
as:
1
σ(z) =
1 + exp (−z)
As you can see in Figure 3.9 where the logistic sigmoid function is graphed, it squashes our output
domain between 0 and 1 as desired for a probability.
⋆ There is a more satisfying derivation for our use of the logistic sigmoid function in logistic regression, but under-
standing its squashing properties as motivation is sufficient for the purposes of this book.
Using the logistic sigmoid function, we now have a means of generating a probability that a new
data point x∗ is part of class y ∗ . Because we are currently operating in the two-class case, which
in this context will be denoted C1 and C2 , we’ll write the probability for each of these classes as:
Now that we have such functions, we can apply the maximum likelihood procedure to determine
the optimal parameters for our logistic regression model.
3.6. PROBABILISTIC METHODS 41
For a data set {xi , yi } where i = 1..N and yi ∈ {0, 1}, the likelihood for our setting of parameters
w can be written as:
N
ŷiyi {1 − ŷi }1−yi
Y
N
p({yi }i=1 |w) = (3.11)
i=1
N
X
ln(p({yi }N
i=1 |w)) = {yi ln ŷi + (1 − yi ) ln (1 − yˆi )} (3.12)
i=1
As a monotonically increasing function, maximizing the logarithm of the likelihood (called the
log likelihood ) will result in the same optimal setting of parameters as if we had just optimized
the likelihood directly. Furthermore, using the log likelihood has the nice effect of turning what
is currently a product of terms from 1..N to a sum of terms from 1..N , which will make our
calculations nicer.
Second, we will turn our log likelihood into an error function by taking the negative of our log
likelihood expression. Now, instead of maximizing the log likelihood, we will be minimizing the
error function, which will again find us the same setting of parameters.
⋆ It’s worth rereading the above paragraph again to understand the pattern presented there, which we will see several
times throughout this book. Instead of maximizing a likelihood function directly, it is often easier to define an error
function using the negative log likelihood, which we can then minimize to find the optimal setting of parameters for
our model.
After taking the negative logarithm of the likelihood function defined by Equation 3.11, we
are left with the following term, known as the cross-entropy error function, which we will seek to
42 CHAPTER 3. CLASSIFICATION
minimize:
N
X
E(w) = − ln p({yi }|w) = − {yi ln ŷi + (1 − yi ) ln (1 − yˆi )} (3.13)
i=1
where as before ŷi = p(yi = C1 |xi ) = σ(w⊤ xi ). The cross-entropy error refers to the log likelhood of
the labels conditioned on the examples. When used with the specific form of the logistic regression,
this is also the logistic loss. Now, to solve for the optimal setting of parameters using a maximum
likelihood approach as we’ve done previously, we start by taking the gradient of the cross-entropy
error function with respect to w:
N
X
∇E(w) = (ŷi − yi )xi (3.14)
i=1
which we arrive at by recognizing that the derivative of the logistic sigmoid function can be written
in terms of itself as:
∂σ(z)
= σ(z)(1 − σ(z))
∂z
Let’s inspect the form of Equation 3.14 for a moment to understand its implications. First, it’s
a summation over all of our data points, as we would expect. Then, for each data point, we are
taking the difference between our predicted value ŷi and the actual value yi , and multiplying that
difference by the input vector xi .
While a closed form solution does not present itself here as it did in the case of linear regression
due to the nonlinearity of the logistic sigmoid function, we can still optimize the parameters w of
our model using an iterative procedure like gradient descent, where the objective function is defined
by Equation 3.13.
Definition 3.6.3 (Softmax): Softmax is the multi-class generalization of the sigmoidal activa-
tion function. It accepts a vector of activations (inputs) and returns a vector of probabilities
corresponding to those activations. It is defined as follows:
exp (zk )
softmaxk (z) = PK , for all k
i=1 exp (zi )
Multi-class logistic regression uses softmax over a vector of activations to select the most likely
target class for a new data point. It does this by applying softmax and then assigning the new data
point to the class with the highest probability.
3.6. PROBABILISTIC METHODS 43
Example 3.1 (Softmax Example): Consider an example that has three classes: C1 , C2 , C3 . Let’s
say we have an activation vector z for our new data point x that we wish to classify, given by:
4
⊤
z = W x = 1
7
where
zj = w⊤
j x
And therefore, we would assign our new data point x to class C3 , which has the largest activation.
As in the two-class logistic regression case, we now need to solve for the parameters W of our
model, also written as {wj }. Assume we have an observed data set {xi , yi } for i = 1..N where yi
are one-hot encoded target vectors. We begin this process by writing the likelihood for our data,
which is only slightly modified here to account for multiple classes:
N Y
K N Y
K
y
Y Y
p({yi }N
i=1 |W) = p(yi = Cj |xi )yij = ŷijij (3.15)
i=1 j=1 i=1 j=1
We can now take the negative logarithm to get the cross-entropy error function for the multi-
class classification problem:
N X
X K
E(W) = − ln p({yi }N
i=1 |W) =− yij ln ŷij (3.16)
i=1 j=1
As in the two-class case, we now take the gradient with respect to one of our weight parameter
vectors wj :
XN
∇wj E(W) = (ŷij − yij )xi (3.17)
i=1
which we arrived at by recognizing that the derivative of the softmax function with respect to the
input activations zj can be written in terms of itself:
∂softmaxk (z)
= softmaxk (z)(Ikj − softmaxj (z))
∂zj
⋆ Notice that with probabilistic generative modeling, we choose a specific distribution for our class-conditional
densities instead of simply using a generalized linear model combined with a sigmoid/softmax function as we did
in the logistic regression setting. This highlights the difference between discriminative and generative modeling: in
the generative setting, we are modeling the production of the data itself instead of simply optimizing the parameters
of a more general model that predicts class membership directly.
For simplicity, we’ll assume a shared covariance matrix Σ between our two classes. Then, for data
points xi from class C1 , we have:
Now that we have specified the log-likelihood function for our model, we can go about optimizing
our model by maximizing this likelihood. One way to do this is with a straightforward maximimum
likelihood estimation approach. We will optimize our parameters π, µ1 , µ2 , and, Σ separately, using
the usual procedure of taking the derivative, setting equal to 0, and then solving for the parameter
of interest. We write down this MLE solution in the following section.
MLE Solution
Solving for π
Beginning with π, we’ll concern ourselves only with the terms that depend on π which are:
N
X
yi ln π + (1 − yi ) ln (1 − π)
i=1
where N1 is the number of data points in our data set from class C1 , N2 is the number of data
points from class C2 , and N is just the total number of data points. This means that the maximum
likelihood solution for π is the fraction of points that are assigned to class C1 , a fairly intuitive
solution and one that will be commonly seen when working with maximum likelihood calculations.
Solving for µ
Let’s now perform the maximization for µ1 . Start by considering the terms from our log likelihood
involving µ1 :
N N
X 1X
yi ln N (xi |µ1 , Σ) = − yi (xi − µ1 )⊤ Σ−1 (xi − µ1 ) + c
2
i=1 i=1
46 CHAPTER 3. CLASSIFICATION
where c are constants not involving the µ1 term. Taking the derivative with respect to µ1 , setting
equal to 0, and rearranging:
N
1 X
µ1 = yi xi
N1
i=1
which is simply the average of all the data points xi assigned to class C1 , a very intuitive result.
By the same derivation, the maximum likelihood solution for µ2 is:
N
1 X
µ2 = (1 − yi )xi
N2
i=1
Solving for Σ
We can also the maximum likelihood solution for the shared covariance matrix Σ. Start by
considering the terms in our log likelihood expression involving Σ:
N N N
1X 1X 1X
− yi ln |Σ| − yi (xi − µ1 )⊤ Σ−1 (xi − µ1 ) − (1 − yi ) ln |Σ|
2 2 2
i=1 i=1 i=1
N
1 X
− (1 − yi )(xi − µ2 )⊤ Σ−1 (xi − µ2 )
2
i=1
We can use the following “matrix cookbook formulas” to help with taking the derivative with
respect to Σ. Also, we adopt convention Z−⊤ := (Z⊤ )−1 . The two helpful formulas are:
∂a⊤ Z−1 b
= −Z−⊤ ab⊤ Z−⊤
∂Z
∂ ln | det(Z)|
= Z−⊤ .
∂Z
Now, taking the derivative with respect to Σ and collecting terms, we have:
N N
1 1X 1X
− N Σ−⊤ + yi Σ−⊤ (xi − µ1 )(xi − µ1 )⊤ Σ−⊤ + (1 − yi )Σ−⊤ (xi − µ2 )(xi − µ2 )⊤ Σ−⊤
2 2 2
i=1 i=1
Setting this to zero and multiplying both sides by Σ⊤ from the left and right (with the effect of
−1 −1 ⊤
retaining Σ⊤ in only the first term, since Σ⊤ Σ−⊤ = Σ⊤ Σ⊤ = I and Σ−⊤ Σ⊤ = Σ⊤ Σ =
I), and multiplying by 2, we have
N
X
⊤ ⊤ ⊤
NΣ − yi (xi − µ1 )(xi − µ1 ) + (1 − yi )(xi − µ2 )(xi − µ2 ) = 0.
i=1
Rearranging to solve for Σ, and recognizing that covariance matrices are symmetric, and so
Σ⊤ = Σ, we have:
N
1 X ⊤ ⊤
Σ= yi (xi − µ1 )(xi − µ1 ) + (1 − yi )(xi − µ2 )(xi − µ2 ) .
N
i=1
3.6. PROBABILISTIC METHODS 47
This has the intuitive interpretation that the maximum likelihood solution for the shared co-
variance matrix is the weighted average of the two individual covariance matrices. Note that
(xi − µ1 )(xi − µ1 )⊤ is a matrix (the outer product of the two vectors). Also, yi is a scalar, which
means that each term is a sum of matrices. For any point i, only one of two matrices inside will
contribute due to the use of yi and (1 − yi ).
It is relatively straightforward to extend these maximum likelihood derivations from their two-
class form to their more general, multi-class form.
Naive Bayes
There exists a further simplification to probabilistic generative modeling in the context of classifi-
cation known as Naive Bayes.
Definition 3.6.4 (Naive Bayes): Naive Bayes is a type of generative model for classification
tasks. It imposes the simplifying rule that for a given class Ck , we assume that each feature of
the data points x generated within that class are independent (hence the descriptor ‘naive’). This
means that the conditional distribution p(x|y = Ck ) can be written as:
D
Y
p(x|y = Ck ) = p(xi |y = Ck )
i=1
where D is the number of features in our data point x and Ck is the class. Note that Naive Bayes
does not specify the form of the model p(xi |y = Ck ), this decision is left up to us.
This is obviously not a realistic simplification for all scenarios, but it can make our calculations
easier and may actually hold true in certain cases. We can build more intuition for how Naive
Bayes works through an example.
Example 3.2 (Naive Bayes Example): Suppose you are given a biased two-sided coin and two
biased dice. The coin has probabilities as follows:
Heads : 30%
Tails : 70%
The dice have the numbers 1 through 6 on them, but they are biased differently. Die 1 has
probabilities as follows:
1 : 40%
2 : 20%
3 : 10%
4 : 10%
5 : 10%
6 : 10%
48 CHAPTER 3. CLASSIFICATION
1 : 20%
2 : 20%
3 : 10%
4 : 30%
5 : 10%
6 : 10%
Your friend is tasked with doing the following. First, they flip the coin. If it lands Heads, they
select Die 1, otherwise they select Die 2. Then, they roll that die 10 times in a row, recording the
results of the die rolls. After they have completed this, you get to observe the aggregated results
from the die rolls. Using this information (and assuming you know the biases associated with the
coin and dice), you must then classify which die the rolls came from. Assume your friend went
through this procedure and produced the following counts:
1:3
2:1
3:2
4:2
5:1
6:1
Determine which die this roll count most likely came from.
Solution:
This problem is situated in the Naive Bayes framework: for a given class (dictated by the coin
flip), the outcomes within that class (each die roll) are independent. Making a classification in this
situation is as simple as computing the probability that the selected die produced the given roll
counts. Let’s start by computing the probability for Die 1:
Notice that we don’t concern ourselves with the normalization constant for the probability of the
roll count - this will not differ between the choice of dice and we can thus ignore it for simplicity.
Now the probability for Die 2:
Therefore, we would classify this roll count as having come from Die 2.
3.7. CONCLUSION 49
Note that this problem asked us only to make a classification prediction after we already
knew the parameters governing the coin flip and dice rolls. However, given a data set, we could
have also used a maximum likelihood procedure under the Naive Bayes assumption to estimate
the values of the parameters governing the probability of the coin flip and die rolls.
3.7 Conclusion
In this chapter, we looked at different objectives and techniques for solving classification problems,
including discriminant functions, probabilistic discriminative models, and probabilistic generative
models. In particular, we emphasized the distinction between two-class and multi-class problems
as well as the philosophical differences between generative and discriminative modeling.
We also covered several topics that we will make use of in subsequent chapters, including sigmoid
functions and softmax, maximum likelihood solutions, and further use of basis changes.
By now, you have a sound understanding of generative modeling and how it can be applied to
classification tasks. In the next chapter, we will explore how generative modeling is applied to a
still broader class of problems.
Chapter 4
Neural Networks
Despite how seemingly popular neural networks have become recently, they aren’t actually a novel
technique. The first neural networks were described in the early 1940s, and the only reason they
weren’t put into practice shortly thereafter was the fact that we didn’t yet have access to the large
amounts of storage and compute that complex neural network require. Over the last two decades,
and particularly with the advent of cloud computing, we now have more and more access to the
cheap processing power and memory required to make neural networks a viable option for model
building.
As we will come to see in this chapter, neural networks are an extraordinarily flexible class of
models used to solve a variety of different problem types. In fact, this flexibility is both what makes
them so widely applicable and yet so difficult to use properly. We will explore the applications,
underlying theory, and training schemes behind neural networks.
4.1 Motivation
For problems that fall into the category of regression or classification, we’ve already discussed the
utility of basis functions. Sometimes, a problem that is intractable with our raw input data will
be readily solvable with basis-transformed data. We often select these basis changes using expert
knowledge. For example, if we were working with a data set that related to chemical information,
and there were certain equations that a chemist told us to be important for the particular problem
we were trying to solve, we might include a variety of the transformations that are present in those
equations.
However, imagine now that we have a data set with no accompanying expert information. More
often than not, complex problem domains don’t come with a useful set of suggested transformations.
How do we find useful basis functions in these situations? This is exactly the strength of neural
networks - they identify the best basis for a data set!
Neural networks simultaneously solve for our model parameters and the best basis transforma-
tions. This makes them exceedingly flexible. Unfortunately, this flexibility is also the weakness of
neural nets: while it enables us to solve difficult problems, it also creates a host of other complica-
tions. Chief among these complications is the fact that neural networks require a lot of computation
to train. This is a result of the effective model space being so large - to explore it all takes time
and resources. Furthermore, this flexibilty can cause rather severe overfitting if we are not careful.
In summary, neural networks identify good basis transformations for our data, and the strengths
and weaknesses of neural networks stem from the same root cause: model flexibility. It will be our
goal then to appropriately harness these properties to create useful models.
50
4.1. MOTIVATION 51
In the previous two chapters, we explored two broad problem types: classification and regression,
and it’s natural to wonder where neural networks fit in. The answer is that they are applicable to
both. The flexibility of neural networks even extends to the types of problems they can be made to
handle. Thus, the tasks that we’ve explored over the last two chapters, such as predicting heights
in the regression case or object category in the classification case, can be performed by neural
networks.
Given that neural networks are flexible enough to be used as models for either regression or
classification tasks, this means that every time you’re faced with a problem that falls into one of
these categories, you have a choice to make between the methods we’ve already covered or using
a neural network. Before we’ve explored the specifics of neural networks, how can we discern at a
high level when they will be a good choice for a specific problem?
One simple way to think about this is that if we never needed to use neural networks, we probably
wouldn’t. In other words, if a problem can be solved effectively by one of the techniques we’ve
already described for regression or classification (such as linear regression, discriminant functions,
etc.), we would prefer to use those. The reason is that neural networks are often more memory
and processor intensive than these other techniques, and they are much more complex to train and
debug.
The flip side of this is that hard problems are often too complex or too hard to engineer features
for to use a simple regression or classification technique. Indeed, even if you eventually think you
will need to use a neural network to solve a given problem, it makes sense to try a simple technique
first both to get a baseline of performance and because it may just happen to be good enough.
What is so special about neural networks that they can solve problems that the other techniques
we’ve explored may not be able to? And why are they so expensive to train? These questions will
be explored over the course of the chapter, and a good place to start is with the status of neural
networks as universal function approximators.
The flexibility of neural networks is a well-established phenomenon. In fact, neural networks are
what are known as universal function approximators. This means that with a large enough network,
it is possible to approximate any function. The proof of this is beyond the scope of this textbook,
but it provides some context for why flexibility is one of the key attributes of neural networks.
ML Framework Cube: Neural Networks
As universal function approximators, neural networks can operate over discrete or continuous out-
puts. We primarily use neural networks to solve regression or classification problems, which involve
training on data sets with example inputs and outputs, making this a supervised technique. Fi-
nally, while there exist probabilistic extensions for neural networks, they primarily operate in the
non-probabilistic setting.
Domain Training Probabilistic
Continuous/Discrete Supervised No
52 CHAPTER 4. NEURAL NETWORKS
Figure 4.2: Zooming in on the inputs and the first node of the first layer.
⋆ Every node in every layer has distinct weights associated with it.
This gives us the activation for the first node in the first hidden layer. Once we’ve done this
for every node in the first hidden layer, we make a non-linear transform of these activations, and
then move on to computing the activations for the second hidden layer (which require the outputs
from the first hidden layer, as indicated by the network of connections). We keep pushing values
through the network in this manner until we have our complete output layer, at which point we
are finished.
We’ve skipped over some important details in this high-level overview, but with this general
information about what a neural network looks like and the terminology associated with it, we can
now dig into the details a little deeper.
where ϕ = ϕ(x), ϕ is the basis transformation function, and D is the dimensionality of the data
point.
Typically with this linear regression setup, we are training our model to optimize the parameters
w. With neural networks this is no different— we still train to learn those parameters. However,
the difference in the neural network setting is that the basis transformation function ϕ is no longer
fixed. Instead, the transformations are incorporated into the model parameters, and thus learned
at the same time.
This leads to a different functional form for neural networks. A neural network with M nodes
in its first hidden layer performs M linear combinations of an input data point x:
D
(1) (1) (1)
X
aj = wjd xd + wj0 ∀j ∈ 1..M (4.2)
d=1
Here, we use a(1) to denote the activation of a unit in layer 1 and notation w(1) denotes the
weights used to determine the activations in layer 1. We also make the bias explicit. We will still
(1)
use the bias trick in general, but we’ve left it out here to explicitly illustrate the bias term wj0 .
Other than this, equation 4.2 describes what we’ve already seen in Figure 4.2. The only difference
is that we index each node in the hidden layer (along with its weights) by j.
(1)
The M different values aj are the activations. We transform these activations with a non-linear
activation function h(·) to give:
(1)
zj = h(aj ) (4.3)
⋆ Note that we didn’t mention activation functions in the previous section only for the sake of simplicity. These
non-linearities are crucial to the performance of neural networks because they allow for modeling of outcomes that
vary non-linearly with their input variables.
(1)
These values zj correspond to the outputs of the hidden units, each of which is associated
with an activation function. Superscript (1) indicates they are the outputs of units in layer 1. A
typical activation function is the sigmoid function, but other common choices are the tanh function
and rectified linear unit (ReLU).
(1)
These output values zj , for units j ∈ {1, . . . , M } in layer 1, form the inputs to the next
layer. The activation of unit j ′ in layer 2 depends on the outputs from layer 1 and the weights
(2) (2) (2)
wj ′ 0 , wj ′ 1 , . . . , wj ′ M that define the linear sum at the input of unit j ′ :
M
(2) (2) (1) (2)
X
aj ′ = wj ′ m z j + wj ′ 0 (4.4)
j=1
We can connect many layers together in this way. They need not all have the same number of
nodes but we will adopt M for the number of nodes in each layer for convenience of exposition.
Eventually, we will reach the output layer, and each output is denoted yk , for k ∈ {1, . . . , K}. The
final activation function may be the sigmoid function, softmax function, or just linear (and thus no
transform).
We can now examine a more complete diagram of a feed-forward neural network, shown in
Figure 4.3. It may be helpful to reread the previous paragraphs and use the diagram to visualize
how a neural network transforms its inputs. This is a single hidden layer, or two-layer, network.
Here, we use z to denote the output values of the units in the hidden layer.
4.3. NEURAL NETWORK BASICS AND TERMINOLOGY 55
⋆ Different resources choose to count the number of layers in a neural net in different ways. We’ve elected to count
each layer of non-input nodes, thus the two-layer network in Figure 4.3. However, some resources will choose to count
every layer of nodes (three in this case) and still others count only the number of hidden layers (making this a one
layer network).
Combining Figure 4.3 and our preceeding functional description, we can describe the operation
performed by a two-layer neural network using a single functional transformation (with m to index
a unit in the hidden layer):
M
X D
X
(2) (1) (1) (2)
yk (x, w) = σ wkm h wmd xd + wm0 + wk0 (4.5)
m=1 d=1
where we’ve elected to make the final activation function the sigmoid function σ(·), as is suitable
for binary classification. We use h to denote the non-linear activation function for a hidden unit.
Written like this, a neural network is simply a non-linear function that transforms an input x into
an output y that is controlled by our set of parameters w.
Furthermore, we see now why this basic variety of neural networks is a feed-forward neural
network. We’re simply feeding our input x forward through the network from the first layer to
the last layer. Assuming we have a fully trained network, we can make predictions on new input
data points by propagating them through the network to generate output predictions (“the forward
pass”).
We can also simplify this equation by utilizing the bias trick and appending an x0 = 1 value to
each of our data points such that:
XM X D
(2) (1)
yk (x, w) = σ wkm h wmd xd
m=1 d=1
Finally, it’s worth considering that while a neural network is a series of linear combinations, it
is special because of the differentiable non-linearities applied at each of the hidden layers. Without
56 CHAPTER 4. NEURAL NETWORKS
these non-linearities, the successive application of different network weights would be equivalent to
a single large linear combination.
N 2
1X
L(w) = y(xn , w) − yn , (4.6)
2
n=1
where yn is the target value on example n. Sometimes we will have a regression problem with
multiple outputs, in which case the loss would also take the sum over these different target values.
For a binary classification problem, which we model through a single, sigmoid output activation
unit, then negated log-likelihood (or cross-entropy) is the typical loss function:
N
X
L(w) = − yn ln ŷn + (1 − yn )(ln (1 − ŷn ) (4.7)
n=1
For a multiclass classification problem, produced by a softmax function in the output activation
layer, we would use the negated log likelihood (cross entropy) loss:
N X
K
X exp(ak (x, w))
L(w) = − ykn ln PK (4.8)
n=1 k=1 j=1 exp(aj (x, w))
⋆ Loss function and objective function all refer to the same concept: the function we optimize to train our model.
4.4.3 Backpropagation
Considering how our feed-forward neural network works, by propagating activations through our
network to produce a final output, it’s not immediately clear how we can compute gradients for
the weights that lie in the middle of our networ. There is an elegant solution to this, which comes
from “sending errors backwards” through our network, in a process known as backpropagation.
Backpropagation refers specifically to the portion of neural network training during which we
compute the derivative of the objective function with respect to the weight parameters. This is
done by propagating errors backwards through the network, hence the name.
⋆ Note that we still need to update the value of the weight parameters after computing their derivatives. This is
typically done using gradient descent or some variant of it.
(ℓ−1) (ℓ−1)
where there are M incoming nodes, each with corresponding output values z1 , ..., zM , and
(ℓ) (ℓ)
with the weights in layer ℓ corresponding to node j denoted by wj1 , ..., wjM . This activation is
(ℓ)
transformed by an activation function h(·) to give unit output zj :
(ℓ) (ℓ)
zj = h(aj ). (4.10)
Computing these values as we flow through the network constitutes the forward pass through
our network.
We now wish to begin the process of computing derivatives of the objective function with respect
to our weights. For the sake of simplicity, we’ll assume that the current setting of our parameters
w generates a loss of L for a single data point, as though we were performing stochastic gradient
descent.
Let’s consider how we could compute the derivative of L with respect to an individual weight
(ℓ)
in our network, wjm (the mth weight for activation j in layer ℓ):
∂L
(ℓ)
. (4.11)
∂wjm
We first need to figure out what the dependence of L is on this weight. This weight contributes
(ℓ)
to the final result only via its contribution to the activation aj . This allows us to use the chain
58 CHAPTER 4. NEURAL NETWORKS
Figure 4.4: Gradient of the loss function in a neural network with respect to a weight. It depends
on the input value zm and the “error” corresponding to the activation value at the output end of
the weight.
(ℓ)
∂L ∂L ∂aj
(ℓ)
= (ℓ)
· (ℓ)
. (4.12)
∂wjm ∂aj ∂wjm
The first part of this is the, typically non-linear, dependence of loss on activation. The second
part is the linear dependence of activation on weight. Using Equation 4.9, we have that:
(ℓ)
∂aj (ℓ−1)
(ℓ)
= zm ,
∂wjm
and just the value of the input from the previous layer. We now introduce the following notation
for the first term,
(ℓ) ∂L
δj = (ℓ)
, (4.13)
∂aj
(ℓ)
where δj values are referred to as errors. We rewrite Equation 4.12 as:
∂L (ℓ)
(ℓ−1)
(ℓ)
= δj zm . (4.14)
∂wjm
The implications of Equation 4.14 are significant for understanding backpropagation. The
derivative of the loss with respect to an arbitrary weight in the network can be calculated as the
(ℓ) (ℓ−1)
product of the error δj at the “output end of that weight” and the value zm at the “input end
of the weight.” We visualize this property in Figure 4.4 (dropping the layer subscripting).
To compute the derivatives, it suffices to compute the values of δj for each node, also saving
the output values zm during the forward pass through the network (to be multiplied by the values
of δj to get partials).
4.4. NETWORK TRAINING 59
⋆ We will only have “errors values” δj for the hidden and output units of our network. This is logical because there
is no notion of applying an error to our input data, which we have no control over.
We now consider how to compute these error values. For a unit in the output layer, indexing
it here by k, and assuming the output activation function is linear and adopting least squares loss
(i.e., regression), we have for the dependence of loss on the activation of this unit,
d 12 (ŷk − yk )2
(ℓ) ∂L ∂L
δk = (ℓ)
= = = ŷk − yk .
∂a ∂ ŷk dŷk
k
(ℓ)
Here, we use shorthand ak = ŷk , providing the kth dimension of the prediction of the model.
Although a regression problem, we’re imagining here that there are multiple regression targets (say,
the height, weight and blood pressure of an individual). Here, yk is the true target value for this
data point. Note that this is for OLS. The expression would be different for a classification problem
and negated log likelihood as the loss.
(ℓ)
To compute the error δj for a hidden unit j in a layer ℓ, we again make use of the chain rule,
and write:
M (ℓ+1)
(ℓ) ∂L X ∂L ∂am
δj = (ℓ)
= (ℓ+1) (ℓ)
, (4.15)
∂aj m=1 ∂am ∂aj
where the summation runs over all of the M nodes to which the node j in layer ℓ sends connections,
as seen in Figure 4.5. This expression recognizes that the activation value of this unit contributes
only via its contribution to the activation value of each unit to which it is connected in the next layer.
The first term in one of the products in the summation is the, typically non-linear, dependence
between loss and activation value of a unit in the next layer. The second term in one of the products
captures the relationship between this activation and the subsequent activation.
Now, we can simplify by noticing that:
∂L (ℓ+1)
(ℓ+1)
= δm {by definition} (4.16)
∂am
(ℓ+1) (ℓ)
∂am dh(aj ) (ℓ+1) (ℓ) (ℓ+1)
(ℓ)
= (ℓ)
· wmj = h′ (aj ) · wmj . {chain rule} (4.17)
∂aj daj
Substituting, and pulling forward the derivative of the activation function, we can rewrite the
expression for the error on a hidden unit j in layer ℓ as:
M
(ℓ) (ℓ) (ℓ+1)
X
δj = h′ (aj ) (ℓ+1)
wmj δm . (4.18)
m=1
This is very useful, and is the key insight in backpropagation. It means that the value of the
errors can be computed by “passing back” (backpropagating) the errors for nodes farther up in the
network! Since we know the values of δ for the final layer of output node, we can recursively apply
Equation 4.18 to compute the values of δ for all the nodes in the network.
Remember that all of these calculations were done for a single input data point that generated
the loss L. If we were using SGD with mini-batches, then we would perform same calculation for
each data point in mini-batch B, and average the gradients as follows:
∂L 1 X ∂Ln
(ℓ)
= (ℓ)
, (4.19)
∂wjm |B|
n∈B ∂wjm
60 CHAPTER 4. NEURAL NETWORKS
Figure 4.5: Summation over the nodes (blue) in layer ℓ + 1 to which node j in layer ℓ (gold) sends
connections (green). Note: read this as m and m′ .
Example 4.1 (Backpropagation Example): Imagine the case of a simple two layer neural net-
work as in Figure 4.3, with K outputs (we denote them ŷ1 through ŷK for a given data point). We
imagine this is a regression problem, but one with multiple dimensions to the output, and assume
OLS. For a given data point, the loss is computed as:
K
X 1
L= (ŷk − yk )2 ,
2
k=1
where we write yk for the kth dimension of the target value. For a unit in the hidden layer, with
activation value a, we make use of the sigmoid activation, with
1
z = σ(a) = ,
1 + exp (−a)
whose derivative is given by:
∂σ(a)
= σ(a)(1 − σ(a)).
∂a
For an input data point x, we forward propagate through the network to get the activations of the
hidden layer, and for each m in this layer we have:
D
(1)
X
a(1)
m = wmd xd ,
d=0
4.5. CHOOSING A NETWORK STRUCTURE 61
(1) (1)
given weights wm0 , . . . , wmD , and with output value from unit m as,
(1)
zm = σ(a(1)
m ).
We propagate these output values forward to get the outputs, and for each output unit k, we have:
M
(2)
X
(1)
ŷk = wkm zm ,
m=0
(2) (2)
where wk0 , . . . , wkM are the weights for unit k.
Now that we’ve propagated forward, we propagate our errors backwards! We start by computing
the errors for the output layer as follows:
(2) ∂L
δk = = ŷk − yk .
∂ ŷk
We then backpropagate these errors back to each hidden unit m in layer 1 as follows:
K
∂L X (2) (2)
(1)
δm = (1)
= h′ (a(1)
m ) wkm δk
∂am k=1
K
(2)
X
= σ(a(1) (1)
m )(1 − σ(am )) wkm (ŷk − yk )
k=1
K
(2)
X
(1) (1)
= zm (1 − zm ) wkm (ŷk − yk ).
k=1
And now that we have our errors for the hidden and output layers, we can compute the derivative
of the loss with respect to our weights as follows, for the dth weight on the mth unit in layer 1,
and the mth weight on the kth unit in layer 2:
∂L (1) ∂L (2)
(1)
(1)
= δm xd , (2)
= δk zm .
∂wmd ∂wkm
We then use these derivatives along with an optimization technique such as stochastic gradient
descent to improve the model weights.
Figure 4.6: Networks with different structures and numbers of internal nodes.
avoiding overfitting. We can use a similar process to identify a reasonable network structure.
First of all, the input and output parameters of a neural network are generally decided for us:
the dimensionality of our input data dictates the number of input units and the dimensionality of
the required output dictates the number of output units. For example, if we have an 8-by-8 pixel
image and need to predict whether it is a ‘0’ or a ‘1’, our input dimensions are fixed at 64 and our
output dimensions are fixed at 2. Depending on whether you wish to perform some sort of pre or
post-processing on the inputs/outputs of your network, this might not actually be the case, but
in general when choosing a network structure we don’t consider the first or last layer of nodes as
being a relevant knob that we can tune.
That leaves us to choose the structure of the hidden layers in our network. Unsurprisingly, the
more hidden layers we have and the more nodes we have in each of those layers, the more variation
we will produce in our results and the closer we will come to overfitting.
Thus, we can use cross validation in the same way we’ve done before: train our model with
differing numbers of internal units and structures (as in Figure 4.6) and then select the model that
performs best on the validation set.
⋆ There are other considerations at play beyond performance when choosing a network structure. For example, the
more internal units you have in your network, the more storage and compute time you will need to train them. If
either training time or response time after training a model is critical, you may need to consider consolidating your
network at the expense of some performace.
Regularization
You can also apply regularization to the weights in your network to help prevent overfitting. For
example, we could introduce a simple quadratic regularizer of the form λ2 wT w to our objective
function. There are other considerations to be made here, for example we would like our regularizer
to be invariant to scaling, meaning that multiplying our input data by a constant would produce
a proportionally equivalent network after training. The quadratic regularizer is not invariant to
scaling, but the basic concept of avoiding extreme weights is the same nonetheless.
Data Augmentation
We can use transformations to augment our data sets, which helps prevent overfitting. This tech-
nique is not specific to neural networks, but often the types of unstructured data for which we use
4.6. SPECIALIZED FORMS OF NEURAL NETWORKS 63
Definition 4.5.1 (Data Augmentation): Data augmentation refers to the practice of increasing
the size and diversity of your training data by applying transformations to the initial data set.
For example, if we are working with image data, we might choose to rotate or reflect the image,
depending on the type of network we are trying to build and whether or not this would preserve
the integrity of the image. We might also change something like the brightness or density of the
image data. In this way, we can produce more and more varied training points, thus reducing the
likelihood of overfitting.
because it’s remembering what came before. We don’t have this ability with a feed-forward network,
which by design only propagates information forward through the network. RNNs add backward
passing of activations into their network structure to improve predictions on data where there is
some temporal dependence on what came previously.
In this chapter, we will explore what are known as a support vector machines, or SVMs for short.
SVMs are broadly useful for problems in classification and regression, and they are part of a
family of techniques known as margin methods. The defining goal of margin methods, and SVMs
specifically, is to put as much distance as possible between data points and decision boundaries.
We will dig deeper into what exactly this means over the course of the chapter. One of the most
appealing aspects of SVMs is that they can be solved as convex optimization problems, for which
we can find a global optimum with relative ease. We will explore the mathematical underpinnings
of SVMs, which can be slightly more challenging than our previous topics, as well as their typical
use cases.
5.1 Motivation
While SVMs can be used for classification or regression, we will reason about them in the classifi-
cation case as it is more straightforward.
The grand idea behind SVMs is that we should construct a linear hyperplane in our feature
space that maximally separates our classes, which means that the different classes should be as
far from that hyperplane as possible. The distance of our data from the hyperplane is known as
margin.
Definition 5.1.1 (Margin): Margin is the distance of the nearest data point from the separating
hyperplane of an SVM model, as seen in Figure 5.1. Larger margins often lead to more generalizable
models.
A larger margin tends to mean that our model will generalize better, since it provides more
wiggle room to correctly classify unseen data (think about new data being a perturbation on current
data).
This idea of the margin of a separator is quite intuitive. If you were presented with Figure 5.1
and were asked to separate the two classes, you would likely draw the line that keeps data points
as far from it as possible. SVMs and other margin-based methods will attempt to algorithmically
recreate this intuition.
65
66 CHAPTER 5. SUPPORT VECTOR MACHINES
decision boundary, with the idea that this leads to model generalizability.
Other max margin methods are outside the scope of this textbook. These alternative methods
may differ from SVMs in a non-trivial way. For example, SVMs do not produce probabilities on
different classes, but rather decision rules for handling new data points. If you needed probabilities,
there are other max margin methods that can be used for the task.
ML Framework Cube: Support Vector Machines
SVMs are typically used in settings with discrete outputs. We need labeled training data to identify
the relevant hyperplane in an SVM model. Finally, SVMs operate in a non-probabilistic setting.
Domain Training Probabilistic
Discrete Supervised No
5.1.2 Applications
The theory behind SVMs has been around for quite some time (since 1963), and prior to the rise of
neural networks and other more computationally intensive techniques, SVMs were used extensively
for image recognition, object categorization, and other typical machine learning tasks.
SVMs are still widely used in practice, for example for classification problems known as anomaly
detection.
⋆ The purpose of anomaly detection is to identify unusual data points. For example, if we are manufacturing shoes,
we may wish to inspect and flag any shoe that seems atypical with respect to the rest of the shoes we produce.
Anomaly detection can be as simple as a binary classification problem where the data set is
comprised of anomalous and non-anomalous data points. As we will see, an SVM can be constructed
from this data set to identify future anomalous points very efficiently. SVMs extend beautifully to
settings where we want to use basis functions, and thus non-linear interactions on features. For
this reason, they continue to be competitive in many real-world situations where these kinds of
interactions are important to work with.
5.2. HARD MARGIN CLASSIFIER FOR LINEARLY SEPARABLE DATA 67
⋆ The expression ‘hard margin’ simply means that we don’t allow any data to be classified incorrectly. If it’s not
possible to find a hyperplane that perfectly separates the data based on class, then the hard margin classifier will
return no solution.
This is the discriminant function and we classify a new example to class 1 or −1 according to
the sign produced by our trained model h(x). Later, we will also make this more general by using
a basis function, ϕ(x) to transform to a higher dimensional feature space.
By specifying our model this way, we have implicity defined a hyperplane separating our two
classes given by:
w⊤ x + w0 = 0 (5.2)
Derivation 5.2.1 (Hyperplane Orthogonal to w): Imagine two data points x1 and x2 on the
hyperplane defined by w⊤ x + w0 = 0. When we project their difference onto our model w, we find:
which means that w is orthogonal to our hyperplane. We can visualize this in Figure 5.2.
Remember that we’re trying to maximize the margin between our training data and the hyper-
plane. The fact that w is orthogonal to our hyperplane will help with this.
To determine the distance between a data point x and the hyperplane, which we denote d, we
need the distance in the direction of w between the point and the hyperplane. We denote xp to be
68 CHAPTER 5. SUPPORT VECTOR MACHINES
the projection of x onto the hyperplane, which allows us to decompose x as the following:
w
x = xp + d (5.4)
||w||2
which is the sum of the portion of the projection of x onto the hyperplane and the portion of x
that is parallel to w (and orthogonal to the hyperplane). From here we can solve for d:
w⊤ w
w⊤ x = w⊤ xp + d .
||w||2
w⊤ x = −w0 + d||w||2
Rearranging:
w⊤ x + w0
d= .
||w||2
For each data point x, we now have the signed distance of that data point from the hyperplane.
For an example that is classified correctly, this signed distance d will be positive for class
yn = 1, and negative for class yn = −1. Given this, we can make the distance unsigned (and
always positive) for a correctly classified data point by multiplying by yn . Then, the margin for an
correctly classified data point (xn , yn ) is given by:
yn (w⊤ xn + w0 )
. (5.5)
||w||2
The margin for an entire data set is given by the margin to the closest point in the data set,
and
yn (w⊤ xn + w0 )
min . (5.6)
n ||w||2
5.2. HARD MARGIN CLASSIFIER FOR LINEARLY SEPARABLE DATA 69
Then, it is our goal to maximize this margin with respect to our model parameters w and w0 .
This is given by:
1 h i
max min yn (w⊤ xn + w0 ) (5.7)
w,w0 ||w||2 n
Here, we pull the 1/||w||2 term forward. Note carefully that w0 does not play a role in the
denominator ||w||2 .
This is a hard problem to optimize, but we can make it more tractable by recognizing some
important features of Equation 5.7. First, rescaling w → αw and w0 → αw0 , for any α > 0, has
no impact on the margin for any correctly classified data point xn . This is because the effect of α
cancels out in the numerator and denominator of Equation 5.5.
We can use this rescaling liberty to enforce
This does not change the optimal margin because we can always scale up both w and w0 by
α > 0 to achieve yn (w⊤ xn + w0 ) ≥ 1, and without affecting the margin. Moreover, since the
problem is to maximize 1/||w||2 , an optimal solution will want ||w||2 to be as small as possible, and
thus at least one of these constraints (5.8) will be binding and equal to one in an optimal solution.
Thus our optimization problem now looks like:
1
max s.t. yn (w⊤ xn + w0 ) ≥ 1, for all n. (5.9)
w,w0 ||w||2
Here, we recognized that minn yn (w⊤ xn + w0 ) = 1 in an optimal solution with these new
constraints, and adopted this in the objective. This simplifies considerably, removing the “min n ”
part of the objective.
1
Notice that maximizing ||w|| 2
is equivalent to minimizing ||w||22 . We will also add a constant
term 12 for convenience, leaving the hard-margin formulation of the training problem:
1
min ||w||22 s.t. yn (w⊤ xn + w0 ) ≥ 1, for all n. (5.10)
w,w0 2
Note that Equation 5.10 is now a quadratic programming problem, which means we wish to
optimize a quadratic function subject to a set of linear constraints on our parameters. Arriving at
this form was the motivation for the preceding mathematic manipulations. We will discuss shortly
how we actually optimize this function.
Definition 5.2.1 (Support Vector): A support vector in a hard-margin SVM formulation must
be a data point that is on the margin boundary of the optimal solution, with yn (w⊤ xn + w0 ) = 1
and margin 1/||w||2 .
70 CHAPTER 5. SUPPORT VECTOR MACHINES
Figure 5.3: Example of the resulting hyperplane for a hard margin SVM. The filled in data points
are support vectors in this example. A support vector for the hard-margin formulation must be on
the margin boundary, with a discriminant value of +1 or -1.
In the hard margin case we have constrained the closest data points to have discriminant value
w⊤ xn + w0 = 1 (−1 for a negative example). Figure 5.3 shows a hard margin SVM solution with
an illustration of corresponding support vectors.
⋆ After we have optimized an SVM in the hard margin case, we must have at least two support vectors with
discriminant value that is 1 or -1, and thus a margin of 1/||w||2 .
Figure 5.4: An outlier can make the hard margin formulation impossible or unable to generalize
well.
close to the hyperplane (or in some cases, even on the wrong side of the hyperplane). That is what
the soft margin formulation will allow for.
These slack variable penalize data points on the wrong side of the margin boundary, but they
don’t forbid us from allowing data points to be on the wrong side if this produces the best model. We
now reformulate the optimization problem as follows. This is the soft-margin training problem:
N
1 X
min ||w||22 + C ξn (5.12)
w,w0 2
n=1
s.t. yn (w⊤ xn + w0 ) ≥ 1 − ξn , for all n
ξn ≥ 0, for all n.
Here, C is a regularization parameter that determines how heavily we penalize violations of the
hard margin constraints. A large C penalizes violation of the hard margin constraints more heavily,
which means our model will follow the data closely and have small regularization. A small C won’t
heavily penalize having data points inside the margin region, relaxing the constraint and allowing
our model to somewhat disregard more of the data. This means more regularization.
72 CHAPTER 5. SUPPORT VECTOR MACHINES
Figure 5.5: Example of the resulting hyperplane for a soft margin SVM. The filled in data points
illustrate the support vectors in this example and must be either on the margin boundary or on
the “wrong side” of the margin boundary.
⋆ Unlike most regularization parameters we’ve seen thus far, C increases regularization as it gets smaller.
⋆ Not every data point on the margin boundary, in the margin region, or that is misclassified needs to be a support
vector in the soft-margin formulation. But those that become support vectors must meet one of these criteria.
⋆ A dual form is an equivalent manner of representing an optimization problem, in this case the quadratic programming
problem we need to optimize. Dual forms can be easier to work with than their initial form (“the primal form.”)
The dual form will be useful because it will allow us to bring in a basis function into the SVM
formulation in a very elegant and computationally efficient way.
But for the new “inequality form” of the constrained optimization problem, we also need to
introduce a new subproblem, which for any fixed w, solves
Now, if w violates one or more constraints in (5.13), then the subproblem (5.15) becomes
unbounded, with αn on the corresponding constraints driven arbitrarily large. Otherwise, if we
have gn (w) < 0 then we will have αn = 0, and we conclude αn gn (x) = 0 in all optimal solutions
to (5.16). Therefore, and assuming that problem (5.13) is feasible, we have L(w, α) = f (w) in an
optimal solution to (5.16). Thus, we establish that (5.16) is an equivalent formulation to (5.13).
Substituting into our problem (5.10), the Lagrangian formulation becomes
" #
1 ⊤ X
min max w w + αn (−yn (w⊤ xn + w0 ) + 1) (5.17)
w,w0 α, α≥0 2
n
74 CHAPTER 5. SUPPORT VECTOR MACHINES
⋆ The ‘ min max’ in Equation 5.20 may be initially confusing. The way to read this is that for any choice of w, w0 ,
w,w0 α≥0
the inner “max” problem then finds values of α to try to “defeat” the outer minimization objective.
We now wish to convert the objective in Equation 5.20 to a dual objective. Under the sufficient
conditions of strong duality which hold for this problem because Equation 5.10 has a quadratic
objective and linear constraints (but whose explanation is beyond the scope of this textbook), we
can equivalently reformulate the optimization problem (5.20) as:
max min L(w, α, w0 ) . (5.21)
α, α≥0 w,w0
At this point, we can use first order optimality conditions to solve for w, i.e., the inner mini-
mization problem, for some choice of α values. Taking the gradient, setting them equal to 0, and
solving for w, we have:
N
X
∇L(w, α, w0 ) = w − αn yn xn = 0
n=1
N
X
∗
⇔ w = αn yn xn . (5.22)
n=1
to (5.21). So, we don’t yet obtain the optimal value for w0 , but we do gain a new constraint on the
α-values that will need to hold in an optimal solution.
Now we substitute for w∗ into our Lagrangian function, and also assume (5.23), since this will
be adopted as a new constraint in solving the optimization problem. Given this, we obtain:
1 X X X
L(w, α, w0 ) = w⊤ w − w⊤ α n y n x n − w0 αn yn + αn
2 n n n
1 X
= − w⊤ w + αn
2 n
!⊤ !
X 1 X X
= αn − α n yn xn αn′ yn′ xn′ (5.24)
n
2 n ′ n
where the second equation follows from the first by using (5.22) and (5.23), and the third equation
follows by using (5.22).
This is now entirely formulated in terms of α, and provides the hard margin, dual formu-
lation:
X 1 XX
max αn − αn αn′ yn yn′ x⊤
n xn′ (5.25)
α
n
2 n n′
X
s.t. αn yn = 0, for all n
n
αn ≥ 0, for all n
P
Here, we add n αn yn = 0 as a constraint. This is another quadratic objective, subject to
linear constraints. This can be solved via SGD or another approach to solving convex optimization
problems. With a little more work we can use the optimal α values to make predictions.
Although the derivation is out of scope for this textbook, there is also a very similar dual form
for the soft-margin SVM training problem:
X 1 XX
max αn − αn αn′ yn yn′ x⊤
n xn′ (5.26)
α
n
2 n n ′
X
s.t. αn yn = 0, for all n
n
C ≥ αn ≥ 0, for all n
This puts an upper-bound on αn to prevent the dual from being unbounded in the case where
the hard-margin SVM problem is infeasible because the data cannot be separated. It is not yet
clear why any of this has been useful. We will see the value of the dual formulation when working
with basis functions.
⋆ By (5.22) we see how to find the weight vector w from a solution α. We didn’t yet explain how to find w0 . That
will be explained next.
For data points with αn > 0, this is taking a weighted vote over examples in the training data
based on the size of the inner product x⊤ n x.
Since α are Lagrange multipliers, they are non-negative, and moreover, by reasoning about
the “max subproblem” in the min-max formulation (5.16), we know that they take on value zero
whenever yn h(xn ) > 1. The data points for which αn > 0 are known as support vectors, and
they will must be data points that are either on the margin boundary, inside the margin region, or
misclassified. For the hard-margin formulation they must be data points on the margin boundary.
This is a major takeaway for the usefulness of SVMs: once we’ve trained our model, we can
discard most of our data. We only need to keep the support vectors to make predictions. Soon we
also see the “kernel trick.” This also illustrates why we need to solve for the values of α: those
values dictate which data points are the support vectors for our model.
Solving for w0
We can solve for w0 by recognizing that yn (w⊤ xn + w0 ) = 1 for any data point on the margin
boundary. For the hard-margin formulation we can solve for w0 using any example for which
αn > 0. For the soft-margin formulation, it can be shown that the only points with αn = C are
those inside the margin region or misclassified, and so that any point with C > αn > 0 is on the
margin boundary. Any such point can be solved to solve for w0 .
The idea of the kernel trick is that we might be able to compute K(·, ·) without actually working
in the basis function space, RM , but rather be able to compute the Kernel function directly through
algebra in the lower dimensional space, RD .
For example, it can be shown that the polynomial kernel
corresponds to computing the inner product with a basis function that makes use of all terms up to
degree q. When q = 2, then it is all constant, linear, and quadratic terms. The polynomial kernel
function does this without needing to actually project the examples to the higher dimensional space.
Rather it takes the inner product in the lower-dimensional space, adds 1 to this scalar, and then
raises it to the power of q. The implicit basis is growing exponentially large in q!
5.4. CONVERSION TO DUAL FORM 77
⋆ The kernel trick can even be used to work in an infinite basis. This is the case with the Gaussian kernel. If that is
of interest, you should look into Taylor series basis expansions and the Gaussian kernel.
The importance of the kernel trick is that when computations can be done efficiently in the
initial space RD , then the training problem can be solved by computing the pairwise K(·, ·) values
for all pairs of training examples, and then using SGD to solve the dual, soft-margin training
problem (with N decision variables).
⋆ The matrix K of elements K(xn , xn′ ) is known as the Gram matrix or Kernel matrix.
In practice, this provides for a logic for how different valid kernels can be composed (if they
maintain a p.s.d. Gram matrix!).
There exists a set of rules that preserve the validity of kernels through transformations. These
include such things as
It is always possible to test the validity of a given kernel by demonstrating that its Gram matrix
K is positive semidefinite.
Chapter 6
Clustering
In this chapter, we will explore a technique known as clustering. This represents our first foray into
unsupervised machine learning techniques. Unlike the previous four chapters, where we explored
techniques that assumed a data set of inputs and targets, with the goal of eventually making
predictions over unseen data, our data set will no longer contain explicit targets. Instead, these
techniques are motivated by the goal of uncovering structure in our data. Identifying clusters of
similar data points is a useful and ubiquitous unsupervised technique.
6.1 Motivation
The reasons for using an unsupervised technique like clustering are broad. We often don’t have a
specific task in mind; rather, we are trying to uncover more information about a potentially opaque
data set. For clustering specifically, our unsupervised goal is to group data points that are similar.
There are many reasons why we might separate our data by similarity. For organizational
purposes, it’s convenient to have different classes of data. It can be easier for a human to sift
through data if it’s loosely categorized beforehand. It may be a preprocessing step for an inference
method; for example, by creating additional features for a supervised technique. It can help identify
which features make our data points most distinct from one another. It might even provide some
idea of how many distinct data types we have in our set.
This idea of data being ‘similar’ means that we need some measure of distance between our
data points. While there are a variety of clustering algorithms available, the importance of this
distance measurement is consistent between them.
Distance is meant to capture how ‘different’ two data points are from each other. Then, we can
use these distance measurements to determine which data points are similar, and thus should be
clustered together. A common distance measurement for two data points x and x′ is given by:
v
uD
uX
||x − x’||L2 = t (xd − x’d )2 (6.1)
d=1
where D is the dimensionality of our data. This is known as L2 or Euclidean distance, and you
can likely see the similarity to L2 regularization.
There are a variety of distance measurements available for data points living in a D-dimensional
Euclidean space, but for other types of data (such as data with discrete features), we would need
to select a different distance metric. Furthermore, the metrics we choose to use will have an impact
on the final results of our clustering.
78
6.2. K-MEANS CLUSTERING 79
In unsupervised learning, the domain refers to the domain of the hidden variable z (which is
analogous to y in supervised learning). In clustering, we have z’s that represent the discrete
clusters. Furthermore, the techniques that we explore in this chapter are fully unsupervised and
non-probabilistic.
Domain Training Probabilistic
Discrete Unsupervised No
6.1.1 Applications
Here are a few specific examples of use cases for clustering:
As we mentioned above, there are different methods available for clustering. In this chapter, we
will explore two of the most common techniques: K-Means Clustering and Hierarchical Agglomer-
ative Clustering. We also touch on the flavors available within each of these larger techniques.
2. Using a distance metric of your choosing, assign each data point to the closest cluster.
3. Update the cluster centers based on your assignments and distance metric (for example, when
using L2 distance, we update the cluster centers by averaging the data points assigned to each
cluster).
In the case where we are using the L2 distance metric, this is known as Lloyd’s algorithm, which
we derive in the next section.
Objective
The loss function for our current assignment of data points to clusters is given by:
N X
C
C N X
rnc ||xn − µc ||22
L(X, µ c=1
, r n=1 ) = (6.2)
n=1 c=1
where X is our N xD data set (N is the number of data points and D is the dimensionality of
C
our data), µ c=1 is the CxD matrix of cluster centers (C is the number of clusters we chose),
N
and r n=1 is our N xC matrix of responsibility vectors. These are one-hot encoded vectors (one
per data point), where the 1 is in the position of the cluster to which we assigned the nth data point.
Algorithm
We first adjust our responsibility vectors to minimize each data point’s distance from its cluster
center. Formally:
= 1 if c = arg min||xn − µc′ ||
rnc = c′ (6.3)
= 0 otherwise
After updating our responsibility vectors, we now wish to minimize our loss by updating our cluster
centers µc . The cluster centers which minimize our loss can be computed by taking the derivative
of our loss with respect to µc , setting equal to 0, and solving for our new cluster centers µc :
N
∂L X
= −2 rnc (xn − µc )
∂µc
n=1 (6.4)
PN
n=1 rnc xn
µc = P N
n=1 rnc
Intuitively, this is the average of all the data points xn assigned to the cluster center µc .
We then update our responsibility vectors based on the new cluster centers, update the cluster
centers again, and continue this cycle until we have converged on a stable set of cluster centers and
responsibility vectors.
Note that while Lloyd’s algorithm is guaranteed to converge, it is only guaranteed to converge
to a locally optimal solution. Finding the globally optimal set of assignments and cluster centers is
an NP-hard problem. As a result, a common strategy is to execute Lloyd’s algorithm several times
with different random initializations of cluster centers, selecting the assignment that minimizes
loss across the different trials. Furthermore, to avoid nonsensical solutions due to scale mismatch
between features (which would throw our Euclidean distance measurements off), it makes sense to
standardize our data in a preprocessing step. This is as easy as subtracting the mean and dividing
by the standard deviation across each feature.
6.2. K-MEANS CLUSTERING 81
For some more clarity on exactly how Lloyd’s algorithm works, let’s walk through an example.
Example 6.1 (Lloyd’s Algorithm Example): We start with a data set of size N = 6. Each
data point is two-dimensional, with each feature taking on a value between -3 and 3. We also
have a ‘Red’ and ‘Green’ cluster. Here is a table and graph of our data points, labelled A through F:
Let’s say we wish to have 2 cluster centers. We then randomly initialize those cluster centers by
selecting two data points. Let’s say we select B and F. We identify our cluster centers with a red
and green ‘X’ respectively:
82 CHAPTER 6. CLUSTERING
We now begin Lloyd’s algorithm by assigning each data point to its closest cluster center:
We then update our cluster centers by averaging the data points assigned to each:
We proceed like this, updating our cluster centers and assignments, until convergence. At conver-
gence, we’ve achieved these cluster centers and assignments:
84 CHAPTER 6. CLUSTERING
Where our red cluster is at (-1.75, -2.25) and our green cluster is at (1.5 , 0). Note that for this
random initialization of cluster centers, we deterministically identified the locally optimal set of
assignments and cluster centers. For a specific initialization, running Lloyd’s algorithm will always
identify the same set of assignments and cluster centers. However, different initializations will
produce different results
We would much rather start with a random initialization that looks like Figure 6.3.
We can use the hint that we want our cluster centers somewhat spread out to find a better
random initialization. This is where the initialization algorithm presented by K-Means++ comes
in.
For K-Means++, we choose the first cluster center by randomly selecting a point in our data
set, same as before. However, for all subsequent cluster center initializations, we select points in
our data set with probability proportional to the squared distance from their nearest cluster center.
The effect of this is that we end up with a set of initializations that are relatively far from one
another, as in Figure 6.3.
until convergence.
Moving on to Hierarchical Agglomerative Clustering (also known as HAC - pronounced ‘hack’),
the motivating idea is instead to group data from the bottom up. This means every data point
starts as its own cluster, and then we merge clusters together based on a distance metric that
we define. This iterative merging allows us to construct a tree over our data set that describes
relationships between our data. These trees are known as dendrograms, with an example found in
Figure 6.4. Notice that the individual data points are the leaves of our tree, and the trunk is the
cluster that contains the entirety of our data set.
We now formally define the HAC algorithm, and in the process, explain how we construct such
a tree.
Example 6.2 (HAC Algorithm Example): Let’s say we have a data set of five points A, B,
C, D, E that we wish to perform HAC on. These points will simply be scalar data that we can
represent on a number line. We start with 5 clusters and no connections at all:
88 CHAPTER 6. CLUSTERING
We find the closest two clusters to merge first. A and B are nearest (it’s actually tied with C and D,
but we can arbitrarily break these ties), so we start by merging them. Notice that we also annotate
the distance between them in the tree, which in this case is 1:
We now have four clusters: (A, B), C, D, and E. We again find the closest two clusters, which in
this case is C and D:
We now have three remaining clusters: (A, B), (C, D), and E. We proceed as before, identifying
the two closest clusters to be (A, B) and (C, D). Merging them:
Finally we are left with two clusters: (A, B, C, D) and E. The remaining two clusters are obviously
the closest together, so we merge them:
6.3. HIERARCHICAL AGGLOMERATIVE CLUSTERING 89
At this point there is only a single cluster. We have constructed our tree and are finished with
HAC.
Notice how the distance between two merged clusters manifests itself through the height of the
dendrogram where they merge (which is why we tracked those distances as we constructed the
tree). Notice also that we now have many layers of clustering: if we’re only only interested in
clusters whose elements are at least k units away from each other, we can ‘cut’ the dendrogram at
that height and examine all the clusters that exist below that cut point.
Finally, we need to handle the important detail of how to compute the distance between clusters.
In the preceding example, we designated the distance between two clusters to be the minimum
distance between any two data points in the clusters. This is what is known as the Min-Linkage
Criterion. However, there are certainly other ways we could have computed the distance between
clusters, and using a different distance measurement can produce different clustering results. We
now turn to these different methods and the properties of clusters they produce.
Min-Linkage Criteria
We’ve already seen the Min-Linkage Criterion in action from the previous example. Formally, the
criterion says that the distance dC,C ′ between each cluster pair C and C ′ is given by
where xk are data points in cluster C and xk′ are data points in cluster C ′ . After computing these
pairwise distances, we choose to merge the two clusters that are closest together.
Max-Linkage Criterion
We could also imagine defining the distance dC,C ′ between two clusters as being the distance between
the two points that are farthest apart in each cluster. This is known as the Max-Linkage Criterion.
90 CHAPTER 6. CLUSTERING
dC,C ′ = max
′
||xk − xk′ || (6.6)
k,k
As with the Min-Linkage Criterion, after computing these pairwise distances, we choose to merge
the two clusters that are closest together.
⋆ Be careful not to confuse the linkage criterion with which clusters we choose to merge. We always merge the clusters
that have the smallest distance between them. How we compute that distance is given by the linkage criterion.
Average-Linkage Criterion
The Average-Linkage Criterion averages the pairwise distance between each point in each cluster.
Formally, this is given by:
K K′
1 XX
dC,C ′ = ||xk − xk′ || (6.7)
KK ′ ′ k=1 k =1
Centroid-Linkage Criterion
The Centroid-Linkage Criterion uses the distance between the centroid of each cluster (which is
the average of the data points in a cluster). Formally, this is given by:
K K ′
1 X 1 X
dC,C ′ = || xk − ′ xk′ || (6.8)
K K ′
k=1 k =1
⋆ You should convince yourself of the different flavors of linkage criteria. For example, when using the min-linkage
criterion, we get these ‘stringy’ results because we’re most inclined to extend existing clusters by grabbing whichever
data points are closest.
Dimensionality Reduction
In previous chapters covering supervised learning techniques, we often used basis functions to
project our data into higher dimensions prior to applying an inference technique. This allowed us
to construct more expressive models, which ultimately produced better results. While it may seem
counterintuitive, in this chapter we’re going to focus on doing exactly the opposite: reducing the
dimensionality of our data through a technique known as Principal Component Analysis (PCA).
We will also explore why it is useful to reduce the dimensionality of some data sets.
7.1 Motivation
Real-world data is often very high dimensional, and it’s common that our data sets contain infor-
mation we are unfamiliar with because the dimensionality is too large for us to comb through all
the features the by hand.
In these situations, it can be very difficult to manipulate or utilize our data effectively. We
don’t have a sense for which features are ‘important’ and which ones are just noise. Fitting a
model to the data may be computationally expensive, and even if we were to fit some sort of model
to our data, it may be difficult to interpret why we obtain specific results. It’s also hard to gain
intuition about our data through visualization since humans struggle to think in more than three
dimensions. All of these are good reasons that we may wish to reduce the dimensionality of a data
set.
ML Framework Cube: Dimensionality Reduction
92
7.2. APPLICATIONS 93
7.2 Applications
As described above, we need a tool like dimensionality reduction in situations where high-dimensional
data hinders us. Here are a few specific situations where we would use such a technique:
3. Efficiently training a neural network to predict supermarket sales on a data set with many
input features.
4. Identifying which costly measurements are worth collecting when experimenting with new
chemicals.
With a few of these use cases in mind, we now turn to the math that underpins the dimension-
ality reduction technique known as Priniciple Component Analysis.
Bear Weights
Figure 7.3: Converting between the reduced data and original data.
To differentiate our data points, we obviously only need to report the weights of the bears. The
variance of the heights is 0, and the variance of the weights is some non-zero number. Intuitively,
the most interesting features from our data sets are those that vary the most.
⋆ In this simple example, the direction of maximal variance occurs exactly along the x1 axis, but in general it will
occur on a plane described by a combination of our input features.
The second way to think about PCA is that we are minimizing the error we incur when we
move from the lower-dimensional representation back to the original representation. This is known
as reconstruction loss. We can consider the meaning of this using our bear example.
Let’s say we project the data set from Figure 7.1 here down to a single dimension by recording
only the weights:
Then, to reconstruct our original graph, we need only to keep track of a slope and bias term
in the form of the familiar equation x2 = mx1 + b. In this case our slope is m = 0 and our bias
b = 3. Note that this storage overhead is constant (just remembering the slope and bias) regardless
of how big our data set gets. Thus we can go from our low-dimensional representation back to our
original data:
It will be our goal to determine a low-dimensional representation of our data that allows us
to return to our high-dimensional data while losing as little information as possible. We wish
to preserve everything salient about the data while discarding as much redundant information as
possible. We now turn to how this can be achieved.
Data Reconstruction
Figure 7.4: Far left: our original data. Middle: our reduced data in scalar form after the projection
given by x·w. Right: our reconstructed data points given by (x·w)w. Notice that our reconstructed
data points are not the same as our original data.
⋆ We’re going to assume that our data set has been mean-centered such that each feature in xn has mean 0. This
will not affect the application of the method (we can always convert back to the uncentered data by adding back the
mean of each feature), but will make our derivations more convenient to work with.
Let’s consider a simple case first: D′ = 1. This means that we’re projecting our D dimensional
data down onto just a single dimension, or in geometric terms, we’re projecting our data points
xn onto a line through the origin. We can define this line as the unit vector w ∈ RD×1 , and the
projection is given by the dot product x · w.
⋆ The unit vector w onto which we project our data is known as a principal component, from which PCA gets its
name.
This projection produces a scalar, and that scalar defines how far our projection x · w is from
the origin. We can convert this scalar back to D dimensional space by multiplying it with the
unit vector w. This means that (x · w)w is the result of projecting our data point x down into
one-dimension and then converting it to its coordinate location in D dimensions. We refer to these
as our projection vectors, and we can observe what this looks like geometrically in Figure 7.4.
The projection vectors we recover from the expression (x · w)w will be in D dimensions, but
they will obviously not be identical to the original D dimensional vectors (Figure 7.4 demonstrates
why that is the case). This difference between the original and projection vectors can be thought
of as error, since it is information lost from our original data. For a given data point xn and unit
vector w, we can measure this error through the expression:
which is known as reconstruction loss because it measures the error incurred when reconstructing
our original data from its projection.
Definition 7.3.1 (Reconstruction Loss): Reconstruction loss is the difference (measured via a
distance metric such as Euclidean distance) between an original data set and its reconstruction from
a lower dimensional representation. It indicates how much information is lost during dimensionality
reduction.
Reconstruction loss is then a metric for evaluating how ‘good’ a subspace in D′ dimensions is
96 CHAPTER 7. DIMENSIONALITY REDUCTION
at representing our original data in D dimensions. The better it is, the less information we lose,
and the reconstruction loss is lower as a result.
where ||w||2 = 1 because it is a unit vector. Note that we can define reconstruction loss over our
entire data set as follows:
N
1 X
RL(w) = ||xn ||2 − (xn · w)2 (7.2)
N
n=1
Recall that our goal is to minimize reconstruction loss over our data set by optimizing the subspace
defined by w. Let’s first rewrite Equation 7.2 as:
N N
1 X 1 X
RL(w) = ||xn ||2 − (xn · w)2
N M
n=1 n=1
where we can see that our optimization will depend only on maximizing the second term:
N
1 X
max (xn · w)2 (7.3)
w N
n=1
since it is thePonly one involving w. Recall that the sample mean of a data set is given by the
expression N1 N 2
n=1 xn , and note that Equation 7.3 is the sample mean of (x · w) . Using the
2
definition of variance for a random variable Z (which is given by V ar(Z) = E(Z ) − (E(Z))2 ), we
can rewrite Equation 7.3 as:
N
1 X 2
(xn · w)2 = V ar {xn · w}N N
n=1 + E {xn · w}n=1
N
n=1
Recall that we centered our data xn to have mean 0 such that the expression above simplifies to:
N
1 X
(xn · w)2 = V ar {xn · w}N
n=1 (7.4)
N
n=1
⋆ Note the intuitiveness of this result. We should like to find a subspace that maintains the spread in our data.
D ′
X
(xn · wd′ )wd′ (7.5)
d′ =1
N
2 N 1 X
(xn · w)2
σw ≡ V ar {xn · w}n=1 =
N
n=1
2 1
σw = (Xw)T (Xw)
N
We can further simplify this:
2 1 T T
σw = w X Xw
N
2 XT X
σw = wT w
N
2
σw = wT Sw
XT X
where, since we assume the design matrix X is mean-centered, then S = N is the empirical
covariance matrix of our data set.
⋆ Notice that by convention we describe the empirical covariance of a data set with the term S instead of the usual
covariance term Σ.
2 = V ar {x · w}N
Our goal is to maximize the term σw n n=1 with respect to w. Furthermore, w is
a unit vector, so we must optimize subject to the constraint wT w = 1. Recalling the discussion of
98 CHAPTER 7. DIMENSIONALITY REDUCTION
Lagrange multipliers from Chapter 6 on Support Vector Machines, we incorporate this constraint
by reformulating our optimization problem as the Lagrangian equation:
L(w, λ) = wT Sw − λ(wT w − 1)
dL(w, λ)
= 2Sw − 2λw
dw
dL(w, λ)
= wT w − 1
dλ
We can now set these equal to 0 and solve for the optimal values:
Sw = λw
wT w = 1
This result is very significant! As we knew already, we needed w to be a unit vector. However, we
also see that w is an eigenvector of the empirical covariance matrix w. Futhermore, the eigenvector
that will maximize our quantity of interest σw 2 = wT Sw will be the eigenvector with the largest
eigenvalue λ.
Linear algebra gives us many tools for finding eigenvectors, and as a result we can efficiently
identify our principal components. Note also that the eigenvectors of a symmetric matrix are
orthogonal, which proves our earlier assumption that our principal components are orthogonal.
⋆ Each eigenvector is a principal component. The larger the eigenvalue associated with that principal component,
the more variance there is along that principal component.
To recap, we’ve learned that the optimal principal components (meaning the vectors describing
our projection subspace) are the eigenvectors of the empirical covariance matrix of our data set.
The vector preserving the most variance in our data (and thus minimizing the reconstruction loss)
is given by the eigenvector with the largest eigenvalue, followed by the eigenvector with the next
largest eigenvalue, and so on. Furthermore, while it is somewhat outside the scope of this textbook,
we are guaranteed to have D distinct, orthogonal eigenvectors with eigenvalues ≥ 0. This is a result
of linear algebra that hinges on the fact that our empirical covariance matrix S is symmetric and
positive semi-definite.
Figure 7.5: PCA applied to Fisher’s Iris data set, which is originally in four dimensions. We reduce
it to three dimensions for visualization purposes and label the different flower types. This example
is taken from the sklearn documentation.
matrix Z, and orthogonal matrices U and V (i.e., U⊤ U = I and V⊤ V = I). Matrix V is the right
singular matrix, and its columns the right singular vectors. The entries in Z are the corresponding
singular values. The column vectors in V also the eigenvectors of X⊤ X. To see this, we have
X⊤ X = VZU⊤ UZV⊤ = VZ2 V⊤ , where we substitute the SVD for X. We also see that the
eigenvalues λi = zi2 , so that they are the squared singular values.
For a mean-centered design matrix X, so that (1/N )X⊤ X = cov(X), we can compute the SVD
on the design matrix X, and then read off the eigenvectors of the covariance as the columns of
the right singular matrix and the√eigenvalues as λi = zi2 /N . Alternatively, you can first divide the
mean-centered design matrix by N before taking its SVD, in which case the square of the singular
values corresponds to the eigenvalues.
Because the principal components are orthogonal, the projections they produce will be entirely
uncorrelated. This means we can project our original data onto each component individually and
then sum those projections to create our lower dimensional data points. Note that it doesn’t make
sense that we would use every one of our D principal components to define our projection subspace,
since that wouldn’t lead to a reduction in the dimensionality of our data at all (the D orthogonal
principal components span the entire D dimensional space of our original data set). We now need
to decide how many principal components we will choose to include, and therefore what subspace
we will be projecting onto.
The ‘right’ number of principal components to use depends on our goals. For example, if we
simply wish to visualize our data, then we would project onto a 2D or 3D space. Therefore, we
would choose the first 2 or 3 principal components, and project our original data onto the subspace
defined by those vectors. This might look something like Figure 7.5.
However, it’s more complicated to choose the optimal number of principal components when our
goal is not simply visualization. We’re now left with the task of trading off how much dimensionality
100 CHAPTER 7. DIMENSIONALITY REDUCTION
Figure 7.6: Reconstruction loss versus the number of principal components. Notice the similarity
to the ‘elbow’ method of K-Means clustering.
reduction we wish to achieve with how much information we want to preserve in our data.
One way to do this is similar to the informal ‘elbow’ method described for K-Means clustering.
We graph our reconstruction loss against the number of principal components used, as seen in
Figure 7.6. The idea is to add principal components to our subspace one at a time, calculating
the reconstruction loss as we go. The first few principal components will greatly reduce the recon-
struction loss, before eventually leveling off. We can identify the ‘elbow’ where the reduction in
loss starts to diminish, and choose to use that number of principal components.
Another way to do this is to consider how much variance we wish to preserve in our data. Each
principal component is associated with an eigenvalue λd that indicates what proportion of the
variance that principal component is responsible for in our data set. Then the fraction of variance
retained from our data set if we choose to keep D′ principal components is given by:
PD′
′ λd′
retained variance = PdD=1 (7.6)
d=1 λd
For different applications, there may be different levels of acceptable variance retention, which can
help us decide how many principal components to keep.
Finally, once we have selected our principal components, we have also defined the subspace onto
which we will be projecting our original data. And although this subspace is defined by the basis
given by our principal components, these principal components are not a unique description of that
subspace. We could choose to use any basis after we’ve identified our subspace through the principal
components. The importance of this idea is simply that although our principal components are
unique, they are not the only basis we could use to define the same projection subspace.
7.4. CONCLUSION 101
7.4 Conclusion
Principal component analysis is useful for visualization purposes, removing redundant informa-
tion, or making a data set more computationally manageable. PCA is also a good tool for data
exploration, particularly when digging into an unfamiliar data set for the first time.
It is not essential that we know by heart the exact derivation for arriving at the principal
components of a data set. The same can be said of the linear algebra machinery needed to compute
principal components. However, it is important to have an intuitve grasp over how variance in
our data set relates to principal components, as well as an understanding of how subspaces in our
data can provide compact representations of that data set. These are critical concepts for working
effectively with real data, and they will motivate related techniques in machine learning.
Chapter 8
Graphical Models
Mathematics, statistics, physics, and other academic fields have useful notational systems. As a
hybrid of these and other disciplines, machine learning borrows from many of the existing systems.
Notational abstractions are important to enable consistent and efficient communication of ideas, for
both teaching and knowledge creation purposes. Much of machine learning revolves around mod-
eling data processes, and then performing inference over those models to generate useful insights.
In this chapter, we will be introducing a notational system known as the directed graphical model
(DGM) that will help us reason about a broad class of models.
8.1 Motivation
Up until this point, we’ve defined notation on the fly, relied on common statistical concepts, and
used diagrams to convey meaning about the problem setup for different techniques. We’ve built
up enough working knowledge and intuition at this point to switch to a more general abstraction
for defining models: directed graphical models (DGMS). DGMs will allow us to both consolidate
our notation and convey information about arbitrary problem formulations. An example of what
a DGM looks like for a linear regression problem setup is given in Figure 8.1, and over the course
of the chapter, we’ll explain how to interpret the symbols in this diagram.
We need graphical models for a couple of reasons. First, and most importantly, a graphical
model unambiguously conveys a problem setup. This is useful both to share models between
people (communication) and to keep all the information in a specific model clear in your own head
(consolidation). Once we understand the meaning of the symbols in a DGM, it will be far easier to
examine one of them than it will be to read several sentences describing the type of model we’re
imagining for a specific problem. Another reason we use DGMs is that they help us reason about
independence properties between different parts of our model. For simple problem setups this may
be easy to keep track of in our heads, but as we introduce more complicated models it will be useful
to reason about independence properties simply by examining the DGM describing that model.
Ultimately, directed graphical models are a tool to boost efficiency and clarity. We’ll examine
the core components of a DGM, as well as some of their properties regarding independence and
model complexity. The machinery we develop here will be used heavily in the coming chapters.
102
8.2. DIRECTED GRAPHICAL MODELS (BAYESIAN NETWORKS) 103
Figure 8.2: Random variables are denoted with an open circle, and it is shaded if the variable is
observed.
parameters, and arrows to indicate the relationships between them. Let’s consider linear regression
as a simple but comprehensive example. We have a random variable yn , the object of predictive
interest, which depend on deterministic parameters in the form of data xn and weights w. This
results in the DGM given by Figure 8.1. There are four primary pieces of notation that the linear
regression setup gives rise to, and these four components form the backbone of every DGM we
would like to construct.
First, we the random variable yn represented by an open circle, shown in Figure 8.2. If we
observe a random variable of a given model, then it is shaded. If the random variable is unobserved
(sometimes called latent), it is unshaded.
Second, we have deterministic parameters represented by a tight, small dot, shown in Figure
8.3.
Third, we have arrows that indicate the dependence relationship between different random
variables and parameters, shown in Figure 8.4. Note that an arrow from X into Y means that Y
depends on X.
And finally, we have plate notation to indicate that we have repeated sets of variables in our
model setup, shown in Figure 8.5.
Figure 8.3: Deterministic parameters are denoted with a tight, small dot.
104 CHAPTER 8. GRAPHICAL MODELS
Figure 8.5: Plates indicate repeated sets of variables. Often there will be a number in one of the
corners (N in this case) indicating how many times that variable is repeated.
8.2. DIRECTED GRAPHICAL MODELS (BAYESIAN NETWORKS) 105
Figure 8.6: DGM for the joint distribution give by Equation 8.1.
With these four constructs, we can describe complex model setups diagrammatically. We can
have an arbitrary number of components and dependence structure. DGMs can be useful as a
reference while working on a problem, and they also make it easy to iterate on an existing model
setup.
p(A, B, C)
where in this setup, we are interested in the joint distribution between three random variables
A, B, C. However, this doesn’t tell us anything about the structure of the problem at hand. We
would like to know where there is independence and use that to simplify our model. For example, if
we knew that B and C were independent and we also knew the conditional distribution of A|B, C
then we would much rather setup our joint probability equation as:
DGMs assist in this process of identifying the appropriate factorization, as their structure allows
us to read off valid factorizations directly. For example, the joint distribution given by Equation
8.1 can be read from Figure 8.6.
We can translate between a DGM and a factorized joint distribution by interpreting the arrows
as dependencies. If a random variable has no dependencies (as neither B nor C do in this example),
they can be written on their own as marginal probabilities p(B) and p(C). Since arrows indicate
dependence of the random variable at the tip on the random variable at the tail, the factorized joint
probability distribution of the dependent random variable is written as a conditional probability,
i.e. P (A|B, C) in 8.6. In this way, we can move back and forth between DGMs and factorized joint
distributions with ease.
identify how the data are generated, and if we have the proper tools, how we could generate new
data ourselves.
Definition 8.2.1 (Generative Models): A generative model describes the entire process by
which data comes into existence. It enables the creation of new data by sampling from the gener-
ative model, but generative models are not required to make predictions or perform other kinds of
inference.
⋆ Note that we can create graphical models for both generative and discriminative models. Discriminative models
will only model the conditional distribution p(Y |Z), while generative models will model the full joint distribution
p(Z, Y ).
Let’s consider a simple example to see how this works in practice. Consider the flow of in-
formation present in Figure 8.7. First, there is some realization of the random variable Z. Then
conditioned on that value of Z, there is some realization of the random variable Y . The equivalent
joint factorization for this DGM is given by p(Z)p(Y |Z). More intuitively, the data are created by
first sampling from Z’s distribution, and then based on the sampled value Z = z, sampling from
the conditional distribution p(Y |Z = z).
As a concrete example, let Z be a random variable representing dog breed and Y be a random
variable representing snout length of a dog of a given breed, which is conditional on the breed of
dog. Notice that we have not specified the distributional form of either Z or Y ; we only have the
story of how they relate to each other (i.e. the dependence relation).
This story also shows us that if we had some model for Z and Y |Z, we could generate data
points ourselves. Our procedure would simply be to sample from Z and then to sample from Y
conditioned on that value of Z. We could perform this process as many times as we like to generate
new data. This is in contrast to sampling directly from the joint p(Z, Y ), which is difficult if we do
not know the exact form of the joint distribution (which we often do not) or if the joint distribution
is difficult to sample from directly.
The technique of sampling from distributions in the order indicated by their DGM is known as
ancestral sampling, and it is a major benefit of generative models.
random variables that depend on those intially sampled random variables, and so on until all the
random variables have been sampled. This is demonstrated in Figure 8.8.
needs.
We’ve already motivated one of the primary uses of DGMs as being the ability to convert a joint
distribution into a factorization. At the heart of that task was recognizing and exploiting indepen-
dence properties in a model over multiple random variables. Another benefit of this process is that
it allows us to easily reason about the size (also called ‘complexity’) of a joint factorization over
discrete random variables. In other words, it allows us to determine how many parameters we will
have to learn to describe the factorization for a given DGM.
Let us consider an example to make this concept clear. Suppose we have four categorical random
variables A, B, C, D which take on 2, 4, 8, and 16 values respectively. If we were to assume full
dependence between each of these random variables, then a joint distribution table over all of these
random variables would require (2 ∗ 4 ∗ 8 ∗ 16) − 1 = 1023 total parameters (where each parameter
corresponds to the probability of a specific permutation of the values A, B, C, D).
⋆ Notice that the number of parameters we need is (2 ∗ 4 ∗ 8 ∗ 16) − 1 and not (2 ∗ 4 ∗ 8 ∗ 16). This is because if we
know the first (2 ∗ 4 ∗ 8 ∗ 16) − 1 parameters, the probability of the final combination of values is fixed since joint
probabilities sum up to 1.
However, if we knew that some of these random variables were conditionally independent, then
the number of parameters would change. For example, consider the joint distribution given by
p(A, B, C, D) = p(A)p(B|A)p(C|A)p(D|A). This would imply that conditioned on A, each of
B, C, D were conditionally independent. This can also be shown by the DGM in Figure 8.10.
In this case, a table of parameters to describe this joint distribution would only require 2 ∗ ((4 −
1) + (8 − 1) + (16 − 1)) = 50 parameters, which is significantly less. In general, the more conditional
independence we can identify between random variables, the easier they are to model and compute.
8.2. DIRECTED GRAPHICAL MODELS (BAYESIAN NETWORKS) 109
Figure 8.11: The three random variable relationships that will tell us about independence relation-
ships.
p(A, B, C) = p(B)p(A|B)p(C|B)
We know that for this case, A and C are not independent (note that we have not observed B).
Therefore, we say that information flows from A to C. However, once we’ve observed B, we know
that A and B are conditionally independent, which is shown in Figure 8.13.
We now say that the flow of information from A to C is ‘blocked’ by the observation of B.
Intuitively, if we observe A but not B, then we have some information about the value of B, which
also gives some information about the value of C. The same applies in the other direction: observing
C but not B has implications on the distribution of B and A.
In the second random variable structure, shown in Figure 8.14, we again consider the unobserved
case first.
This allows us to write the joint distribution as:
p(A, B, C) = p(A)p(B|A)p(C|B)
Again, A and C are dependent if we have not observed B. Information is flowing from A to C
through B. However, once we’ve observed B, then A and C are again conditionally independent,
shown in Figure 8.15.
The flow of information from A to C is ‘blocked’ by the observation of B. Intuitively, if we
observed A but not B, we have some information about what B might be and therefore what C
might be as well. The same applies in the other direction: observing C but not B.
Notice that these first two cases behave in the same manner: observing a random variable
in between two other random variable ‘blocks’ information from flowing between the two outer
random variables. In the third and final case the opposite is true. Not observing data in this case
will ‘block’ information, and we will explain this shift through an idea known as ‘explaining away’.
We have the third and final random variable structure, shown in Figure 8.16. We consider the
unobserved case first.
In this setup, we say that information from A to C is being ‘blocked’ by the unobserved variable
B. Thus A and C are currently conditionally independent. However, once we’ve observed B, as
shown in Figure 8.17, the information flow changes.
Now, information is flowing between A and C through the observed variable B, making A and
C conditionally dependent. This phenomenon, where the observation of the random variable in the
middle creates conditional dependence is known as explaining away. The idea relies on knowledge
of the value for B giving information about how much A or C may have contributed to B adopting
that value.
Consider the following example: let the random variables A correspond to whether or not it
rained on a certain day, B correspond to the lawn being wet, and C correspond to the sprinkler
being on. Let’s say we observe B: the lawn is wet. Then, if we observe variable A: it has not rained
today, we would infer that variable C has the value: the sprinkler has been on, because that’s the
only way for the lawn to be wet. This is the phenomenon of explaining away. Observing B unblocks
the flow of information between A and C because we can now use an observation to ‘explain’ how
B got its value, and therefore determine what the other unobserved value might have been.
Notice that we’ve only described three simple cases relating dependence relationships between
random variables in a DGM, but with just these three cases, we can determine the dependence
structure of any arbitrarily complicated DGM. We just have to consider how information flows
from node to node. If information gets ‘blocked’ at any point in our DGM network because of an
observation (or lack thereof), then we gain some knowledge about independence within our model.
Consider the dependence between random variables A and F in Figure 8.18. Initially, before any
observations are made, we can see that A and F are dependent (information flows from A through
B). However, after observing B, the nodes A and F become independent (because information
blocked at both the observed B and the unobserved D). Finally, after observing D, dependence is
restored between A and F because information flows from A through D.
Figure 8.18: Notice how the independence between A and F depends on observations made within
the network.
112 CHAPTER 8. GRAPHICAL MODELS
⋆ In some other resources, you’ll come upon the idea of ‘D-separation’ or ‘D-connection’. D-separation is simply
applying the principles outlined above to determine if two nodes are independent, or D-separated. By contrast,
two-nodes that are D-connected are dependent on one another.
Writing this factorization is facilitated directly by our DGM, even if we’ve never heard of Naive
Bayes before. It provides a common language for us to move fluidly between detailed probability
factorizations and general modeling intuition.
8.4 Conclusion
Directed graphical models are indispensible for model visualization and effective communication of
modeling ideas. With an understanding of what DGMs represent, it’s much easier to analyze more
complex probabilistic models. In many ways, this chapter is preparation for where we head next.
The topics of the following chapters will rely heavily on DGMs to explain their structure, use cases,
and interesting variants.
Chapter 9
Mixture Models
The real world often generates observable data that falls into a combination of unseen categories.
For example, at a specific moment in time I could record sound waves on a busy street that come
from a combination of cars, pedestrians, and animals. If I were to try to model my data points, it
would be helpful if I could group them by source, even though I didn’t observe where each sound
wave came from individually.
In this chapter we explore what are known as mixture models. Their purpose is to handle data
generated by a combination of unobserved categories. We would like to discover the properties of
these individual categories and determine how they mix together to produce the data we observe.
We consider the statistical ideas underpinning mixture models, as well as how they can be used in
practice.
9.1 Motivation
Mixture models are used to model data involving latent variables.
Definition 9.1.1 (Latent Variable): A latent variable is a piece of data that is not observed,
but that influences the observed data. We often wish to create models that capture the behavior
of our latent variables.
We are sometimes unable to observe all the data present in a given system. For example,
if we measure the snout length of different animals but only get to see the snout measurements
themselves, the latent variable would be the type of animal we are measuring for each data point.
For most data generating processes, we will only have access to a portion of the data and the rest
will be hidden from us. However, if we can find some way to also model the latent variables, our
model will potentially be much richer, and we will also be able to probe it with more interesting
questions. To build some intuition about latent variable models, we present a simple directed
graphical model with a latent variable zn in Figure 9.1.
One common means of modeling data involving latent variables, and the topic of this chapter,
is known as a mixture model.
Definition 9.1.2 (Mixture Model): A mixture model captures the behavior of data coming from
a combination of different distributions.
At a high level, a mixture model operates under the assumption that our data is generated by
first sampling a discrete class, and then sampling a data point from within that category according
113
114 CHAPTER 9. MIXTURE MODELS
to the distribution for that category. For the example of animal snouts, we would first sample a
species of animal, and then based on the distribution of snout lengths in that species, we would
sample an observation to get a complete data point.
Probabilistically, sampling a class (which is our latent variable, since we don’t actually observe
it) happens according to a Categorical distribution, and we typically refer to the latent variable as
z. Thus:
p(z = Ck ; θ) = θk
where Ck is class k, and θ is the parameter to the Categorical distribution that specifies the
probability of drawing each class. We write the latent variable in bold z because we will typically
consider it to be one-hot encoded (of dimension K, for K classes). Then, once we have a class, we
have a distribution for the observed data point coming from that class:
p(x|z = Ck ; w)
This distribution depends on the type of data we are observing, and is parameterized by an
arbitrary parameter w whose form depends on what is chosen as the class-conditional distribution.
For the case of snout lengths, and many other examples, this conditional distribution is often
modeled using a Gaussian distribution, in which case our model is known as a Gaussian Mixture
Model. We will discuss Gaussian Mixture Models in more detail later in the chapter.
If we can effectively model the distribution of our observed data points and the latent variables
responsible for producing the data, we will be able to ask interesting questions of our model. For
example, upon observing a new data point x′ we will be able to produce a probability that it came
from a specific class z′ = Ck using Bayes’ rule and our model parameters:
p(x′ |z′ = Ck ; w)p(z′ = Ck ; θ)
p(z′ = Ck |x′ ) = P ′ ′ ′
k′ p(x |z = Ck′ ; w)p(z = Ck′ ; θ)
9.2. APPLICATIONS 115
Furthermore, after modeling the generative process, we will be able to generate new data points by
sampling from our categorical class distribution, and then from the class-conditional distribution
for that category:
z ∼ Cat(θ)
x ∼ p(x|z = Ck ; w)
Finally, it will also be possible for us to get a sense of the cardinality of z (meaning the number of
classes our data falls into), even if that was not something we were aware of a priori.
ML Framework Cube: Mixture Models
The classes of data z in a mixture model will typically be discrete. Notice also that this is an
unsupervised technique: while we have a data set X of observations, our goal is not to make
predictions. Rather, we are trying to model the generative process of this data by accounting for
the latent variables that generated the data points. Finally, this is a probabilistic model both for
the latent variables and for our observed data.
Domain Training Probabilistic
Discrete Unsupervised Yes
9.2 Applications
Since much of the data we observe in our world has some sort of unobserved category associated
with it, there are a wide variety of applications for mixture models. Here are just a few:
1. Handwriting image recognition. The categories are given by the characters (letters, numbers,
etc.) and the class-conditional is a distribution over what each of those characters might look
like.
2. Noise classification. The categories are given by the source of a noise (e.g. we could have
different animal noises), and the class-conditional is a distribution over what the sound waves
for each animal noise look like.
3. Vehicle prices. The categories are given by the brand of vehicle (we could alternatively
categorize by size, safety, year, etc.), and the class-conditional is a distribution over the price
of each brand.
over the possible classes for each of our N data points as follows:
N X
Y K
p(X; θ, w) = p(xn , zn,k ; θ, w)
n=1 k=1
P
This uses p(xn ; θ, w) = k p(xn , zn,k ; θ, w) (marginalizing out over the latent class). Taking
the logarithm to get our log-likelihood as usual:
N
X K
X
log p(X; θ, w) = log p(xn , zn,k ; θ, w) (9.1)
n=1 k=1
It may not be immediately obvious, but under this setup, the maximum likelihood calculation for
our parameters θ and w is now intractable. The summation over the K classes of our latent variable
zn , which is required because we don’t actually observe those classes, is inside of the logarithm,
which prevents us from arriving at an analytical solution (it may be helpful to try to solve this
yourself, you’ll realize that consolidating a summation inside of a logarithm is not possible). You
could still try to use gradient descent, but the problem is non-convex and we’ll see a much more
elegant approach. The rest of this chapter will deal with how we can optimize our mixture model
in the face of this challenge.
Notice that because we’ve now observed zn , we don’t have to marginalize over its possible
values. This motivates an interesting approach that takes advantage of our ability to work with
p(x, z) if we only knew z.
The expression p(x, z) is known as the complete-data likelihood because it assumes that we have
both our observation x and the class z that x came from. Our ability to efficiently calculate the
complete-data log likelihood log p(x, z) is the crucial piece of the algorithm we will present to op-
timize our mixture model parameters. This algorithm is known as Expectation-Maximization,
or EM for short.
logarithm. This summation was required because we didn’t observe a crucial piece of data, the
class z, and therefore we had to sum over its values.
EM uses an iterative approach to optimize our model parameters. It proposes a soft value for
z using an expectation calculation (we can think about this as giving a distribution on zn for each
n), and then based on that proposed value, it maximizes the expected complete-data log likelihood
with respect to the model parameters θ and w via a standard MLE procedure.
Notice that EM is composed of two distinct steps: an “E step” that finds the expected value of
the latent class variables given the current set of parameters, and an “M step” that improves the
model parameters by maximizing expected complete-data log likelihood given these soft assignments
to class variables. These two steps give the algorithm its name, and more generally, this type of
approach is also referred to as coordinate ascent. The idea behind coordinate ascent is that we
can replace a hard problem (maximizing the log likelihood for our mixture model directly) with
two easier problems, namely the E- and M-step. We alternate between the two easier problems,
executing each of them until we reach a point of convergence or decide that we’ve done enough.
We may also restart because EM will provide a local but not global optimum.
We’ll walk through the details of each of these two steps and then tie them together with the
complete algorithm.
⋆ K-Means, an algorithm we discussed in the context of clustering, is also a form of coordinate ascent. K-Means is
sometimes referred to as a “maximization-maximization” algorithm because we iteratively maximize our assignments
(by assigning each data point to just a single cluster) and then update our cluster centers to maximize their likelihood
with respect to the new assignments. That is, it does a “max” in place of the E-etep, making a hard rather than soft
assignment.
As we’ve already explained, we don’t know the true value of this latent variable. Instead, we will
compute its conditional expectation based on the current setting of our model parameters and our
observed data xn . We denote the expectation of our latent variables as qn , and we calculate them
as follows:
p(zn = C1 |xn ; θ, w) p(xn |zn = C1 ; w)p(zn = C1 ; θ)
qn = E[zn |xn ] = p(zn = C2 |xn ; θ, w) ∝ p(xn |zn = C2 ; w)p(zn = C2 ; θ)
p(zn = C3 |xn ; θ, w) p(xn |zn = C3 ; w)p(zn = C3 ; θ)
The expectation of a 1-hot encoded vector is equivalent to a distribution on the values that the
latent variable might take on. Notice that we can switch from proportionality in our qn values to
118 CHAPTER 9. MIXTURE MODELS
actual probabilities by simply dividing each unnormalized value by the sum of all the unnormalized
values. Then, our qn values will look something like the following, where a larger number indicates
a stronger belief that the data point xn came from that class:
0.8
qn = 0.1
0.1
There are two important things to note about the expectation step. First, the model parameters θ
and w are held fixed. We’re computing the expectation of our latent variables based on the current
setting of those model parameters. Those parameters are randomly initialized if this is our first
time running the expectation step.
Second, we have a value of qn for every data point xn in our data set. As a result, qn are
sometimes called “local parameters,” since there is one assigned to each data point. This is in
contrast to our model parameters θ and w, which are “global parameters.” The size of the global
model parameters doesn’t fluctuate based on the size of our data set.
After performing the E-step, we now have an expectation for our latent variables, given by qn .
In the maximization step, which we describe next, we use these qn values to improve our global
parameters.
Notice the crucial difference between this summation and that of Equation 9.1: the summation
over the classes is now outside of the logarithm! Recall that using the log-likelihood directly was
intractable precisely because the summation over the classes was inside of the logarithm. This
maximization became possible by taking the expectation over our latent variables (using the values
we computed in the E-step), which moved the summation over the classes outside of the logarithm.
We can now complete the M-step by maximizing Equation 9.6 with respect to our model pa-
rameters θ and w. This has an analytical solution. We take the derivative with respect to the
parameter of interest, set to 0, solve, and update the parameter with the result.
1. Begin by initializing our model parameters w and θ, which we can do at random. Since
the EM algorithm is performed over a number of iterative steps, we will denote these initial
parameter values w(0) and θ (0) . We will increment those values as the algorithm proceeds.
2. E-step: compute the values of qn based on the current setting of our model parameters.
qn = E[zn |xn ] = .. ..
∝
. .
(i) (i) (i) (i)
p(zn = CK |xn ; θ , w ) p(xn |zn = CK ; w )p(zn = CK ; θ )
3. M-step: compute the values of w and θ that maximize our expected complete-data log like-
lihood for the current setting of the values of qn :
4. Return to step 2, repeating this cycle until our likelihood converges. Note that the likelihood
is guaranteed to (weakly) increase at each step using this procedure.
It is also typical to re-start the procedure because we are guaranteed a local but not global
optimum.
We adopt this latter expression, since it also holds in cases where zn is not discrete.
As touched on earlier, the main issues with directly optimizing this expression are as follows:
1. If we were to take a gradient, the distribution over which we’re taking an expectation involves
one of the parameters, θ, and hence computing the gradient with respect to θ is difficult.
2. There’s an expectation (previously a sum) inside the logarithm.
To solve this, we ultimately derive a lower bound on the observed data log likelihood which
proves to be computationally tractable to optimize. First, let q(zn ) denote another probability
distribution on zn . Then:
N
"K #
X X
log p(xn , zn = Ck ; θ, w)
n=1 k=1
N K
" #
X X
= log p(xn | zn = Ck ; θ, w)p(zn = Ck | θ, w)
n=1 k=1
N K
" #
X X q(zn = Ck )
= log p(xn | zn = Ck ; θ, w)p(zn = Ck | θ, w) ·
q(zn = Ck )
n=1 k=1
N K
" #
X p(xn | zn = Ck ; θ, w)p(zn = Ck | θ, w)
X
= log · q(zn = Ck )
q(zn = Ck )
n=1 k=1
N
X p(xn | zn ; θ, w)p(zn | θ, w)
= log Ezn ∼q(zn ) (9.13)
q(zn )
n=1
The above derivation again restricts to discrete zn , but the equivalence between the expressions
in (9.12) and (9.13) is in fact more general, holding whenever p is absolutely continuous with respect
to the chosen q.
In any case, we’ve now fixed the first of the two issues. By introducing the distribution q, the
expectation is no longer over some distribution depending on the parameter θ.
To fix the second issue, we must somehow pass the log into the expectation. This is accomplished
using Jensen’s inequality, at the cost of turning the equality into a lower bound:
N
X p(xn | zn ; θ, w)p(zn | θ, w)
log Ezn ∼q(zn ) (9.14)
q(zn )
n=1
N
X p(xn | zn ; θ, w)p(zn | θ, w)
≥ Ezn ∼q(zn ) log (9.15)
q(zn )
n=1
9.4. EXPECTATION-MAXIMIZATION (EM) 121
In summary, the two issues with directly optimizing the observed data log likelihood were
resolved as follows:
However, this came at the cost of converting our objective into a lower bound of the original
quantity, as well as introducing the new parameter that is the distribution q.
It turns out that the iterative process given in the section above amounts to alternating between
optimizing the parameters w, θ and the distribution q.
Optimization
Before delving into the optimization process, we first establish two identities involving the ELBO.
N
X p(xn | zn ; θ, w)p(zn | θ, w)
ELBO(w, q) = Ezn ∼q(zn ) log (9.16)
q(zn )
n=1
First,
N
X
ELBO(w, q) = Ezn ∼q(zn ) [log (p(xn | zn ; θ, w)p(zn | θ, w))] − Ezn ∼q(zn ) [log q(zn )]
n=1
XN N
X
= Ezn ∼q(zn ) [log (p(xn | zn ; θ, w)p(zn | θ, w))] − Ezn ∼q(zn ) [log q(zn )]
n=1 n=1
XN N
X
= Ezn ∼q(zn ) [log p(xn , zn | θ, w)] − Ezn ∼q(zn ) [log q(zn )] (9.17)
n=1 n=1
The first term in (9.17) is the expected complete data log likelihood, and the second is the
entropy of the distribution q.
Second, using Bayes’ rule:
N
X p(xn | θ, w)p(zn | xn ; θ, w)
ELBO(w, q) = Ezn ∼q(zn ) log
q(zn )
n=1
N
X p(zn | xn ; θ, w)
= log p(xn | θ, w) + Ezn ∼q(zn ) log
q(zn )
n=1
N N
X X q(zn )
= log p(xn | θ, w) − Ezn ∼q(zn ) log (9.18)
p(zn | xn ; θ, w)
n=1 n=1
The first term is exactly the observed data log likelihood, and the second is known as the KL diver-
gence between q(zn ) and p(zn | xn ; θ, w). In particular, the KL-divergence between distributions
P and Q is defined as
P (x)
DKL (P ∥Q) = Ex∼P (x) log (9.19)
Q(x)
122 CHAPTER 9. MIXTURE MODELS
The key property of KL-divergence is that it is nonnegative, that is DKL (P ∥Q) ≥ 0 always,
with equality if and only if P = Q:
Q(x) Q(x)
−DKL (P ∥Q) = Ex∼P (x) log ≤ log Ex∼P (x)
P (x) P (x)
X Q(x)
≤ log · P (x)
P (x)
≤0
where we have equality when the equality condition of Jensen’s is met, meaning Q(x) = cP (x),
which forces Q(x) = P (x) since both are distributions. The proof above is only for discrete variables,
but can again be extended to the continuous case.
Using this, we rewrite (9.18) as
N
X N
X
ELBO(w, q) = log p(xn | θ, w) − DKL (q(zn )∥p(zn | xn ; θ, w)) (9.20)
n=1 n=1
We now describe the E and M steps of the procedure. Begin with some random assignment of
θ, w, and q.
1. The E-step: Fixing the parameters θ, w, choose the distributions q that maximizes ELBO(w, q).
Observe that only the second term of (9.20) depends on the distributions
PNq. Hence maximizing
the ELBO is equivalent to minimizes the sum of KL divergences n=1 DKL (q(zn )∥p(zn |
xn ; θ, w)). This immediately yields that
2. The M-step: Fixing the distributions q, maximize the ELBO with respect to the parameters
θ, w.
Observe that only the first term of (9.17) depends on the parameters, and hence this is
equivalent to maximizing the expected complete data log likelihood,
N
X
Ezn ∼q(zn ) [log p(xn , zn | θ, w)] . (9.21)
n=1
An equivalent formulation is then as follows. Define the auxiliary function Q(w, θ | wold , θ old )
as
N
X
old old
Q(w, θ | w ,θ )= Ezn ∼p(zn |xn ;wold ,θold ) [log p(xn , zn | w, θ)] (9.22)
n=1
1. The E-step: compute the auxiliary function Q(w, θ; wold , θ old ). Note that this essentially
just requires computing the posterior distribution p(zn | xn ; wold , θ old ), as before.
Why are these equivalent? We’ve dropped the qs by noting that at each step, they will just be
p(zn | xn ; wold , θ old ).
It is important to note the significance of expression (9.20). Since the first term is exactly the
observed data log likelihood, this indicates that the gap between the ELBO and the observed data
log likelihood is exactly the KL divergence between q(zn ) and p(zn | xn ; θ, w). If this gap goes to
zero, then we will have succeeded in maximizing the observed data log likelihood. This is why it is
important to continually maximize over the distributions q - choosing the best q makes the bound
as tight as possible every round.
Correctness
We will only show that EM increases the observed log likelihood at every iteration. We do this
using the second of the two formulations.
N
X N
X
log p(xn | w, θ) − log p(xn | wold , θ old )
n=1 n=1
N
X p(xn | w, θ)
= log
n=1
p(xn | wold , θ old )
N PK
X p(xn , zn | w, θ)
= log k=1
n=1
p(xn | zn ; wold , θ old )
N K
X X p(xn , zn | w, θ)
= log p(zn | wold , θ old )
p(x | z ; w old , θ old )p(z | wold , θ old )
n=1 k=1 n n n
N
X p(xn , zn | w, θ)
= log Ezn ∼p(zn |wold ,θold )
n=1
p(xn , zn | wold , θ old )
N
X p(xn , zn | w, θ)
≥ Ezn ∼p(zn |wold ,θold ) log
n=1
p(xn , zn | wold , θ old )
N
X N
X h i
= Ezn ∼p(zn |wold ,θold ) [log p(xn , zn | w, θ)] − Ezn ∼p(zn |wold ,θold ) log p(xn , zn | wold , θ old )
n=1 n=1
old old old old
= Q(w, θ | w , θ) − Q(w ,θ |w , θ)
We are maximizing Q at each step, and thus this difference is strictly positive.
It is not hard to see that the formulation presented above, centered around maximizing the ELBO,
is equivalent to the one presented in section 9.4.3. Note that qn,k is exactly p(zn = Ck | xn , w, θ),
and thus the quantity being computed in the M -step (as in (9.10)) is exactly the same as in (9.21).
Recall that K-Means proceeds according to a similar iterative algorithm: we first make hard
assignments of data points to existing cluster centers, and then we update the cluster centers based
on the most recent data point assignments.
In fact, the main differences between K-Means clustering and the EM algorithm are that:
1. In the EM setting, we make soft cluster assignments through our qn values, rather than
definitively assigning each data point to only one cluster.
In the context of a mixture-of-Gaussian model, which we get to later in the chapter, we can
confirm that K-means is equal to the limiting case of EM where the variance of each class-conditional
Gaussian goes to 0, the prior probability of each class is uniform, and the distributions are spherical.
We’re going to try to infer the parameters of each of the dice based on these observations. Let’s
consider how this scenario fits into our idea of a mixture model. First, the latent variable zn has a
natural interpretation as being which dice was rolled for the nth observed data point xn . We can
represent zn using a one-hot vector, so that if the nth data point came from Dice 1, we’d denote
that:
1
zn =
0
We denote the probability vector associated with the biased coin as θ ∈ [0, 1]2 , summing to
1, with θ1 being the probability of the biased coin landing heads and θ2 being the probability of
the biased coin landing tails. Furthermore, we need parameters to describe the behavior of the
biased dice. We use π1 , π 2 ∈ [0, 1]6 , summing to 1, where each 6-dimensional vector describes the
probability that the respective dice lands on each face.
9.4. EXPECTATION-MAXIMIZATION (EM) 125
For a given dice, this defines a multinomial distribution. For c trials, and counts x1 , . . . , x6 for
each of 6 faces on a 6-sided dice, and probabilities π, this is
c!
p(x; π) = π x1 · . . . · π6x6 (9.23)
x1 ! · . . . · x6 ! 1
For our purposes, let p(xn |zn = Ck ; π 1 , π 2 ) denote the multinomial distribution on observation
xnj when latent vector zn = Ck .
The model parameters are w = {θ, π 1 , π 2 }. We can optimize the model parameters using EM.
We start by initializing the parameters θ (0) , π (0) .
In the E-step, we compute the soft assignment values, qn . For dice k, this given by
After computing the values of qn , we are ready to perform the M-step. Recall that we are
maximizing the expected complete-data log likelihood, which takes the form:
N
X
EZ|X [log p(X, Z)] = Eqn log p(zn ; θ (i+1) , π (i+1) ) + log p(xn |zn ; θ (i+1) , π (i+1) ) (9.28)
n=1
N
X
(i+1) (i+1) (i+1) (i+1)
= Ezn |xn log p(zn ; θ ,π ) + log p(xn |zn ; θ ,π ) (9.29)
n=1
We can then substitute in for the multinomial expression and simplify, and dropping constants
we have that we’re looking for parameters that solve
(N 2 N X
2
)
(i+1) xn ,1 xn ,6
XX X
arg max qn,k log θk + qn,k log πk,1 · . . . · πk,6
θ (i+1) ,π (i+1) n=1 k=1 n=1 k=1
XN X
2 N X
2 X
6
(i+1)
X
= arg max qn,k log θk + qn,k xn,j log(πk,j ) (9.30)
(i+1)
θ ,π (i+1)
n=1 k=1 n=1 k=1 j=1
126 CHAPTER 9. MIXTURE MODELS
To maximize the expected complete-data log likelihood, it’s necessary to introduce Lagrange
P (i+1) P (i+1)
multipliers to enforce the constraints k θk = 1 and j πk,j = 1, for each k. After doing this,
and solving, we recover the following update equations for the model parameters:
PN
(i+1) qn,k
θk ← n=1
N
PN
(i+1) n=1 qn,k xn
πk ← P ,
c N n=1 qn,k
where c = 10 in out example.
We now have everything we need to perform EM for this setup. After initializing our parameters
w(0) , we perform the E-step by evaluating 9.27. After calculating our values of qn in the E-step, we
update our parameters w = {θ, π 1 , π 2 } in the M-step by maximizing 9.30 with respect to θ, π 1 , π 2 .
We perform these two steps iteratively, until convergence of our parameters. We may also do a
restart.
This is the current expectation for our latent variables zn given our data xn and the current
setting of our model parameters θ, {µk , Σk }K
k=1 .
3. [M-Step] Using our values of qn , calculate the expected complete-data log likelihood, and
then use that term to optimize our model parameters:
N
X
Eqn [log p(X, Z)] = Eqn ln(p(xn , zn ; θ, {µk , Σk }K
k=1 ))
n=1
N X
X K
= qn,k ln θk + qn,k ln N (xn ; µk , Σk )
n=1 k=1
We can then use this expected complete-data log likelihood to optimize our model parameters
θ, {µ , Σ }K by computing the MLE as usual. Using a Lagrange multiplier to enforce
PK k k k=1
k=1 θk = 1, we recover the update equations:
PN
(i+1) n=1 qn,k
θk ←
N
PN
(i+1) n=1 qn,k xn
µk ← PN
n=1 qn,k
PN (i+1) (i+1) T
(i+1) n=1 qn,k (xn −µ )(xn − µk )
Σk ← PN k
n=1 qn,k
Finally, it’s worth comparing EM and K-Means clustering as applied to GMMs. First, as dis-
cussed previously, EM uses soft assignments of data points to clusters rather than hard assignments.
Second, the standard K-Means algorithm does not estimate the covariance of each cluster. How-
ever, if we enforce as a part of our GMM setup that the covariance matrices of all the clusters are
given by ϵI, then as ϵ → 0, EM and K-Means will in fact produce the same results.
⋆ If you haven’t seen the Dirichlet before, it is a distribution over an n-dimensional vector whose components sum to
1. For example, a sample from a dirichlet distribution in 3-dimensions could produce a sample that is the vector
0.2
0.5
0.3
.
We sample from that Dirichlet distribution to determine the mixture of topics θ n in our docu-
ment Dn :
θ n ∼ Dir(α)
Then, for each possible topic, we sample from a Dirichlet distribution to determine the mixture of
words ϕk in that topic:
ϕk ∼ Dir(β)
Then, for each word wn,j in the document Dn , we first sample from a Categorical parameterized
by the topic mixture θ n to determine which topic that word will come from:
zn,j ∼ Cat(θ n )
Then, now that we have a topic given by zn,j for this word wn,j , we sample from a Categorical
parameterized by that topic’s mixture over words given by ϕzn,j :
wn,j ∼ Cat(ϕzn,j )
Notice the mixture of mixtures at play here: we have a mixture model over the topics to produce
each document in our corpus, and then for every word in a given document, we have a mixture
over the topics to generate each individual word.
The indexing is particularly confusing because there are several layers of mixtures here, but to
clarify: n ∈ 1..N indexes each document Dn in our corpus, k ∈ 1..K indexes each possible topic,
and j ∈ 1..J indexes each word wn,j in document Dn , and e ∈ 1..E indexes each word in our
dictionary (note that wn,j ∈ RE ).
θ n specifies the distribution over topics in document Dn , and α is the hyperparameter for the
distribution that produces θ n . Similarly, ϕk specifies the distribution over words for the k th topic,
and β is the hyperparameter for the distribution that produces ϕk .
2. [E-Step] Fix the topic distribution of the document given by θ n and the word distribution
under a topic given by ϕk . Calculate the posterior distribution qn,j = p(zn,j |wn,j ), and note
that this is the distribution over the possible topics of a word:
p(zn,j = C1 |wn,j ; θ n , ϕ1 )
qn,j = E[zn,j |wn,j ] =
..
.
p(zn,j = CK |wn,j ; θ n , ϕK )
p(wn,j |zn,j = C1 ; ϕ1 )p(zn,j = C1 ; θ n )
∝
..
.
p(wn,j |zn,j = CK ; ϕK )p(zn,j = CK ; θ n )
ϕ1,wn,j · θn,1
=
..
.
ϕK,wn,j · θn,K
3. [M-Step] Using our values of qn , calculate the expected complete-data log likelihood (which
marginalizes over the unknown hidden variables zn,j ), and then use that expression to optimize
our model parameters θ n and ϕk :
N X
X J
Eqn [log p(W, Z)] = Eqn ln(p(wn,j , zn,j ; {θ n }N K
n=1 , {ϕk }k=1
n=1 j=1
N X
X J X
K
= qn,j,k ln θn,k + qn,j,k ln ϕk,wn,j
n=1 j=1 k=1
We can then use this expected complete-data log likelihood to optimize our model parameters
{θ n }N , {ϕk }K
k=1 by computing the MLE as usual. Using Lagrange multipliers to enforce
Pn=1
∀n k=1 θn,k = 1 and ∀k E
K P
e=1 k,e = 1 (where e indexes each word in our dictionary), we
ϕ
recover the update equations:
PJ
(i+1) j=1 qn,j,k
θn,k ←
J
PN PJ
(i+1) n=1 j=1 qn,j,k wn,j,d
ϕk,d ← PN PJ
n=1 j=1 qn,j,k
The largest headache for applying the EM algorithm to LDA is keeping all of the indices in
order, and this is the result of working with a mixture of mixtures. Once the bookkeeping is sorted
out, the actual updates are straightforward.
130 CHAPTER 9. MIXTURE MODELS
9.7 Conclusion
Mixture models are one common way of handling data that we believe is generated through a
combination of unobserved, latent variables. We’ve seen that training these models directly is
intractable (due to the marginalization over the latent variables), and so we turned to a coordinate
ascent based algorithm known as Expectation-Maximization to get around this difficulty. We then
explored a couple of common mixture models, including a multinomial mixture, Gaussian Mixture
Model, and an admixture model known as Latent Dirichlet Allocation. Mixture models are a subset
of a broader range of models known as latent variable models, and the examples seen in this chapter
are just a taste of the many different mixture models available to us. Furthermore, EM is just a
single algorithm for optimizing these models. A good grasp on the fundamentals of mixture models
and the EM algorithm will be useful background for expanding to more complicated, expressive
latent variable models.
Chapter 10
Many of the techniques we’ve considered so far in this book have been motivated by the types of data
we could expect to work with. For example, the supervised learning techniques (forms of regression,
neural networks, support vector machines, etc.) were motivated by the fact that we had labelled
training data. We ventured into clustering to group unlabelled data and discussed dimensionality
reduction to handle overly high-dimensional data. In the previous chapter, we examined techniques
for managing incomplete data with latent variable models. In this chapter we turn to a technique
for handling temporal data.
10.1 Motivation
One major type of data we have not yet paid explicit attention to is time series data. Most of the
information we record comes with some sort of a timestamp. For example, any time we take an
action online, there is a high probability that the database storing the data also tracks it with a
timestamp. Physical sensors in the real world always record timestamps because it would be very
difficult to make sense of their information if it is not indexed by time. When we undergo medical
exams, the results are recorded along with a timestamp. It’s almost inconceivable at this point
that we would record information without also keeping track of when that data was generated, or
at the very least when we saved that data.
For these reasons, its interesting to develop models that are specialized to temporal data.
Certainly, time encodes a lot of information that we take for granted about the physical and digital
worlds. For example, if the sensors on a plane record the position of the plane at a specific point
in time, we would expect the surrounding data points to be relatively similar, or at least move in
a consistent direction. In a more general sense, we expect that time constrains other attributes of
the data in specific ways.
In this chapter, we will focus on one such model known as a Hidden Markov Model or
HMM. At a high level, the goal of an HMM is to model the state of an entity over time, with the
caveat that we never actually observe the state itself. Instead, we observe a data point xt at each
time step (often called an ‘emission’ or ‘observation’) that depends on the state st . For example,
we could model the position of a robot over time given a noisy estimation of the robot’s current
position at each time step. Furthermore, we will assume that one state st transitions to the next
state st+1 according to a probabilistic model. Graphically, an HMM looks like Figure 10.1, which
encodes the relationships between emissions and hidden states. Here, there are n time steps in
total.
We will probe HMMs in more detail over the course of the chapter, but for now let’s consider
131
132 CHAPTER 10. HIDDEN MARKOV MODELS
⋆ Models that treat continuous state variables are commonly referred to as dynamical systems.
10.2 Applications
Unsurprisingly, there are many applications for models like HMMs that explicitly account for time
and unobserved states, especially those that relate to the physical world. Examples include:
1. The position of a robot arm when its movements may be non-deterministic and sensor readings
are noisy. [State = robot arm position; observation = sensor reading]
3. Analyzing sequences that occur in the natural world, such as DNA [State = codon, a genetic
code in a DNA molecule; observation= one of the four bases, i.e., A, C, T, or G]
10.3. HMM DATA, MODEL, AND PARAMETERIZATION 133
⋆ In general the observed emissions don’t have to be discrete, but for the sake of being explicit, we present the discrete
interpretation here.
A data set has N data points, meaning N sequences where each sequence is composed of n
emissions (in general they can be of different lengths, we assume same length for simplicity). To
summarize:
• A data set consists of N sequences.
• Each sequence is composed of n observed emissions x1 , ..., xn .
• Each emission xt takes on one of M possible values.
• Each hidden state st take on one of K possible values.
⋆ The Markovian assumption for transitions, as well as the fact that we don’t observe the true states, gives rise to
the Hidden Markov Model name.
These two assumptions allow us to factorize the large joint distribution given by Equation 10.1
as follows:
n−1
Y n
Y
p(s1 , ..., sn )p(x1 , ..., xn |s1 , ..., sn ) = p(s1 ) p(st+1 |st ) p(xt |st ) (10.2)
t=1 t=1
This factorization will prove important for making HMM training and inference tractable.
134 CHAPTER 10. HIDDEN MARKOV MODELS
1. Parameters
P for the prior over the initial hidden state p(s1 ). This will be denoted θ ∈ [0, 1]K ,
with k θk = 1, such that:
p(s1 = k) = θk .
2. Parameters for the transition probabilities between states p(st+1 |st ). This will be denoted
T ∈ [0, 1]K×K , with j Ti,j = 1 for each i, such that:
P
3. Parameters for the conditional probability of the emission, p(xt |st ), given the state. This will
K×M
P
be denoted π ∈ [0, 1] , with m πk,m = 1 for each k, such that:
In sum, we have three sets of parameters θ ∈ [0, 1]K , T ∈ [0, 1]K×K , and π ∈ [0, 1]K×M that we
need to learn from our data set. Then, using a trained model, we will be able to perform several
types of inference over our hidden states, as detailed next.
• Prediction p(xt+1 |x1 , . . . , xt ) (what is the prediction of the next emission given what is known
so far?)
• Smoothing p(st |x1 , . . . , xn ), t ≤ n (after the fact, what do we predict for some earlier state?)
• Transition p(st , st+1 |x1 , . . . , xn ), t + 1 ≤ n (after the fact, what do we predict for the joint
distribution on some pair of temporally adjacent states?)
• Filtering p(st |x1 , . . . , xt ) (what is the prediction, in real-time, of the current state?)
• Best path max p(s1 , . . . , sn |x1 , . . . , xn ) (after the fact, what is the most likely sequence of
states?)
10.4. INFERENCE IN HMMS 135
For just one example, let’s consider smoothing p(st |x1 , ..., xn ). To compute this would require
marginalizing over all the unobserved states other than st , as follows:
X n−1
Y n
Y
= p(s1 ) p(st+1 |st ) p(xt |st ).
s1 ,...,st−1 ,st+1 ,...,sn t=1 t=1
Without making use of variable elimination, this requires summing over all possible states other
than t, which is very costly. Moreover, suppose we then query this for another state. We’d need
to sum again over all the states except for this new state, which duplicates a lot of work. Rather
than performing these summations over and over again, we can instead “memoize” (or reuse) these
kinds of summations using the Forward-Backward algorithm. This algorithm also makes uses of
variable elimination to improve the efficiency of inference.
The Forward-Backward algorithm uses variable elimination methods to compute two sets of quanti-
ties, that we refer to as the “alpha” and “beta” values. The algorithm makes elegant use of dynamic
programming (breaking down optimization problem into sub-problems, solving each sub-problem a
single time, and storing the solutions). It can be viewed as a preliminary inference step such that
the alpha and beta values can then be used for all inference tasks of interest as well as within EM
for training a HMM model.
The Forward-Backward algorithm is also an example of a message-passing scheme, which means
we can conceptualize it as passing around compact messages along edges of the graphical model that
corresponds to a HMM. The algorithm passes messages forwards and backwards through ‘time’,
meaning up and down the chain shown in the graphical model representation in Figure 10.1. The
forward messages are defined at each state as αt (st ), while the backward messages are defined at
each state as βt (st ). The overarching idea is to factor the joint distribution
because the factored terms can be efficiently computed. Let’s define these α and β terms explicitly.
The αt ’s represent the joint probability of all our observed emissions from time 1, ..., t as well
as the state at time t:
Graphically, the αt ’s are capturing the portion of the HMM shown in Figure 10.2.
We can factorize this joint probability using what we know about the conditional independence
136 CHAPTER 10. HIDDEN MARKOV MODELS
Figure 10.2: αt ’s capture the joint probability for the boxed portion; shown for α3 (s3 )
The first term in Equation (10.4) follows from the Markov property, and for the second term
we’ve expressed this joint probability by explicitly introducing st−1 and marginalizing out over this
variable. Equation 10.6 follows from the Markov property, and by substituting for the definition of
the alpha value.
Notice that our expression for αt (st ) includes the expression for αt−1 (st−1 ), which is the α from
the previous time step. This means we can define our messages recursively. After we’ve computed
the α’s at one time step, we pass them forwards along the chain and use them in the computation
of alpha values for the next time step. In other words, we compute the α values for period 1, then
pass that message along to compute the α values in period 2, and so forth until we reach the end
of the chain and have all the α’s in hand.
⋆ These α values are used both for inference and training a HMM via EM (in the E-step)
At this point, we’ve handled the forward messages, which send information from the beginning
to the end of the chain. In the backward portion, we also send information from the end of the
chain back to the beginning. In this backward message pass, we will compute our β values. The
βt ’s represent the joint probability over all the observed emissions from time t + 1, ..., n conditioned
on the state at time t:
βt (st ) = p(xt+1 , ..., xn |st ) (10.7)
Graphically, this means that the βt ’s are capturing the portion of the HMM shown in Figure 10.3.
We can factorize Equation 10.7 in a similar way to how we factorized the distribution described
10.4. INFERENCE IN HMMS 137
by the α’s:
Here, Equation (10.8) introduces st+1 and marginalizes out over this variable. Equation (10.9)
is the product rule (recall that n ≥ t + 1, so the third conditional probability starts at xt+2 ).
Equation (10.10) makes use of the Markov property in the last two terms. Equation (10.11)
substitutes in the expression for βt+1 from Equation (10.7).
As we saw with our calculation of the α’s, we can calculate β recursively. This recursive
definition enables us to propagate messages backward and compute the beta values efficiently in
one pass. In this case, we start at the end of the chain (t = n), and compute our β’s for each state
by passing messages back toward the front.
To summarize, the Forward-Backward algorithm calculates the α and β values as follows:
( P
p(xt |st ) st−1 p(st |st−1 )αt−1 (st−1 ) 1<t≤n
αt (st ) =
p(x1 |s1 )p(s1 ) otherwise
(P
st+1 p(st+1 |st )p(xt+1 |st+1 )βt+1 (st+1 ) 1 ≤ t < n
βt (st ) =
1 otherwise
⋆ Notice that the base case for the β’s is n. This is a quirk of our indexing, and it ensures we have a defined sn when
we pass messages back to calculate sn−1 , sn−2 , . . . .
Figure 10.3: βt ’s capture the joint probability for the boxed portion of the HMM. Shown here for
β2 (s2 )
138 CHAPTER 10. HIDDEN MARKOV MODELS
αt (st )βt (st ) = p(x1 , ..., xt , st )p(xt+1 , ..., xn |st ) = p(x1 , ..., xn , st ).
This is the joint distribution over all emissions and the state at time t. Using this as a building
block, this can support many kinds of inference.
p(Seq)
For example, we might like to evaluate the joint distribution over the emissions.
X X
p(x1 , ..., xn ) = p(x1 , ..., xn , st ) = αt (st )βt (st ) (10.12)
st st
where we can sum over the possible state values. This calculation be defined for any state st .
Prediction
Another common task is to predict the value of the next emission given the previous emissions.
To compute this we can sum over state st and the next state st+1 as follows:
Here, Equation (10.13) follows by introducing states st and st+1 and marginalizing out over
them. Equation (10.14) follows from the product rule, and Equation (10.15) by using the Markov
property in two places and substituting for αt (st ).
Smoothing
Smoothing is the problem of predicting the state at time t given all the observed emissions. We can
think about this as updating the beliefs that we would have had in real-time, given emissions up
to and including t, given all observed evidence up to period n. Hence the phrasing “smoothing.”
For this, we have
p(st |x1 , ..., xn ) ∝ p(x1 , ..., xn , st ) = αt (st )βt (st ). (10.16)
10.4. INFERENCE IN HMMS 139
Transition
Finally, we may wish to understand the joint distribution on states st and st+1 given all the observed
evidence.
Here, Equation (10.17) follows from the product rule, and (10.18) by substituting for αt (st ),
applying the Markov property three times, and substituting for βt+1 (st+1 ).
Filtering
For filtering, we have
Best path
For the best path problem, we want to find the sequence of states that is most likely to give rise to
the observed emissions. We solve
This sometimes referred to as the “decoding” (or explanation) problem. For this, we can define
the following function:
This is the likelihood of x1 , . . . , xt , if the current state is st , and under the best explanation so
far (we maximized over s1 , . . . , st−1 ). Recall that the recurrence for alpha is as follows:
P
p(xt |st ) st−1 p(st |st−1 )αt−1 (st−1 ) if 1 < t ≤ n
∀st : αt (st ) =
p(x1 |s1 )p(s1 ) otherwise
Analogously, the recurrence for this γ-value can be shown to be (see if you can derive this for
yourself):
p(xt |st ) maxst−1 p(st |st−1 )γt−1 (st−1 ) if 1 < t ≤ n
∀st : γt (st ) = (10.21)
p(x1 |s1 )p(s1 ) otherwise
To be able to find the optimal sequence, we also store, for each st , the best choice of st−1 :
This recursive procedure is known as the Viterbi algorithm and provides an efficient way to
infer the “best path” through states given a sequence of observations.
140 CHAPTER 10. HIDDEN MARKOV MODELS
10.5.1 E-Step
For the E-Step, we take the parameters θ, T, π as fixed. For each data point xi , we run Forward-
Backward with these parameters to get the α and β values for this data point.
The hidden variables are the states s1 , ..., sn . For each data point xi = (xi1 , ..., xin ), we are
interested in computing the predicted probabilities qi ∈ [0, 1]n×K , for K possible hidden state
values. The qi represent the predicted probability of each hidden state value for each time period.
In particular, we have
i
qt,k = p(sit = k | xi1 , ..., xin ). (10.23)
This is the probability that state sit takes on value k given the data point xi . This is the
smoothing operation described in the previous section and we can use Equation 10.16 to compute
our qi values.
Ordinarily, we’d be done with the E-step after computing the marginal probability of each latent
variable. But in this case we will also want to estimate the transition probabilities between hidden
states, i.e., parameter matrix T. For this, we also need to calculate the joint distribution between
temporally-adjacent pairs of latent variables (e.g. st , st+1 ). For data point xi , and for periods t
and t + 1, we denote this as Qit,t+1 ∈ [0, 1]K×K , where the entries in this matrix sum to 1. This
represents the distribution on pairs of states in periods t and t + 1 for this data point. We write
Qi to denote the corresponding values for all pairs of time periods.
To see how the entries are calculated, we can use Qit,t+1,k,ℓ to denote the transition from state
k at time step t to state ℓ at time step t + 1,
This is exactly the transition inference problem that we described in the previous section.
Because of this, we can directly use our α and β values in the transition operation, as given by
Equation (10.18).
With our qi and Qi values for each data point, we are ready to move on to the maximization
step.
10.6. CONCLUSION 141
10.5.2 M-Step
We now solve for the expected complete-data log likelihood problem, making use of the qi and Qi
values from the E-step.
Given knowledge of states, the complete-data likelihood for one data point with observations x
and states s is
n−1
Y n
Y
p(x, s) = p(s1 ; θ) p(st+1 |st ; T) p(xt |st ; π).
t=1 t=1
With one-hot encoding of xt and st , and taking the log, the expression becomes
K
X n−1
XX K X
K n X
X K M
X
ln[p(x, s)] = s1k ln θk + st,k st+1,ℓ ln Tk,ℓ + st,k xt,m ln(πk,m ).
k=1 t=1 k=1 ℓ=1 t=1 k=1 m=1
To see this, recall that productorials become summations when we take logarithms. The un-
bolded symbols represent single entries of the discrete probability distributions, and we create
additional summation indices k, ℓ over K possible hidden state values for st , st+1 (K is the di-
mension of the one-hot encoding of s). We also create an additional summation index m over M
possible emission values for each xt .
From this, we would be able to solve for the MLE for the parameters for the complete-data log
likelihood. Now, the states are latent variables, and we need to work with the expected complete-data
log likelihood, which for a single data point xi , is
K
X n−1
XX K X
K n X
X K M
X
Esi [ln(p(xi , si ))] = i
q1k ln θk + Qit,t+1,k,ℓ ln Tk,l + i
qt,k xit,m ln πk,m .
k=1 t=1 k=1 ℓ=1 t=1 k=1 m=1
(10.25)
Applying the appropriate Lagrange multipliers, and maximizing with respect to each of the
parameters of interest, we can obtain the following update equations for each of the parameters
(for N data points):
PN i
i=1 q1,k
θk = , for all states k (10.26)
N
PN Pn−1 i
i=1 t=1 Qt,t+1,k,l
Tk,l = PN Pn−1 i , for all states k, ℓ (10.27)
i=1 t=1 qt,k
PN Pn i i
i=1 t=1 qt,k xt,m
πk,m = PN Pn , for all states k, observations m (10.28)
i
i=1 t=1 qt,k
After updating our parameter matrices θ, T, and π, we switch back to the E-step, continuing
in this way until convergence. As with other uses of EM, it provides only a local optimum and it
can be useful to try a few different random initializations.
10.6 Conclusion
The Hidden Markov Model is a type of latent variable model motivated by the combination of time
series and discrete observations and states. We relied on the Expectation-Maximization algorithm
to train a HMM, and developed the Forward-Backward algorithm to make both inference and
training (the E-step) computationally efficient. Many of the ideas developed in this chapter will
offer good intuition for how to develop learning and inference methods for dynamical systems and
other time series models.
Chapter 11
In the previous chapter we learned about Hidden Markov Models (HMMs) in which we modeled
our underlying environment as being a Markov chain where each state was hidden, but produced
certain observations that we could use to infer the state. The process of learning in this setting
included finding the distribution over initial hidden states, the transition probabilities between
states, and the probabilities of each state producing each measurement. The movement of the
underlying Markov chain from one state to the next was a completely autonomous process and
questions of interest focused mainly on inference.
Markov Decision Processes (MDPs) introduce somewhat of a paradigm shift. Similar to HMMs,
in the MDP setting we model our underlying environment as a Markov chain. However, our
environment doesn’t transition from state to state autonomously like in an HMM, but rather
requires that an agent in the environment takes some action. Thus, the probability of transitioning
to some state at time t + 1 depends on both the state and the action taken at time t. Furthermore,
after taking some action the agent receives a reward which you can think of as characterizing the
“goodness” of performing a certain action in a certain state. Thus, our overall objective in the
MDP setting is different. Rather than trying to perform inference, we rather would like to find a
sequence of actions that maximizes the agent’s total reward.
While MDPs can have hidden states – these are called partially-observable MDPs (POMDPs)
– for the purposes of this course we will only consider MDPs where the agent’s state is known. If,
like in HMMs, we do not have full knowledge of our environment and thus do not know transition
probabilities between states or the rewards associated with each state and action we can try to
find the optimal sequence of actions using reinforcement learning —a subfield of machine learning,
distinct from supervised and unsupervised learning, that is characterized by problems in which an
agent seeks to explore its environment, and simultaneously use that knowledge to perform actions
that exploit its environment and obtain rewards. This chapter discusses how we find the optimal
set of actions given that we do have full knowledge of the environment.
Note that the question mark under the “training” category means that learning in MDPs is neither
supervised nor unsupervised but rather falls under the umbrella of reinforcement learning - an
entirely separate branch of ML.
142
11.1. FORMAL DEFINITION OF AN MDP 143
D = {s0 , a0 , r0 , s1 , a1 , r1 , . . . }
Note that since our state transitions are modeled as a Markov chain, the Markov assumption
holds:
In other words, the probability that the environment transitions to state st+1 at time t only depends
on the state the agent was in and the action the agent took at time t. Furthermore, we can assume
that transition probabilities are stationary:
In other words, the probability of transitioning from one state to another does not depend on the
current timestep (note how we dropped the subscript t in the expression on the RHS).
We call the process of finding the optimal policy in an MDP given full knowledge of our envi-
ronment planning (when we don’t have prior knowledge about our environment but know that it
behaves according to some MDP with unknown parameters we turn to Reinforcement Learning -
see Chapter 12). Our approach to planning changes depending on whether there is a limit to the
number of timesteps the agent may act for, which we call a finite horizon, or if they may act forever
(or effectively forever), which we call an infinite horizon.
144 CHAPTER 11. MARKOV DECISION PROCESSES
Note that the only source of randomness in the system is the transition function.
We find the optimal policy for a finite horizon MDP using dynamic programming. To formalize
this, let’s define the optimal value function V(t)∗ (s) to be the highest value achievable in state s
(i.e. acting under the optimal policy π ∗ ) with t timesteps remaining for the agent to act. Then, we
know
∗
V(1) (s) = max[r(s, a)]
a
since with only one timestep left to act, the best the agent can do is take the action that maximizes
their immediate reward. We can define the optimal value function for t > 1 recursively as follows:
X
∗
V(t+1) (s) = max[r(s, a) + p(s′ |s, a)V(t)
∗
(s′ )]
a
s′ ∈S
In other words, with more than one timestep to go we take the action that maximizes not only
our immediate reward but also our expected future reward. This formulation makes use of the
principle of optimality or the Bellman consistency equation which states that an optimal policy
consists of taking an optimal first action and then following the optimal policy from the successive
state (these concepts will be discussed further in section 11.3.1). Consequently, we have a different
optimal policy at each timestep where
∗
π(1) (s) = arg max[r(s, a)]
a
X
∗
π(t+1) (s) = arg max[r(s, a) + p(s′ |s, a)V(t)
∗
(s′ )]
a
s′ ∈S
The computational complexity of this approach is O(|S|2 |A|T ) since for each state and each
timestep we find the value-maximizing action which involves calculating the expected future reward
requiring a summation over the value function for all states.
where γ ∈ [0, 1] is called the discount factor. Multiplying the reward at time t in the infinite sum
by γ t makes the sum resolve to a finite number, since our rewards are bounded in [0, 1]. As γ → 1,
11.3. INFINITE HORIZON PLANNING 145
rewards further in the future have more of an impact on the expected total reward and thus the
optimal policy will be more “patient”, acting in a way to achieve high rewards in the future rather
than just maximizing short-term rewards. Conversely, as γ → 0, rewards further in the future will
have less of an impact on the expected total reward and thus the optimal policy will be less patient,
preferring to maximize short-term rewards.
The effective time horizon of an infinite horizon MDP is the number of timesteps after which
t
γ becomes so small that rewards after time t are negligible. Using the formula for the sum of a
1
geometric series we find that this is approximately 1−γ . This yields our first approach to planning
1
in an infinite horizon MDP which is to convert it into a finite horizon MDP where T = 1−γ .
However, note that this fraction could be arbitrarily large and that the time complexity of the
dynamic programming solution is linear with respect to T . This implores us to consider alternate
solutions.
′
Then, policy π is weakly better than policy π ′ if V π (s) ≥ V π (s) ∀s. The optimal policy π ∗ is that
which is weakly better than all other policies.
Consequently, the optimal value function is the value function following policy π ∗ :
We now apply Adam’s law (dropping the subscripts on the expectations to avoid notational clutter):
The Bellman optimality conditions are a set of two theorems that tell us important proper-
ties about the optimal value function V ∗ .
This theorem also tells us that if π̂(s) = arg maxa [r(s, a) + γEs′ ∼p [V ∗ (s′ )]] then π̂ is the optimal
policy.
Proof: Let π̂(s) = arg maxa [r(s, a) + γEs′ ∼p [V ∗ (s′ )]]. To prove the claim it suffices to show that
V ∗ (s) ≤ V π̂ (s) ∀s. By the Bellman consistency equation we have:
V ∗ (s) = r(s, π ∗ (s)) + γEs′ ∼p [V ∗ (s′ )]
≤ max[r(s, a) + γEs′ ∼p [V ∗ (s′ )]]
a
= r(s, π̂(s)) + γEs′ ∼p [V ∗ (s′ )]
Proceeding recursively:
≤ r(s, π̂(s)) + γEs′ ∼p [r(s′ , π̂(s′ )) + γEs′′ ∼p [V ∗ (s′′ )]]
···
≤ Es,s′ ,···∼p [r(s, π̂(s)) + γr(s′ , π̂(s′ ) + · · · |π̂]
= V π̂ (s)
Theorem 2: For any value function V , if V (s) = maxa [r(s, a) + γEs′ ∼p [V (s′ )]] ∀s then V = V ∗ .
Proof: First, we define the maximal component distance between V ad V ∗ as follows:
||V − V ∗ ||∞ = max |V (s) − V ∗ (s)|
s
For V which satisfies the Bellman optimality equations, if we could show that ||V − V ∗ ||∞ ≤
γ||V − V ∗ ||∞ then the proof would be complete since it would imply:
||V − V ∗ ||∞ ≤ γ||V − V ∗ ||∞ ≤ γ 2 ||V − V ∗ ||∞ ≤ · · · ≤ lim γ k ||V − V ∗ ||∞ = 0
k→∞
Thus we have:
|V (s) − V ∗ (s)| = |maxa [r(s, a) + γEs′ ∼p [V (s′ )]] − max[r(s, a) + γEs′ ∼p [V ∗ (s′ )]]|
a
≤ max |r(s, a) + γEs′ ∼p [V (s′ )] − r(s, a) − γEs′ ∼p [V ∗ (s′ )]|
a
= γ max |Es′ ∼p [V (s′ )] − Es′ ∼p [V ∗ (s′ )]|
a
≤ max Es′ ∼p [|V (s′ ) − V ∗ (s′ )]
a
≤ γ max
′
|V (s′ ) − V ∗ (s′ )|
a,s
= γ max
′
|V (s′ ) − V ∗ (s′ )|
s
= γ||V − V ∗ ||∞
11.3. INFINITE HORIZON PLANNING 147
They key takeaways from the Bellman optimality conditions are twofold. Theorem 1 gives us
a nice expression for the optimal value function that we will soon show allows us to iteratively
improve an arbitrary value function towards optimality. It also tells us what the corresponding
optimal policy is. Theorem 2 tells us that the optimal value function is unique so if we find some
value function that satisfies the Bellman consistency equations given in Theorem 1 then it must be
optimal.
Bellman Operator
In the previous section we mentioned that the formulation of the optimal value function we get
from Theorem 1 will allow us to improve some arbitrary value function iteratively. In this section
we will show why this is the case. Define the Bellman operator B : R|S| → R|S| to be a function
that takes as input a value function and outputs another value function as follows:
By the Bellman optimality conditions we know that B(V ∗ ) = V ∗ and that this identity property
doesn’t hold for any other value function. The optimal value function V ∗ is thus the unique fix-point
of the function B which means that passing V ∗ into B always returns V ∗ .
Furthermore, we know that B is a contraction mapping which means that ||B(V ) − B(V ′ )||∞ ≤
γ||V − V ′ ||∞ .
Proof:
|B(V (s)) − B(V ′ (s))| = | max[r(s, a) + γEs′ ∼p [V (s′ )] − max[r(s, a) + γEs′ ∼p [V ′ (s′ )]]]
a a
≤ max |r(s, a) + γEs′ ∼p [V (s′ ) − r(s, a) − γEs′ ∼p V ′ (s′ )]|
a
= γ max |Es′ ∼p [V (s′ ) − V ′ (s′ )]|
a
≤ γ max[Es′ ∼p [|V (s′ ) − V ′ (s′ )|]]
a
≤ γ max |V (s) − V ′ (s)|
a,s
= γ||V − V ′ ||∞
It is a theorem that value iteration converges to V ∗ asymptotically. We can also extract the
optimal policy from the resulting value function asymptotically. Each iteration of value iteration
takes O(|S|2 |A|) time.
Policy Evaluation
In the previous section, we saw that one of the steps in the policy iteration algorithm was to
calculate the value function for each state following a policy π. There are two ways to do this:
• Exact policy evaluation:
We know that our value function must satisfy the Bellman consistency equations:
X
V π (s) = r(s, π(s)) + γ p(s′ |s, π(s))V π (s′ ) ∀s
s∈S
This is a system of |S| linear equations that has a unique solution. Rearranging terms,
representing functions as matrices and vectors, and replacing sums with matrix multiplication
we get that:
Vπ = (I − γPπ )−1 Rπ
where Vπ is an |S| × 1 vector, I is the identity matrix, Pπ is an |S| × |S| matrix where
Pπs,s′ = p(s′ |s, π(s))), and Rπ is an |S| × 1 vector where Rπs = r(s, π(s)). This subroutine
runs in time O(|S|3 ).
• Iterative policy evaluation:
Rather than solving the system of equations described earlier exactly, we can instead do it
iteratively. We perform the following two steps:
h i
1. Initialize V0 such that ||V0 ||∞ ∈ 0, 1−γ
1
Reinforcement Learning
12.1 Motivation
In the last chapter, we discussed the Markov Decision Process (MDP): a framework that models a
learner’s environment as a vector of states, actions, rewards, and transition probabilities. Given this
model, we can solve for an optimal (reward-maximizing) policy using either value iteration or policy
iteration. Sometimes, however, the learner doesn’t have prior knowledge of the rewards they will
get from each state or the probability distribution over states they could end up in after taking some
action from their current state. Is it still possible to learn a policy that will maximize rewards? In
this chapter, we will learn about Reinforcement Learning (RL) - a machine learning technique
that addresses this problem.
ML Framework Cube: Reinforcement Learning
Note that we will only cover RL for discrete state spaces in this chapter. The question mark under
the “training” category means that RL is neither supervised nor unsupervised.
1. Time spent exploring. If you spend the right amount of time exploring new restaurants,
then there is a good chance that you will find one that you like a lot. However, spending too
much time exploring will result in an average experience overall since some restaurants will
be good and some will be bad.
149
150 CHAPTER 12. REINFORCEMENT LEARNING
2. Time spent exploiting. If you spend the right amount of time exploiting your knowledge of
what restaurants you’ve liked the most from your exploration you will have a good experience.
However, if you spend too long exploiting (and consequently less time exploring) you risk
missing out on eating at all the restaurants that you could have liked more.
This is called the exploration vs exploitation tradeoff. The more time you spend exploring
the less time you can spend exploiting, and vice versa. In order to have the best vacation experi-
ence, you need to find a balance between time spent exploring different restaurants and time spent
exploiting what you have learned by eating at that restaurants that you liked the most thus far.
A reinforcement learner’s task is similar to the vacation restaurant task described above. They
must balance time spent exploring the environment by taking actions and observing rewards and
penalties, and time spent maximizing rewards based on what they know.
We now return to the problem described in the previous section: how do we learn a policy that
will maximize rewards in an MDP where we know nothing about which rewards we will get from
each state and which state we will end up in after taking an action? One approach is to try and
learn the MDP. In other words, we can create a learner that will explore its environment for some
amount of time in order to learn the rewards associated with each state and the transition prob-
abilities associated with each action. Then, they can use an MDP planning algorithm like policy
iteration or value iteration to come up with an optimal policy and maximize their rewards. This is
called model-based learning. The main advantage of model-based learning is that it can inexpen-
sively incorporate changes in the reward structure or transition function of the environment into
the model. However, model-based learning is computationally expensive, and an incorrect model
can yield a sub-optimal policy.
We will begin by discussing value-based methods. In this family of RL algorithms, the learner tries
to calculate the expected reward they will receive from a state s upon taking action a. Formally,
they are trying to learn a function that maps (s, a) to some value representing the expected reward.
This is called the Q-function and is defined as follows for some policy π:
X
Qπ (s, a) = r(s, a) + γ p(s′ |s, a)V π (s′ ) (12.1)
s′
In words, the approximate expected reward (Q value) for taking action s from state a is the actual
reward received from the environment by doing so in the current iteration plus the expectation
taken over all reachable states of the highest value achievable starting at that state times the
12.3. MODEL-FREE LEARNING 151
discount factor γ. The Q-function following the optimal policy is defined analogously:
X
Q∗ (s, a) = r(s, a) + γ p(s′ |s, a)V ∗ (s′ ) (12.2)
s′
Note, that V ∗ (s′ ) = maxa′ Q∗ (s′ , a′ ) since the highest value achievable from state s′ following policy
∗ is the Q value of taking the optimal action from s′ . Substituting this in, we get the following
Bellman Equation:
X′
Q∗ (s, a) = r(s, a) + γ p(s′ |s, a) max
′
Q∗ (s′ , a) (12.3)
a
s
Note that we can’t directly calculate the term γ ′s p(s′ |s, a) max′a Q∗ (s′ , a) since we don’t know
P
p(s′ |s, a). We will discuss how this is addressed by the two algorithms we will cover in the value-
based family.
2. Use s (current state), a (action), r (reward), s′ (next state), a′ (action taken from next state)
in order to update the approximation of Q(s, a)
We will refer to (s, a, r, s′ , a′ ) as an experience. Let π(s) we the action that a learner takes from
state s. One strategy for acting that attempts to balance exploration and exploitation is called
ϵ-greedy and defined as follows:
(
argmaxa Q(s, a) with probability 1 - ϵ
π(s) = (12.4)
random with probability ϵ
Here, ϵ is some number ∈ [0, 1] which controls how likely the learner is to choose a random action
as opposed to the currently known optimal action. Varying the value of ϵ changes the balance
between exploration and exploitation.
Once the learner has had an experience, they can begin to learn Q∗ . We will now describe two
algorithms which perform this update differently for every new experience:
152 CHAPTER 12. REINFORCEMENT LEARNING
1. SARSA: Q(s, a) ← αt (s, a)[r + γQ(s′ , a′ ) − Q(s, a)] where αt (s, a) is the learning rate, a
parameter which controls how much the observation affects Q(s, a).
The expression r + γQ(s′ , a′ ) is a 1-step estimate of Q(s, a). The expression in the square
brackets above is called the temporal difference (TD) error and represents the difference
between the previous estimate of Q(s, a) and the new one. Since the action a′ that we use
for our update is the one that was recommended by the policy π (recall that a′ = π(s′ )),
SARSA is an on-policy algorithm. This means that if there was no epsilon greedy action
selection, the reinforcement learner would always act according to the policy π and SARSA
would converge to V π .
Convergence Conditions
Let αt (s, a) = 0 for all (s, a) that are not visited at time t. It is a theorem that Q-learning converges
to Q∗ (and hence π converges to π ∗ ) as t → ∞ as long as the following two conditions are met:
•
P
t αt (s, a) = ∞ ∀s, a
The sum of the learning rate over infinitely many time steps must diverge. In order for this
to happen, each state-action pair (s, a) must be visited infinitely often. Thus, we see the
importance of an ϵ-greedy learner which forces the agent to probablistically take random
actions in order to explore more of the state space.
• 2
P
t α(s, a) < ∞ ∀s, a
The sum of the square of the learning rate over infinitely many time steps must converge.
Note that in order for αt (s, a)2 < ∞, the learning rate must be iteratively reduced. For
1
example, we could set αt (s, a) = Nt (s,a) where Nt (s, a) is the number of times the learner
took action a from state s.
SARSA converges to Q∗ if the above two conditions are met and behavior is greedy in the limit
(ie. ϵ → 0 as t → ∞). One common choice is ϵt (s) = Nt1(s) where Nt (s) is the number of times the
learner visited state s. The notation ϵt (s) implies that a separate ϵ value is maintained for every
state rather than just maintaining one value of ϵ that controls learner behavior across all states
(though this is also an option).
The neural network takes the learner’s current state as input and outputs the approximated Q-
values for taking each possible action from that state by training parameters w. The learner’s next
action is the maximum over all outputs of the neural network.
While there are several specific variants depending on the problem that is trying to be solved, the
general loss function that the neural network tries to minimize at iteration i is as follows:
Here, (s, a, r, s′ , a′ ) and γ are defined the same as in regular Q-learning. wi are the parameters
that the neural net is training during the current iteration of the RL algorithm, and wi−1 are
the optimal parameters from the previous iteration of the algorithm. The TD error is the term
r + γ maxa′ Q(s′ , a′ ; wi−1 ), and the squared term inside the expectation is the TD target. Since
directly optimizing the loss in equation 12.5 is difficult, gradient descent, specifically stochastic
gradient descent is used (in order to avoid having to calculate the entire expected value term in
12.5).
The Atari deep Q network that solved the game of brickbreaker used a technique called experi-
ence replay to make updates to the network more stable. The experience (s, a, r, s′ ) was put into a
replay buffer which was sampled from to perform minibatch gradient descent to minimize the loss
function.
1. The Q function is far more complex than the policy being learned.
We seem to have run into an issue here. The term ∇θ µθ (h) depends on the transition probability
p(∗|s, a) by definition, but point of policy learning was to avoid having to learn these transition
probabilities. However, it turns out we can circumvent this issue by applying the useful identity
1
∇ θ µθ = µ θ ∇θ µθ = µθ ∇θ ln µθ (12.7)
µθ
Thus, we can rewrite the previous equation as follows:
Z " #
X X
r(h)µθ (h)∇θ ln πθ (at |st ) + ln p(st+1 |st , a) δh (12.8)
h t t
" #
X
= E r(h) ∇θ ln πθ (at |st ) (12.9)
t
" #
X
=E ∇θ ln πθ (at |st )rt (12.10)
t
Equation 12.10 only involves πθ (at |st ) which we are trying to learn, and rt which we observe from
the environment. We can now perform SGD to find the optimal θ. Note policy learning is on-policy.
Recall that in model-based learning, the learner tries to learn the parameters of the MDP underlying
their environment and then uses planning to come up with a policy. A model-based learner repeats
the following three steps until satisfactory performance is achieved:
12.5. CONCLUSION 155
2. Learn/update a model of the underlying MDP based on the experiences gained from step 1.
3. Use the model from step 2 to plan for the next M steps.
In many implementations, “learning” a model means coming up with a maximum likelihood esti-
mate of the parameters of the underlying MDP: the reward function and transition probabilities
(we will not cover the specifics of this process in this section). We can now plan according to this
model and follow the recommended policy for the next M steps. One issue with this basic model
is that it doesn’t do anything to handle time spent exploring vs exploiting. While there are many
ways to address this in practice, we will now present three common approaches:
• Incorporate exploration directly into the policy, for example, by using an ϵ-greedy strategy
as discussed in 12.3.1.
• Optimism under uncertainty: Let N (s, a) be the number of times we visited the state-
action pair (s, a). Each “visit” is a time that the agent was in state s and chose to perform
action a. When we are learning, we will assume that if N (s, a) is small, then the next
state will have higher reward than predicted by the maximum likelihood model. Since the
model thinks taking visiting lesser-known state-action pairs will lead to high reward (we are
optimistic when we are uncertain), we’ll have a tendency to explore these lesser-known areas.
12.5 Conclusion
To recap, reinforcement learning is a machine learning technique that allows a learner to find an
optimal policy in an environment modeled by an MDP where the transition probabilities and reward
function are unknown. The learner gains information by interacting with the environment and
receiving rewards and uses this information to update the model they have of the underlying MDP
(model-based) or their beliefs about the value of taking particular actions from particular states
(value-based model-free). Q-learning and temporal difference updating are key topics in model-
free learning that underlie many techniques in the area including SARSA, Q-learning and deep
Q-learning. Contemporary RL models such as the one used by AlphaGo often use a combination
of supervised and reinforcement learning methods which can be endlessly customized to meet the
needs of a problem.