CS168: The Modern Algorithmic Toolbox Lecture #5: Generalization (Or, How Much Data Is Enough?)
CS168: The Modern Algorithmic Toolbox Lecture #5: Generalization (Or, How Much Data Is Enough?)
CS168: The Modern Algorithmic Toolbox Lecture #5: Generalization (Or, How Much Data Is Enough?)
1
In this course, you’ll see plenty of lectures in each mode. We’ll also return to the dis-
tributional viewpoint in future lectures on techniques for sampling and estimation. Last
week (e.g., with k-d trees or dimensionality reduction), and the upcoming weeks on linear
algebraic techniques, are more in the first mode.
2 Binary Classification
There are many different types of learning problems. For concreteness, we’ll illustrate our
points using binary classification problems.1 Here are the ingredients of the model:
1. Data points correspond to points in d-dimensional Euclidean space Rd . (For example,
the vector of word frequencies in an email, for d different words that you’re keeping
track of.)
3. There is an unknown distribution D on Rd (the origin of our training data, see below).
For example, you can think of D as the uniform distribution over a very large finite
subset of Rd (e.g., all emails with at most 10, 000 characters that could conceivably be
received).
Here’s the problem solved by a learning algorithm:3
Input: n data points x1 , . . . , xn ∈ Rd , with each xi drawn independently and iden-
tically (“i.i.d.”) from the distribution D, and the corresponding ground truth labels
f (x1 ), . . . , f (xn ) ∈ {0, 1}. The xi ’s (and their labels) constitute the training data. For
example, the xi ’s could be emails, with each email correctly labeled by a human as
spam or not.4
Success criterion: the prediction function g is identical to the ground truth func-
tion f . (Or later in lecture, g is at least “close” to f .)
1
The lessons learned in this lecture carry over to many other learning problems, such as linear and logistic
regression.
2
Again, one can consider variations, with {0, 1} replaced by [0, 1] (allowing “soft” classifications) or R—the
main points and results of this lecture remain the same.
3
This type of learning problem is called “batch” or “offline” learning, because all of the training data is
available up front in a batch.
4
This type of problem, where the learning algorithm is given labeled examples, is called a supervised
learning problem. Also important are unsupervised learning problems, where an algorithm is given unlabeled
data and is responsible for identifying “interesting patterns.” We’ll talk more about unsupervised learning
next week, when we discuss principal components analysis (PCA).
2
Figure 1 shows an example with d = 2, with n data points labeled as “+” or “-.” The
learning algorithm needs to label the entire plane with “+”s and “-”s, ideally with 100%
accuracy (with respect to the ground truth f ). Figure 1 also shows a simple example of what
a ground truth function f might look like—a line (or hyperplane in higher dimensions), with
all “+”s on one side and all “-”s on the other. (We’ll return to such “linear classifiers” later
in the lecture, see Section 6.)
Again, we emphasize that f is initially unknown, with the learning algorithm acting as
a detective who is trying to reconstruct f . The learning algorithm does receive some clues
about f , specifically its evaluation at n different data points (the xi ’s). The prediction
function g is the algorithm’s extrapolation of f from these n data points to all of the data
points (included the never-before-seen ones). “Learning” means extrapolating correctly, in
that g should coincide with f .
Are there any learning algorithms that actually achieve this goal? It depends, with two
parameters of the application being crucial.
The amount of data (i.e., n). The more data you have, the more you know about the
ground truth function f , and the better your prospects for figuring it out.
The number of possible functions that might be the ground truth. The fewer functions
that you’re trying to distinguish between, the easier the learning problem.
3
3 Training Error and Generalization
How do we assess the “goodness” of a prediction function g? What we really care about is
the generalization error, defined as the probability that g disagrees with the ground truth f
on the label of a random sample (from the underlying distribution D):
Our learning goal can be rephrased as identifying a function g with 0% generalization error
(or later in lecture, with very small generalization error).
We’re not in a position to evaluate the generalization error of a prediction function, as
we know neither the ground truth f nor the distribution D. What we do know is n sample
points and their ground truth labels, and it’s natural to use g’s performance on these as a
proxy for g’s generalization error. This proxy is called the training error of g (with respect
to a sample x1 , . . . , xn with ground truth labels f (x1 ), . . . , f (xn )):
1
training error(g) = · [number of xi ’s with g(xi ) 6= f (xi )].
n
For any given prediction function g, its expected training error (over the random sample
x1 , . . . , xn ) is exactly its generalization error.5 But is the training error of g on a random
sample very likely to be close to its generalization error, or might there be a big discrepancy?
As a special case of this, is a function g with 0% training error likely to also have 0%
generalization error?
Such questions are usually phrased as: does g generalize? The answer depends on the
amount of training data n (with more being better), and in some cases also on the learning
algorithm used to come up with the prediction function g (see Mini-Project #3). When
a learning algorithm outputs a prediction function g with generalization error much higher
than its training error, then it is said to have overfit the training data. In this case, the
prediction function learned is largely an artifact of the particular sample, and does not
capture the more fundamental patterns in the data distribution.6
5 1
Pn
We have Ex1 ,...,xn ∼D [training error(g) on x1 , . . . , xn ] = n i=1 Prxi ∼D [g(xi ) 6= f (xi )] =
generalization error(g), where in the first equation we’re using the linearity of expectation.
6
In practice, one checks for generalization/overfitting by dividing the available data set into two parts,
the “training set” and the “test set.” The learning algorithm is given the training set only, and the test
error of the computed prediction function fˆ is then defined as the fraction of test set data points that fˆ
labels incorrectly. If there is a significant difference between the training and test error of the output of the
learning algorithm (generally with the latter much bigger than the former), then you’ve got a problem, and
should consider gathering more data and/or using some of the techniques discussed in the next lecture to
learn a better prediction function.
4
4 Analysis: The Well-Separated Finite Case
4.1 Assumptions
To develop our intuition, we first consider the learning problem with two additional assump-
tions:
(A1) (Finite) The ground truth function f belongs to a known set {f1 , . . . , fh } of h different
functions. That is, there are only h different possibilities for what f might be, and we
know the options up front.
How well does this algorithm work? Does its output g generalize, in the sense of also having
0% generalization error? The answer depends on n. For example, if n = 1, there are
presumably lot of choices of fj that are correct on this one data point, and all but one of
them have non-zero generalization error. So our goal is to identify a sufficient condition
on the data set size n such that the prediction function output by the learning algorithm
generalizes (i.e., is always correct even on not-yet-seen data points).
5
having generalization error at least , by (A2)). The probability of this bad event is
n
Y
Prx1 ,...,xn ∼D [fj (xi ) = f (xi ) for all i = 1, 2, . . . , n] = Prxi ∼D [fj (xi ) = f (xi )]
i=1
≤ (1 − )n
≤ e−n ,
where the equation follows from our assumption that the xi ’s are independent samples from
D, the first inequality follows from assumption (A2), and the last inequality follows because
1 + x ≤ ex for all x ∈ R (see Figure 2, we’re taking x = −). As we hoped, this probability
is decreasing (exponentially, in fact) with the number of samples n.
Importantly, the events are completely arbitrary, and do not need to be independent. The
proof is a one-liner. In terms of Figure 3, the union bound just says that the area (i.e.,
probability mass) in the union is bounded above by the sum of the areas of the circles.
6
The bound is tight if the events are disjoint; otherwise the right-hand side is larger, due to
double-counting the overlaps. In most applications, including the present one, the events
A1 , . . . , Ah represent “bad events” that we’re hoping don’t happen; the union bound says
that as long as each event occurs with low probability and there aren’t too many events,
then with high probability none of them occur.
Figure 3: Union bound: area of the union is bounded by the sum of areas of the circles.
Also, because our learning algorithm can only fail to output the ground truth f when there
is a function fj 6= f with 0% training error (otherwise the only remaining candidate is the
correct answer f ), we have
That is, he−n is an upper bound on the failure probability of our learning algorithm. This
upper bound increases linearly with the number of possible functions (remember the learning
problem is harder as you’re trying to differentiate between more functions) but decreases
exponentially with the size of training data set.
7
So suppose you want a failure probability of at most δ (say, δ = 1%).7 How much data
do you need? Setting he−n = δ and solving for n, we get the following sufficient condition
for generalization.
Theorem 4.1 Suppose assumptions (A1) and (A2) hold, and assume that
1 1
n≥ ln h + ln . (1)
δ
Then with probability at least 1 − δ over the samples x1 , . . . , xn ∼ D, the output of the
learning algorithm is the ground truth function f .
The sample complexity of a learning task is the minimum number of i.i.d. samples neces-
sary to accomplish it. Thus Theorem 4.1 states that the right-hand side of (1) is an upper
bound on the sample complexity of learning an unknown ground truth function with prob-
ability at least 1 − δ (under assumptions (A1) and (A2)). Let’s inspect this upper bound
more closely. The not-so-good news is that the dependence on 1 is linear. So to reduce
from 10% to 1%, according to this bound you need 10 times as much data. The much better
news is that the dependence on h and 1δ is only logarithmic, so there’s generally no problem
with taking δ to be quite small (1% or even less), and large (even exponential-size) values of
h can be accommodated. For example, if we take = 5%, δ = 1%, and h = 1000, then our
bound on the sample complexity is something like 230. If we reduce to 1% then the bound
shoots up past 1000, but if we reduce δ to 0.1% then the bound remains under 300.
8
Several comments. First, note the differences between the statements in Theorems 4.1
and 5.1: the latter has one less assumption, but compromises by allowing the learning
algorithm to output a function with small (but non-zero) generalization error. Note that the
semantics of the parameter have changed: in Theorem 4.1, depended on the fj ’s (and D)
and was outside anyone’s control, while in Theorem 5.1, is a user-specified parameter (in
the same sense as δ), controlling the trade-off between the sample complexity and the error
of the learned prediction function. The type of guarantee in Theorem 5.1 is often called a
PAC guarantee, which stands for “probably approximately correct” [2]. (The “probably”
refers to the failure probability δ, and the “approximately” to the allowed generalization
error .) Finally, note that the proof of Theorem 4.1 also immediately implies Theorem 5.1:
the analysis in Section 4 continues to apply to the functions fj with generalization error at
least , so with probability at least 1 − δ, the only hypotheses eligible for having 0% training
error are those with generalization error less than . Our learning algorithm then outputs
an arbitrary such hypothesis.
1 if di=1 ai xi ≥ 0
P
fa ((x1 , . . . , xd )) = (2)
0 if di=1 ai xi < 0.
P
Note that a linear classifier in d dimensions has d degrees of freedom (the ai ’s). Geometrically,
a linear classifier corresponds to a hyperplane through the origin (with normal vector a),
with all points on the same side of the hyperplane as the normal vector labeled “1” and
9
all points on the other side labeled “0”. Linear classifiers may feel like a toy example, but
they’re not—for example, “support vector machines (SVMs)” are basically the same as linear
classifiers.
Theorem 6.2 Suppose f is a linear classifier of the form in (2), and assume that
c 1
n≥ d + ln ,
δ
where c is a sufficiently large constant. Then with probability at least 1 − δ over the samples
x1 , . . . , xn ∼ D, the output of the learning algorithm is a linear classifier with generalization
error less than .
The algorithm in Theorem 6.2 is the same as before—just output an arbitrary linear classifier
that has 0% training error. Current proofs of Theorem 6.2 require the constant c to be larger
than one would like. In practice, a reasonable guideline is to think of c as 1.
10
Upshot: to guarantee generalization, make sure that your training data set size n is
at least linear in the number d of free parameters in the function that you’re trying to
learn.
Mini-Project #3 will drive home this guideline—you’ll observe empirically that learned pre-
diction functions tend to generalize when n d but not when n d.
Theorem 6.2 makes the rule of thumb above precise for the problem of learning a linear
classifier. For most other types of prediction functions that you might want to learn (e.g., in
linear or logistic regression, or even with neural nets) there is a sensible notion of “number
of parameters,” and the guideline above is usually pretty accurate.
6.5 FAQ
Three questions you might have at this point are:
1. How do you actually implement the basic learning algorithm?
2. What if, contrary to our standing assumption, no function under consideration has 0%
training error?
3. What should you do if n d?
The next two sections answer the first two questions; next lecture is devoted to answering
the third.
7 Computational Considerations
How do we actually implement the basic learning algorithm? That is, given a bunch of
correctly labeled data points, how do we compute a candidate function fˆ with 0% training
error? If there are only a finite number of candidates (as in Sections 4–5), then if nothing
else we can resort to exhaustive search. (For some types of functions, there will also be faster
algorithms.) But we can’t very well try each of the infinitely many linear classifiers, can we?
More formally, the relevant algorithmic problem for learning a good linear classifier is:
given a number of “+”s and “-”s in Rd , compute a hyperplane such that all of the “+”s are
on one side and all of the “-”s are on the other (see Figure 4).10
One way to solve the problem is via linear programming, which we’ll discuss in Lec-
ture #18. With this approach you can even maximize the “margin,” meaning the smallest
Euclidean distance between a training point and the hyperplane. (This requires convex pro-
gramming, which is still tractable and again will be discussed in Lecture #18.) This is one of
the key ideas behind support vector machines (SVMs). Alternatively, one can apply iterative
methods (like stochastic gradient descent, the perceptron algorithm, etc.).
10
We’re glossing over the distinction of whether or not the hyperplane has to go through the origin. It
really doesn’t matter—hyperplanes in d dimensions can be simulated by hyperplanes through the origin in
d + 1 dimensions (by adding a dummy coordinate with value “1” for every data point). More on this next
lecture.
11
Figure 4: We want to find a hyperplane that separates the positive points (plus signs) from
the negative points (minus signs).
Output the function fˆ (from the set of functions under consideration) that has
the smallest training error, breaking ties arbitrarily.
This algorithm is often called the ERM algorithm, for “empirical risk minimization.” (Sounds
fancy, but it’s really just the one-line algorithm above.)
11
A second solution (compatible also with the first) is to enrich your class of functions so that there
is a function with 0% training error (or at least smaller than before). One systematic way of doing this,
discussed in detail next lecture, is to use linear classifiers in a higher-dimensional space (which can correspond
to non-linear classifiers in the original space). There are various ways of adding new dimensions, such
as appending to each data point (x1 , . . . , xd ) new coordinates containing quadratic terms in the original
coordinates (x21 , . . . , x2d , x1 x2 , x1 x3 , . . .).
12
The ERM algorithm has a PAC guarantee analogous to that of the basic learning algo-
rithm, although with somewhat larger sample complexity. We state the guarantee for linear
classifiers; the same result holds for arbitrary finite sets of functions (as before with the d
replaced by ln h, where h is the number of candidate functions). The key behind the ERM’s
generalization guarantee is the following theorem, which states that, with high probability
after sufficiently many samples, the training error of every linear classifier is very close to its
actual generalization error.
where c is a sufficiently large constant. Then with probability at least 1 − δ over the samples
x1 , . . . , xn ∼ D, for every linear classifier fˆ,
Note that Theorem 8.1 does not assume that the ground truth f is a linear function—it
could be anything. Of course, if f looks nothing like any linear classifier, then all linear
classifiers will have large generalization error and the guarantee in Theorem 8.1 is not very
relevant.12
A special case of Theorem 8.1 is: if fˆ has 0% training error, then it has generalization
error at most . So Theorem 8.1 generalizes Theorem 6.2, with the caveat that its sample
complexity is larger by a 1 factor.13
Theorem 8.1 is a sufficient condition for generalization, in the following sense.
Corollary 8.2 (Guarantee for the ERM Algorithm) In the setting of Theorem 8.1, let
τ ∈ [0, 1] denote the minimum generalization error of any linear classifier. Assume that
c 1
n ≥ 2 d + ln ,
δ
12
This highlights two different types of errors in learning. The first type of error is approximation error,
where discrepancies between the ground truth and the set of allowed prediction functions cause all available
functions to have large generalization error. No amount of algorithmic cleverness or data can reduce approx-
imation error; the only solution is to enrich the set of allowable prediction functions (see also the previous
footnote and next lecture). The second type of error, which we’re more focused on here, is estimation error.
The concern is that we choose a prediction function with training error much less than its generalization
error, due to an overly small or unlucky sample.
These two types of errors connect to the “bias-variance” trade-off that you might hear about in a statistics
course; enlarging the set of possible prediction functions tends to decrease the approximation error (reducing
“bias”) but increase estimation error (enlarging the “variance”).
13
The guarantee in (4) is sometimes called a “uniform convergence” result. “Convergence” refers to the
fact that the training error of every classifier fˆ approaches the generalization error (with high probability),
while “uniform” refers to the fact that a single finite set of samples (with size as in (3)) suffices to give the
guarantee in (4) simultaneously for all infinitely many linear classifiers.
13
where c is a sufficiently large constant. Then with probability at least 1 − δ over the samples
x1 , . . . , xn ∼ D, the output fˆ of the ERM algorithm satisfies
generalization error of fˆ ≤ τ + 2.
Thus with high probability, the ERM algorithm learns a linear classifier that is almost
as accurate as the best linear classifier.
Proof of Corollary 8.2: Let f ∗ denote the most accurate linear classifier, with generalization
error τ . Let fˆ denote the linear classifier returned by the ERM algorithm. Theorem 8.1
implies that, with high probability,
training error of f ∗ ≤ τ +
and also
training error of fˆ ≥ generalization error of fˆ − .
Since the ERM algorithm picked fˆ over f ∗ , we must have
training error of fˆ ≤ training error of f ∗ .
Chaining together the three inequalities proves the corollary.
14
Another way to think about the difference is the following. In the first experiment, there
is “one-sided error;” you might mistake a 1%-bias coin for a 0%-bias coin (from a long run
of tails), but never vice versa (once you see a heads you know with certainty that you’re
correct). The second problem, of estimating the bias, has “two-sided error”—no matter
what bias you guess, there’s a chance that your guess is too high, and also that it’s too low.
Needing to hedge between two different types of error causes the sample complexity to jump
by 1 .
Despite all this, the dependence on d in (3) remains linear, and our key take-away from
before remains valid:
• Upshot: to guarantee generalization, make sure your training data set size n is at
least linear in the number d of free parameters in the function that you’re trying to
learn.
Empirically, one usually sees pretty good generalization when n is equal to or a small multiple
of d; see Mini-Project #3.
15
4. A sufficient condition for generalization is for the data set size to exceed the number
of free parameters in the prediction function being learned. This rule of thumb is
surprisingly robust across learning problems (e.g., linear classifiers, linear and logistic
regression, and even in many cases for neural nets).
5. In practice, you might be stuck with n much smaller than d. Next lecture discusses
additional techniques that are useful for this case, including regularization.
References
[1] Martin Anthony and Peter L. Bartlett. Neural Network Learning: Theoretical Founda-
tions. Cambridge University Press, NY, NY, USA, 1999.
[2] Leslie G Valiant. A theory of the learnable. Communications of the ACM, 27(11):1134–
1142, 1984.
[3] W. Zachary. An information flow model for conflict and fission in small groups. Journal
of Anthropological Research, 33(4):452–473, 1977.
16