R For Statistical Learning
R For Statistical Learning
David Dalpiaz
2020-10-28
2
Contents
0.0.1 Mathematics . . . . . . . . . . . . . . . . . . . . . . . . . 11
0.0.2 Code . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 11
I Prerequisites 15
1 Overview 17
2 Probability Review 19
2.1 Probability Models . . . . . . . . . . . . . . . . . . . . . . . . . . 19
2.2 Probability Axioms . . . . . . . . . . . . . . . . . . . . . . . . . . 19
2.3 Probability Rules . . . . . . . . . . . . . . . . . . . . . . . . . . . 20
2.4 Random Variables . . . . . . . . . . . . . . . . . . . . . . . . . . 22
2.4.1 Distributions . . . . . . . . . . . . . . . . . . . . . . . . . 22
2.4.2 Discrete Random Variables . . . . . . . . . . . . . . . . . 22
2.4.3 Continuous Random Variables . . . . . . . . . . . . . . . 23
2.4.4 Several Random Variables . . . . . . . . . . . . . . . . . . 23
2.5 Expectations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 24
2.6 Likelihood . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 25
2.7 Videos . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 25
2.8 References . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 25
3 R, RStudio, RMarkdown 27
3.1 Videos . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 27
3.2 Template . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 27
4 Modeling Basics in R 29
4.1 Visualization for Regression . . . . . . . . . . . . . . . . . . . . . 30
4.2 The lm() Function . . . . . . . . . . . . . . . . . . . . . . . . . . 32
4.3 Hypothesis Testing . . . . . . . . . . . . . . . . . . . . . . . . . . 32
4.4 Prediction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 33
4.5 Unusual Observations . . . . . . . . . . . . . . . . . . . . . . . . 34
4.6 Adding Complexity . . . . . . . . . . . . . . . . . . . . . . . . . . 35
4.6.1 Interactions . . . . . . . . . . . . . . . . . . . . . . . . . . 35
3
4 CONTENTS
4.6.2 Polynomials . . . . . . . . . . . . . . . . . . . . . . . . . . 37
4.6.3 Transformations . . . . . . . . . . . . . . . . . . . . . . . 37
4.7 rmarkdown . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 38
II Regression 39
5 Overview 41
6 Linear Models 47
6.1 Assesing Model Accuracy . . . . . . . . . . . . . . . . . . . . . . 48
6.2 Model Complexity . . . . . . . . . . . . . . . . . . . . . . . . . . 49
6.3 Test-Train Split . . . . . . . . . . . . . . . . . . . . . . . . . . . . 49
6.4 Adding Flexibility to Linear Models . . . . . . . . . . . . . . . . 51
6.5 Choosing a Model . . . . . . . . . . . . . . . . . . . . . . . . . . 53
7 𝑘-Nearest Neighbors 57
7.1 Parametric versus Non-Parametric Models . . . . . . . . . . . . . 57
7.2 Local Approaches . . . . . . . . . . . . . . . . . . . . . . . . . . . 58
7.2.1 Neighbors . . . . . . . . . . . . . . . . . . . . . . . . . . . 58
7.2.2 Neighborhoods . . . . . . . . . . . . . . . . . . . . . . . . 58
7.3 𝑘-Nearest Neighbors . . . . . . . . . . . . . . . . . . . . . . . . . 58
7.4 Tuning Parameters versus Model Parameters . . . . . . . . . . . 58
7.5 KNN in R . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 58
7.6 Choosing 𝑘 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 61
7.7 Linear versus Non-Linear . . . . . . . . . . . . . . . . . . . . . . 62
7.8 Scaling Data . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 63
7.9 Curse of Dimensionality . . . . . . . . . . . . . . . . . . . . . . . 64
7.10 Train Time versus Test Time . . . . . . . . . . . . . . . . . . . . 66
7.11 Interpretability . . . . . . . . . . . . . . . . . . . . . . . . . . . . 66
7.12 Data Example . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 66
7.13 rmarkdown . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 67
8 Bias–Variance Tradeoff 69
8.1 Reducible and Irreducible Error . . . . . . . . . . . . . . . . . . . 70
8.2 Bias-Variance Decomposition . . . . . . . . . . . . . . . . . . . . 71
8.3 Simulation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 73
8.4 Estimating Expected Prediction Error . . . . . . . . . . . . . . . 82
8.5 rmarkdown . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 84
III Classification 85
9 Overview 87
9.1 Visualization for Classification . . . . . . . . . . . . . . . . . . . 88
9.2 A Simple Classifier . . . . . . . . . . . . . . . . . . . . . . . . . . 92
9.3 Metrics for Classification . . . . . . . . . . . . . . . . . . . . . . . 93
CONTENTS 5
9.4 rmarkdown . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 97
10 Logistic Regression 99
10.1 Linear Regression . . . . . . . . . . . . . . . . . . . . . . . . . . . 100
10.2 Bayes Classifier . . . . . . . . . . . . . . . . . . . . . . . . . . . . 101
10.3 Logistic Regression with glm() . . . . . . . . . . . . . . . . . . . 102
10.4 ROC Curves . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 107
10.5 Multinomial Logistic Regression . . . . . . . . . . . . . . . . . . . 110
10.6 rmarkdown . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 111
13 Overview 137
13.1 Methods . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 137
13.1.1 Principal Component Analysis . . . . . . . . . . . . . . . 137
13.1.2 𝑘-Means Clustering . . . . . . . . . . . . . . . . . . . . . . 137
13.1.3 Hierarchical Clustering . . . . . . . . . . . . . . . . . . . . 137
13.2 Examples . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 137
13.2.1 US Arrests . . . . . . . . . . . . . . . . . . . . . . . . . . 137
13.2.2 Simulated Data . . . . . . . . . . . . . . . . . . . . . . . . 144
13.2.3 Iris Data . . . . . . . . . . . . . . . . . . . . . . . . . . . 153
13.3 External Links . . . . . . . . . . . . . . . . . . . . . . . . . . . . 159
13.4 RMarkdown . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 159
15 k-Means 163
V In Practice 169
18 Overview 171
20 Resampling 177
20.1 Validation-Set Approach . . . . . . . . . . . . . . . . . . . . . . . 179
20.2 Cross-Validation . . . . . . . . . . . . . . . . . . . . . . . . . . . 180
20.3 Test Data . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 182
20.4 Bootstrap . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 187
20.5 Which 𝐾? . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 187
20.6 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 187
20.7 External Links . . . . . . . . . . . . . . . . . . . . . . . . . . . . 187
20.8 rmarkdown . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 188
24 Regularization 219
24.1 Ridge Regression . . . . . . . . . . . . . . . . . . . . . . . . . . . 220
24.2 Lasso . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 225
24.3 broom . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 229
24.4 Simulated Data, 𝑝 > 𝑛 . . . . . . . . . . . . . . . . . . . . . . . . 229
24.5 External Links . . . . . . . . . . . . . . . . . . . . . . . . . . . . 233
24.6 rmarkdown . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 233
CONTENTS 7
26 Trees 243
26.1 Classification Trees . . . . . . . . . . . . . . . . . . . . . . . . . . 243
26.2 Regression Trees . . . . . . . . . . . . . . . . . . . . . . . . . . . 251
26.3 rpart Package . . . . . . . . . . . . . . . . . . . . . . . . . . . . 256
26.4 External Links . . . . . . . . . . . . . . . . . . . . . . . . . . . . 259
26.5 rmarkdown . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 259
9
10 CONTENTS
Organization
The text is organized into roughly seven parts.
1. Prerequisites
2. (Supervised Learning) Regression
3. (Supervised Learning) Classification
4. Unsupervised Learning
5. (Statistical Learning) in Practice
6. (Statistical Learning) in The Modern Era
7. Appendix
Part 1 details the assumed prerequisite knowledge required to read the text. It
recaps some of the more important bits of information. It is currently rather
sparse.
Parts 2, 3, and 4 discuss the theory of statistical learning. Several methods are
introduced throughout to highlight different theoretical concepts.
Parts 5 and 6 highlight the use of statistical learning in practice. Part 5 focuses
on practical usage of the techniques seen in Parts 2, 3, and 4. Part 6 introduces
techniques that are most commonly used in practice today.
Who?
This book is targeted at advanced undergraduate or first year MS students in
Statistics who have no prior statistical learning experience. While both will be
discussed in great detail, previous experience with both statistical modeling and
R are assumed.
Caveat Emptor
This “book” is under active development. Much of the text was hastily
written during the Spring 2017 run of the course. While together with ISL the
coverage is essentially complete, significant updates are occurring during Fall
2017.
When possible, it would be best to always access the text online to be sure you
are using the most up-to-date version. Also, the html version provides additional
features such as changing text size, font, and colors. If you are in need of a local
copy, a pdf version is continuously maintained. While development is taking
CONTENTS 11
place, formatting in the pdf version may not be as well planned as the html
version since the html version does not need to worry about pagination.
Since this book is under active development you may encounter errors ranging
from typos, to broken code, to poorly explained topics. If you do, please let us
know! Simply send an email and we will make the changes as soon as possible.
(dalpiaz2 AT illinois DOT edu) Or, if you know rmarkdown and are familiar
with GitHub, make a pull request and fix an issue yourself! This process is
partially automated by the edit button in the top-left corner of the html version.
If your suggestion or fix becomes part of the book, you will be added to the list
at the end of this chapter. We’ll also link to your GitHub account, or personal
website upon request.
While development is taking place, you may see “TODO” scattered throughout
the text. These are mostly notes for internal use, but give the reader some idea
of what development is still to come.
Please see the [README] file on GitHub for notes on the development process.
Conventions
0.0.1 Mathematics
This text uses MathJax to render mathematical notation for the web. Occa-
sionally, but rarely, a JavaScript error will prevent MathJax from rendering
correctly. In this case, you will see the “code” instead of the expected math-
ematical equations. From experience, this is almost always fixed by simply
refreshing the page. You’ll also notice that if you right-click any equation you
can obtain the MathML Code (for copying into Microsoft Word) or the TeX
command used to generate the equation.
𝑎2 + 𝑏 2 = 𝑐 2
0.0.2 Code
R code will be typeset using a monospace font which is syntax highlighted.
a = 3
b = 4
sqrt(a ^ 2 + b ^ 2)
R output lines, which would appear in the console will begin with ##. They will
generally not be syntax highlighted.
12 CONTENTS
## [1] 5
For the most part, we will follow the tidyverse style guide, however with one
massive and obvious exception. Instead of the usual assignment operator, <-,
we will instead use the more visually appealing and easier to type -. Not many
do this, but there are dozens of us.
Acknowledgements
The following is a (likely incomplete) list of helpful contributors.
• James Balamuta, Summer 2016 - ???
• Korawat Tanwisuth, Spring 2017
• Yiming Gao, Spring 2017
• Binxiang Ni, Summer 2017
• Ruiqi (Zoe) Li, Summer 2017
• Haitao Du, Summer 2017
• Rachel Banoff, Fall 2017
• Chenxing Wu, Fall 2017
• Wenting Xu, Fall 2017
• Yuanning Wei, Fall 2017
• Ross Drucker, Fall 2017
• Craig Bonsignore, Fall 2018
• Ashish Kumar, Fall 2018
Your name could be here! If you submit a correction and would like to be listed
below, please provide your name as you would like it to appear, as well as a link
to a GitHub, LinkedIn, or personal website. Pull requests encouraged!
Looking for ways to contribute?
• You’ll notice that a lot of the plotting code is not displayed in the text, but
is available in the source. Currently that code was written to accomplish
a task, but without much thought about the best way to accomplish the
task. Try refactoring some of this code.
• Fix typos. Since the book is actively being developed, typos are getting
added all the time.
• Suggest edits. Good feedback can be just as helpful as actually contribut-
ing code changes.
License
CONTENTS 13
Prerequisites
15
Chapter 1
Overview
17
18 CHAPTER 1. OVERVIEW
Chapter 2
Probability Review
19
20 CHAPTER 2. PROBABILITY REVIEW
Using these axioms, many additional probability rules can easily be derived.
𝑃 [𝐴𝑐 ] = 1 − 𝑃 [𝐴]
𝑃 [𝐴 ∪ 𝐵] = 𝑃 [𝐴] + 𝑃 [𝐵] − 𝑃 [𝐴 ∩ 𝐵]
𝑃 [𝐴 ∪ 𝐵] = 𝑃 [𝐴] + 𝑃 [𝐵]
𝑛
𝑛
𝑃 [⋃𝑖=1 𝐸𝑖 ] = ∑ 𝑃 [𝐸𝑖 ]
𝑖=1
Often, we would like to understand the probability of an event 𝐴, given some in-
formation about the outcome of event 𝐵. In that case, we have the conditional
probability rule provided 𝑃 [𝐵] > 0.
𝑃 [𝐴 ∩ 𝐵]
𝑃 [𝐴 ∣ 𝐵] =
𝑃 [𝐵]
𝑃 [𝐴 ∩ 𝐵] = 𝑃 [𝐵] ⋅ 𝑃 [𝐴 ∣ 𝐵]⋅
𝑛 𝑛−1
𝑃 [⋂𝑖=1 𝐸𝑖 ] = 𝑃 [𝐸1 ] ⋅ 𝑃 [𝐸2 ∣ 𝐸1 ] ⋅ 𝑃 [𝐸3 ∣ 𝐸1 ∩ 𝐸2 ] ⋯ 𝑃 [𝐸𝑛 ∣ ⋂𝑖=1 𝐸𝑖 ]
𝐴𝑖 ∩ 𝐴 𝑗 = ∅
𝑛
⋃ 𝐴𝑖 = Ω.
𝑖=1
Now, let 𝐴1 , 𝐴2 , … , 𝐴𝑛 form a partition of the sample space where 𝑃 [𝐴𝑖 ] > 0
for all 𝑖. Then for any event 𝐵 with 𝑃 [𝐵] > 0 we have Bayes’ Rule:
The denominator of the latter equality is often called the law of total proba-
bility:
𝑛
𝑃 [𝐵] = ∑ 𝑃 [𝐴𝑖 ]𝑃 [𝐵|𝐴𝑖 ]
𝑖=1
𝑃 [𝐴 ∩ 𝐵] = 𝑃 [𝐴] ⋅ 𝑃 [𝐵]
𝑃 [⋂ 𝐸𝑖 ] = ∏ 𝑃 [𝐸𝑖 ]
𝑖∈𝑆 𝑖∈𝑆
𝑛
𝑛
𝑃 [⋂𝑖=1 𝐸𝑖 ] = ∏ 𝑃 [𝐸𝑖 ]
𝑖=1
22 CHAPTER 2. PROBABILITY REVIEW
2.4.1 Distributions
We often talk about the distribution of a random variable, which can be
thought of as:
This is not a strict mathematical definition, but is useful for conveying the idea.
If the possible values of a random variables are discrete, it is called a discrete
random variable. If the possible values of a random variables are continuous, it
is called a continuous random variable.
Note we almost always drop the subscript from the more correct 𝑝𝑋 (𝑥) and
simply refer to 𝑝(𝑥). The relevant random variable is discerned from context
The most common example of a discrete random variable is a binomial random
variable. The mass function of a binomial random variable 𝑋, is given by
𝑛
𝑝(𝑥|𝑛, 𝑝) = ( )𝑝𝑥 (1 − 𝑝)𝑛−𝑥 , 𝑥 = 0, 1, … , 𝑛, 𝑛 ∈ ℕ, 0 < 𝑝 < 1.
𝑥
𝑋 ∼ bin(𝑛, 𝑝).
𝑏
𝑃 [𝑎 < 𝑋 < 𝑏] = ∫ 𝑓(𝑥)𝑑𝑥.
𝑎
1 −1 𝑥 − 𝜇 2
𝑓(𝑥|𝜇, 𝜎2 ) = √ ⋅exp [ ( ) ] , −∞ < 𝑥 < ∞, −∞ < 𝜇 < ∞, 𝜎 > 0.
𝜎 2𝜋 2 𝜎
𝑋 ∼ 𝑁 (𝜇, 𝜎2 )
for all 𝑥 and 𝑦. Here 𝑓(𝑥, 𝑦) is the joint density (mass) function of 𝑋 and 𝑌 . We
call 𝑓(𝑥) the marginal density (mass) function of 𝑋. Then 𝑓(𝑦) the marginal
24 CHAPTER 2. PROBABILITY REVIEW
density (mass) function of 𝑌 . The joint density (mass) function 𝑓(𝑥, 𝑦) together
with the possible (𝑥, 𝑦) values specify the joint distribution of 𝑋 and 𝑌 .
Similar notions exist for more than two variables.
2.5 Expectations
For discrete random variables, we define the expectation of the function of a
random variable 𝑋 as follows.
𝔼[𝑔(𝑋)] ≜ ∑ 𝑔(𝑥)𝑝(𝑥)
𝑥
𝔼[𝑔(𝑋)] ≜ ∫ 𝑔(𝑥)𝑓(𝑥)𝑑𝑥
𝜇𝑋 = mean[𝑋] ≜ 𝔼[𝑋].
mean[𝑋] = ∑ 𝑥 ⋅ 𝑝(𝑥)
𝑥
For a continuous random variable we would simply replace the sum by an inte-
gral.
The variance of a random variable 𝑋 is given by
2
𝜎𝑋 = var[𝑋] ≜ 𝔼[(𝑋 − 𝔼[𝑋])2 ] = 𝔼[𝑋 2 ] − (𝔼[𝑋])2 .
2 = √var[𝑋].
𝜎𝑋 = sd[𝑋] ≜ √𝜎𝑋
2.6 Likelihood
Consider 𝑛 iid random variables 𝑋1 , 𝑋2 , … 𝑋𝑛 . We can then write their likeli-
hood as
𝑛
ℒ(𝜃 ∣ 𝑥1 , 𝑥2 , … 𝑥𝑛 ) = ∏ 𝑓(𝑥𝑖 ; 𝜃)
𝑖=𝑖
where 𝑓(𝑥𝑖 ; 𝜃) is the density (or mass) function of random variable 𝑋𝑖 evaluated
at 𝑥𝑖 with parameter 𝜃.
Whereas a probability is a function of a possible observed value given a particular
parameter value, a likelihood is the opposite. It is a function of a possible
parameter value given observed data.
Maximumizing likelihood is a common techinque for fitting a model to data.
2.7 Videos
The YouTube channel mathematicalmonk has a great Probability Primer
playlist containing lectures on many fundamental probability concepts. Some
of the more important concepts are covered in the following videos:
• Conditional Probability
• Independence
• More Independence
• Bayes Rule
2.8 References
Any of the following are either dedicated to, or contain a good coverage of the
details of the topics above.
• Probability Texts
– Introduction to Probability by Dimitri P. Bertsekas and John N. Tsit-
siklis
– A First Course in Probability by Sheldon Ross
• Machine Learning Texts with Probability Focus
– Probability for Statistics and Machine Learning by Anirban Das-
Gupta
– Machine Learning: A Probabilistic Perspective by Kevin P. Murphy
• Statistics Texts with Introduction to Probability
– Probability and Statistical Inference by Robert V. Hogg, Elliot Tanis,
and Dale Zimmerman
– Introduction to Mathematical Statistics by Robert V. Hogg, Joseph
McKean, and Allen T. Craig
26 CHAPTER 2. PROBABILITY REVIEW
Chapter 3
R, RStudio, RMarkdown
Materials for learning R, RStudio, and RMarkdown can be found in another text
from the same author, Applied Statistics with R.
The chapters up to and including Chapter 6 - R Resources contain an introduc-
tion to using R, RStudio, and RMarkdown. This chapter in particular contains
a number of videos to get you up to speed on R, RStudio, and RMarkdown,
which are also linked below. Also linked is an RMarkdown template which is
referenced in the videos.
3.1 Videos
• R and RStudio Playlist
• Data in R Playlist
• RMarkdown Playlist
3.2 Template
• RMarkdown Template
27
28 CHAPTER 3. R, RSTUDIO, RMARKDOWN
Chapter 4
Modeling Basics in R
After loading data into R, our first step should always be to inspect the data.
We will start by simply printing some observations in order to understand the
basic structure of the data.
Advertising
## # A tibble: 200 x 4
## TV Radio Newspaper Sales
## <dbl> <dbl> <dbl> <dbl>
## 1 230. 37.8 69.2 22.1
## 2 44.5 39.3 45.1 10.4
## 3 17.2 45.9 69.3 9.3
## 4 152. 41.3 58.5 18.5
## 5 181. 10.8 58.4 12.9
## 6 8.7 48.9 75 7.2
## 7 57.5 32.8 23.5 11.8
## 8 120. 19.6 11.6 13.2
## 9 8.6 2.1 1 4.8
## 10 200. 2.6 21.2 10.6
## # ... with 190 more rows
29
30 CHAPTER 4. MODELING BASICS IN R
15
10
5
TV
pairs(Advertising)
0 10 20 30 40 50 5 10 15 20 25
300
TV
150
0
40
Radio
20
0
80
Newspaper
40
0
25
Sales
15
5
Often, we will be most interested in only the relationship between each predictor
and the response. For this, we can use the featurePlot() function from the
caret package. (We will use the caret package more and more frequently as
we introduce new topics.)
library(caret)
featurePlot(x = Advertising[ , c("TV", "Radio", "Newspaper")], y = Advertising$Sales)
TV Radio Newspaper
25
20
15
10
Feature
between Radio and TV. To investigate further, we will need to model the data.
Note that the commented line is equivalent to the line that is run, but we will
often use the response ~ . syntax when possible.
##
## Call:
## lm(formula = Sales ~ ., data = Advertising)
##
## Residuals:
## Min 1Q Median 3Q Max
## -8.8277 -0.8908 0.2418 1.1893 2.8292
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 2.938889 0.311908 9.422 <2e-16 ***
## TV 0.045765 0.001395 32.809 <2e-16 ***
## Radio 0.188530 0.008611 21.893 <2e-16 ***
## Newspaper -0.001037 0.005871 -0.177 0.86
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 1.686 on 196 degrees of freedom
## Multiple R-squared: 0.8972, Adjusted R-squared: 0.8956
## F-statistic: 570.3 on 3 and 196 DF, p-value: < 2.2e-16
4.4. PREDICTION 33
The anova() function is useful for comparing two models. Here we compare
the full additive model, mod_1, to a reduced model mod_0. Essentially we are
testing for the significance of the Newspaper variable in the additive model.
anova(mod_0, mod_1)
Note that hypothesis testing is not our focus, so we omit many details.
4.4 Prediction
The predict() function is an extremely versatile function, for, prediction.
When used on the result of a model fit using lm() it will, by default, return
predictions for each of the data points used to fit the model. (Here, we limit
the printed result to the first 10.)
head(predict(mod_1), n = 10)
## 1 2 3 4 5 6 7 8
## 20.523974 12.337855 12.307671 17.597830 13.188672 12.478348 11.729760 12.122953
## 9 10
## 3.727341 12.550849
Note that the effect of the predict() function is dependent on the input to the
function. Here, we are supplying as the first argument a model object of class
lm. Because of this, predict() then runs the predict.lm() function. Thus,
we should use ?predict.lm() for details.
We could also specify new data, which should be a data frame or tibble with
the same column names as the predictors.
new_obs = data.frame(TV = 150, Radio = 40, Newspaper = 1)
We can then use the predict() function for point estimates, confidence intervals,
and prediction intervals.
Using only the first two arguments, R will simply return a point estimate, that
is, the “predicted value,” 𝑦.̂
34 CHAPTER 4. MODELING BASICS IN R
## 1
## 17.34375
If we specify an additional argument interval with a value of "confidence",
R will return a 95% confidence interval for the mean response at the specified
point. Note that here R also gives the point estimate as fit.
predict(mod_1, newdata = new_obs, interval = "confidence")
## 1 2 3 4 5 6
## 1.57602559 -1.93785482 -3.00767078 0.90217049 -0.28867186 -5.27834763
## 7 8 9 10
## 0.07024005 1.07704683 1.07265914 -1.95084872
head(hatvalues(mod_1), n = 10)
## 1 2 3 4 5 6
## 0.025202848 0.019418228 0.039226158 0.016609666 0.023508833 0.047481074
## 7 8 9 10
## 0.014435091 0.009184456 0.030714427 0.017147645
head(rstudent(mod_1), n = 10)
## 1 2 3 4 5 6
## 0.94680369 -1.16207937 -1.83138947 0.53877383 -0.17288663 -3.28803309
4.6. ADDING COMPLEXITY 35
## 7 8 9 10
## 0.04186991 0.64099269 0.64544184 -1.16856434
head(cooks.distance(mod_1), n = 10)
## 1 2 3 4 5 6
## 5.797287e-03 6.673622e-03 3.382760e-02 1.230165e-03 1.807925e-04 1.283058e-01
## 7 8 9 10
## 6.452021e-06 9.550237e-04 3.310088e-03 5.945006e-03
4.6.1 Interactions
Interactions can be introduced to the lm() procedure in a number of ways.
We can use the : operator to introduce a single interaction of interest.
mod_2 = lm(Sales ~ . + TV:Newspaper, data = Advertising)
coef(mod_2)
Note that, we have only been dealing with numeric predictors. Categorical
predictors are often recorded as factor variables in R.
library(tibble)
cat_pred = tibble(
x1 = factor(c(rep("A", 10), rep("B", 10), rep("C", 10))),
x2 = runif(n = 30),
y = rnorm(n = 30)
)
cat_pred
## # A tibble: 30 x 3
## x1 x2 y
## <fct> <dbl> <dbl>
## 1 A 0.898 0.569
## 2 A 0.590 -0.819
## 3 A 0.748 1.43
## 4 A 0.364 -1.98
## 5 A 0.274 -1.65
## 6 A 0.197 0.686
## 7 A 0.384 -0.300
## 8 A 0.335 0.334
## 9 A 0.920 -1.23
## 10 A 0.0780 -0.224
## # ... with 20 more rows
The following two models illustrate the effect of factor variables on linear models.
cat_pred_mod_add = lm(y ~ x1 + x2, data = cat_pred)
coef(cat_pred_mod_add)
4.6.2 Polynomials
Polynomial terms can be specified using the inhibit function I() or through the
poly() function. Note that these two methods produce different coefficients,
but the same residuals! This is due to the poly() function using orthogonal
polynomials by default.
mod_5 = lm(Sales ~ TV + I(TV ^ 2), data = Advertising)
coef(mod_5)
## (Intercept) TV I(TV^2)
## 6.114120e+00 6.726593e-02 -6.846934e-05
mod_6 = lm(Sales ~ poly(TV, degree = 2), data = Advertising)
coef(mod_6)
## [1] TRUE
Polynomials and interactions can be mixed to create even more complex models.
mod_7 = lm(Sales ~ . ^ 2 + poly(TV, degree = 3), data = Advertising)
# mod_7 = lm(Sales ~ . ^ 2 + I(TV ^ 2) + I(TV ^ 3), data = Advertising)
coef(mod_7)
## (Intercept) TV Radio
## 6.206394e+00 2.092726e-02 3.766579e-02
## Newspaper poly(TV, degree = 3)1 poly(TV, degree = 3)2
## 1.405289e-02 NA -9.925605e+00
## poly(TV, degree = 3)3 TV:Radio TV:Newspaper
## 5.309590e+00 1.082074e-03 -5.690107e-05
## Radio:Newspaper
## -9.924992e-05
Notice here that R ignores the first order term from poly(TV, degree = 3) as
it is already in the model. We could consider using the commented line instead.
4.6.3 Transformations
Note that we could also create more complex models, which allow for non-
linearity, using transformations. Be aware, when doing so to the response
variable, that this will affect the units of said variable. You may need to un-
transform to compare to non-transformed models.
mod_8 = lm(log(Sales) ~ ., data = Advertising)
sqrt(mean(resid(mod_8) ^ 2)) # incorrect RMSE for Model 8
38 CHAPTER 4. MODELING BASICS IN R
## [1] 0.1849483
sqrt(mean(resid(mod_7) ^ 2)) # RMSE for Model 7
## [1] 0.4813215
sqrt(mean(exp(resid(mod_8)) ^ 2)) # correct RMSE for Model 8
## [1] 1.023205
4.7 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2. The following packages (and their dependencies) were
loaded in this file:
## [1] "tibble" "caret" "ggplot2" "lattice" "readr"
Part II
Regression
39
Chapter 5
Overview
41
42 CHAPTER 5. OVERVIEW
problems where there are both an input and an output. Regression problems
are the subset of supervised learning problems with a numeric output.
Often one of the biggest differences between statistical learning, machine learn-
ing, artificial intelligence are the names used to describe variables and methods.
• The input can be called: input vector, feature vector, or predictors. The
elements of these would be an input, feature, or predictor. The individual
features can be either numeric or categorical.
• The output may be called: output, response, outcome, or target. The
response must be numeric.
As an aside, some textbooks and statisticians use the terms independent and
dependent variables to describe the response and the predictors. However, this
practice can be confusing as those terms have specific meanings in probability
theory.
Our goal is to find a rule, algorithm, or function which takes as input a feature
vector, and outputs a response which is as close to the true value as possible. We
often write the true, unknown relationship between the input and output 𝑓(x).
̂
The relationship (model) we learn (fit, train), based on data, is written 𝑓(x).
From a statistical learning point-of-view, we write,
𝑌 = 𝑓(x) + �
to indicate that the true response is a function of both the unknown relationship,
as well as some unlearnable noise.
𝑛
1 2
RMSE(𝑓,̂ Data) = √ ∑ (𝑦𝑖 − 𝑓(x
̂ ))
i
𝑛 𝑖=1
1 2
RMSETrain = RMSE(𝑓,̂ Train Data) = √ ̂ ))
∑ (𝑦𝑖 − 𝑓(x i
𝑛Tr 𝑖∈Train
1 2
RMSETest = RMSE(𝑓,̂ Test Data) = √ ̂ ))
∑ (𝑦𝑖 − 𝑓(x i
𝑛Te 𝑖∈Test
# simulate data
## signal
f = function(x) {
x ^ 3
}
# fit models
## tree models
tree_fit_l = rpart(y ~ x, data = sim_trn_data,
control = rpart.control(cp = 0.500, minsplit = 2))
tree_fit_m = rpart(y ~ x, data = sim_trn_data,
control = rpart.control(cp = 0.015, minsplit = 2))
tree_fit_h = rpart(y ~ x, data = sim_trn_data,
control = rpart.control(cp = 0.000, minsplit = 2))
## knn models
knn_fit_l = knn.reg(train = sim_trn_data["x"], y = sim_trn_data$y,
test = x_grid, k = 40)
knn_fit_m = knn.reg(train = sim_trn_data["x"], y = sim_trn_data$y,
test = x_grid, k = 5)
knn_fit_h = knn.reg(train = sim_trn_data["x"], y = sim_trn_data$y,
test = x_grid, k = 1)
## polynomial models
poly_fit_l = lm(y ~ poly(x, 1), data = sim_trn_data)
44 CHAPTER 5. OVERVIEW
# get predictions
## tree models
tree_fit_l_pred = predict(tree_fit_l, newdata = x_grid)
tree_fit_m_pred = predict(tree_fit_m, newdata = x_grid)
tree_fit_h_pred = predict(tree_fit_h, newdata = x_grid)
## knn models
knn_fit_l_pred = knn_fit_l$pred
knn_fit_m_pred = knn_fit_m$pred
knn_fit_h_pred = knn_fit_h$pred
## polynomial models
poly_fit_l_pred = predict(poly_fit_l, newdata = x_grid)
poly_fit_m_pred = predict(poly_fit_m, newdata = x_grid)
poly_fit_h_pred = predict(poly_fit_h, newdata = x_grid)
1.0
1.0
0.5
0.5
0.5
y
y
0.0
0.0
0.0
-0.5
-0.5
-0.5
-1.0
-1.0
-1.0
-1.0 -0.5 0.0 0.5 1.0 -1.0 -0.5 0.0 0.5 1.0 -1.0 -0.5 0.0 0.5 1.0
x x x
1.0
1.0
0.5
0.5
0.5
y
y
0.0
0.0
0.0
-0.5
-0.5
-0.5
-1.0
-1.0
-1.0
-1.0 -0.5 0.0 0.5 1.0 -1.0 -0.5 0.0 0.5 1.0 -1.0 -0.5 0.0 0.5 1.0
x x x
grid()
lines(x_grid$x, poly_fit_l_pred, col = "darkgrey", lwd = 2)
1.0
1.0
0.5
0.5
0.5
y
y
0.0
0.0
0.0
-0.5
-0.5
-0.5
-1.0
-1.0
-1.0 -0.5 0.0 0.5 1.0 -1.0 -0.5 0.0 0.5 1.0 -1.0 -1.0 -0.5 0.0 0.5 1.0
x x x
Chapter 6
Linear Models
First, note that a linear model is one of many methods used in regression.
To discuss linear models in the context of prediction, we return to the
Advertising data from the previous chapter.
Advertising
## # A tibble: 200 x 4
## TV Radio Newspaper Sales
## <dbl> <dbl> <dbl> <dbl>
## 1 230. 37.8 69.2 22.1
47
48 CHAPTER 6. LINEAR MODELS
TV Radio Newspaper
25
20
15
10
Feature
1 𝑛 2
RMSE(𝑓,̂ Data) = √ ̂ ))
∑ (𝑦𝑖 − 𝑓(xi
𝑛 𝑖=1
While for the sake of comparing models, the choice between RMSE and MSE
is arbitrary, we have a preference for RMSE, as it has the same units as the
response variable. Also, notice that in the prediction context MSE refers to an
average, whereas in an ANOVA context, the denominator for MSE may not be
𝑛.
For a linear model , the estimate of 𝑓, 𝑓,̂ is given by the fitted regression line.
𝑦(x ̂ )
̂ i ) = 𝑓(xi
6.2. MODEL COMPLEXITY 49
We can write an R function that will be useful for performing this calculation.
rmse = function(actual, predicted) {
sqrt(mean((actual - predicted) ^ 2))
}
We will look at two measures that assess how well a model is predicting, the
train RMSE and the test RMSE.
1 2
RMSETrain = RMSE(𝑓,̂ Train Data) = √ ̂ ))
∑ (𝑦 − 𝑓(x
𝑛Tr 𝑖∈Train 𝑖 i
Here 𝑛𝑇 𝑟 is the number of observations in the train set. Train RMSE will still
always go down (or stay the same) as the complexity of a linear model increases.
That means train RMSE will not be useful for comparing models, but checking
that it decreases is a useful sanity check.
1 2
RMSETest = RMSE(𝑓,̂ Test Data) = √ ̂ ))
∑ (𝑦𝑖 − 𝑓(x i
𝑛Te 𝑖∈Test
Here 𝑛𝑇 𝑒 is the number of observations in the test set. Test RMSE uses the
model fit to the training data, but evaluated on the unused test data. This is a
measure of how well the fitted model will predict in general, not simply how
well it fits data used to train the model, as is the case with train RMSE. What
happens to test RMSE as the size of the model increases? That is what we will
investigate.
We will start with the simplest possible linear model, that is, a model with no
predictors.
fit_0 = lm(Sales ~ 1, data = train_data)
get_complexity(fit_0)
## [1] 0
# train RMSE
sqrt(mean((train_data$Sales - predict(fit_0, train_data)) ^ 2))
## [1] 5.529258
# test RMSE
sqrt(mean((test_data$Sales - predict(fit_0, test_data)) ^ 2))
## [1] 4.914163
The previous two operations obtain the train and test RMSE. Since these are
operations we are about to use repeatedly, we should use the function that we
happen to have already written.
6.4. ADDING FLEXIBILITY TO LINEAR MODELS 51
# train RMSE
rmse(actual = train_data$Sales, predicted = predict(fit_0, train_data))
## [1] 5.529258
# test RMSE
rmse(actual = test_data$Sales, predicted = predict(fit_0, test_data))
## [1] 4.914163
This function can actually be improved for the inputs that we are using. We
would like to obtain train and test RMSE for a fitted model, given a train or
test dataset, and the appropriate response variable.
get_rmse = function(model, data, response) {
rmse(actual = subset(data, select = response, drop = TRUE),
predicted = predict(model, data))
}
By using this function, our code becomes easier to read, and it is more obvious
what task we are accomplishing.
get_rmse(model = fit_0, data = train_data, response = "Sales") # train RMSE
## [1] 5.529258
get_rmse(model = fit_0, data = test_data, response = "Sales") # test RMSE
## [1] 4.914163
## [1] 3
get_rmse(model = fit_1, data = train_data, response = "Sales") # train RMSE
## [1] 1.888488
get_rmse(model = fit_1, data = test_data, response = "Sales") # test RMSE
## [1] 1.461661
52 CHAPTER 6. LINEAR MODELS
## [1] 7
get_rmse(model = fit_2, data = train_data, response = "Sales") # train RMSE
## [1] 1.016822
get_rmse(model = fit_2, data = test_data, response = "Sales") # test RMSE
## [1] 0.9117228
fit_3 = lm(Sales ~ Radio * Newspaper * TV + I(TV ^ 2), data = train_data)
get_complexity(fit_3)
## [1] 8
get_rmse(model = fit_3, data = train_data, response = "Sales") # train RMSE
## [1] 0.6553091
get_rmse(model = fit_3, data = test_data, response = "Sales") # test RMSE
## [1] 0.6633375
fit_4 = lm(Sales ~ Radio * Newspaper * TV +
I(TV ^ 2) + I(Radio ^ 2) + I(Newspaper ^ 2), data = train_data)
get_complexity(fit_4)
## [1] 10
get_rmse(model = fit_4, data = train_data, response = "Sales") # train RMSE
## [1] 0.6421909
get_rmse(model = fit_4, data = test_data, response = "Sales") # test RMSE
## [1] 0.7465957
fit_5 = lm(Sales ~ Radio * Newspaper * TV +
I(TV ^ 2) * I(Radio ^ 2) * I(Newspaper ^ 2), data = train_data)
get_complexity(fit_5)
## [1] 14
get_rmse(model = fit_5, data = train_data, response = "Sales") # train RMSE
## [1] 0.6120887
get_rmse(model = fit_5, data = test_data, response = "Sales") # test RMSE
## [1] 0.7864181
6.5. CHOOSING A MODEL 53
To better understand the relationship between train RMSE, test RMSE, and
model complexity, we summarize our results, as the above is somewhat cluttered.
We then obtain train RMSE, test RMSE, and model complexity for each.
train_rmse = sapply(model_list, get_rmse, data = train_data, response = "Sales")
test_rmse = sapply(model_list, get_rmse, data = test_data, response = "Sales")
model_complexity = sapply(model_list, get_complexity)
We then plot the results. The train RMSE can be seen in blue, while the test
RMSE is given in orange.
plot(model_complexity, train_rmse, type = "b",
ylim = c(min(c(train_rmse, test_rmse)) - 0.02,
max(c(train_rmse, test_rmse)) + 0.02),
col = "dodgerblue",
xlab = "Model Size",
ylab = "RMSE")
lines(model_complexity, test_rmse, type = "b", col = "darkorange")
54 CHAPTER 6. LINEAR MODELS
1.8
1.4
RMSE
1.0
0.6
4 6 8 10 12 14
Model Size
We also summarize the results as a table. fit_1 is the least flexible, and fit_5
is the most flexible. We see the Train RMSE decrease as flexibility increases.
We see that the Test RMSE is smallest for fit_3, thus is the model we believe
will perform the best on future data not used to train the model. Note this may
not be the best model, but it is the best model of the models we have seen in
this example.
To summarize:
𝑘-Nearest Neighbors
Chapter Status: Under Constructions. Main ideas in place but lack narrative.
Functional version of much of the code exist but will be cleaned up. Some code
and simulation examples need to be expanded.
• TODO: last chapter..
• TODO: recall goal
– frame around estimating regression function
𝑓(𝑥) = 𝔼[𝑌 ∣ 𝑋 = 𝑥]
𝑓(𝑥) = 𝛽0 + 𝛽1 𝑥1 + 𝛽2 𝑥2 + … + 𝛽𝑝 𝑥𝑝
̂ = average({𝑦 ∶ 𝑥 = 𝑥})
𝑓(𝑥) 𝑖 𝑖
57
58 CHAPTER 7. 𝐾-NEAREST NEIGHBORS
7.2.1 Neighbors
• example: knn
7.2.2 Neighborhoods
• example: trees
1
𝑓𝑘̂ (𝑥) = ∑ 𝑦𝑖
𝑘 𝑖∈𝒩𝑘 (𝑥,𝒟)
7.5 KNN in R
library(FNN)
library(MASS)
data(Boston)
set.seed(42)
boston_idx = sample(1:nrow(Boston), size = 250)
trn_boston = Boston[boston_idx, ]
tst_boston = Boston[-boston_idx, ]
7.5. KNN IN R 59
X_trn_boston = trn_boston["lstat"]
X_tst_boston = tst_boston["lstat"]
y_trn_boston = trn_boston["medv"]
y_tst_boston = tst_boston["medv"]
To perform KNN for regression, we will need knn.reg() from the FNN package.
Notice that, we do not load this package, but instead use FNN::knn.reg to
access the function. Note that, in the future, we’ll need to be careful about
loading the FNN package as it also contains a function called knn. This function
also appears in the class package which we will likely use later.
knn.reg(train = ?, test = ?, y = ?, k = ?)
INPUT
OUTPUT
We make predictions for a large number of possible values of lstat, for different
values of k. Note that 250 is the total number of observations in this training
dataset.
60 CHAPTER 7. 𝐾-NEAREST NEIGHBORS
k=1 k=5
50
50
40
40
30
30
medv
medv
20
20
10
10
10 20 30 10 20 30
lstat lstat
k = 10 k = 25
50
50
40
40
30
30
medv
medv
20
20
10
10
10 20 30 10 20 30
lstat lstat
k = 50 k = 250
50
50
40
40
30
30
medv
medv
20
20
10
10
10 20 30 10 20 30
lstat lstat
• TODO: Orange “curves” are 𝑓𝑘̂ (𝑥) where 𝑥 are the values we defined in
lstat_grid. So really a bunch of predictions with interpolated lines, but
you can’t really tell…
7.6 Choosing 𝑘
• low k = very complex model. very wiggly. specifically jagged
• high k = very inflexible model. very smooth.
• want: something in the middle which predicts well on unseen data
• that is, want 𝑓𝑘̂ to minimize
• TODO: Test MSE is an estimate of this. So finding best test RMSE will
be our strategy. (Best test RMSE is same as best MSE, but with more
understandable units.)
rmse = function(actual, predicted) {
sqrt(mean((actual - predicted) ^ 2))
}
# define helper function for getting knn.reg predictions
# note: this function is highly specific to this situation and dataset
make_knn_pred = function(k = 1, training, predicting) {
pred = FNN::knn.reg(train = training["lstat"],
test = predicting["lstat"],
y = training$medv, k = k)$pred
act = predicting$medv
rmse(predicted = pred, actual = act)
}
# define values of k to evaluate
k = c(1, 5, 10, 25, 50, 250)
# get requested train RMSEs
knn_trn_rmse = sapply(k, make_knn_pred,
training = trn_boston,
predicting = trn_boston)
# get requested test RMSEs
knn_tst_rmse = sapply(k, make_knn_pred,
training = trn_boston,
predicting = tst_boston)
# determine "best" k
best_k = k[which.min(knn_tst_rmse)]
62 CHAPTER 7. 𝐾-NEAREST NEIGHBORS
# display results
knitr::kable(knn_results, escape = FALSE, booktabs = TRUE)
• TODO: What about ties? why isn’t k = 1 give 0 training error? There
are some non-unique 𝑥𝑖 values in the training data. How do we predict
when this is the case?
1.5
25
1.0
4
20
0.5
2
15
0.0
0
y
y
10
-0.5
-2
-1.0
-4
-1.5
-6
-2.0
-5
-4 -2 0 2 4 -4 -2 0 2 4 -4 -2 0 2 4
x x x
10
10
8
8
6
6
scale(x1)
x1
4
2
2
0
0
-2
-2
-2 -1 0 1 2 -2 -1 0 1 2
x2 scale(x2)
7.11 Interpretability
• TODO: lm (high) vs knn (low)
– somewhat generalizes to parametric vs non-parametric
# test rmse
rmse(predicted = scaled_pred, actual = y_tst_boston) # with scaling
7.13 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2. The following packages (and their dependencies) were
loaded when knitting this file:
## [1] "MASS" "FNN"
68 CHAPTER 7. 𝐾-NEAREST NEIGHBORS
Chapter 8
Bias–Variance Tradeoff
Consider the general regression setup where we are given a random pair (𝑋, 𝑌 ) ∈
ℝ𝑝 × ℝ. We would like to “predict” 𝑌 with some function of 𝑋, say, 𝑓(𝑋).
To clarify what we mean by “predict,” we specify that we would like 𝑓(𝑋) to be
“close” to 𝑌 . To further clarify what we mean by “close,” we define the squared
error loss of estimating 𝑌 using 𝑓(𝑋).
Now we can clarify the goal of regression, which is to minimize the above loss,
on average. We call this the risk of estimating 𝑌 using 𝑓(𝑋).
Before attempting to minimize the risk, we first re-write the risk after condi-
tioning on 𝑋.
𝑓(𝑥) = 𝔼(𝑌 ∣ 𝑋 = 𝑥)
69
70 CHAPTER 8. BIAS–VARIANCE TRADEOFF
Note that the choice of squared error loss is somewhat arbitrary. Suppose instead
we chose absolute error loss.
𝑓(𝑥) = median(𝑌 ∣ 𝑋 = 𝑥)
Despite this possibility, our preference will still be for squared error loss. The
reasons for this are numerous, including: historical, ease of optimization, and
protecting against large deviations.
Now, given data 𝒟 = (𝑥𝑖 , 𝑦𝑖 ) ∈ ℝ𝑝 × ℝ, our goal becomes finding some 𝑓 ̂ that
is a good estimate of the regression function 𝑓. We’ll see that this amounts to
minimizing what we call the reducible error.
2
̂
EPE (𝑌 , 𝑓(𝑋)) ̂
≜ 𝔼𝑋,𝑌 ,𝒟 [(𝑌 − 𝑓(𝑋)) ]
2 2
̂
EPE (𝑌 , 𝑓(𝑥)) ̂
= 𝔼𝑌 ∣𝑋,𝒟 [(𝑌 − 𝑓(𝑋)) ̂
∣ 𝑋 = 𝑥] = 𝔼𝒟 [(𝑓(𝑥) − 𝑓(𝑥)) ] + 𝕍𝑌 ∣𝑋 [𝑌 ∣ 𝑋 = 𝑥]
⏟⏟⏟⏟⏟⏟⏟⏟⏟ ⏟⏟⏟⏟⏟⏟⏟
irreducible error
reducible error
bias(𝜃)̂ ≜ 𝔼 [𝜃]̂ − 𝜃
Using this, we further decompose the reducible error (mean squared error) into
bias squared and variance.
2 2 2
̂
MSE (𝑓(𝑥), 𝑓(𝑥)) ̂
= 𝔼𝒟 [(𝑓(𝑥) − 𝑓(𝑥)) ] = ⏟⏟
(𝑓(𝑥) ̂
− 𝔼 [𝑓(𝑥)]) ̂ − 𝔼 [𝑓(𝑥)])
̂
⏟⏟⏟⏟ ⏟⏟⏟ + 𝔼 [(𝑓(𝑥)
⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟
]
2 ̂
bias (𝑓(𝑥)) ̂
var(𝑓(𝑥))
This is actually a common fact in estimation theory, but we have stated it here
specifically for estimation of some regression function 𝑓 using 𝑓 ̂ at some point
𝑥.
̂ ̂ 2 ̂
MSE (𝑓(𝑥), 𝑓(𝑥)) = bias (𝑓(𝑥)) + var (𝑓(𝑥))
72 CHAPTER 8. BIAS–VARIANCE TRADEOFF
It turns out, there is a bias-variance tradeoff. That is, often, the more
bias in our estimation, the lesser the variance. Similarly, less variance is often
accompanied by more bias. Complex models tend to be unbiased, but highly
variable. Simple models are often extremely biased, but have low variance.
• Parametric: The form of the model does not incorporate all the necessary
variables, or the form of the relationship is too simple. For example, a
parametric model assumes a linear relationship, but the true relationship
is quadratic.
• Non-parametric: The model provides too much smoothing.
So for us, to select a model that appropriately balances the tradeoff between
bias and variance, and thus minimizes the reducible error, we need to select a
model of the appropriate complexity for the data.
Recall that when fitting models, we’ve seen that train RMSE decreases as model
complexity is increasing. (Technically it is non-increasing.) For test RMSE, we
expect to see a U-shaped curve. Importantly, test RMSE decreases, until a
certain complexity, then begins to increase.
Now we can understand why this is happening. The expected test RMSE is
essentially the expected prediction error, which we now known decomposes into
(squared) bias, variance, and the irreducible Bayes error. The following plots
show three examples of this.
Squared Bias
Variance
Bayes
EPE
Error
Error
Error
The three plots show three examples of the bias-variance tradeoff. In the left
panel, the variance influences the expected prediction error more than the bias.
In the right panel, the opposite is true. The middle panel is somewhat neutral.
In all cases, the difference between the Bayes error (the horizontal dashed grey
line) and the expected prediction error (the solid black curve) is exactly the mean
squared error, which is the sum of the squared bias (blue curve) and variance
(orange curve). The vertical line indicates the complexity that minimizes the
prediction error.
To summarize, if we assume that irreducible error can be written as
𝕍[𝑌 ∣ 𝑋 = 𝑥] = 𝜎2
then we can write the full decomposition of the expected prediction error of
predicting 𝑌 using 𝑓 ̂ when 𝑋 = 𝑥 as
̂ 2 ̂ ̂
EPE (𝑌 , 𝑓(𝑥)) = bias (𝑓(𝑥))
⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟ + var (𝑓(𝑥)) +𝜎2 .
reducible error
(Expected) Test
Train
8.3 Simulation
We will illustrate these decompositions, most importantly the bias-variance
tradeoff, through simulation. Suppose we would like to train a model to learn
74 CHAPTER 8. BIAS–VARIANCE TRADEOFF
𝔼[𝑌 ∣ 𝑋 = 𝑥] = 𝑓(𝑥) = 𝑥2
and
𝕍[𝑌 ∣ 𝑋 = 𝑥] = 𝜎2 .
𝑌 = 𝑓(𝑋) + 𝜖
where 𝔼[𝜖] = 0 and 𝕍[𝜖] = 𝜎2 . In this formulation, we call 𝑓(𝑋) the signal and
𝜖 the noise.
To carry out a concrete simulation example, we need to fully specify the data
generating process. We do so with the following R code.
get_sim_data = function(f, sample_size = 100) {
x = runif(n = sample_size, min = 0, max = 1)
y = rnorm(n = sample_size, mean = f(x), sd = 0.3)
data.frame(x, y)
}
Also note that if you prefer to think of this situation using the 𝑌 = 𝑓(𝑋) + 𝜖
formulation, the following code represents the same data generating process.
get_sim_data = function(f, sample_size = 100) {
x = runif(n = sample_size, min = 0, max = 1)
eps = rnorm(n = sample_size, mean = 0, sd = 0.75)
y = f(x) + eps
data.frame(x, y)
}
To completely specify the data generating process, we have made more model
assumptions than simply 𝔼[𝑌 ∣ 𝑋 = 𝑥] = 𝑥2 and 𝕍[𝑌 ∣ 𝑋 = 𝑥] = 𝜎2 . In
particular,
• The 𝑥𝑖 in 𝒟 are sampled from a uniform distribution over [0, 1].
• The 𝑥𝑖 and 𝜖 are independent.
• The 𝑦𝑖 in 𝒟 are sampled from the conditional normal distribution.
8.3. SIMULATION 75
𝑌 ∣ 𝑋 ∼ 𝑁 (𝑓(𝑥), 𝜎2 )
Using this setup, we will generate datasets, 𝒟, with a sample size 𝑛 = 100 and
fit four models.
To get a sense of the data and these four models, we generate one simulated
dataset, and fit the four models.
set.seed(1)
sim_data = get_sim_data(f)
Note that technically we’re being lazy and using orthogonal polynomials, but
the fitted values are the same, so this makes no difference for our purposes.
Plotting these four trained models, we see that the zero predictor model does
very poorly. The first degree model is reasonable, but we can see that the second
degree model fits much better. The ninth degree model seem rather wild.
76 CHAPTER 8. BIAS–VARIANCE TRADEOFF
1.5
y~1
y ~ poly(x, 1)
y ~ poly(x, 2)
y ~ poly(x, 9)
1.0
truth
0.5
y
0.0
-0.5
The following three plots were created using three additional simulated datasets.
The zero predictor and ninth degree polynomial were fit to each.
1.5
1.0
1.0
0.5
y
y
0.5
0.5
0.0
0.0
0.0
-0.5
-0.5
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
x x x
This plot should make clear the difference between the bias and variance of
these two models. The zero predictor model is clearly wrong, that is, biased,
but nearly the same for each of the datasets, since it has very low variance.
While the ninth degree model doesn’t appear to be correct for any of these three
simulations, we’ll see that on average it is, and thus is performing unbiased
estimation. These plots do however clearly illustrate that the ninth degree
polynomial is extremely variable. Each dataset results in a very different fitted
model. Correct on average isn’t the only goal we’re after, since in practice, we’ll
only have a single dataset. This is why we’d also like our models to exhibit low
variance.
We could have also fit 𝑘-nearest neighbors models to these three datasets.
8.3. SIMULATION 77
1.5
k=5 k=5 k=5
k = 100 k = 100 k = 100
1.5
1.0
1.0
1.0
0.5
y
y
0.5
0.5
0.0
0.0
0.0
-0.5
-0.5
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
x x x
Here we see that when 𝑘 = 100 we have a biased model with very low variance.
(It’s actually the same as the 0 predictor linear model.) When 𝑘 = 5, we again
have a highly variable model.
These two sets of plots reinforce our intuition about the bias-variance tradeoff.
Complex models (ninth degree polynomial and 𝑘 = 5) are highly variable, and
often unbiased. Simple models (zero predictor linear model and 𝑘 = 100) are
very biased, but have extremely low variance.
We will now complete a simulation study to understand the relationship between
the bias, variance, and mean squared error for the estimates for 𝑓(𝑥) given by
these four models at the point 𝑥 = 0.90. We use simulation to complete this
task, as performing the analytical calculations would prove to be rather tedious
and difficult.
set.seed(1)
n_sims = 250
n_models = 4
x = data.frame(x = 0.90) # fixed point at which we make predictions
predictions = matrix(0, nrow = n_sims, ncol = n_models)
# fit models
fit_0 = lm(y ~ 1, data = sim_data)
fit_1 = lm(y ~ poly(x, degree = 1), data = sim_data)
fit_2 = lm(y ~ poly(x, degree = 2), data = sim_data)
fit_9 = lm(y ~ poly(x, degree = 9), data = sim_data)
# get predictions
predictions[sim, 1] = predict(fit_0, x)
predictions[sim, 2] = predict(fit_1, x)
78 CHAPTER 8. BIAS–VARIANCE TRADEOFF
predictions[sim, 3] = predict(fit_2, x)
predictions[sim, 4] = predict(fit_9, x)
}
Note that this is one of many ways we could have accomplished this task using R.
For example we could have used a combination of replicate() and *apply()
functions. Alternatively, we could have used a tidyverse approach, which likely
would have used some combination of dplyr, tidyr, and purrr.
Our approach, which would be considered a base R approach, was chosen to
make it as clear as possible what is being done. The tidyverse approach is
rapidly gaining popularity in the R community, but might make it more diffi-
cult to see what is happening here, unless you are already familiar with that
approach.
Also of note, while it may seem like the output stored in predictions would
meet the definition of tidy data given by Hadley Wickham since each row rep-
resents a simulation, it actually falls slightly short. For our data to be tidy, a
row should store the simulation number, the model, and the resulting predic-
tion. We’ve actually already aggregated one level above this. Our observational
unit is a simulation (with four predictions), but for tidy data, it should be a
single prediction. This may be revised by the author later when there are more
examples of how to do this from the R community.
0.6
0.4
0.2
0 1 2 9
Polynomial Degree
The above plot shows the predictions for each of the 250 simulations of each
of the four models of different polynomial degrees. The truth, 𝑓(𝑥 = 0.90) =
8.3. SIMULATION 79
2 2
MSE (𝑓(0.90), 𝑓𝑘̂ (0.90)) = ⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟
(𝔼 [𝑓𝑘̂ (0.90)] − 𝑓(0.90)) + 𝔼 [(𝑓𝑘̂ (0.90) − 𝔼 [𝑓𝑘̂ (0.90)]) ]
⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟
2
bias (𝑓𝑘̂ (0.90)) var(𝑓𝑘̂ (0.90))
We’ll use the empirical results of our simulations to estimate these quantities.
(Yes, we’re using estimation to justify facts about estimation.) Note that we’ve
actually used a rather small number of simulations. In practice we should use
more, but for the sake of computation time, we’ve performed just enough sim-
ulations to obtain the desired results. (Since we’re estimating estimation, the
bigger the sample size, the better.)
To estimate the mean squared error of our predictions, we’ll use
𝑛sims
̂ (𝑓(0.90), 𝑓 ̂ (0.90)) = 1 2
MSE 𝑘 ∑ (𝑓(0.90) − 𝑓𝑘̂ (0.90))
𝑛sims 𝑖=1
𝑛sims
̂ 1
̂ (𝑓(0.90))
bias = ∑ (𝑓𝑘̂ (0.90)) − 𝑓(0.90)
𝑛sims 𝑖=1
𝑛sims 𝑛sims 2
̂ 1 1
v
̂ar (𝑓(0.90)) = ∑ (𝑓 ̂ (0.90) −
𝑘 ∑ 𝑓 ̂ (0.90))
𝑘
𝑛sims 𝑖=1
𝑛sims 𝑖=1
While there is already R function for variance, the following is more appropriate
in this situation.
get_var = function(estimate) {
mean((estimate - mean(estimate)) ^ 2)
}
To quickly obtain these results for each of the four models, we utilize the apply()
function.
bias = apply(predictions, 2, get_bias, truth = f(x = 0.90))
variance = apply(predictions, 2, get_var)
mse = apply(predictions, 2, get_mse, truth = f(x = 0.90))
## [1] TRUE
all(diff(variance) > 0)
## [1] TRUE
8.3. SIMULATION 81
diff(mse) < 0
## 1 2 9
## TRUE TRUE FALSE
The models with polynomial degrees 2 and 9 are both essentially unbiased. We
see some bias here as a result of using simulation. If we increased the number of
simulations, we would see both biases go down. Since they are both unbiased,
the model with degree 2 outperforms the model with degree 9 due to its smaller
variance.
Models with degree 0 and 1 are biased because they assume the wrong form
of the regression function. While the degree 9 model does this as well, it does
include all the necessary polynomial degrees.
𝔼[𝛽𝑑̂ ] = 𝛽𝑑 = 0
for 𝑑 = 3, 4, … 9, we have
𝔼 [𝑓9̂ (𝑥)] = 𝛽0 + 𝛽1 𝑥 + 𝛽2 𝑥2
## 0 1 2 9
## FALSE FALSE FALSE FALSE
But wait, this says it isn’t true, except for the degree 9 model? It turns out, this
is simply a computational issue. If we allow for some very small error tolerance,
we see that the bias-variance decomposition is indeed true for predictions from
these for models.
all.equal(bias ^ 2 + variance, mse)
## [1] TRUE
See ?all.equal() for details.
So far, we’ve focused our efforts on looking at the mean squared error of esti-
̂
mating 𝑓(0.90) using 𝑓(0.90). We could also look at the expected prediction
̂
error of using 𝑓(𝑋) when 𝑋 = 0.90 to estimate 𝑌 .
2
EPE (𝑌 , 𝑓𝑘̂ (0.90)) = 𝔼𝑌 ∣𝑋,𝒟 [(𝑌 − 𝑓𝑘̂ (𝑋)) ∣ 𝑋 = 0.90]
82 CHAPTER 8. BIAS–VARIANCE TRADEOFF
We can estimate this quantity for each of the four models using the simulation
study we already performed.
get_epe = function(realized, estimate) {
mean((realized - estimate) ^ 2)
}
y = rnorm(n = nrow(predictions), mean = f(x = 0.9), sd = 0.3)
epe = apply(predictions, 2, get_epe, realized = y)
epe
## 0 1 2 9
## 0.3180470 0.1104055 0.1095955 0.1205570
What about the unconditional expected prediction error. That is, for any 𝑋,
not just 0.90. Specifically, the expected prediction error of estimating 𝑌 using
̂
𝑓(𝑋). The following (new) simulation study provides an estimate of
2
EPE (𝑌 , 𝑓𝑘̂ (𝑋)) = 𝔼𝑋,𝑌 ,𝒟 [(𝑌 − 𝑓𝑘̂ (𝑋)) ]
for (i in seq_along(X)) {
sim_data = get_sim_data(f)
fit_2 = lm(y ~ poly(x, degree = 2), data = sim_data)
f_hat_X[i] = predict(fit_2, newdata = data.frame(x = X[i]))
}
mean((Y - f_hat_X) ^ 2)
## [1] 0.09997319
Note that in practice, we should use many more simulations in this study.
𝕍[𝑌 ∣ 𝑋 = 𝑥] = 𝜎2 .
we have
̂ ̂ 2 2 ̂ ̂ 2
EPE (𝑌 , 𝑓(𝑋)) = 𝔼𝑋,𝑌 ,𝒟 [(𝑌 − 𝑓(𝑋)) ]=𝔼 𝑋 [bias (𝑓(𝑋))] + 𝔼𝑋 [var (𝑓(𝑋))] +𝜎
⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟⏟
reducible error
where
and
Then, if we use 𝒟trn to fit (train) a model, we can use the test mean squared
error
2
̂ ))
∑ (𝑦𝑖 − 𝑓(𝑥 𝑖
𝑖∈tst
as an estimate of
̂
𝔼𝑋,𝑌 ,𝒟 [(𝑌 − 𝑓(𝑋)) 2
]
the expected prediction error. (In practice we prefer RMSE to MSE for com-
paring models and reporting because of the units.)
How good is this estimate? Well, if 𝒟 is a random sample from (𝑋, 𝑌 ), and
tst are randomly sampled observations randomly sampled from 𝑖 = 1, 2, … , 𝑛,
then it is a reasonable estimate. However, it is rather variable due to the
randomness of selecting the observations for the test set. How variable? It turns
out, pretty variable. While it’s a justified estimate, eventually we’ll introduce
cross-validation as a procedure better suited to performing this estimation to
select a model.
84 CHAPTER 8. BIAS–VARIANCE TRADEOFF
8.5 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2.
Part III
Classification
85
Chapter 9
Overview
̂
That is, the classifier 𝐶(𝑥) returns the predicted category 𝑦(𝑋).
̂
𝑦(𝑥)
̂ ̂
= 𝐶(𝑥)
To build our first classifier, we will use the Default dataset from the ISLR
package.
library(ISLR)
library(tibble)
as_tibble(Default)
## # A tibble: 10,000 x 4
## default student balance income
## <fct> <fct> <dbl> <dbl>
## 1 No No 730. 44362.
## 2 No Yes 817. 12106.
## 3 No No 1074. 31767.
## 4 No No 529. 35704.
87
88 CHAPTER 9. OVERVIEW
## 5 No No 786. 38463.
## 6 No Yes 920. 7492.
## 7 No No 826. 24905.
## 8 No Yes 809. 17600.
## 9 No No 1161. 37469.
## 10 No No 0 29275.
## # ... with 9,990 more rows
Our goal is to properly classify individuals as defaulters based on student status,
credit card balance, and income. Be aware that the response default is a factor,
as is the predictor student.
is.factor(Default$default)
## [1] TRUE
is.factor(Default$student)
## [1] TRUE
As we did with regression, we test-train split our data. In this case, using 50%
for each.
set.seed(42)
default_idx = sample(nrow(Default), 5000)
default_trn = Default[default_idx, ]
default_tst = Default[-default_idx, ]
A density plot can often suggest a simple split based on a numeric predictor.
Essentially this plot graphs a density estimate
𝑓𝑋̂ 𝑖 (𝑥𝑖 ∣ 𝑌 = 𝑘)
pch = "|",
layout = c(2, 1),
auto.key = list(columns = 2))
No Yes
balance income
Feature
No Yes
balance income
8e-04
8e-05
6e-04
6e-05
4e-04
4e-05
2e-04
2e-05
0e+00
0e+00
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| |||||||||||||||||||||||||||||
||||||||||||||||||||||||||||||||||||||||||||||
|||||| ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| |||||||||||| || | || | | || ||| |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||| |||| ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| | |||||| |||||| |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| ||||||||||||||||||||||||||| ||||||||||||||||||||||||||||||||||||||||||||||||||
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|||||||||||||||||||||||||||||||||||||||||||||||||| |||||||| ||||| ||||| |||| || | | |
|||||||||||||||||||||||||| |
| ||| || ||||| |||||||| ||||||||||||||||||||||||||||||||||||| |||||||||||||||||||||||||||| |||||||||| ||| || | || |||| | |||| | | | | | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| || |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| |||||| ||| ||| |
||||||||||||||||
|||||||||
Feature
Above, we create a similar plot, except with student as the response. We see
that students often carry a slightly larger balance, and have far lower income.
This will be useful to know when making more complicated classifiers.
featurePlot(x = default_trn[, c("student", "balance", "income")],
y = default_trn$default,
plot = "pairs",
auto.key = list(columns = 2))
9.1. VISUALIZATION FOR CLASSIFICATION 91
No Yes
40000 60000
60000
40000
income
20000
0 20000
0
2500 1500 2500
2000
1500
balance
1000
500
0 500 1000
0
Yes
Yes
student
No
No
We can use plot = "pairs" to consider multiple variables at the same time.
This plot reinforces using balance to create a classifier, and again shows that
income seems not that useful.
library(ellipse)
featurePlot(x = default_trn[, c("balance", "income")],
y = default_trn$default,
plot = "ellipse",
auto.key = list(columns = 2))
92 CHAPTER 9. OVERVIEW
No Yes
40000 60000
60000
40000
income
20000
0 20000
0
2000
1500
balance
1000
500
0 500 1000 0
Similar to pairs is a plot of type ellipse, which requires the ellipse package.
Here we only use numeric predictors, as essentially we are assuming multivariate
normality. The ellipses mark points of equal density. This will be useful later
when discussing LDA and QDA.
̂ 1 𝑥>𝑏
𝐶(𝑥) ={
0 𝑥≤𝑏
Based on the first plot, we believe we can use balance to create a reasonable
classifier. In particular,
9.3. METRICS FOR CLASSIFICATION 93
## [1] "No" "No" "No" "No" "No" "No" "No" "No" "No" "No"
## actual
## predicted No Yes
## No 4354 25
## Yes 476 145
(tst_tab = table(predicted = default_tst_pred, actual = default_tst$default))
## actual
## predicted No Yes
## No 4326 27
## Yes 511 136
Often we give specific names to individual cells of these tables, and in the
predictive setting, we would call this table a confusion matrix. Be aware,
that the placement of Actual and Predicted values affects the names of the cells,
and often the matrix may be presented transposed.
94 CHAPTER 9. OVERVIEW
In statistics, we label the errors Type I and Type II, but these are hard to
remember. False Positive and False Negative are more descriptive, so we choose
to use these.
The confusionMatrix() function from the caret package can be used to obtain
a wealth of additional information, which we see output below for the test data.
Note that we specify which category is considered “positive.”
trn_con_mat = confusionMatrix(trn_tab, positive = "Yes")
(tst_con_mat = confusionMatrix(tst_tab, positive = "Yes"))
1 𝑛
err(𝐶,̂ Data) = ̂ ))
∑ 𝐼(𝑦𝑖 ≠ 𝐶(𝑥 𝑖
𝑛 𝑖=1
̂
1 𝑦𝑖 ≠ 𝐶(𝑥)
̂
𝐼(𝑦𝑖 ≠ 𝐶(𝑥)) ={
̂
0 𝑦𝑖 = 𝐶(𝑥)
It is also common to discuss the accuracy, which is simply one minus the error.
Like regression, we often split the data, and then consider Train (Classification)
Error and Test (Classification) Error will be used as a measure of how well a
classifier will work on unseen future data.
1
errtrn (𝐶,̂ Train Data) = ̂ ))
∑ 𝐼(𝑦𝑖 ≠ 𝐶(𝑥 𝑖
𝑛trn 𝑖∈trn
1
errtst (𝐶,̂ Test Data) = ̂ ))
∑ 𝐼(𝑦𝑖 ≠ 𝐶(𝑥 𝑖
𝑛tst 𝑖∈tst
## Accuracy
## 0.1002
1 - tst_con_mat$overall["Accuracy"]
## Accuracy
## 0.1076
Sometimes guarding against making certain errors, FP or FN, are more impor-
tant than simply finding the best accuracy. Thus, sometimes we will consider
sensitivity and specificity.
TP TP
Sens = True Positive Rate = =
P TP + FN
tst_con_mat$byClass["Sensitivity"]
## Sensitivity
## 0.8343558
96 CHAPTER 9. OVERVIEW
TN TN
Spec = True Negative Rate = =
N TN + FP
tst_con_mat$byClass["Specificity"]
## Specificity
## 0.894356
Like accuracy, these can easily be found using confusionMatrix().
When considering how well a classifier is performing, often, it is understandable
to assume that any accuracy in a binary classification problem above 0.50, is
a reasonable classifier. This however is not the case. We need to consider the
balance of the classes. To do so, we look at the prevalence of positive cases.
P TP + FN
Prev = =
Total Obs Total Obs
trn_con_mat$byClass["Prevalence"]
## Prevalence
## 0.034
tst_con_mat$byClass["Prevalence"]
## Prevalence
## 0.0326
Here, we see an extremely low prevalence, which suggests an even simpler clas-
sifier than our current based on balance.
## actual
## predicted No Yes
## No 4837 163
The confusionMatrix() function won’t even accept this table as input, because
it isn’t a full matrix, only one row, so we calculate error rates directly. To do
so, we write a function.
9.4. RMARKDOWN 97
## [1] 0.0326
Here we see that the error rate is exactly the prevelance of the minority class.
table(default_tst$default) / length(default_tst$default)
##
## No Yes
## 0.9674 0.0326
This classifier does better than the previous. But the point is, in reality, to
create a good classifier, we should obtain a test error better than 0.033, which is
obtained by simply manipulating the prevalences. Next chapter, we’ll introduce
much better classifiers which should have no problem accomplishing this task.
9.4 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2. The following packages (and their dependencies) were
loaded when knitting this file:
## [1] "ellipse" "caret" "ggplot2" "lattice" "tibble" "ISLR"
98 CHAPTER 9. OVERVIEW
Chapter 10
Logistic Regression
## # A tibble: 10,000 x 4
## default student balance income
## <fct> <fct> <dbl> <dbl>
## 1 No No 730. 44362.
## 2 No Yes 817. 12106.
## 3 No No 1074. 31767.
## 4 No No 529. 35704.
## 5 No No 786. 38463.
## 6 No Yes 920. 7492.
## 7 No No 826. 24905.
## 8 No Yes 809. 17600.
## 9 No No 1161. 37469.
## 10 No No 0 29275.
## # ... with 9,990 more rows
99
100 CHAPTER 10. LOGISTIC REGRESSION
Since linear regression expects a numeric response variable, we coerce the re-
sponse to be numeric. (Notice that we also shift the results, as we require 0 and
1, not 1 and 2.) Notice we have also copied the dataset so that we can return
the original data with factors later.
default_trn_lm$default = as.numeric(default_trn_lm$default) - 1
default_tst_lm$default = as.numeric(default_tst_lm$default) - 1
̂ ∣ 𝑋 = 𝑥] = 𝑋 𝛽.̂
𝔼[𝑌
𝔼[𝑌 ∣ 𝑋 = 𝑥] = 𝑃 (𝑌 = 1 ∣ 𝑋 = 𝑥).
0.4
0.2
0.0
||||||||||||||||||||||||||||||||||||||||||||||
||||||||
|||||||||||||||||||
|||||||||||||||||||||
|||||||||
||||||||
||||||||||
|||||||||||||||
||||||
|||||||||||
|||||
|||
||||
|||||||||||||||||||||||||||
|||||||||||||||||||||||||||||||
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| || | || | | ||
-0.2
balance
Two issues arise. First, all of the predicted probabilities are below 0.5. That
means, we would classify every observation as a "No". This is certainly possible,
but not what we would expect.
all(predict(model_lm) < 0.5)
## [1] TRUE
The next, and bigger issue, is predicted probabilities less than 0.
any(predict(model_lm) < 0)
## [1] TRUE
𝐶 𝐵 (𝑥) = argmax 𝑃 (𝑌 = 𝑔 ∣ 𝑋 = 𝑥)
𝑔
𝑝(𝑥)
̂ = 𝑃 ̂ (𝑌 = 1 ∣ 𝑋 = 𝑥)
102 CHAPTER 10. LOGISTIC REGRESSION
and
𝑃 ̂ (𝑌 = 0 ∣ 𝑋 = 𝑥)
and then classify to the larger of the two. We actually only need to consider a
single probability, usually 𝑃 ̂ (𝑌 = 1 ∣ 𝑋 = 𝑥). Since we use it so often, we give
it the shorthand notation, 𝑝(𝑥).
̂ Then the classifier is written,
̂ 1 𝑝(𝑥)
̂ > 0.5
𝐶(𝑥) ={
0 𝑝(𝑥)
̂ ≤ 0.5
𝑝(𝑥) = 𝑃 (𝑌 = 1 ∣ 𝑋 = 𝑥)
we turn to logistic regression. The model is written
𝑝(𝑥)
log ( ) = 𝛽0 + 𝛽1 𝑥1 + 𝛽2 𝑥2 + ⋯ + 𝛽𝑝 𝑥𝑝 .
1 − 𝑝(𝑥)
1
𝑝(𝑥) = −(𝛽0 +𝛽1 𝑥1 +𝛽2 𝑥2 +⋯+𝛽𝑝 𝑥𝑝 )
= 𝜎(𝛽0 + 𝛽1 𝑥1 + 𝛽2 𝑥2 + ⋯ + 𝛽𝑝 𝑥𝑝 )
1+𝑒
Notice, we use the sigmoid function as shorthand notation, which appears often
in deep learning literature. It takes any real input, and outputs a number
between 0 and 1. How useful! (This is actualy a particular sigmoid function
called the logistic function, but since it is by far the most popular sigmoid
function, often sigmoid function is used to refer to the logistic function)
𝑒𝑥 1
𝜎(𝑥) = =
1 + 𝑒𝑥 1 + 𝑒−𝑥
The model is fit by numerically maximizing the likelihood, which we will let R
take care of.
We start with a single predictor example, again using balance as our single
predictor.
10.3. LOGISTIC REGRESSION WITH GLM() 103
Fitting this model looks very similar to fitting a simple linear regression. In-
stead of lm() we use glm(). The only other difference is the use of family
= "binomial" which indicates that we have a two-class categorical response.
Using glm() with family = "gaussian" would perform the usual linear regres-
sion.
First, we can obtain the fitted coefficients the same way we did with linear
regression.
coef(model_glm)
## (Intercept) balance
## -10.493158288 0.005424994
The next thing we should understand is how the predict() function works with
glm(). So, let’s look at some predictions.
head(predict(model_glm))
𝑝(𝑥)
̂ = 𝑃 ̂ (𝑌 = 1 ∣ 𝑋 = 𝑥)
1 ̂ >0
𝑓(𝑥)
̂
𝐶(𝑥) ={
0 ̂ ≤0
𝑓(𝑥)
where
̂ = 𝛽̂ + 𝛽̂ 𝑥 + 𝛽̂ 𝑥 + ⋯ + 𝛽̂ 𝑥 .
𝑓(𝑥) 0 1 1 2 2 𝑝 𝑝
The commented line, which would give the same results, is performing
̂ 1 𝑝(𝑥)
̂ > 0.5
𝐶(𝑥) ={
0 𝑝(𝑥)
̂ ≤ 0.5
where
𝑝(𝑥)
̂ = 𝑃 ̂ (𝑌 = 1 ∣ 𝑋 = 𝑥).
## [1] 0.0284
As we saw previously, the table() and confusionMatrix() functions can be
used to quickly obtain many more metrics.
train_tab = table(predicted = model_glm_pred, actual = default_trn$default)
library(caret)
train_con_mat = confusionMatrix(train_tab, positive = "Yes")
c(train_con_mat$overall["Accuracy"],
train_con_mat$byClass["Sensitivity"],
train_con_mat$byClass["Specificity"])
We could also write a custom function for the error for use with trained logist
regression models.
get_logistic_error = function(mod, data, res = "y", pos = 1, neg = 0, cut = 0.5) {
probs = predict(mod, newdata = data, type = "response")
preds = ifelse(probs > cut, pos, neg)
calc_class_err(actual = data[, res], predicted = preds)
}
This function will be useful later when calculating train and test errors for
several models at the same time.
get_logistic_error(model_glm, data = default_trn,
res = "default", pos = "Yes", neg = "No", cut = 0.5)
## [1] 0.0284
To see how much better logistic regression is for this task, we create the same
plot we used for linear regression.
plot(default ~ balance, data = default_trn_lm,
col = "darkorange", pch = "|", ylim = c(-0.2, 1),
main = "Using Logistic Regression for Classification")
abline(h = 0, lty = 3)
abline(h = 1, lty = 3)
abline(h = 0.5, lty = 2)
curve(predict(model_glm, data.frame(balance = x), type = "response"),
add = TRUE, lwd = 3, col = "dodgerblue")
abline(v = -coef(model_glm)[1] / coef(model_glm)[2], lwd = 2)
106 CHAPTER 10. LOGISTIC REGRESSION
1.0
|| | | | || | || ||||||||| ||| ||||||||||| ||||||||||||||||||||||||||||||||||||||| || | | || | | ||| | |
0.8
0.6
default
0.4
0.2
0.0
||||||||||||||||||||||||||||||||||||||||||||||
||||||||
|||||||||||||||||||
|||||||||||||||||||||
|||||||||
||||||||
||||||||||
|||||||||||||||
||||||
|||||||||||
|||||
|||
||||
|||||||||||||||||||||||||||
|||||||||||||||||||||||||||||||
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| || | || | | ||
-0.2
balance
𝑝(𝑥)
̂ = 𝑃 ̂ (𝑌 = 1 ∣ 𝑋 = 𝑥) = 0.5
𝛽0̂ + 𝛽1̂ 𝑥1 = 0.
Thus, for logistic regression with a single predictor, the decision boundary is
given by the point
−𝛽0̂
𝑥1 = .
𝛽̂ 1
10.4. ROC CURVES 107
The following is not run, but an alternative way to add the logistic curve to the
plot.
grid = seq(0, max(default_trn$balance), by = 0.01)
sigmoid = function(x) {
1 / (1 + exp(-x))
}
Using the usual formula syntax, it is easy to add or remove complexity from
logistic regressions.
model_1 = glm(default ~ 1, data = default_trn, family = "binomial")
model_2 = glm(default ~ ., data = default_trn, family = "binomial")
model_3 = glm(default ~ . ^ 2 + I(balance ^ 2),
data = default_trn, family = "binomial")
Here we see the misclassification error rates for each model. The train decreases,
and the test decreases, until it starts to increases. Everything we learned about
the bias-variance tradeoff for regression also applies here.
diff(train_errors)
̂ 1 𝑝(𝑥)
̂ >𝑐
𝐶(𝑥) ={
0 𝑝(𝑥)
̂ ≤𝑐
Let’s use this to obtain predictions using a low, medium, and high cutoff. (0.1,
0.5, and 0.9)
test_pred_10 = get_logistic_pred(model_glm, data = default_tst, res = "default",
pos = "Yes", neg = "No", cut = 0.1)
test_pred_50 = get_logistic_pred(model_glm, data = default_tst, res = "default",
pos = "Yes", neg = "No", cut = 0.5)
test_pred_90 = get_logistic_pred(model_glm, data = default_tst, res = "default",
pos = "Yes", neg = "No", cut = 0.9)
metrics = rbind(
c(test_con_mat_10$overall["Accuracy"],
test_con_mat_10$byClass["Sensitivity"],
test_con_mat_10$byClass["Specificity"]),
c(test_con_mat_50$overall["Accuracy"],
test_con_mat_50$byClass["Sensitivity"],
test_con_mat_50$byClass["Specificity"]),
c(test_con_mat_90$overall["Accuracy"],
test_con_mat_90$byClass["Sensitivity"],
test_con_mat_90$byClass["Specificity"])
Note that usually the best accuracy will be seen near 𝑐 = 0.50.
AUC: 0.949
0.4
0.2
0.0
as.numeric(test_roc$auc)
## [1] 0.9492866
A good model will have a high AUC, that is as often as possible a high sensitivity
and specificity.
110 CHAPTER 10. LOGISTIC REGRESSION
We will omit the details, as ISL has as well. If you are interested, the Wikipedia
page provides a rather thorough coverage. Also note that the above is an exam-
ple of the softmax function.
As an example of a dataset with a three category response, we use the iris
dataset, which is so famous, it has its own Wikipedia entry. It is also a default
dataset in R, so no need to load it.
Before proceeding, we test-train split this data.
set.seed(430)
iris_obs = nrow(iris)
iris_idx = sample(iris_obs, size = trunc(0.50 * iris_obs))
iris_trn = iris[iris_idx, ]
iris_test = iris[-iris_idx, ]
10.6 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2. The following packages (and their dependencies) were
loaded when knitting this file:
## [1] "nnet" "pROC" "caret" "ggplot2" "lattice" "tibble" "ISLR"
112 CHAPTER 10. LOGISTIC REGRESSION
Chapter 11
Generative Models
𝜋𝑘 ⋅ 𝑓𝑘 (𝑥)
𝑝𝑘 (𝑥) = 𝑃 (𝑌 = 𝑘 ∣ 𝑋 = 𝑥) = 𝐺
∑𝑔=1 𝜋𝑔 ⋅ 𝑓𝑔 (𝑥)
We call 𝑝𝑘 (𝑥) the posterior probability, which we will estimate then use to
create classifications. The 𝜋𝑔 are called the prior probabilities for each possible
class 𝑔. That is, 𝜋𝑔 = 𝑃 (𝑌 = 𝑔), unconditioned on 𝑋. The 𝑓𝑔 (𝑥) are called the
likelihoods, which are indexed by 𝑔 to denote that they are conditional on the
classes. The denominator is often referred to as a normalizing constant.
The methods will differ by placing different modeling assumptions on the like-
lihoods, 𝑓𝑔 (𝑥). For each method, the priors could be learned from data or
pre-specified.
For each method, classifications are made to the class with the highest estimated
posterior probability, which is equivalent to the class with the largest
113
114 CHAPTER 11. GENERATIVE MODELS
To illustrate these new methods, we return to the iris data, which you may
remember has three classes. After a test-train split, we create a number of plots
to refresh our memory.
set.seed(430)
iris_obs = nrow(iris)
iris_idx = sample(iris_obs, size = trunc(0.50 * iris_obs))
# iris_index = sample(iris_obs, size = trunc(0.10 * iris_obs))
iris_trn = iris[iris_idx, ]
iris_tst = iris[-iris_idx, ]
4
1.5
3
1.0
2
0.5
1
0.0
| | || | || || | |
| | | | | || | || | | || | || || || || || | | | | | | || | | | | || ||| || || | || | | || || || | | || || | | | || ||
0
2 4 6 8 0 1 2 3
Sepal.Length Sepal.Width
1.0
1.0
0.8
0.6
0.5
0.4
0.2
0.0
|
0.0
| | | | | || | || | | | | | || || || | || || || | || | | ||
| | || | | || || || || | | || || | |
| || | | | | | || || | | | | | | | | |
4 5 6 7 8 9 2 3 4 5
Feature
1.5
Petal.Width
1.0
0.5
0.0 0.5 1.0
0.0
4 5 6
6
4 Petal.Length
3
2
1 2 3
1
3.5
Sepal.Width
3.0
2.5
2.5 3.0
2.5
4.0
2.0
7
5
3.5
1.5
4
6
3.0
1.0
3
5
0.5
2.5
2
1
0.0
versicolor
versicolor
versicolor
versicolor
virginica
virginica
virginica
virginica
setosa
setosa
setosa
setosa
Feature
Especially based on the pairs plot, we see that it should not be too difficult to
find a good classifier.
Notice that we use caret::featurePlot to access the featurePlot() function
without loading the entire caret package.
𝑋 ∣ 𝑌 = 𝑘 ∼ 𝑁 (𝜇𝑘 , Σ)
1 1
𝑓𝑘 (x) = exp [− (x − 𝜇𝑘 )′ Σ−1 (x − 𝜇𝑘 )]
(2𝜋)𝑝/2 |Σ|1/2 2
Notice that Σ does not depend on 𝑘, that is, we are assuming the same Σ for
each class. We then use information from all the classes to estimate Σ.
To fit an LDA model, we use the lda() function from the MASS package.
library(MASS)
iris_lda = lda(Species ~ ., data = iris_trn)
iris_lda
## Call:
## lda(Species ~ ., data = iris_trn)
##
## Prior probabilities of groups:
118 CHAPTER 11. GENERATIVE MODELS
## [1] TRUE
names(predict(iris_lda, iris_trn))
## [1] 0.02666667
calc_class_err(predicted = iris_lda_tst_pred, actual = iris_tst$Species)
## [1] 0.01333333
As expected, LDA performs well on both the train and test data.
table(predicted = iris_lda_tst_pred, actual = iris_tst$Species)
## actual
## predicted setosa versicolor virginica
## setosa 21 0 0
## versicolor 0 27 0
## virginica 0 1 26
Looking at the test set, we see that we are perfectly predicting both setosa and
versicolor. The only error is labeling a virginica as a versicolor.
iris_lda_flat = lda(Species ~ ., data = iris_trn, prior = c(1, 1, 1) / 3)
iris_lda_flat
## Call:
## lda(Species ~ ., data = iris_trn, prior = c(1, 1, 1)/3)
##
## Prior probabilities of groups:
## setosa versicolor virginica
## 0.3333333 0.3333333 0.3333333
##
## Group means:
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## setosa 4.958621 3.420690 1.458621 0.237931
## versicolor 6.063636 2.845455 4.318182 1.354545
## virginica 6.479167 2.937500 5.479167 2.045833
##
120 CHAPTER 11. GENERATIVE MODELS
## [1] 0.02666667
calc_class_err(predicted = iris_lda_flat_tst_pred, actual = iris_tst$Species)
## [1] 0.01333333
This actually gives a better test accuracy!
𝑋 ∣ 𝑌 = 𝑘 ∼ 𝑁 (𝜇𝑘 , Σ𝑘 )
1 1
𝑓𝑘 (x) = exp [− (x − 𝜇𝑘 )′ Σ−1
𝑘 (x − 𝜇𝑘 )]
(2𝜋)𝑝/2 |Σ𝑘 |1/2 2
Notice that now Σ𝑘 does depend on 𝑘, that is, we are allowing a different Σ𝑘
for each class. We only use information from class 𝑘 to estimate Σ𝑘 .
iris_qda = qda(Species ~ ., data = iris_trn)
iris_qda
## Call:
## qda(Species ~ ., data = iris_trn)
##
11.2. QUADRATIC DISCRIMINANT ANALYSIS 121
Here the output is similar to LDA, again giving the estimated 𝜋𝑘̂ and 𝜇𝑘̂ for
each class. Like lda(), the qda() function is found in the MASS package.
Consider trying to fit QDA again, but this time with a smaller training set.
(Use the commented line above to obtain a smaller test set.) This will cause an
error because there are not enough observations within each class to estimate
the large number of parameters in the Σ𝑘 matrices. This is less of a problem
with LDA, since all observations, no matter the class, are being use to estimate
the shared Σ matrix.
iris_qda_trn_pred = predict(iris_qda, iris_trn)$class
iris_qda_tst_pred = predict(iris_qda, iris_tst)$class
The predict() function operates the same as the predict() function for LDA.
calc_class_err(predicted = iris_qda_trn_pred, actual = iris_trn$Species)
## [1] 0.01333333
calc_class_err(predicted = iris_qda_tst_pred, actual = iris_tst$Species)
## [1] 0.05333333
table(predicted = iris_qda_tst_pred, actual = iris_tst$Species)
## actual
## predicted setosa versicolor virginica
## setosa 21 0 0
## versicolor 0 25 1
## virginica 0 3 25
Also note that, QDA creates quadratic decision boundaries, while LDA creates
linear decision boundaries. We could also add quadratic terms to LDA to allow
it to create quadratic decision boundaries.
122 CHAPTER 11. GENERATIVE MODELS
X ∣ 𝑌 = 𝑘 ∼ 𝑁 (𝜇𝑘 , Σ𝑘 )
𝑗=𝑝
𝑓𝑘 (𝑥) = ∏ 𝑓𝑘𝑗 (𝑥𝑗 )
𝑗=1
Here, 𝑓𝑘𝑗 (𝑥𝑗 ) is the density for the 𝑗-th predictor conditioned on the 𝑘-th class.
Notice that there is a 𝜎𝑘𝑗 for each predictor for each class.
2
1 1 𝑥𝑗 − 𝜇𝑘𝑗
𝑓𝑘𝑗 (𝑥𝑗 ) = √ exp [− ( ) ]
𝜎𝑘𝑗 2𝜋 2 𝜎𝑘𝑗
##
## Naive Bayes Classifier for Discrete Predictors
##
## Call:
## naiveBayes.default(x = X, y = Y, laplace = laplace)
##
## A-priori probabilities:
## Y
## setosa versicolor virginica
## 0.3866667 0.2933333 0.3200000
##
## Conditional probabilities:
## Sepal.Length
11.3. NAIVE BAYES 123
## Y [,1] [,2]
## setosa 4.958621 0.3212890
## versicolor 6.063636 0.5636154
## virginica 6.479167 0.5484993
##
## Sepal.Width
## Y [,1] [,2]
## setosa 3.420690 0.4012296
## versicolor 2.845455 0.3262007
## virginica 2.937500 0.3267927
##
## Petal.Length
## Y [,1] [,2]
## setosa 1.458621 0.1880677
## versicolor 4.318182 0.5543219
## virginica 5.479167 0.4995469
##
## Petal.Width
## Y [,1] [,2]
## setosa 0.237931 0.09788402
## versicolor 1.354545 0.21979920
## virginica 2.045833 0.29039578
Many packages implement naive Bayes. Here we choose to use naiveBayes()
from the package e1071. (The name of this package has an interesting his-
tory. Based on the name you wouldn’t know it, but the package contains many
functions related to machine learning.)
The Conditional probabilities: portion of the output gives the mean and
standard deviation of the normal distribution for each predictor in each class.
Notice how these mean estimates match those for LDA and QDA above.
Note that naiveBayes() will work without a factor response, but functions
much better with one. (Especially when making predictions.) If you are using
a 0 and 1 response, you might consider coercing to a factor first.
head(predict(iris_nb, iris_trn))
## [1] 0.05333333
calc_class_err(predicted = iris_nb_tst_pred, actual = iris_tst$Species)
## [1] 0.02666667
table(predicted = iris_nb_tst_pred, actual = iris_tst$Species)
## actual
## predicted setosa versicolor virginica
## setosa 21 0 0
## versicolor 0 28 2
## virginica 0 0 24
Like LDA, naive Bayes is having trouble with virginica.
Method Train Error Test Error
LDA 0.0266667 0.0133333
LDA, Flat Prior 0.0266667 0.0133333
QDA 0.0133333 0.0533333
Naive Bayes 0.0533333 0.0266667
Summarizing the results, we see that Naive Bayes is the worst of LDA, QDA,
and NB for this data. So why should we care about naive Bayes?
The strength of naive Bayes comes from its ability to handle a large number of
predictors, 𝑝, even with a limited sample size 𝑛. Even with the naive indepen-
dence assumption, naive Bayes works rather well in practice. Also because of
this assumption, we can often train naive Bayes where LDA and QDA may be
impossible to train because of the large number of parameters relative to the
number of observations.
Here naive Bayes doesn’t get a chance to show its strength since LDA and
QDA already perform well, and the number of predictors is low. The choice
between LDA and QDA is mostly down to a consideration about the amount of
complexity needed.
11.4. DISCRETE INPUTS 125
unique(iris_trn_mod$Sepal.Width)
##
## Naive Bayes Classifier for Discrete Predictors
##
## Call:
## naiveBayes.default(x = X, y = Y, laplace = laplace)
##
## A-priori probabilities:
## Y
## setosa versicolor virginica
## 0.3866667 0.2933333 0.3200000
##
## Conditional probabilities:
## Sepal.Length
## Y [,1] [,2]
## setosa 4.958621 0.3212890
## versicolor 6.063636 0.5636154
## virginica 6.479167 0.5484993
##
## Sepal.Width
## Y Large Medium Small
## setosa 0.06896552 0.75862069 0.17241379
## versicolor 0.00000000 0.27272727 0.72727273
## virginica 0.00000000 0.33333333 0.66666667
Naive Bayes makes a somewhat obvious and intelligent choice to model the cat-
egorical variable as a multinomial. It then estimates the probability parameters
126 CHAPTER 11. GENERATIVE MODELS
of a multinomial distribution.
lda(Species ~ Sepal.Length + Sepal.Width, data = iris_trn_mod)
## Call:
## lda(Species ~ Sepal.Length + Sepal.Width, data = iris_trn_mod)
##
## Prior probabilities of groups:
## setosa versicolor virginica
## 0.3866667 0.2933333 0.3200000
##
## Group means:
## Sepal.Length Sepal.WidthMedium Sepal.WidthSmall
## setosa 4.958621 0.7586207 0.1724138
## versicolor 6.063636 0.2727273 0.7272727
## virginica 6.479167 0.3333333 0.6666667
##
## Coefficients of linear discriminants:
## LD1 LD2
## Sepal.Length 2.194825 0.7108153
## Sepal.WidthMedium 1.296250 -0.7224618
## Sepal.WidthSmall 2.922089 -2.5286497
##
## Proportion of trace:
## LD1 LD2
## 0.9929 0.0071
LDA however creates dummy variables, here with Large is the reference level,
then continues to model them as normally distributed. Not great, but better
then not using a categorical variable.
11.5 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2. The following packages (and their dependencies) were
loaded when knitting this file:
## [1] "e1071" "MASS"
Chapter 12
k-Nearest Neighbors
𝑝(𝑥)
log ( ) = 𝛽0 + 𝛽1 𝑥1 + 𝛽2 𝑥2 + ⋯ + 𝛽𝑝 𝑥𝑝 .
1 − 𝑝(𝑥)
In this case, the 𝛽𝑗 are the parameters of the model, which we learned (es-
timated) by training (fitting) the model. Those estimates were then used to
obtain estimates of the probability 𝑝(𝑥) = 𝑃 (𝑌 = 1 ∣ 𝑋 = 𝑥),
̂ ̂ ̂ ̂
𝑒𝛽0 +𝛽1 𝑥1 +𝛽2 𝑥2 +⋯+𝛽𝑝 𝑥𝑝
𝑝(𝑥)
̂ = ̂ ̂ ̂ ̂
1 + 𝑒𝛽0 +𝛽1 𝑥1 +𝛽2 𝑥2 +⋯+𝛽𝑝 𝑥𝑝
As we saw in regression, 𝑘-nearest neighbors has no such model parameters. In-
stead, it has a tuning parameter, 𝑘. This is a parameter which determines
how the model is trained, instead of a parameter that is learned through train-
ing. Note that tuning parameters are not used exclusively with non-parametric
methods. Later we will see examples of tuning parameters for parametric meth-
ods.
Often when discussing 𝑘-nearest neighbors for classification, it is framed as a
black-box method that directly returns classifications. We will instead frame
it as a non-parametric model for the probabilites 𝑝𝑔 (𝑥) = 𝑃 (𝑌 = 𝑔 ∣ 𝑋 = 𝑥).
That is a 𝑘-nearest neighbors model using 𝑘 neighbors estimates this probability
as
1
̂ (𝑥) = 𝑃𝑘̂ (𝑌 = 𝑔 ∣ 𝑋 = 𝑥) =
𝑝𝑘𝑔 ∑ 𝐼(𝑦𝑖 = 𝑔)
𝑘 𝑖∈𝒩𝑘 (𝑥,𝒟)
127
128 CHAPTER 12. K-NEAREST NEIGHBORS
This is the same as saying that we classify to the class with the most observations
in the 𝑘 nearest neighbors. If more than one class is tied for the highest estimated
probablity, simply assign a class at random to one of the classes tied for highest.
In the binary case this becomes
1 𝑝𝑘0
̂ (𝑥) > 0.5
𝐶𝑘̂ (𝑥) = {
0 𝑝𝑘1
̂ (𝑥) < 0.5
Again, if the probability for class 0 and 1 are equal, simply assign at random.
10
8
x
6
x2
4
2
O
B
0
0 2 4 6 8 10
x1
3
̂ (𝑥1 = 8, 𝑥2 = 6) = 𝑃5̂ (𝑌 = Blue ∣ 𝑋1 = 8, 𝑋2 = 6) =
𝑝5𝐵
5
12.1. BINARY DATA EXAMPLE 129
2
̂ (𝑥1 = 8, 𝑥2 = 6) = 𝑃5̂ (𝑌 = Orange ∣ 𝑋1 = 8, 𝑋2 = 6) =
𝑝5𝑂
5
Thus
We first load some necessary libraries. We’ll begin discussing 𝑘-nearest neighbors
for classification by returning to the Default data from the ISLR package. To
perform 𝑘-nearest neighbors for classification, we will use the knn() function
from the class package.
Unlike many of our previous methods, such as logistic regression, knn() requires
that all predictors be numeric, so we coerce student to be a 0 and 1 dummy
variable instead of a factor. (We can, and should, leave the response as a
factor.) Numeric predictors are required because of the distance calculations
taking place.
set.seed(42)
Default$student = as.numeric(Default$student) - 1
default_idx = sample(nrow(Default), 5000)
default_trn = Default[default_idx, ]
default_tst = Default[-default_idx, ]
Like we saw with knn.reg form the FNN package for regression, knn() from
class does not utilize the formula syntax, rather, requires the predictors be
their own data frame or matrix, and the class labels be a separate factor variable.
Note that the y data should be a factor vector, not a data frame containing a
factor vector.
Note that the FNN package also contains a knn() function for classification. We
choose knn() from class as it seems to be much more popular. However, you
should be aware of which packages you have loaded and thus which functions
you are using. They are very similar, but have some differences.
# training data
X_default_trn = default_trn[, -1]
y_default_trn = default_trn$default
# testing data
130 CHAPTER 12. K-NEAREST NEIGHBORS
Again, there is very little “training” with 𝑘-nearest neighbors. Essentially the
only training is to simply remember the inputs. Because of this, we say that
𝑘-nearest neighbors is fast at training time. However, at test time, 𝑘-nearest
neighbors is very slow. For each test observation, the method must find the
𝑘-nearest neighbors, which is not computationally cheap. Note that by deafult,
knn() uses Euclidean distance to determine neighbors.
head(knn(train = X_default_trn,
test = X_default_tst,
cl = y_default_trn,
k = 3))
## [1] No No No No No No
## Levels: No Yes
Because of the lack of any need for training, the knn() function immediately
returns classifications. With logistic regression, we needed to use glm() to
fit the model, then predict() to obtain probabilities we would use to make
a classifier. Here, the knn() function directly returns classifications. That is
knn() is essentially 𝐶𝑘̂ (𝑥).
Here, knn() takes four arguments:
• train, the predictors for the train set.
• test, the predictors for the test set. knn() will output results (classifica-
tions) for these cases.
• cl, the true class labels for the train set.
• k, the number of neighbors to consider.
calc_class_err = function(actual, predicted) {
mean(actual != predicted)
}
We’ll use our usual calc_class_err() function to asses how well knn() works
with this data. We use the test data to evaluate.
calc_class_err(actual = y_default_tst,
predicted = knn(train = X_default_trn,
test = X_default_tst,
cl = y_default_trn,
k = 5))
## [1] 0.0312
Often with knn() we need to consider the scale of the predictors variables. If
one variable is contains much larger numbers because of the units or range of
the variable, it will dominate other variables in the distance measurements. But
12.1. BINARY DATA EXAMPLE 131
## [1] 0.0284
Here we see the scaling slightly improves the classification accuracy. This may
not always be the case, and often, it is normal to attempt classification with
and without scaling.
How do we choose 𝑘? Try different values and see which works best.
set.seed(42)
k_to_try = 1:100
err_k = rep(x = 0, times = length(k_to_try))
for (i in seq_along(k_to_try)) {
pred = knn(train = scale(X_default_trn),
test = scale(X_default_tst),
cl = y_default_trn,
k = k_to_try[i])
err_k[i] = calc_class_err(y_default_tst, pred)
}
The seq_along() function can be very useful for looping over a vector that
stores non-consecutive numbers. It often removes the need for an additional
counter variable. We actually didn’t need it in the above knn() example, but it
is still a good habit. For example maybe we didn’t want to try every value of 𝑘,
but only odd integers, which woudl prevent ties. Or perhaps we’d only like to
check multiples of 5 to further cut down on computation time.
Also, note that we set a seed before running this loops. This is because we are
considering even values of 𝑘, thus, there are ties which are randomly broken.
Naturally, we plot the 𝑘-nearest neighbor results.
# plot error vs choice of k
plot(err_k, type = "b", col = "dodgerblue", cex = 1, pch = 20,
xlab = "k, number of neighbors", ylab = "classification error",
main = "(Test) Error Rate vs Neighbors")
# add line for min error seen
abline(h = min(err_k), col = "darkorange", lty = 3)
# add line for minority prevalence in test set
abline(h = mean(y_default_tst == "Yes"), col = "grey", lty = 2)
132 CHAPTER 12. K-NEAREST NEIGHBORS
0.040
classification error
0.035
0.030
0.025
0 20 40 60 80 100
k, number of neighbors
The dotted orange line represents the smallest observed test classification error
rate.
min(err_k)
## [1] 0.025
We see that five different values of 𝑘 are tied for the lowest error rate.
which(err_k == min(err_k))
## [1] 24
Given a choice of these five values of 𝑘, we select the largest, as it is the least
variable, and has the least chance of overfitting.
max(which(err_k == min(err_k)))
## [1] 24
Recall that defaulters are the minority class. That is, the majority of observa-
tions are non-defaulters.
table(y_default_tst)
## y_default_tst
## No Yes
## 4837 163
12.2. CATEGORICAL DATA 133
Notice that, as 𝑘 increases, eventually the error approaches the test prevalence
of the minority class.
mean(y_default_tst == "Yes")
## [1] 0.0326
All the predictors here are numeric, so we proceed to splitting the data into
predictors and classes.
# training data
X_iris_trn = iris_trn[, -5]
y_iris_trn = iris_trn$Species
# testing data
X_iris_tst = iris_tst[, -5]
y_iris_tst = iris_tst$Species
Like previous methods, we can obtain predicted probabilities given test predic-
tors. To do so, we add an argument, prob = TRUE
iris_pred = knn(train = scale(X_iris_trn),
test = scale(X_iris_tst),
cl = y_iris_trn,
k = 10,
prob = TRUE)
head(iris_pred, n = 50)
## [1] 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0
## [20] 1.0 1.0 0.9 0.7 0.7 0.9 0.8 0.8 1.0 0.8 1.0 0.5 1.0 1.0 0.5 0.9 1.0 0.9 0.9
## [39] 0.6 0.8 0.7 1.0 1.0 1.0 1.0 1.0 0.9 0.8 1.0 0.5
12.4 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2. The following packages (and their dependencies) were
loaded when knitting this file:
## [1] "class" "ISLR"
Part IV
Unsupervised Learning
135
Chapter 13
Overview
TODO: Move current content into the following placeholder chapters. Add
details.
13.1 Methods
13.1.1 Principal Component Analysis
To perform PCA in R we will use prcomp(). See ?prcomp() for details.
13.2 Examples
13.2.1 US Arrests
library(ISLR)
data(USArrests)
apply(USArrests, 2, mean)
137
138 CHAPTER 13. OVERVIEW
apply(USArrests, 2, sd)
“Before” performing PCA, we will scale the data. (This will actually happen
inside the prcomp() function.)
USArrests_pca = prcomp(USArrests, scale = TRUE)
## Importance of components:
## PC1 PC2 PC3 PC4
## Standard deviation 1.5749 0.9949 0.59713 0.41645
## Proportion of Variance 0.6201 0.2474 0.08914 0.04336
## Cumulative Proportion 0.6201 0.8675 0.95664 1.00000
USArrests_pca$center
We see that $center and $scale give the mean and standard deviations for the
original variables. $rotation gives the loading vectors that are used to rotate
the original data to obtain the principal components.
dim(USArrests_pca$x)
## [1] 50 4
13.2. EXAMPLES 139
dim(USArrests)
## [1] 50 4
head(USArrests_pca$x)
## [,1]
## [1,] -0.9756604
scale(as.matrix(USArrests))[1, ] %*% USArrests_pca$rotation[, 2]
## [,1]
## [1,] 1.122001
scale(as.matrix(USArrests))[1, ] %*% USArrests_pca$rotation[, 3]
## [,1]
## [1,] -0.4398037
scale(as.matrix(USArrests))[1, ] %*% USArrests_pca$rotation[, 4]
## [,1]
## [1,] 0.1546966
head(scale(as.matrix(USArrests)) %*% USArrests_pca$rotation[,1])
## [,1]
## Alabama -0.9756604
## Alaska -1.9305379
## Arizona -1.7454429
## Arkansas 0.1399989
## California -2.4986128
## Colorado -1.4993407
head(USArrests_pca$x[, 1])
sum(USArrests_pca$rotation[, 1] ^ 2)
## [1] 1
USArrests_pca$rotation[, 1] %*% USArrests_pca$rotation[, 2]
## [,1]
## [1,] -1.665335e-16
USArrests_pca$rotation[, 1] %*% USArrests_pca$rotation[, 3]
## [,1]
## [1,] 1.110223e-16
USArrests_pca$x[, 1] %*% USArrests_pca$x[, 2]
## [,1]
## [1,] -7.938095e-15
USArrests_pca$x[, 1] %*% USArrests_pca$x[, 3]
## [,1]
## [1,] 1.787459e-14
The above verifies some of the “math” of PCA. We see how the loadings obtain
the principal components from the original data. We check that the loading
vectors are normalized. We also check for orthogonality of both the loading
vectors and the principal components. (Note the above inner products aren’t
exactly 0, but that is simply a numerical issue.)
biplot(USArrests_pca, scale = 0, cex = 0.5)
13.2. EXAMPLES 141
Mississippi
North Carolina
2
South Carolina
0.5
Murder West Virginia Vermont
Georgia
Alabama Arkansas
Alaska
1
Kentucky
Louisiana Tennessee South Dakota
Assault
North Dakota
Montana
Maryland
Maine
Wyoming
Virginia Idaho
New Mexico
PC2
0.0
Florida
New Hampshire
0
Iowa
Michigan Indiana Nebraska
Missouri Kansas
DelawareOklahoma
Texas
Rape Oregon Pennsylvania
MinnesotaWisconsin
Illinois
Nevada Arizona Ohio
New York
Colorado Washington
-1
Connecticut
-0.5
-2
UrbanPop
-3
-3 -2 -1 0 1 2 3
PC1
A biplot can be used to visualize both the principal component scores and the
principal component loadings. (Note the two scales for each axis.)
USArrests_pca$sdev
pve = get_PVE(USArrests_pca)
pve
0.8
0.6
0.4
0.2
0.0
Principal Component
We can then plot the proportion of variance explained for each PC. As expected,
we see the PVE decrease.
cumsum(pve)
)
Cumulative Proportion of Variance Explained
1.0
0.8
0.6
0.4
0.2
0.0
Principal Component
## [1] 0.8513357
144 CHAPTER 13. OVERVIEW
max(regular_scaling$results$Accuracy)
## [1] 0.8267131
using_pca$preProcess
Above we simulate data for clustering. Note that, we did this in a way that will
result in three clusters.
true_clusters = c(rep(3, n / 3), rep(1, n / 3), rep(2, n / 3))
We label the true clusters 1, 2, and 3 in a way that will “match” output from
𝑘-means. (Which is somewhat arbitrary.)
kmean_out = kmeans(clust_data, centers = 3, nstart = 10)
names(kmean_out)
table(true_clusters, kmeans_clusters)
## kmeans_clusters
## true_clusters 1 2 3
## 1 0 2 58
## 2 0 55 5
## 3 60 0 0
## [1] 180 10
146 CHAPTER 13. OVERVIEW
4
2
0
0 2 4 6
First Variable
6
Second Variable
4
2
0
0 2 4 6
First Variable
Even when using their true clusters for coloring, this plot isn’t very helpful.
clust_data_pca = prcomp(clust_data, scale = TRUE)
plot(
clust_data_pca$x[, 1],
clust_data_pca$x[, 2],
pch = 0,
xlab = "First Principal Component",
ylab = "Second Principal Component"
)
148 CHAPTER 13. OVERVIEW
3
Second Principal Component
2
1
0
-1
-2
-3
-3 -2 -1 0 1 2 3
If we instead plot the first two principal components, we see, even without
coloring, one blob that is clearly separate from the rest.
plot(
clust_data_pca$x[, 1],
clust_data_pca$x[, 2],
col = true_clusters,
pch = 0,
xlab = "First Principal Component",
ylab = "Second Principal Component",
cex = 2
)
points(clust_data_pca$x[, 1], clust_data_pca$x[, 2], col = kmeans_clusters, pch = 20, c
13.2. EXAMPLES 149
3
Second Principal Component
2
1
0
-1
-2
-3
-3 -2 -1 0 1 2 3
Now adding the true colors (boxes) and the 𝑘-means results (circles), we obtain
a nice visualization.
clust_data_pve = get_PVE(clust_data_pca)
plot(
cumsum(clust_data_pve),
xlab = "Principal Component",
ylab = "Cumulative Proportion of Variance Explained",
ylim = c(0, 1),
type = 'b'
)
150 CHAPTER 13. OVERVIEW
1.0
0.8
0.6
0.4
0.2
0.0
2 4 6 8 10
Principal Component
The above visualization works well because the first two PCs explain a large
proportion of the variance.
#install.packages('sparcl')
library(sparcl)
dist(scale(clust_data))
hclust (*, "complete")
Here we apply hierarchical clustering to the scaled data. The dist() function
is used to calculate pairwise distances between the (scaled in this case) observa-
tions. We use complete linkage. We then use the cutree() function to cluster
the data into 3 clusters. The ColorDendrogram() function is then used to plot
the dendrogram. Note that the branchlength argument is somewhat arbitrary
(the length of the colored bar) and will need to be modified for each dendrogram.
table(true_clusters, clust_data_cut)
## clust_data_cut
## true_clusters 1 2 3
## 1 0 16 44
## 2 2 56 2
## 3 58 2 0
3.0
2.5
2.0
1.5
1.0
dist(scale(clust_data))
hclust (*, "single")
table(true_clusters, clust_data_cut)
## clust_data_cut
## true_clusters 1 2 3
## 1 59 1 0
## 2 59 0 1
## 3 60 0 0
clust_data_hc = hclust(dist(scale(clust_data)), method = "average")
clust_data_cut = cutree(clust_data_hc , 3)
ColorDendrogram(clust_data_hc, y = clust_data_cut,
labels = names(clust_data_cut),
main = "Simulated Data, Average Linkage",
branchlength = 1)
13.2. EXAMPLES 153
dist(scale(clust_data))
hclust (*, "average")
table(true_clusters, clust_data_cut)
## clust_data_cut
## true_clusters 1 2 3
## 1 1 58 1
## 2 1 59 0
## 3 60 0 0
We also try single and average linkage. Single linkage seems to perform poorly
here, while average linkage seems to be working well.
0
-1
-2
-3 -2 -1 0 1 2 3
iris_pca$x[, 1]
0.4
0.2
iris_pca$x[, 4]
0.0
-0.2
-0.4
iris_pca$x[, 3]
iris_pve = get_PVE(iris_pca)
plot(
cumsum(iris_pve),
xlab = "Principal Component",
ylab = "Cumulative Proportion of Variance Explained",
ylim = c(0, 1),
type = 'b'
)
156 CHAPTER 13. OVERVIEW
1.0
0.8
0.6
0.4
0.2
0.0
Principal Component
##
## setosa versicolor virginica
## 1 50 0 0
## 2 0 2 36
## 3 0 48 14
iris_hc = hclust(dist(scale(iris[,-5])), method = "complete")
iris_cut = cutree(iris_hc , 3)
ColorDendrogram(iris_hc, y = iris_cut,
labels = names(iris_cut),
main = "Iris, Complete Linkage",
branchlength = 1.5)
13.2. EXAMPLES 157
dist(scale(iris[, -5]))
hclust (*, "complete")
table(iris_cut, iris[,5])
##
## iris_cut setosa versicolor virginica
## 1 49 0 0
## 2 1 21 2
## 3 0 29 48
table(iris_cut, iris_kmeans$clust)
##
## iris_cut 1 2 3
## 1 49 0 0
## 2 1 0 23
## 3 0 38 39
iris_hc = hclust(dist(scale(iris[,-5])), method = "single")
iris_cut = cutree(iris_hc , 3)
ColorDendrogram(iris_hc, y = iris_cut,
labels = names(iris_cut),
main = "Iris, Single Linkage",
branchlength = 0.3)
158 CHAPTER 13. OVERVIEW
1.5
1.0
0.5
0.0
dist(scale(iris[, -5]))
hclust (*, "single")
dist(scale(iris[, -5]))
hclust (*, "average")
13.4 RMarkdown
The RMarkdown file for this chapter can be found here. The file was created
using R version 4.0.2 and the following packages:
• Base Packages, Attached
## [1] "stats" "graphics" "grDevices" "utils" "datasets" "methods"
## [7] "base"
• Additional Packages, Attached
## [1] "sparcl" "MASS" "mlbench" "caret" "ggplot2" "lattice" "ISLR"
• Additional Packages, Not Attached
## [1] "tidyselect" "xfun" "purrr" "reshape2" "splines"
160 CHAPTER 13. OVERVIEW
Principal Component
Analysis
161
162 CHAPTER 14. PRINCIPAL COMPONENT ANALYSIS
Chapter 15
k-Means
163
164 CHAPTER 15. K-MEANS
Chapter 16
Mixture Models
165
166 CHAPTER 16. MIXTURE MODELS
Chapter 17
Hierarchical Clustering
167
168 CHAPTER 17. HIERARCHICAL CLUSTERING
Part V
In Practice
169
Chapter 18
Overview
171
172 CHAPTER 18. OVERVIEW
Chapter 19
Supervised Learning
Overview
Bayes Classifier
• Classify to the class with the highest probability given a particular input
𝑥.
𝐶 𝐵 (x) = argmax 𝑃 [𝑌 = 𝑘 ∣ X = x]
𝑘
173
174 CHAPTER 19. SUPERVISED LEARNING OVERVIEW
– A more complex model than the model with the best test accuracy
is overfitting.
Classification Methods
• Logistic Regression
• Linear Discriminant Analysis (LDA)
• Quadratic Discriminant Analysis (QDA)
• Naive Bayes (NB)
• 𝑘-Nearest Neighbors (KNN)
• For each, we can:
– Obtain predicted probabilities.
– Make classifications.
– Find decision boundaries. (Seen only for some.)
Tuning Parameters
• Specify how to train a model. This in contrast to model parameters,
which are learned through training.
Cross-Validation
• A method to estimate test metrics with training data. Repeats the train-
validate split inside the training data.
Curse of Dimensionality
• As feature space grows, that is as 𝑝 grows, “neighborhoods” must become
much larger to contain “neighbors,” thus local methods are not so local.
19.1. EXTERNAL LINKS 175
No-Free-Lunch Theorem
• There is no one classifier that will be best across all datasets.
19.2 RMarkdown
The RMarkdown file for this chapter can be found here. The file was created
using R version 4.0.2.
176 CHAPTER 19. SUPERVISED LEARNING OVERVIEW
Chapter 20
Resampling
𝑌 ∼ 𝑁 (𝜇 = 𝑥3 , 𝜎2 = 0.252 )
We first simulate a single dataset, which we also split into a train and validation
set. Here, the validation set is 20% of the data.
set.seed(42)
sim_data = gen_sim_data(sample_size = 200)
sim_idx = sample(1:nrow(sim_data), 160)
177
178 CHAPTER 20. RESAMPLING
sim_trn = sim_data[sim_idx, ]
sim_val = sim_data[-sim_idx, ]
-0.5
-1.0
Recall that we needed this validation set because the training error was far too
optimistic for highly flexible models. This would lead us to always use the most
flexible model.
fit = lm(y ~ poly(x, 10), data = sim_trn)
## [1] 0.2297774
calc_rmse(actual = sim_val$y, predicted = predict(fit, sim_val))
## [1] 0.2770462
20.1. VALIDATION-SET APPROACH 179
30
25
0.30
20
Times Chosen
RMSE
15
0.25
10
0.20
5
0
2 4 6 8 10 1 2 3 4 5 6 7 8 9 10
20.2 Cross-Validation
Instead of using a single test-train split, we instead look to use 𝐾-fold cross-
validation.
𝐾
𝑛𝑘
RMSE-CV𝐾 = ∑ RMSE𝑘
𝑘=1
𝑛
1 ̂ (𝑥 ))
2
RMSE𝑘 = √ ∑ (𝑦𝑖 − 𝑓 −𝑘 𝑖
𝑛𝑘 𝑖∈𝐶
𝑘
1 𝐾
RMSE-CV𝐾 = ∑ RMSE𝑘
𝐾 𝑘=1
• TODO: create and add graphic that shows the splitting process
• TODO: Can be used with any metric, MSE, RMSE, class-err, class-acc
There are many ways to perform cross-validation in R, depending on the statis-
tical learning method of interest. Some methods, for example glm() through
boot::cv.glm() and knn() through knn.cv() have cross-validation capabilities
built-in. We’ll use glm() for illustration. First we need to convince ourselves
that glm() can be used to perform the same tasks as lm().
glm_fit = glm(y ~ poly(x, 3), data = sim_trn)
coef(glm_fit)
We are actually given two values. The first is exactly the LOOCV-MSE. The
second is a minor correction that we will not worry about. We take a square
root to obtain LOOCV-RMSE.
We repeat the above simulation study, this time performing 5-fold cross-
validation. With a total sample size of 𝑛 = 200 each validation set has 40
observations, as did the single validation set in the previous simulations.
cv_rmse = matrix(0, ncol = num_degrees, nrow = num_sims)
set.seed(42)
for (i in 1:num_sims) {
# simulate data, use all data for training
sim_trn = gen_sim_data(sample_size = 200)
# fit models and store RMSE
for (j in 1:num_degrees) {
#fit model
fit = glm(y ~ poly(x, degree = j), data = sim_trn)
# calculate error
cv_rmse[i, j] = sqrt(boot::cv.glm(sim_trn, fit, K = 5)$delta[1])
}
}
30
25
25
20
20
Times Chosen
Times Chosen
15
15
10
10
5
5
0
1 2 3 4 5 6 7 8 9 10 1 2 3 4 5 6 7 8 9 10
0.35
0.30
0.30
RMSE
RMSE
0.25
0.25
0.20
0.20
2 4 6 8 10 2 4 6 8 10
Essentially, this example will also show how to not cross-validate properly. It
will also show can example of cross-validated in a classification setting.
calc_err = function(actual, predicted) {
mean(actual != predicted)
}
𝑌 ∼ bern(𝑝 = 0.5)
𝑋𝑗 ∼ 𝑁 (𝜇 = 0, 𝜎2 = 1)
Now we would like to train a logistic regression model to predict 𝑌 using the
available predictor data. However, here we have 𝑝 > 𝑛, which prevents us from
fitting logistic regression. To overcome this issue, we will first attempt to find
a subset of relevant predictors. To do so, we’ll simply find the predictors that
are most correlated with the response.
# find correlation between y and each predictor variable
correlations = apply(trn_data[, -1], 2, cor, y = trn_data$y)
184 CHAPTER 20. RESAMPLING
Histogram of correlations
1500
Frequency
1000
500
0
correlations
While many of these correlations are small, many very close to zero, some are
as large as 0.40. Since our training data has 50 observations, we’ll select the 25
predictors with the largest (absolute) correlations.
selected = order(abs(correlations), decreasing = TRUE)[1:25]
correlations[selected]
We subset the training and test sets to contain only the response as well as these
25 predictors.
trn_screen = trn_data[c(1, selected)]
tst_screen = tst_data[c(1, selected)]
Then we finally fit an additive logistic regression using this subset of predictors.
We perform 10-fold cross-validation to obtain an estimate of the classification
error.
add_log_mod = glm(y ~ ., data = trn_screen, family = "binomial")
20.3. TEST DATA 185
## [1] 0.3742339
The 10-fold cross-validation is suggesting a classification error estimate of almost
30%.
add_log_pred = (predict(add_log_mod, newdata = tst_screen, type = "response") > 0.5) * 1
calc_err(predicted = add_log_pred, actual = tst_screen$y)
## [1] 0.48
However, if we obtain an estimate of the error using the set, we see an error
rate of 50%. No better than guessing! But since 𝑌 has no relationship with the
predictors, this is actually what we would expect. This incorrect method we’ll
call screen-then-validate.
Now, we will correctly screen-while-validating. Essentially, instead of simply
cross-validating the logistic regression, we also need to cross validate the screen-
ing process. That is, we won’t simply use the same variables for each fold, we
get the “best” predictors for each fold.
For methods that do not have a built-in ability to perform cross-validation, or
for methods that have limited cross-validation capability, we will need to write
our own code for cross-validation. (Spoiler: This is not completely true, but
let’s pretend it is, so we can see how to perform cross-validation from scratch.)
This essentially amounts to randomly splitting the data, then looping over the
splits. The createFolds() function from the caret() package will make this
much easier.
caret::createFolds(trn_data$y, k = 10)
## $Fold01
## [1] 17 23 27 44 45 76 85 87 93 97
##
## $Fold02
## [1] 6 14 15 26 37 38 55 68 69 71
##
## $Fold03
## [1] 3 4 7 29 39 52 54 57 59 82
##
## $Fold04
## [1] 19 21 40 46 48 56 73 78 91 96
##
## $Fold05
## [1] 25 34 36 58 61 65 66 75 83 89
##
## $Fold06
186 CHAPTER 20. RESAMPLING
## [1] 2 9 10 62 74 79 80 90 92 98
##
## $Fold07
## [1] 8 31 32 41 43 53 60 67 88 95
##
## $Fold08
## [1] 12 18 33 35 42 49 51 64 84 94
##
## $Fold09
## [1] 11 13 16 20 28 47 50 77 99 100
##
## $Fold10
## [1] 1 5 22 24 30 63 70 72 81 86
# use the caret package to obtain 10 "folds"
folds = caret::createFolds(trn_data$y, k = 10)
for (i in seq_along(folds)) {
## [1] 0.4 0.9 0.6 0.4 0.6 0.3 0.7 0.5 0.6 0.6
# properly cross-validated error
# this roughly matches what we expect in the test set
mean(fold_err)
## [1] 0.56
• TODO: note that, even cross-validated correctly, this isn’t a brilliant vari-
able selection procedure. (it completely ignores interactions and correla-
tions among the predictors. however, if it works, it works.) next chapters…
20.4 Bootstrap
ISL discusses the bootstrap, which is another resampling method. However, it
is less relevant to the statistical learning tasks we will encounter. It could be
used to replace cross-validation, but encounters significantly more computation.
It could be more useful if we were to attempt to calculate the bias and variance
of a prediction (estimate) without access to the data generating process. Return
to the bias-variance tradeoff chapter and think about how the bootstrap could
be used to obtain estimates of bias and variance with a single dataset, instead
of repeated simulated datasets.
20.5 Which 𝐾?
• TODO: LOO vs 5 vs 10
• TODO: bias and variance
20.6 Summary
• TODO: using cross validation for: tuning, error estimation
20.8 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2.
Chapter 21
Now that we have seen a number of classification and regression methods, and
introduced cross-validation, we see the general outline of a predictive analysis:
• Test-train split the available data
– Consider a method
∗ Decide on a set of candidate models (specify possible tuning pa-
rameters for method)
∗ Use resampling to find the “best model” by choosing the values
of the tuning parameters
– Use chosen model to make predictions
– Calculate relevant metrics on the test data
At face value it would seem like it should be easy to repeat this process for a
number of different methods, however we have run into a number of difficulties
attempting to do so with R.
• The predict() function seems to have a different behavior for each new
method we see.
• Many methods have different cross-validation functions, or worse yet, no
built-in process for cross-validation.
• Not all methods expect the same data format. Some methods do not use
formula syntax.
• Different methods have different handling of categorical predictors. Some
methods cannot handle factor variables.
Thankfully, the R community has essentially provided a silver bullet for these
issues, the caret package. Returning to the above list, we will see that a number
of these tasks are directly addressed in the caret package.
• Test-train split the available data
– createDataPartition() will take the place of our manual data split-
ting. It will also do some extra work to ensure that the train and
189
190 CHAPTER 21. THE CARET PACKAGE
21.1 Classification
To illustrate caret, first for classification, we will use the Default data from
the ISLR package.
data(Default, package = "ISLR")
library(caret)
Here, we have supplied four arguments to the train() function form the caret
package.
• form = default ~ . specifies the default variable as the response. It
also indicates that all available predictors should be used.
• data = default_trn specifies that training will be down with the
default_trn data
• trControl = trainControl(method = "cv", number = 5) specifies
that we will be using 5-fold cross-validation.
• method = glm specifies that we will fit a generalized linear model.
The method essentially specifies both the model (and more specifically the func-
tion to fit said model in R) and package that will be used. The train() function
is essentially a wrapper around whatever method we chose. In this case, the
function is the base R function glm(), so no additional package is required.
When a method requires a function from a certain package, that package will
need to be installed. See the list of availible models for package information.
The list that we have passed to the trControl argument is created using the
trainControl() function from caret. The trainControl() function is a pow-
erful tool for specifying a number of the training choices required by train(),
in particular the resampling scheme.
trainControl(method = "cv", number = 5)[1:3]
## $method
## [1] "cv"
##
## $number
## [1] 5
##
## $repeats
## [1] NA
Here we see just the first three elements of this list, which are related to how
the resampling will be done. These are the three elements that we will be most
interested in. Here, only the first two are relevant.
• method specifies how resampling will be done. Examples include cv, boot,
LOOCV, repeatedcv, and oob.
192 CHAPTER 21. THE CARET PACKAGE
• number specifies the number of times resampling should be done for meth-
ods that require resample, such as, cv and boot.
• repeats specifies the number of times to repeat resampling for methods
such as repeatedcv
For details on the full capabilities of this function, see the relevant documenta-
tion. The out-of-bag, oob which is a sort of automatic resampling for certain
statistical learning methods, will be introduced later.
We’ve also passed an additional argument of "binomial" to family. This isn’t
actually an argument for train(), but an additional argument for the method
glm. In actuality, we don’t need to specify the family. Since default is a factor
variable, caret automatically detects that we are trying to perform classifica-
tion, and would automatically use family = "binomial". This isn’t the case
if we were simply using glm().
default_glm_mod
train(). Two elements that we will often be interested in are results and
finalModel.
default_glm_mod$results
##
## Call: NULL
##
## Coefficients:
## (Intercept) studentYes balance income
## -1.070e+01 -6.992e-01 5.676e-03 4.383e-07
##
## Degrees of Freedom: 7500 Total (i.e. Null); 7497 Residual
## Null Deviance: 2192
## Residual Deviance: 1186 AIC: 1194
The finalModel is a model object, in this case, the object returned from glm().
This final model, is fit to all of the supplied training data. This model object
is often used when we call certain relevant functions on the object returned by
train(), such as summary()
summary(default_glm_mod)
##
## Call:
## NULL
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.1317 -0.1420 -0.0568 -0.0210 3.7348
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -1.070e+01 5.607e-01 -19.079 < 2e-16 ***
## studentYes -6.992e-01 2.708e-01 -2.582 0.00984 **
## balance 5.676e-03 2.644e-04 21.471 < 2e-16 ***
## income 4.383e-07 9.389e-06 0.047 0.96276
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
194 CHAPTER 21. THE CARET PACKAGE
##
## Null deviance: 2192.2 on 7500 degrees of freedom
## Residual deviance: 1185.8 on 7497 degrees of freedom
## AIC: 1193.8
##
## Number of Fisher Scoring iterations: 8
We see that this summary is what we had seen previously from objects of type
glm.
calc_acc = function(actual, predicted) {
mean(actual == predicted)
}
To obtain test accuracy, we will need to make predictions on the test data. With
the object returned by train(), this is extremely easy.
head(predict(default_glm_mod, newdata = default_tst))
## [1] No No No No No No
## Levels: No Yes
## [1] 0.9735894
## No Yes
## 2 0.9988332 1.166819e-03
## 4 0.9995369 4.630821e-04
## 7 0.9975279 2.472097e-03
## 8 0.9988855 1.114516e-03
## 10 0.9999771 2.290522e-05
## 11 0.9999887 1.134693e-05
Notice that this returns the probabilities for all possible classes, in this case No
and Yes. Again, this will be true for all methods! This is especially useful for
multi-class data!.
21.1. CLASSIFICATION 195
21.1.1 Tuning
Since logistic regression has no tuning parameters, we haven’t really highlighted
the full potential of caret. We’ve essentially used it to obtain cross-validated
results, and for the more well-behaved predict() function. These are excellent
improvements over our previous methods, but the real power of caret is its
ability to provide a framework for tuning model.
To illustrate tuning, we now use knn as our method, which performs 𝑘-nearest
neighbors.
default_knn_mod = train(
default ~ .,
data = default_trn,
method = "knn",
trControl = trainControl(method = "cv", number = 5)
)
First, note that we are using formula syntax here, where previously we needed
to create separate response and predictors matrices. Also, we’re using a factor
variable as a predictor, and caret seems to be taking care of this automatically.
default_knn_mod
## k-Nearest Neighbors
##
## 7501 samples
## 3 predictor
## 2 classes: 'No', 'Yes'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 6001, 6000, 6001, 6001, 6001
## Resampling results across tuning parameters:
##
## k Accuracy Kappa
## 5 0.9677377 0.2125623
## 7 0.9664047 0.1099835
## 9 0.9680044 0.1223319
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 9.
default_knn_mod = train(
default ~ .,
data = default_trn,
method = "knn",
trControl = trainControl(method = "cv", number = 5),
preProcess = c("center", "scale"),
tuneGrid = expand.grid(k = seq(1, 101, by = 2))
)
Here, we’ve specified that we would like to center and scale the data. Essentially
transforming each predictor to have mean 0 and variance 1. The documentation
on the preProcess() function provides examples of additional possible pre-
processing. IN our call to train() we’re essentially specifying how we would
like this function applied to our data.
We’ve also provided a “tuning grid,” in this case, the values of k to try. The
tuneGrid argument expects a data frame, which expand.grid() returns. We
don’t actually need expand.grid() for this example, but it will be a useful
habit to develop when we move to methods with multiple tuning parameters.
head(default_knn_mod$results, 5)
Since how we have a large number of results, display the entire results would
create a lot of clutter. Instead, we can plot the tuning results by calling plot()
on the object returned by train().
plot(default_knn_mod)
21.1. CLASSIFICATION 197
0.970
Accuracy (Cross-Validation)
0.965
0.960
0.955
0 20 40 60 80 100
#Neighbors
By default, caret utilizes the lattice graphics package to create these plots.
Recently, additional support for ggplot2 style graphics has been added for some
plots.
ggplot(default_knn_mod) + theme_bw()
0.970
Accuracy (Cross-Validation)
0.965
0.960
0.955
0 25 50 75 100
#Neighbors
198 CHAPTER 21. THE CARET PACKAGE
Now that we are dealing with a tuning parameter, train() determines the best
value of those considered, by default selecting the best (highest cross-validated)
accuracy, and returning that value as bestTune.
default_knn_mod$bestTune
## k
## 6 11
get_best_result = function(caret_fit) {
best = which(rownames(caret_fit$results) == rownames(caret_fit$bestTune))
best_result = caret_fit$results[best, ]
rownames(best_result) = NULL
best_result
}
Sometimes it will be useful to obtain the results for only that value. The above
function does this automatically.
get_best_result(default_knn_mod)
## No Yes
## 1 1.0000000 0.00000000
## 2 1.0000000 0.00000000
## 3 1.0000000 0.00000000
## 4 1.0000000 0.00000000
## 5 0.9090909 0.09090909
## 6 0.9090909 0.09090909
As an example of a multi-class response consider the following three models
fit to the the iris data. Note that the first model is essentially “multinomial
21.1. CLASSIFICATION 199
logistic regression,” but you might notice it also has a tuning parameter now.
(Spoiler: It’s actually a neural network, so you’ll need the nnet package.)
iris_log_mod = train(
Species ~ .,
data = iris,
method = "multinom",
trControl = trainControl(method = "cv", number = 5),
trace = FALSE
)
iris_knn_mod = train(
Species ~ .,
data = iris,
method = "knn",
trControl = trainControl(method = "cv", number = 5),
preProcess = c("center", "scale"),
tuneGrid = expand.grid(k = seq(1, 21, by = 2))
)
iris_qda_mod = train(
Species ~ .,
data = iris,
method = "qda",
trControl = trainControl(method = "cv", number = 5)
)
We can obtain predicted probabilities with these three models. Notice that
they give the predicted probability for each class, using the same syntax for
each model.
head(predict(iris_log_mod, type = "prob"))
## 6 1 0 0
head(predict(iris_qda_mod, type = "prob"))
21.2 Regression
To illustrate the use of caret for regression, we’ll consider some simulated data.
gen_some_data = function(n_obs = 50) {
x1 = seq(0, 10, length.out = n_obs)
x2 = runif(n = n_obs, min = 0, max = 2)
x3 = sample(c("A", "B", "C"), size = n_obs, replace = TRUE)
x4 = round(runif(n = n_obs, min = 0, max = 5), 1)
x5 = round(runif(n = n_obs, min = 0, max = 5), 0)
y = round(x1 ^ 2 + x2 ^ 2 + 2 * (x3 == "B") + rnorm(n = n_obs), 3)
data.frame(y, x1, x2, x3, x4, x5)
}
Fitting knn works nearly identically to its use for classification. Really, the only
difference here is that we have a numeric response, which caret understands to
be a regression problem.
sim_knn_mod = train(
y ~ .,
data = sim_trn,
method = "knn",
trControl = trainControl(method = "cv", number = 5),
# preProcess = c("center", "scale"),
tuneGrid = expand.grid(k = seq(1, 31, by = 2))
)
sim_knn_mod$modelType
## [1] "Regression"
21.2. REGRESSION 201
Notice that we’ve commented out the line to perform pre-processing. Can you
figure out why?
get_best_result(sim_knn_mod)
A few things to notice in the results. In addition to the usual RMSE, which
is be used to determine the best model, we also have MAE, the mean absolute
error. We are also given standard deviations of both of these metrics.
plot(sim_knn_mod)
6.0
5.5
RMSE (Cross-Validation)
5.0
4.5
4.0
0 5 10 15 20 25 30
#Neighbors
The following plot adds error bars to RMSE estimate for each k.
202 CHAPTER 21. THE CARET PACKAGE
6.0
RMSE (Cross-Validation)
5.0
4.0
3.0
0 5 10 15 20 25 30
Sometimes, instead of simply picking the model with the best RMSE (or accu-
racy), we pick the simplest model within one standard error of the model with
the best RMSE. Here then, we would consider k = 11 instead of k = 7 since
there isn’t a statistically significant difference. This is potentially a very good
idea in practice. By picking a a simpler model, we are essentially at less risk
of overfitting, especially since in practice, future data may be slightly different
than the data that we are training on. If you’re trying to win a Kaggle competi-
tion, this might not be as useful, since often the test and train data come from
the exact same source.
• TODO: additional details about 1-SE rule.
calc_rmse = function(actual, predicted) {
sqrt(mean((actual - predicted) ^ 2))
}
Since we simulated this data, we have a rather large test dataset. This allows
us to compare our cross-validation error estimate, to an estimate using (an
impractically large) test set.
get_best_result(sim_knn_mod)$RMSE
## [1] 3.6834
calc_rmse(actual = sim_tst$y,
predicted = predict(sim_knn_mod, sim_tst))
## [1] 3.412332
21.2. REGRESSION 203
Here we see that the cross-validated RMSE is a bit of an overestimate, but still
rather close to the test error. The real question is, are either of these any good?
Is this model predicting well? No! Notice that we simulated this data with an
error standard deviation of 1!
21.2.1 Methods
Now that caret has given us a pipeline for a predictive analysis, we can very
quickly and easily test new methods. For example, in an upcoming chapter we
will discuss boosted tree models, but now that we understand how to use caret,
in order to use a boosted tree model, we simply need to know the “method”
to do so, which in this case is gbm. Beyond knowing that the method exists,
we just need to know its tuning parameters, in this case, there are four. We
actually could get away with knowing nothing about them, and simply specify
a tuneLength, then caret would automatically try some reasonable values.
Instead, we could read the caret documentation to find the tuning parameters,
as well as the required packages. For now, we’ll simply use the following tuning
grid. In later chapters, we’ll discuss how these effect the model.
gbm_grid = expand.grid(interaction.depth = c(1, 2, 3),
n.trees = (1:30) * 100,
shrinkage = c(0.1, 0.3),
n.minobsinnode = 20)
head(gbm_grid)
We added verbose = FALSE to the train() call to suppress some of the inter-
mediate output of the gbm fitting procedure.
How this training is happening is a bit of a mystery to us right now. What is
204 CHAPTER 21. THE CARET PACKAGE
this method? How does it deal with the factor variable as a predictor? We’ll
answer these questions later, for now, we do know how to evaluate how well the
method is working.
knitr::kable(head(sim_gbm_mod$results), digits = 3)
2.6
2.4
2.2
2.0
# Boosting Iterations
sim_gbm_mod$bestTune
## [1] 1.568517
Again, the cross-validated result is overestimating the error a bit. Also, this
model is a big improvement over the knn model, but we can still do better.
sim_lm_mod = train(
y ~ poly(x1, 2) + poly(x2, 2) + x3,
data = sim_trn,
method = "lm",
trControl = trainControl(method = "cv", number = 5)
)
sim_lm_mod$finalModel
##
## Call:
## lm(formula = .outcome ~ ., data = dat)
##
## Coefficients:
## (Intercept) `poly(x1, 2)1` `poly(x1, 2)2` `poly(x2, 2)1` `poly(x2, 2)2`
## 34.75615 645.50804 167.12875 26.00951 6.86587
## x3B x3C
## 1.80700 0.07108
Here we fit a good old linear model, except, we specify a very specific formula.
sim_lm_mod$results$RMSE
## [1] 1.046702
calc_rmse(actual = sim_tst$y,
predicted = predict(sim_lm_mod, sim_tst))
## [1] 1.035896
This model dominates the previous two. The gbm model does still have a big
advantage. The lm model needed the correct form of the model, whereas gbm
nearly learned it automatically!
This question of which variables should be included is where we will turn our
focus next. We’ll consider both what variables are useful for prediction, and
learn tools to asses how useful they are.
206 CHAPTER 21. THE CARET PACKAGE
21.4 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2. The following packages (and their dependencies) were
loaded when knitting this file:
## [1] "caret" "ggplot2" "lattice"
Chapter 22
Subset Selection
sum(is.na(Hitters))
## [1] 59
sum(is.na(Hitters$Salary))
## [1] 59
Hitters = na.omit(Hitters)
sum(is.na(Hitters))
## [1] 0
207
208 CHAPTER 22. SUBSET SELECTION
Adjusted RSq
0.50
RSS
2.4e+07
0.35
5 10 15 5 10 15
-110
BIC
Cp
20
-150
5 10 15 5 10 15
coef(fit_fwd, 7)
## [1] 10
coef(fit_bwd, which.min(fit_bwd_sum$cp))
420
400
380
5 10 15
Number of Predictors
212 CHAPTER 22. SUBSET SELECTION
which.min(test_err)
## [1] 5
coef(fit_all, which.min(test_err))
## [1] "regsubsets"
predict.regsubsets = function(object, newdata, id, ...) {
form = as.formula(object$call[[2]])
mat = model.matrix(form, newdata)
coefs = coef(object, id = id)
xvars = names(coefs)
for(j in 1:num_folds) {
train_fold = Hitters[-folds[[j]], ]
validate_fold = Hitters[ folds[[j]], ]
for (i in 1:num_vars) {
## 1 2 3 4 5 6 7 8
## 381.6473 362.3809 355.9959 354.6139 352.5358 345.6078 352.2963 332.4575
## 9 10 11 12 13 14 15 16
## 342.1292 339.7967 338.0266 338.2973 336.7897 337.6876 340.1955 339.9188
## 17 18 19
## 339.6058 339.6544 339.4893
plot(cv_error, type='b', ylab = "Corss-Validated RMSE", xlab = "Number of Predictors")
380
Corss-Validated RMSE
370
360
350
340
5 10 15
Number of Predictors
22.4 RMarkdown
The RMarkdown file for this chapter can be found here. The file was created
using R version 4.0.2 and the following packages:
• Base Packages, Attached
## [1] "stats" "graphics" "grDevices" "utils" "datasets" "methods"
## [7] "base"
• Additional Packages, Attached
## [1] "leaps"
• Additional Packages, Not Attached
## [1] "tidyselect" "xfun" "purrr" "reshape2" "splines"
## [6] "lattice" "colorspace" "vctrs" "generics" "htmltools"
## [11] "stats4" "yaml" "survival" "prodlim" "rlang"
## [16] "ModelMetrics" "pillar" "glue" "withr" "foreach"
## [21] "lifecycle" "plyr" "lava" "stringr" "timeDate"
## [26] "munsell" "gtable" "recipes" "codetools" "evaluate"
## [31] "knitr" "caret" "class" "Rcpp" "scales"
## [36] "ipred" "ggplot2" "digest" "stringi" "bookdown"
## [41] "dplyr" "grid" "tools" "magrittr" "tibble"
## [46] "crayon" "pkgconfig" "ellipsis" "MASS" "Matrix"
## [51] "data.table" "pROC" "lubridate" "gower" "rmarkdown"
## [56] "iterators" "R6" "rpart" "nnet" "nlme"
## [61] "compiler"
Part VI
215
Chapter 23
Overview
217
218 CHAPTER 23. OVERVIEW
Chapter 24
Regularization
This dataset has some missing data in the response Salaray. We use the
na.omit() function the clean the dataset.
sum(is.na(Hitters))
## [1] 59
sum(is.na(Hitters$Salary))
## [1] 59
Hitters = na.omit(Hitters)
sum(is.na(Hitters))
## [1] 0
The predictors variables are offensive and defensive statistics for a number of
baseball players.
names(Hitters)
219
220 CHAPTER 24. REGULARIZATION
We use the glmnet() and cv.glmnet() functions from the glmnet package to
fit penalized regressions.
library(glmnet)
Unfortunately, the glmnet function does not allow the use of model formulas, so
we setup the data for ease of use with glmnet. Eventually we will use train()
from caret which does allow for fitting penalized regression with the formula
syntax, but to explore some of the details, we first work with the functions from
glmnet directly.
X = model.matrix(Salary ~ ., Hitters)[, -1]
y = Hitters$Salary
First, we fit an ordinary linear regression, and note the size of the predictors’
coefficients, and predictors’ coefficients squared. (The two penalties we will use.)
fit = lm(Salary ~ ., Hitters)
coef(fit)
## [1] 238.7295
sum(coef(fit)[-1] ^ 2)
## [1] 18337.3
𝑝 2 𝑝
𝑛
∑ (𝑦𝑖 − 𝛽0 − ∑ 𝛽𝑗 𝑥𝑖𝑗 ) + 𝜆 ∑ 𝛽𝑗2 .
𝑖=1 𝑗=1 𝑗=1
24.1. RIDGE REGRESSION 221
Notice that the intercept is not penalized. Also, note that that ridge regres-
sion is not scale invariant like the usual unpenalized regression. Thankfully,
glmnet() takes care of this internally. It automatically standardizes predictors
for fitting, then reports fitted coefficient using the original scale.
The two plots illustrate how much the coefficients are penalized for different
values of 𝜆. Notice none of the coefficients are forced to be zero.
par(mfrow = c(1, 2))
fit_ridge = glmnet(X, y, alpha = 0)
plot(fit_ridge)
plot(fit_ridge, xvar = "lambda", label = TRUE)
19 19 19 19 19 19 19 19 19 19
50
50
0
0
Coefficients
Coefficients
-50
-50
-100
-100
19 19 19 19 19 19 19 19 19 19 19 19 19 19 19
220000
Mean-Squared Error
180000
140000
100000
4 6 8 10 12
Log (λ )
The cv.glmnet() function returns several details of the fit for both 𝜆 values in
the plot. Notice the penalty terms are smaller than the full linear regression.
(As we would expect.)
# fitted coefficients, using 1-SE rule lambda, default behavior
coef(fit_ridge_cv)
## Assists 0.005930000
## Errors -0.087618226
## NewLeagueN 1.836629069
# fitted coefficients, using minimum lambda
coef(fit_ridge_cv, s = "lambda.min")
## [1] 18367.29
# fitted coefficients, using 1-SE rule lambda
coef(fit_ridge_cv, s = "lambda.1se")
## CAtBat 0.006369369
## CHits 0.024201921
## CHmRun 0.180499284
## CRuns 0.048544437
## CRBI 0.050169414
## CWalks 0.049897906
## LeagueN 1.802540410
## DivisionW -16.185025086
## PutOuts 0.040146198
## Assists 0.005930000
## Errors -0.087618226
## NewLeagueN 1.836629069
# penalty term using 1-SE rule lambda
sum(coef(fit_ridge_cv, s = "lambda.1se")[-1] ^ 2)
## [1] 275.24
# predict using minimum lambda
predict(fit_ridge_cv, X, s = "lambda.min")
## [1] 141009.7
# CV-RMSEs
sqrt(fit_ridge_cv$cvm)
## [1] 343.2022
# CV-RMSE using 1-SE rule lambda
sqrt(fit_ridge_cv$cvm[fit_ridge_cv$lambda == fit_ridge_cv$lambda.1se])
## [1] 379.7137
24.2 Lasso
We now illustrate lasso, which can be fit using glmnet() with alpha = 1 and
seeks to minimize
𝑝 2 𝑝
𝑛
∑ (𝑦𝑖 − 𝛽0 − ∑ 𝛽𝑗 𝑥𝑖𝑗 ) + 𝜆 ∑ |𝛽𝑗 |.
𝑖=1 𝑗=1 𝑗=1
0 6 6 11 17 19 17 12 6
50
50
Coefficients
Coefficients
0
0
-50
-50
-100
-100
19 18 17 17 17 13 13 11 9 6 6 6 6 5 4 3 0
220000
Mean-Squared Error
180000
140000
100000
-2 0 2 4
Log (λ )
cv.glmnet() returns several details of the fit for both 𝜆 values in the plot.
Notice the penalty terms are again smaller than the full linear regression. (As
we would expect.) Some coefficients are 0.
# fitted coefficients, using 1-SE rule lambda, default behavior
coef(fit_lasso_cv)
## Assists .
## Errors .
## NewLeagueN .
# fitted coefficients, using minimum lambda
coef(fit_lasso_cv, s = "lambda.min")
## [1] 15509.95
# fitted coefficients, using 1-SE rule lambda
coef(fit_lasso_cv, s = "lambda.1se")
## CAtBat .
## CHits .
## CHmRun .
## CRuns 0.16027975
## CRBI 0.33667715
## CWalks .
## LeagueN .
## DivisionW -8.06171247
## PutOuts 0.08393604
## Assists .
## Errors .
## NewLeagueN .
# penalty term using 1-SE rule lambda
sum(coef(fit_lasso_cv, s = "lambda.1se")[-1] ^ 2)
## [1] 69.66661
# predict using minimum lambda
predict(fit_lasso_cv, X, s = "lambda.min")
## [1] 118581.5
# CV-RMSEs
sqrt(fit_lasso_cv$cvm)
## [1] 335.1692
# CV-RMSE using 1-SE rule lambda
sqrt(fit_lasso_cv$cvm[fit_lasso_cv$lambda == fit_lasso_cv$lambda.1se])
24.3. BROOM 229
## [1] 359.3377
24.3 broom
Sometimes, the output from glmnet() can be overwhelming. The broom pack-
age can help with that.
library(broom)
# the output from the commented line would be immense
# fit_lasso_cv
tidy(fit_lasso_cv)
## # A tibble: 80 x 6
## lambda estimate std.error conf.low conf.high nzero
## <dbl> <dbl> <dbl> <dbl> <dbl> <int>
## 1 255. 202300. 23338. 178962. 225637. 0
## 2 233. 193586. 23779. 169807. 217365. 1
## 3 212. 184510. 22822. 161687. 207332. 2
## 4 193. 177003. 22023. 154980. 199026. 2
## 5 176. 170233. 21430. 148803. 191663. 3
## 6 160. 163228. 21037. 142191. 184265. 4
## 7 146. 156644. 20382. 136262. 177027. 4
## 8 133. 150772. 19854. 130918. 170626. 4
## 9 121. 145915. 19434. 126481. 165349. 4
## 10 111. 141852. 19031. 122821. 160883. 4
## # ... with 70 more rows
# the two lambda values of interest
glance(fit_lasso_cv)
## # A tibble: 1 x 3
## lambda.min lambda.1se nobs
## <dbl> <dbl> <int>
## 1 2.22 69.4 263
z = X %*% beta
prob = exp(z) / (1 + exp(z))
y = as.factor(rbinom(length(z), size = 1, prob = prob))
We then use a lasso penalty to fit penalized logistic regression. This minimizes
𝑛 𝑝 𝑝
∑ 𝐿 (𝑦𝑖 , 𝛽0 + ∑ 𝛽𝑗 𝑥𝑖𝑗 ) + 𝜆 ∑ |𝛽𝑗 |
𝑖=1 𝑗=1 𝑗=1
1.4
1.2
1.0
-6 -5 -4 -3 -2
Log (λ )
head(coef(fit_cv), n = 10)
## V2 0.56251761
## V3 0.60065105
## V4 .
## V5 .
## V6 .
## V7 .
## V8 .
## V9 .
fit_cv$nzero
## s0 s1 s2 s3 s4 s5 s6 s7 s8 s9 s10 s11 s12 s13 s14 s15 s16 s17 s18 s19
## 0 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
## s20 s21 s22 s23 s24 s25 s26 s27 s28 s29 s30 s31 s32 s33 s34 s35 s36 s37 s38 s39
## 3 3 3 3 3 3 3 3 3 3 4 6 7 10 18 24 35 54 65 75
## s40 s41 s42 s43 s44 s45 s46 s47 s48 s49 s50 s51 s52 s53 s54 s55 s56 s57 s58 s59
## 86 100 110 129 147 168 187 202 221 241 254 269 283 298 310 324 333 350 364 375
## s60 s61 s62 s63 s64 s65 s66 s67 s68 s69 s70 s71 s72 s73 s74 s75 s76 s77 s78 s79
## 387 400 411 429 435 445 453 455 462 466 475 481 487 491 496 498 502 504 512 518
## s80 s81 s82 s83 s84 s85 s86 s87 s88 s89 s90 s91 s92 s93 s94 s95 s96 s97 s98 s99
## 523 526 528 536 543 550 559 561 563 566 570 571 576 582 586 590 596 596 600 599
Notice, only the first three predictors generated are truly significant, and that
is exactly what the suggested model finds.
fit_1se = glmnet(X, y, family = "binomial", lambda = fit_cv$lambda.1se)
which(as.vector(as.matrix(fit_1se$beta)) != 0)
## [1] 1 2 3
We can also see in the following plots, the three features entering the model well
ahead of the irrelevant features.
par(mfrow = c(1, 2))
plot(glmnet(X, y, family = "binomial"))
plot(glmnet(X, y, family = "binomial"), xvar = "lambda")
232 CHAPTER 24. REGULARIZATION
2.0
2.0
1.5
1.5
Coefficients
Coefficients
1.0
1.0
0.5
0.5
0.0
0.0
0 10 20 30 40 -6 -5 -4 -3 -2
## [1] 0.03718493
fit_cv$lambda.1se
## [1] 0.0514969
Since cv.glmnet() does not calculate prediction accuracy for classification, we
take the 𝜆 values and create a grid for caret to search in order to obtain
prediction accuracy with train(). We set 𝛼 = 1 in this grid, as glmnet can
actually tune over the 𝛼 = 1 parameter. (More on that later.)
Note that we have to force y to be a factor, so that train() recognizes we want
to have a binomial response. The train() function in caret use the type of
variable in y to determine if you want to use family = "binomial" or family
= "gaussian".
library(caret)
cv_5 = trainControl(method = "cv", number = 5)
lasso_grid = expand.grid(alpha = 1,
lambda = c(fit_cv$lambda.min, fit_cv$lambda.1se))
lasso_grid
## alpha lambda
## 1 1 0.03718493
## 2 1 0.05149690
sim_data = data.frame(y, X)
fit_lasso = train(
y ~ ., data = sim_data,
method = "glmnet",
24.5. EXTERNAL LINKS 233
trControl = cv_5,
tuneGrid = lasso_grid
)
fit_lasso$results
24.6 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2. The following packages (and their dependencies) were
loaded when knitting this file:
## [1] "caret" "ggplot2" "lattice" "broom" "glmnet" "Matrix"
234 CHAPTER 24. REGULARIZATION
Chapter 25
Elastic Net
We again use the Hitters dataset from the ISLR package to explore another
shrinkage method, elastic net, which combines the ridge and lasso methods
from the previous chapter.
data(Hitters, package = "ISLR")
Hitters = na.omit(Hitters)
We again remove the missing data, which was all in the response variable,
Salary.
tibble::as_tibble(Hitters)
## # A tibble: 263 x 20
## AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns CRBI
## <int> <int> <int> <int> <int> <int> <int> <int> <int> <int> <int> <int>
## 1 315 81 7 24 38 39 14 3449 835 69 321 414
## 2 479 130 18 66 72 76 3 1624 457 63 224 266
## 3 496 141 20 65 78 37 11 5628 1575 225 828 838
## 4 321 87 10 39 42 30 2 396 101 12 48 46
## 5 594 169 4 74 51 35 11 4408 1133 19 501 336
## 6 185 37 1 23 8 21 2 214 42 1 30 9
## 7 298 73 0 24 24 7 3 509 108 0 41 37
## 8 323 81 6 26 32 8 2 341 86 6 32 34
## 9 401 92 17 49 66 65 13 5206 1332 253 784 890
## 10 574 159 21 107 75 59 10 4631 1300 90 702 504
## # ... with 253 more rows, and 8 more variables: CWalks <int>, League <fct>,
## # Division <fct>, PutOuts <int>, Assists <int>, Errors <int>, Salary <dbl>,
## # NewLeague <fct>
dim(Hitters)
## [1] 263 20
235
236 CHAPTER 25. ELASTIC NET
Because this dataset isn’t particularly large, we will forego a test-train split, and
simply use all of the data as training data.
library(caret)
library(glmnet)
Since he have loaded caret, we also have access to the lattice package which
has a nice histogram function.
histogram(Hitters$Salary, xlab = "Salary, $1000s",
main = "Baseball Salaries, 1986 - 1987")
30
Percent of Total
20
10
Salary, $1000s
25.1 Regression
Like ridge and lasso, we again attempt to minimize the residual sum of squares
plus some penalty term.
𝑝 2
𝑛
∑ (𝑦𝑖 − 𝛽0 − ∑ 𝛽𝑗 𝑥𝑖𝑗 ) + 𝜆 [(1 − 𝛼)||𝛽||22 /2 + 𝛼||𝛽||1 ]
𝑖=1 𝑗=1
𝑝
||𝛽||1 = ∑ |𝛽𝑗 |
𝑗=1
25.1. REGRESSION 237
√ 𝑝
√
||𝛽||2 = √∑ 𝛽𝑗2
⎷ 𝑗=1
These both quantify how “large” the coefficients are. Like lasso and ridge, the
intercept is not penalized and glment takes care of standardization internally.
Also reported coefficients are on the original scale.
The new penalty is 𝜆⋅(1−𝛼)
2 times the ridge penalty plus 𝜆 ⋅ 𝛼 times the lasso
lasso penalty. (Dividing the ridge penalty by 2 is a mathematical convenience
for optimization.) Essentially, with the correct choice of 𝜆 and 𝛼 these two
“penalty coefficients” can be any positive numbers.
Often it is more useful to simply think of 𝛼 as controlling the mixing between
the two penalties and 𝜆 controlling the amount of penalization. 𝛼 takes values
between 0 and 1. Using 𝛼 = 1 gives the lasso that we have seen before. Similarly,
𝛼 = 0 gives ridge. We used these two before with glmnet() to specify which to
method we wanted. Now we also allow for 𝛼 values in between.
set.seed(42)
cv_5 = trainControl(method = "cv", number = 5)
We first setup our cross-validation strategy, which will be 5 fold. We then use
train() with method = "glmnet" which is actually fitting the elastic net.
hit_elnet = train(
Salary ~ ., data = Hitters,
method = "glmnet",
trControl = cv_5
)
First, note that since we are using caret() directly, it is taking care of dummy
variable creation. So unlike before when we used glmnet(), we do not need to
manually create a model matrix.
Also note that we have allowed caret to choose the tuning parameters for us.
hit_elnet
## glmnet
##
## 263 samples
## 19 predictor
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 211, 210, 210, 211, 210
## Resampling results across tuning parameters:
238 CHAPTER 25. ELASTIC NET
##
## alpha lambda RMSE Rsquared MAE
## 0.10 0.5106 335.1 0.4549 235.2
## 0.10 5.1056 332.4 0.4632 231.9
## 0.10 51.0564 339.4 0.4486 231.1
## 0.55 0.5106 334.9 0.4551 234.5
## 0.55 5.1056 332.7 0.4650 230.4
## 0.55 51.0564 343.5 0.4440 235.9
## 1.00 0.5106 334.9 0.4546 234.1
## 1.00 5.1056 336.0 0.4590 230.6
## 1.00 51.0564 353.3 0.4231 244.1
##
## RMSE was used to select the optimal model using the smallest value.
## The final values used for the model were alpha = 0.1 and lambda = 5.106.
Notice a few things with these results. First, we have tried three 𝛼 values, 0.10,
0.55, and 1. It is not entirely clear why caret doesn’t use 0. It likely uses 0.10
to fit a model close to ridge, but with some potential for sparsity.
Here, the best result uses 𝛼 = 0.10, so this result is somewhere between ridge
and lasso, but closer to ridge.
hit_elnet_int = train(
Salary ~ . ^ 2, data = Hitters,
method = "glmnet",
trControl = cv_5,
tuneLength = 10
)
Now we try a much larger model search. First, we’re expanding the feature
space to include all interactions. Since we are using penalized regression, we
don’t have to worry as much about overfitting. If many of the added variables
are not useful, we will likely use a model close to lasso which makes many of
them 0.
We’re also using a larger tuning grid. By setting tuneLength = 10, we will
search 10 𝛼 values and 10 𝜆 values for each. Because of this larger tuning grid,
the results will be very large.
To deal with this, we write a quick helper function to extract the row with the
best tuning parameters.
get_best_result = function(caret_fit) {
best = which(rownames(caret_fit$results) == rownames(caret_fit$bestTune))
best_result = caret_fit$results[best, ]
rownames(best_result) = NULL
best_result
}
25.2. CLASSIFICATION 239
## [1] 306.9
The commented line is not run, since it produces a lot of output, but if run, it
will show that the fast majority of the coefficients are zero! Also, you’ll notice
that cv.glmnet() does not respect the usual predictor hierarchy. Not a problem
for prediction, but a massive interpretation issue!
#coef(fit_lasso_cv)
sum(coef(fit_lasso_cv) != 0)
## [1] 8
sum(coef(fit_lasso_cv) == 0)
## [1] 183
25.2 Classification
Above, we have performed a regression task. But like lasso and ridge, elastic net
can also be used for classification by using the deviance instead of the residual
sum of squares. This essentially happens automatically in caret if the response
variable is a factor.
We’ll test this using the familiar Default dataset, which we first test-train split.
240 CHAPTER 25. ELASTIC NET
set.seed(42)
default_idx = createDataPartition(Default$default, p = 0.75, list = FALSE)
default_trn = Default[default_idx, ]
default_tst = Default[-default_idx, ]
## glmnet
##
## 7501 samples
## 3 predictor
## 2 classes: 'No', 'Yes'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 6001, 6001, 6001, 6000, 6001
## Resampling results across tuning parameters:
##
## alpha lambda Accuracy Kappa
## 0.10 0.0001253 0.9732 0.41637
## 0.10 0.0012527 0.9727 0.37280
## 0.10 0.0125270 0.9676 0.07238
## 0.55 0.0001253 0.9735 0.42510
## 0.55 0.0012527 0.9727 0.38012
## 0.55 0.0125270 0.9679 0.09251
## 1.00 0.0001253 0.9736 0.42638
## 1.00 0.0012527 0.9725 0.38888
## 1.00 0.0125270 0.9692 0.16987
##
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 1 and lambda = 0.0001253.
Since the best model used 𝛼 = 1, this is a lasso model.
We also try an expanded feature space, and a larger tuning grid.
def_elnet_int = train(
default ~ . ^ 2, data = default_trn,
method = "glmnet",
25.3. EXTERNAL LINKS 241
trControl = cv_5,
tuneLength = 10
)
Since the result here will return 100 models, we again use are helper function
to simply extract the best result.
get_best_result(def_elnet_int)
Evaluating the test accuracy of this model, we obtain one of the highest accu-
racies for this dataset of all methods we have tried.
# test acc
calc_acc(actual = default_tst$default,
predicted = predict(def_elnet_int, newdata = default_tst))
## [1] 0.9736
25.4 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2. The following packages (and their dependencies) were
loaded when knitting this file:
## [1] "glmnet" "Matrix" "caret" "ggplot2" "lattice"
242 CHAPTER 25. ELASTIC NET
Chapter 26
Trees
Chapter Status: This chapter was originally written using the tree packages.
Currently being re-written to exclusively use the rpart package which seems
more widely suggested and provides better plotting features.
library(tree)
In this document, we will use the package tree for both classification and
regression trees. Note that there are many packages to do this in R. rpart
may be the most common, however, we will use tree for simplicity.
To understand classification trees, we will use the Carseat dataset from the ISLR
package. We will first modify the response variable Sales from its original use
as a numerical variable, to a categorical variable with High for high sales, and
Low for low sales.
data(Carseats)
#?Carseats
str(Carseats)
243
244 CHAPTER 26. TREES
We first fit an unpruned classification tree using all of the predictors. Details of
this process can be found using ?tree and ?tree.control
seat_tree = tree(Sales ~ ., data = Carseats)
# seat_tree = tree(Sales ~ ., data = Carseats,
# control = tree.control(nobs = nrow(Carseats), minsize = 10))
summary(seat_tree)
##
## Classification tree:
## tree(formula = Sales ~ ., data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "US" "Income" "CompPrice"
## [6] "Population" "Advertising" "Age"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
We see this tree has 27 terminal nodes and a misclassification rate of 0.09.
plot(seat_tree)
text(seat_tree, pretty = 0)
title(main = "Unpruned Classification Tree")
26.1. CLASSIFICATION TREES 245
ShelveLoc: Good
|
Above we plot the tree. Below we output the details of the splits.
seat_tree
We now test-train split the data so we can evaluate how well our tree is working.
We use 200 observations for each.
dim(Carseats)
## [1] 400 11
set.seed(2)
seat_idx = sample(1:nrow(Carseats), 200)
seat_trn = Carseats[seat_idx,]
seat_tst = Carseats[-seat_idx,]
summary(seat_tree)
##
## Classification tree:
26.1. CLASSIFICATION TREES 247
Note that, the tree is not using all of the available variables.
summary(seat_tree)$used
Also notice that, this new tree is slightly different than the tree fit to all of the
data.
plot(seat_tree)
text(seat_tree, pretty = 0)
title(main = "Unpruned Classification Tree")
ShelveLoc: Good
High
US: No
Education < 16.5 Low
High
Low High
Price < 101.5 Price < 119
High Low
When using the predict() function on a tree, the default type is vector which
gives predicted probabilities for both classes. We will use type = class to
directly obtain classes. We first fit the tree using the training data (above),
then obtain predictions on both the train and test set, then view the confusion
matrix for both.
248 CHAPTER 26. TREES
# train confusion
table(predicted = seat_trn_pred, actual = seat_trn$Sales)
## actual
## predicted High Low
## High 67 8
## Low 14 111
# test confusion
table(predicted = seat_tst_pred, actual = seat_tst$Sales)
## actual
## predicted High Low
## High 51 12
## Low 32 105
accuracy = function(actual, predicted) {
mean(actual == predicted)
}
# train acc
accuracy(predicted = seat_trn_pred, actual = seat_trn$Sales)
## [1] 0.89
# test acc
accuracy(predicted = seat_tst_pred, actual = seat_tst$Sales)
## [1] 0.78
Here it is easy to see that the tree has been over-fit. The train set performs
much better than the test set.
## [1] 1
26.1. CLASSIFICATION TREES 249
## [1] 21
# misclassification rate of each tree
seat_tree_cv$dev / length(seat_idx)
## [1] 0.375 0.380 0.405 0.405 0.375 0.385 0.390 0.425 0.405
par(mfrow = c(1, 2))
# default plot
plot(seat_tree_cv)
# better plot
plot(seat_tree_cv$size, seat_tree_cv$dev / nrow(seat_trn), type = "b",
xlab = "Tree Size", ylab = "CV Misclassification Rate")
CV Misclassification Rate
0.41
82
misclass
0.40
80
0.39
78
0.38
76
5 10 15 20 5 10 15 20
It appears that a tree of size 9 has the fewest misclassifications of the considered
trees, via cross-validation.
We use prune.misclass() to obtain that tree from our original tree, and plot
this smaller tree.
seat_tree_prune = prune.misclass(seat_tree, best = 9)
summary(seat_tree_prune)
##
## Classification tree:
250 CHAPTER 26. TREES
ShelveLoc: Good
High
Low High
Age < 49.5
Low
We again obtain predictions using this smaller tree, and evaluate on the test
and train sets.
# train
seat_prune_trn_pred = predict(seat_tree_prune, seat_trn, type = "class")
table(predicted = seat_prune_trn_pred, actual = seat_trn$Sales)
## actual
## predicted High Low
## High 62 16
## Low 19 103
accuracy(predicted = seat_prune_trn_pred, actual = seat_trn$Sales)
## [1] 0.825
26.2. REGRESSION TREES 251
# test
seat_prune_tst_pred = predict(seat_tree_prune, seat_tst, type = "class")
table(predicted = seat_prune_tst_pred, actual = seat_tst$Sales)
## actual
## predicted High Low
## High 58 20
## Low 25 97
accuracy(predicted = seat_prune_tst_pred, actual = seat_tst$Sales)
## [1] 0.775
The train set has performed almost as well as before, and there was a small
improvement in the test set, but it is still obvious that we have over-fit. Trees
tend to do this. We will look at several ways to fix this, including: bagging,
boosting and random forests.
##
## Regression tree:
## tree(formula = medv ~ ., data = boston_trn)
## Variables actually used in tree construction:
## [1] "lstat" "rm" "dis" "tax" "crim"
## Number of terminal nodes: 8
## Residual mean deviance: 12.2 = 2988 / 245
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -10.25000 -2.35500 -0.06778 0.00000 1.87700 15.31000
plot(boston_tree)
text(boston_tree, pretty = 0)
title(main = "Unpruned Regression Tree")
252 CHAPTER 26. TREES
9
8
CV-RMSE
7
6
5
1 2 3 4 5 6 7 8
Tree Size
While the tree of size 9 does have the lowest RMSE, we’ll prune to a size of 7
as it seems to perform just as well. (Otherwise we would not be pruning.) The
pruned tree is, as expected, smaller and easier to interpret.
boston_tree_prune = prune.tree(boston_tree, best = 7)
summary(boston_tree_prune)
##
## Regression tree:
## snip.tree(tree = boston_tree, nodes = 4L)
## Variables actually used in tree construction:
## [1] "lstat" "rm" "tax" "crim"
## Number of terminal nodes: 7
## Residual mean deviance: 13.35 = 3284 / 246
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -10.2500 -2.3680 -0.2229 0.0000 1.8770 17.1000
plot(boston_tree_prune)
text(boston_tree_prune, pretty = 0)
title(main = "Pruned Regression Tree")
254 CHAPTER 26. TREES
32.90 46.47
Let’s compare this regression tree to an additive linear model and use RMSE
as our metric.
rmse = function(actual, predicted) {
sqrt(mean((actual - predicted) ^ 2))
}
We obtain predictions on the train and test sets from the pruned tree. We also
plot actual vs predicted. This plot may look odd. We’ll compare it to a plot for
linear regression below.
# training RMSE two ways
sqrt(summary(boston_tree_prune)$dev / nrow(boston_trn))
## [1] 3.603014
boston_prune_trn_pred = predict(boston_tree_prune, newdata = boston_trn)
rmse(boston_prune_trn_pred, boston_trn$medv)
## [1] 3.603014
# test RMSE
boston_prune_tst_pred = predict(boston_tree_prune, newdata = boston_tst)
rmse(boston_prune_tst_pred, boston_tst$medv)
## [1] 5.477353
plot(boston_prune_tst_pred, boston_tst$medv, xlab = "Predicted", ylab = "Actual")
abline(0, 1)
26.2. REGRESSION TREES 255
50
40
Actual
30
20
10
15 20 25 30 35 40 45
Predicted
Here, using an additive linear regression the actual vs predicted looks much
more like what we are used to.
bostom_lm = lm(medv ~ ., data = boston_trn)
boston_lm_pred = predict(bostom_lm, newdata = boston_tst)
plot(boston_lm_pred, boston_tst$medv, xlab = "Predicted", ylab = "Actual")
abline(0, 1)
256 CHAPTER 26. TREES
50
40
Actual
30
20
10
0 10 20 30 40
Predicted
rmse(boston_lm_pred, boston_tst$medv)
## [1] 5.016083
We also see a lower test RMSE. The most obvious linear regression beats the
tree! Again, we’ll improve on this tree soon. Also note the summary of the
additive linear regression below. Which is easier to interpret, that output, or
the small tree above?
coef(bostom_lm)
library(rpart)
set.seed(430)
# Fit a decision tree using rpart
# Note: when you fit a tree using rpart, the fitting routine automatically
# performs 10-fold CV and stores the errors for later use
# (such as for pruning the tree)
size of tree
1 2 3 6 7 9 10
1.2
1.1
X-val Relative Error
1.0
0.9
0.8
0.7
cp
## [1] 0.03703704
# prunce tree using best cp
seat_rpart_prune = prune(seat_rpart, cp = min_cp)
# nicer plots
library(rpart.plot)
258 CHAPTER 26. TREES
prp(seat_rpart_prune)
High Low
prp(seat_rpart_prune, type = 4)
Low
Price < 97
>= 97
High Low
ShelveLo = God
Bad,Mdm
High Low
Price < 125
>= 125
Low Low
Age < 50
>= 50
High Low
CompPric >= 116
< 116
High Low
26.4. EXTERNAL LINKS 259
rpart.plot(seat_rpart_prune)
Low
0.59
100%
yes Price < 97 no
Low
0.68
80%
ShelveLoc = Good
Low
0.74
68%
Price < 125
Low
0.63
41%
Age < 50
High
0.38
17%
CompPrice >= 116
26.5 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2. The following packages (and their dependencies) were
loaded when knitting this file:
## [1] "rpart.plot" "rpart" "MASS" "ISLR" "tree"
260 CHAPTER 26. TREES
Chapter 27
Ensemble Methods
27.1 Regression
We first consider the regression case, using the Boston data from the MASS
package. We will use RMSE as our metric, so we write a function which will
help us along the way.
calc_rmse = function(actual, predicted) {
sqrt(mean((actual - predicted) ^ 2))
}
We first test-train split the data and fit a single tree using rpart.
261
262 CHAPTER 27. ENSEMBLE METHODS
set.seed(18)
boston_idx = sample(1:nrow(Boston), nrow(Boston) / 2)
boston_trn = Boston[boston_idx,]
boston_tst = Boston[-boston_idx,]
30
20
10
15 20 25 30 35 40 45
Predicted
## [1] 5.051138
much better.
boston_lm = lm(medv ~ ., data = boston_trn)
30
20
10
0 10 20 30 40
Predicted
## [1] 5.016083
27.1.3 Bagging
We now fit a bagged model, using the randomForest package. Bagging is actu-
ally a special case of a random forest where mtry is equal to 𝑝, the number of
predictors.
boston_bag = randomForest(medv ~ ., data = boston_trn, mtry = 13,
importance = TRUE, ntrees = 500)
boston_bag
##
264 CHAPTER 27. ENSEMBLE METHODS
## Call:
## randomForest(formula = medv ~ ., data = boston_trn, mtry = 13, importance = TR
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 13
##
## Mean of squared residuals: 13.79736
## % Var explained: 82.42
boston_bag_tst_pred = predict(boston_bag, newdata = boston_tst)
plot(boston_bag_tst_pred,boston_tst$medv,
xlab = "Predicted", ylab = "Actual",
main = "Predicted vs Actual: Bagged Model, Test Data",
col = "dodgerblue", pch = 20)
grid()
abline(0, 1, col = "darkorange", lwd = 2)
30
20
10
10 20 30 40
Predicted
## [1] 3.905538
Here we see two interesting results. First, the predicted versus actual plot
no longer has a small number of predicted values. Second, our test error has
dropped dramatically. Also note that the “Mean of squared residuals” which is
output by randomForest is the Out of Bag estimate of the error.
27.1. REGRESSION 265
plot(boston_bag, col = "dodgerblue", lwd = 2, main = "Bagged Trees: Error vs Number of Trees")
grid()
18
16
14
trees
##
## Call:
## randomForest(formula = medv ~ ., data = boston_trn, mtry = 4, importance = TRUE, ntrees
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 4
##
## Mean of squared residuals: 12.629
## % Var explained: 83.91
importance(boston_forest, type = 1)
## %IncMSE
## crim 14.451052
266 CHAPTER 27. ENSEMBLE METHODS
## zn 2.878652
## indus 10.258393
## chas 1.317298
## nox 12.400294
## rm 27.137361
## age 10.473007
## dis 12.568593
## rad 5.120156
## tax 6.960258
## ptratio 10.684564
## black 7.750034
## lstat 28.943216
varImpPlot(boston_forest, type = 1)
boston_forest
lstat
rm
crim
dis
nox
ptratio
age
indus
black
tax
rad
zn
chas
5 10 15 20 25 30
%IncMSE
30
20
10
10 20 30 40
Predicted
## [1] 4.172905
boston_forest_trn_pred = predict(boston_forest, newdata = boston_trn)
forest_trn_rmse = calc_rmse(boston_forest_trn_pred, boston_trn$medv)
forest_oob_rmse = calc_rmse(boston_forest$predicted, boston_trn$medv)
Here we note three RMSEs. The training RMSE (which is optimistic), the OOB
RMSE (which is a reasonable estimate of the test error) and the test RMSE.
Also note that variables importance was calculated.
## Data Error
## 1 Training 1.583693
## 2 OOB 3.553731
## 3 Test 4.172905
27.1.5 Boosting
Lastly, we try a boosted model, which by default will produce a nice variable
importance plot as well as plots of the marginal effects of the predictors. We
use the gbm package.
booston_boost = gbm(medv ~ ., data = boston_trn, distribution = "gaussian",
n.trees = 5000, interaction.depth = 4, shrinkage = 0.01)
booston_boost
268 CHAPTER 27. ENSEMBLE METHODS
0 10 20 30 40
Relative influence
## # A tibble: 13 x 2
## var rel.inf
## <chr> <dbl>
## 1 lstat 44.3
## 2 rm 26.8
## 3 dis 5.70
## 4 crim 5.00
## 5 nox 4.80
## 6 black 3.72
27.1. REGRESSION 269
## 7 age 3.16
## 8 ptratio 2.66
## 9 tax 2.11
## 10 indus 0.869
## 11 rad 0.735
## 12 zn 0.165
## 13 chas 0.0440
par(mfrow = c(1, 3))
plot(booston_boost, i = "rm", col = "dodgerblue", lwd = 2)
30
y
25
20
4 5 6 7 8 9
rm
30
25
y
20
10 20 30
lstat
25
24
23
y
22
21
2 4 6 8 10
dis
## [1] 3.656622
plot(boston_boost_tst_pred, boston_tst$medv,
xlab = "Predicted", ylab = "Actual",
main = "Predicted vs Actual: Boosted Model, Test Data",
col = "dodgerblue", pch = 20)
grid()
abline(0, 1, col = "darkorange", lwd = 2)
30
20
10
10 20 30 40 50
Predicted
27.2. CLASSIFICATION 271
27.1.6 Results
(boston_rmse = data.frame(
Model = c("Single Tree", "Linear Model", "Bagging", "Random Forest", "Boosting"),
TestError = c(tree_tst_rmse, lm_tst_rmse, bag_tst_rmse, forest_tst_rmse, boost_tst_rmse)
)
)
## Model TestError
## 1 Single Tree 5.051138
## 2 Linear Model 5.016083
## 3 Bagging 3.905538
## 4 Random Forest 4.172905
## 5 Boosting 3.656622
While a single tree does not beat linear regression, each of the ensemble methods
perform much better!
27.2 Classification
We now return to the Carseats dataset and the classification setting. We see
that an additive logistic regression performs much better than a single tree, but
we expect ensemble methods to bring trees closer to the logistic regression. Can
they do better?
set.seed(2)
seat_idx = sample(1:nrow(Carseats), 200)
seat_trn = Carseats[seat_idx,]
seat_tst = Carseats[-seat_idx,]
rpart.plot(seat_tree)
272 CHAPTER 27. ENSEMBLE METHODS
Low
0.59
100%
High Low
0.28 0.68
20% 80%
High Low
0.32 0.74
12% 68%
Low
0.63
41%
Age < 50
High Low
0.38 0.81
17% 24%
Low
0.65
10%
Advertising >= 5
High Low High Low High Low High Low Low Low
0.17 0.55 0.17 0.71 0.26 0.86 0.29 0.85 0.93 0.91
14% 6% 9% 4% 14% 4% 4% 6% 14% 26%
## actual
## predicted High Low
## High 58 20
## Low 25 97
(tree_tst_acc = calc_acc(predicted = seat_tree_tst_pred, actual = seat_tst$Sales))
## [1] 0.775
## actual
## predicted High Low
## High 72 6
## Low 11 111
27.2. CLASSIFICATION 273
## [1] 0.915
27.2.3 Bagging
seat_bag = randomForest(Sales ~ ., data = seat_trn, mtry = 10,
importance = TRUE, ntrees = 500)
seat_bag
##
## Call:
## randomForest(formula = Sales ~ ., data = seat_trn, mtry = 10, importance = TRUE, ntrees
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 10
##
## OOB estimate of error rate: 26%
## Confusion matrix:
## High Low class.error
## High 51 30 0.3703704
## Low 22 97 0.1848739
seat_bag_tst_pred = predict(seat_bag, newdata = seat_tst)
table(predicted = seat_bag_tst_pred, actual = seat_tst$Sales)
## actual
## predicted High Low
## High 62 14
## Low 21 103
(bag_tst_acc = calc_acc(predicted = seat_bag_tst_pred, actual = seat_tst$Sales))
## [1] 0.825
##
## Call:
## randomForest(formula = Sales ~ ., data = seat_trn, mtry = 3, importance = TRUE, ntrees =
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 3
274 CHAPTER 27. ENSEMBLE METHODS
##
## OOB estimate of error rate: 28.5%
## Confusion matrix:
## High Low class.error
## High 44 37 0.4567901
## Low 20 99 0.1680672
seat_forest_tst_perd = predict(seat_forest, newdata = seat_tst)
table(predicted = seat_forest_tst_perd, actual = seat_tst$Sales)
## actual
## predicted High Low
## High 58 8
## Low 25 109
(forest_tst_acc = calc_acc(predicted = seat_forest_tst_perd, actual = seat_tst$Sales))
## [1] 0.835
27.2.5 Boosting
To perform boosting, we modify the response to be 0 and 1 to work with gbm.
Later we will use caret to fit gbm models, which will avoid this annoyance.
seat_trn_mod = seat_trn
seat_trn_mod$Sales = as.numeric(ifelse(seat_trn_mod$Sales == "Low", "0", "1"))
## actual
## predicted High Low
## High 68 10
## Low 15 107
(boost_tst_acc = calc_acc(predicted = seat_boost_tst_pred, actual = seat_tst$Sales))
## [1] 0.875
27.3. TUNING 275
27.2.6 Results
(seat_acc = data.frame(
Model = c("Single Tree", "Logistic Regression", "Bagging", "Random Forest", "Boosting"),
TestAccuracy = c(tree_tst_acc, glm_tst_acc, bag_tst_acc, forest_tst_acc, boost_tst_acc)
)
)
## Model TestAccuracy
## 1 Single Tree 0.775
## 2 Logistic Regression 0.915
## 3 Bagging 0.825
## 4 Random Forest 0.835
## 5 Boosting 0.875
Here we see each of the ensemble methods performing better than a single tree,
however, they still fall behind logistic regression. Sometimes a simple linear
model will beat more complicated models! This is why you should always try a
logistic regression for classification.
27.3 Tuning
So far we fit bagging, boosting and random forest models, but did not tune any
of them, we simply used certain, somewhat arbitrary, parameters. Now we will
see how to modify the tuning parameters to make these models better.
• Bagging: Actually just a subset of Random Forest with mtry = 𝑝.
• Random Forest: mtry
• Boosting: n.trees, interaction.depth, shrinkage, n.minobsinnode
We will use the caret package to accomplish this. Technically ntrees is a
tuning parameter for both bagging and random forest, but caret will use 500
by default and there is no easy way to tune it. This will not make a big difference
since for both we simply need “enough” and 500 seems to do the trick.
While mtry is a tuning parameter, there are suggested values for classification
and regression:
• Regression: mtry = 𝑝/3.
√
• Classification: mtry = 𝑝.
Also note that with these tree-based ensemble methods there are two resampling
solutions for tuning the model:
• Out of Bag
• Cross-Validation
Using Out of Bag samples is advantageous with these methods as compared to
Cross-Validation since it removes the need to refit the model and is thus much
276 CHAPTER 27. ENSEMBLE METHODS
To tune a Random Forest in caret we will use method = "rf" which uses
the randomForest function in the background. Here we elect to use the OOB
training control that we created. We could also use cross-validation, however it
will likely select a similar model, but require much more time.
We setup a grid of mtry values which include all possible values since there are
10 predictors in the dataset. An mtry of 10 is actually bagging.
dim(seat_trn)
## [1] 200 11
rf_grid = expand.grid(mtry = 1:10)
set.seed(825)
seat_rf_tune = train(Sales ~ ., data = seat_trn,
method = "rf",
trControl = oob,
verbose = FALSE,
tuneGrid = rf_grid)
seat_rf_tune
## Random Forest
##
## 200 samples
## 10 predictor
## 2 classes: 'High', 'Low'
##
## No pre-processing
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 1 0.695 0.3055556
## 2 0.740 0.4337363
## 3 0.720 0.4001071
27.3. TUNING 277
## 4 0.740 0.4406798
## 5 0.740 0.4474551
## 6 0.735 0.4333975
## 7 0.735 0.4402197
## 8 0.730 0.4308000
## 9 0.710 0.3836999
## 10 0.740 0.4474551
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
calc_acc(predict(seat_rf_tune, seat_tst), seat_tst$Sales)
## [1] 0.82
The results returned are based on the OOB samples. (Coincidentally, the test
accuracy is the same as the best accuracy found using OOB samples.) Note
that when using OOB, for some reason the default plot is not what you would
expect and is not at all useful. (Which is why it is omitted here.)
seat_rf_tune$bestTune
## mtry
## 2 2
Based on these results, we would select the random forest model with an mtry
of 2. Note that based on the OOB estimates, the bagging model is expected to
perform worse than this selected model, however, based on our results above,
that is not what we find to be true in our test set.
Also note that method = "ranger" would also fit a random forest model.
Ranger is a newer R package for random forests that has been shown to be
much faster, especially when there are a larger number of predictors.
27.3.2 Boosting
We now tune a boosted tree model. We will use the cross-validation tune control
setup above. We will fit the model using gbm with caret.
We now train the model using all possible combinations of the tuning parameters
we just specified.
seat_gbm_tune = train(Sales ~ ., data = seat_trn,
method = "gbm",
trControl = cv_5,
verbose = FALSE,
tuneGrid = gbm_grid)
The additional verbose = FALSE in the train call suppresses additional output
from each gbm call.
By default, calling plot here will produce a nice graphic summarizing the re-
sults.
plot(seat_gbm_tune)
0.75
0.70
0.65
# Boosting Iterations
## [1] 0.84
27.4. TREE VERSUS ENSEMBLE BOUNDARIES 279
We see our tuned model does no better on the test set than the arbitrary boosted
model we had fit above, with the slightly different parameters seen below. We
could perhaps try a larger tuning grid, but at this point it seems unlikely that
we could find a much better model. There seems to be no way to get a tree
method to out-perform logistic regression in this dataset.
seat_gbm_tune$bestTune
library(mlbench)
27.4 Tree versus Ensemble Boundaries
set.seed(42)
sim_trn = mlbench.circle(n = 1000, d = 2)
sim_trn = data.frame(sim_trn$x, class = as.factor(sim_trn$classes))
sim_tst = mlbench.circle(n = 1000, d = 2)
sim_tst = data.frame(sim_tst$x, class = as.factor(sim_tst$classes))
1.0
0.5
0.0
X2
-0.5
-1.0
X1
library(rpart.plot)
rpart.plot(sim_tree_cv$finalModel)
27.4. TREE VERSUS ENSEMBLE BOUNDARIES 281
2
0.53
100%
yes X2 >= -0.71 no
1
0.45
83%
X2 < 0.67
1
0.33
67%
X1 >= -0.75
1
0.23
58%
X1 < 0.71
1 2 2 2 2
0.08 0.94 0.98 0.93 0.95
48% 10% 9% 17% 17%
sim_gbm_cv = train(class ~ .,
data = sim_trn,
method = "gbm",
trControl = cv_5,
verbose = FALSE,
tuneGrid = gbm_grid)
plot_grid = expand.grid(
X1 = seq(min(sim_tst$X1) - 1, max(sim_tst$X1) + 1, by = 0.01),
X2 = seq(min(sim_tst$X2) - 1, max(sim_tst$X2) + 1, by = 0.01)
)
1.0
1.0
0.5
0.5
0.5
0.0
0.0
0.0
X2
X2
X2
-0.5
-0.5
-0.5
-1.0
-1.0
-1.0
-1.0 -0.5 0.0 0.5 1.0 -1.0 -0.5 0.0 0.5 1.0 -1.0 -0.5 0.0 0.5 1.0
X1 X1 X1
27.6 rmarkdown
The rmarkdown file for this chapter can be found here. The file was created
using R version 4.0.2. The following packages (and their dependencies) were
loaded when knitting this file:
## [1] "mlbench" "ISLR" "MASS" "caret" "ggplot2"
## [6] "lattice" "gbm" "randomForest" "rpart.plot" "rpart"
284 CHAPTER 27. ENSEMBLE METHODS
Chapter 28
285
286 CHAPTER 28. ARTIFICIAL NEURAL NETWORKS
Part VII
Appendix
287
Chapter 29
Overview
TODO: Add a section about “coding” tips and tricks. For example: beware
when using code you found on the internet.
TODO: Add a section about ethics in machine learning
• https://www.newyorker.com/news/daily-comment/the-ai-gaydar-study-
and-the-real-dangers-of-big-data
• https://www.propublica.org/article/facebook-enabled-advertisers-to-
reach-jew-haters
289
290 CHAPTER 29. OVERVIEW
Chapter 30
Non-Linear Models
291
292 CHAPTER 30. NON-LINEAR MODELS
Chapter 31
Regularized Discriminant
Analysis
We now use the Sonar dataset from the mlbench package to explore a new reg-
ularization method, regularized discriminant analysis (RDA), which com-
bines the LDA and QDA. This is similar to how elastic net combines the ridge
and lasso.
library(mlbench)
library(caret)
library(glmnet)
library(klaR)
data(Sonar)
#View(Sonar)
table(Sonar$Class) / nrow(Sonar)
##
## M R
## 0.5336538 0.4663462
ncol(Sonar) - 1
293
294 CHAPTER 31. REGULARIZED DISCRIMINANT ANALYSIS
## [1] 60
31.2 RDA
Regularized discriminant analysis uses the same general setup as LDA and QDA
but estimates the covariance in a new way, which combines the covariance of
QDA (Σ̂ 𝑘 ) with the covariance of LDA (Σ)̂ using a tuning parameter 𝜆.
Using the rda() function from the klaR package, which caret utilizes, makes
an additional modification to the covariance matrix, which also has a tuning
parameter 𝛾.
1
Σ̂ 𝑘 (𝜆, 𝛾) = (1 − 𝛾)Σ̂ 𝑘 (𝜆) + 𝛾 tr(Σ̂ 𝑘 (𝜆))𝐼
𝑝
Both 𝛾 and 𝜆 can be thought of as mixing parameters, as they both take values
between 0 and 1. For the four extremes of 𝛾 and 𝜆, the covariance structure
reduces to special cases:
• (𝛾 = 0, 𝜆 = 0): QDA - individual covariance for each group.
• (𝛾 = 0, 𝜆 = 1): LDA - a common covariance matrix.
• (𝛾 = 1, 𝜆 = 0): Conditional independent variables - similar to Naive Bayes,
but variable variances within group (main diagonal elements) are all equal.
• (𝛾 = 1, 𝜆 = 1): Classification using euclidean distance - as in previous
case, but variances are the same for all groups. Objects are assigned to
group with nearest mean.
set.seed(1337)
fit_rda_grid = train(Class ~ ., data = Sonar, method = "rda", trControl = cv_5_grid)
fit_rda_grid
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 167, 166, 166, 167, 166
## Resampling results across tuning parameters:
##
## gamma lambda Accuracy Kappa
## 0.0 0.0 0.6977933 0.3791172
## 0.0 0.5 0.7644599 0.5259800
## 0.0 1.0 0.7310105 0.4577198
## 0.5 0.0 0.7885017 0.5730052
## 0.5 0.5 0.8271777 0.6502693
## 0.5 1.0 0.7988386 0.5939209
## 1.0 0.0 0.6732869 0.3418352
## 1.0 0.5 0.6780488 0.3527778
## 1.0 1.0 0.6825784 0.3631626
##
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were gamma = 0.5 and lambda = 0.5.
plot(fit_rda_grid)
Lambda
0 0.5 1
Accuracy (Cross-Validation)
0.80
0.75
0.70
Gamma
set.seed(1337)
cv_5_rand = trainControl(method = "cv", number = 5, search = "random")
1.00
0.75
Accuracy
0.750
Lambda
0.775
0.50
0.800
0.825
0.850
0.25
0.00
0.25 0.50 0.75 1.00
Gamma
set.seed(1337)
fit_elnet_int_grid = train(Class ~ . ^ 2, data = Sonar, method = "glmnet",
trControl = cv_5_grid, tuneLength = 10)
31.6 Results
get_best_result = function(caret_fit) {
best_result = caret_fit$results[as.numeric(rownames(caret_fit$bestTune)), ]
rownames(best_result) = NULL
best_result
}
knitr::kable(rbind(
get_best_result(fit_rda_grid),
get_best_result(fit_rda_rand)))
298 CHAPTER 31. REGULARIZED DISCRIMINANT ANALYSIS
31.8 RMarkdown
The RMarkdown file for this chapter can be found here. The file was created
using R version 4.0.2 and the following packages:
• Base Packages, Attached
## [1] "stats" "graphics" "grDevices" "utils" "datasets" "methods"
## [7] "base"
• Additional Packages, Attached
## [1] "klaR" "MASS" "glmnet" "Matrix" "caret" "ggplot2" "lattice"
## [8] "mlbench"
• Additional Packages, Not Attached
## [1] "splines" "foreach" "prodlim" "shiny" "highr"
## [6] "stats4" "yaml" "ipred" "pillar" "glue"
## [11] "pROC" "digest" "promises" "colorspace" "recipes"
## [16] "htmltools" "httpuv" "plyr" "timeDate" "pkgconfig"
## [21] "labelled" "haven" "questionr" "bookdown" "purrr"
## [26] "xtable" "scales" "later" "gower" "lava"
## [31] "tibble" "combinat" "generics" "farver" "ellipsis"
## [36] "withr" "nnet" "survival" "magrittr" "crayon"
## [41] "mime" "evaluate" "nlme" "forcats" "class"
## [46] "tools" "data.table" "hms" "lifecycle" "stringr"
## [51] "munsell" "compiler" "e1071" "rlang" "grid"
## [56] "iterators" "rstudioapi" "miniUI" "labeling" "rmarkdown"
## [61] "gtable" "ModelMetrics" "codetools" "reshape2" "R6"
31.8. RMARKDOWN 299
301