We are making predictions all the time, often without realizing it. For example, imagine we are waiting at a bus stop and want to guess how long it will be before a bus arrives. We can combine many sources of evidence,
To think about the process formally, we could imagine a vector \(\mathbf{x}_i \in \mathbb{R}^{D}\) reflecting \(D\) characteristics of our environment. If we collected data about how long we actually had to wait, call it \(y_i\), for every day in a year, then we would have a dataset \[\begin{align*} \left(\mathbf{x}_1, y_1\right) \\ \left(\mathbf{x}_2, y_2\right) \\ \vdots \\ \left(\mathbf{x}_{365}, y_{365}\right) \\ \end{align*}\] and we could try to summarize the relationship \(\mathbf{x}_i \to y_i\). Methods for making this process automatic, based simply on a training dataset, are called supervised learning methods.
In the above example, the inputs were a mix of counts (number of people at stop?) and categorical (weather) data types, and our response was a nonnegative continuous value. In general, we could have arbitrary data types for either input or response variable. A few types of outputs are so common that they come with their own names,
For example,
There are in fact many other types of responses (ordinal, multiresponse, survival, functional, image-to-image, …) each which come with their own names and set of methods, but for our purposes, it’s enough to focus on regression and classification.
There is a nice geometric way of thinking about supervised learning. For regression, think of the inputs on the \(x\)-axis and the response on the \(y\)-axis. Regression then becomes the problem of estimating a one-dimensional curve from data.
In higher-dimensions, this becomes a surface.
If some of the inputs are categorical (e.g., poor vs. good weather), then the regression function is no longer a continuous curve, but we can still identify group means.
In higher-dimensions, the view is analogous. We just want to find boundaries between regions with clearly distinct colors. For example, for disease recurrence, blood pressure and resting heart rate might be enough to make a good guess about whether a patient will have recurrence or not.
When we have many input features, the equivalent formula is \[\begin{align*} f_{b}\left(x\right) = b_0 + b_1 x_1 + \dots + b_{D}x_{D} := b^{T}x, \end{align*}\] where I’ve used the dot-product from linear algebra to simplify notation (after having appended a 1). This kind of model is called a linear regression model.
\[\begin{align*} L\left(b\right) = \sum_{i = 1}^{N} \left(y_i - b^{T}x_{i}\right)^2. \end{align*}\]
To describe this, we need to define a direction \(b\) perpendicular to the boundary. We will say that whenever \[\begin{align*} f_{b}\left(x\right) = \frac{1}{1 + \text{exp}\left(b^T x\right)} \end{align*}\] is larger than 0.5, we’re in the red region, and whenever it’s smaller than 0.5, we’re in the purple region. This kind of model is called a logistic regression model.
\[\begin{align*} -\left[\sum_{i = 1}^{N} y_i \log\left(f_{b}\left(x_i\right)\right) + \left(1 - y_i\right) \log\left(1 - f_{b}\left(x_i\right)\right)\right] \end{align*}\]
To understand this loss, note that each term decomposes into either the blue or red curve, depending on whether the \(y_i\) is 1 or 0.
If 1 is predicted with probability 1, then there is no loss (and conversely for 0). The loss increases the further the predicted probability is from the true class.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import datasets, linear_model
= datasets.load_diabetes(return_X_y=True)
X, y 4, :5] # first five predictors X[:
array([[ 0.03807591, 0.05068012, 0.06169621, 0.02187239, -0.0442235 ],
[-0.00188202, -0.04464164, -0.05147406, -0.02632753, -0.00844872],
[ 0.08529891, 0.05068012, 0.04445121, -0.00567042, -0.04559945],
[-0.08906294, -0.04464164, -0.01159501, -0.03665608, 0.01219057]])
4] # example response y[:
array([151., 75., 141., 206.])
Let’s now fit a linear model from \(\mathbf{x}_1, \dots, \mathbf{x}_{N}\) to \(y\).
The first line tells python that we are using a LinearRegression
model class.
The second searches over coefficients \(b\) to minimize the squared-error loss
between the \(b^T x_i\) and \(y_i\). The third line prints out the fitted coefficient
\(\hat{b}\).
= linear_model.LinearRegression()
model model.fit(X, y)
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
LinearRegression()
# fitted b coefficients model.coef_
array([ -10.0098663 , -239.81564367, 519.84592005, 324.3846455 ,
-792.17563855, 476.73902101, 101.04326794, 177.06323767,
751.27369956, 67.62669218])
palmerspenguins
package).We’ll read the data from a public link and print the first few rows.
import pandas as pd
= pd.read_csv("https://raw.githubusercontent.com/krisrs1128/stat679_code/0330ce6257ff077c5d4ed9f102af6be089f5c486/examples/week6/week6-4/penguins.csv")
penguins penguins.head()
species island bill_length_mm ... body_mass_g sex year
0 Adelie Torgersen 39.1 ... 3750 male 2007
1 Adelie Torgersen 39.5 ... 3800 female 2007
2 Adelie Torgersen 40.3 ... 3250 female 2007
3 Adelie Torgersen 36.7 ... 3450 female 2007
4 Adelie Torgersen 39.3 ... 3650 male 2007
[5 rows x 8 columns]
We’ll predict species
using just bill
length and depth. First, let’s make a
plot to see how easy / difficult it will be to create a decision boundary.
ggplot(py$penguins) +
geom_point(aes(bill_length_mm, bill_depth_mm, col = species)) +
scale_color_manual(values = c("#3DD9BC", "#6DA671", "#F285D5")) +
labs(x = "Bill Length", y = "Bill Depth")
= linear_model.LogisticRegression()
model = penguins.dropna()
penguins = penguins[["bill_length_mm", "bill_depth_mm"]], penguins["species"]
X, y model.fit(X, y)
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
LogisticRegression()
"y_hat"] = model.predict(X) penguins[
The plot below compares the predicted class (left, middle, and right panels) with the true class (color). We get most of the samples correct, but have a few missclassifications near the boundaries.
ggplot(py$penguins) +
geom_point(aes(bill_length_mm, bill_depth_mm, col = species)) +
scale_color_manual(values = c("#3DD9BC", "#6DA671", "#F285D5")) +
labs(x = "Bill Length", y = "Bill Depth") +
facet_wrap(~ y_hat)
Exercise: Repeat this classification, but using at least two additional predictors.
sklearn
, we can use the
ElasticNet
class. We’ll work with a dataset of American Baseball sports
statistics. The task is to predict each player’s salary based on their batting
statistics.import pandas as pd
= pd.read_csv("https://github.com/krisrs1128/naamii_summer_2023/raw/main/assets/baseball.csv")
baseball baseball.head()
player salary AtBat ... Assists Errors NewLeagueN
0 -Alan Ashby -0.135055 -0.601753 ... -0.522196 0.212946 1.073007
1 -Alvin Davis -0.123972 0.511566 ... -0.253380 0.818404 -0.928417
2 -Andre Dawson -0.079637 0.626971 ... -0.742763 -0.846605 1.073007
3 -Andres Galarraga -0.985164 -0.561022 ... -0.542874 -0.695240 1.073007
4 -Alfredo Griffin 0.474541 1.292248 ... 2.083253 2.483412 -0.928417
[5 rows x 21 columns]
= baseball.iloc[:, 2:], baseball["salary"]
X, y = (y - y.mean()) / y.std() # standardize y
The block below fits the Elastic Net model and saves the coefficients \(\hat{b}\). Notice that most of them are 0 – only a few of the features make a big difference in the salary.
from sklearn import datasets, linear_model
= linear_model.ElasticNet(alpha=1e-1, l1_ratio=0.5) # in real life, have to tune these parameters
model model.fit(X, y)
ElasticNet(alpha=0.1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
ElasticNet(alpha=0.1)
= model.predict(X)
y_hat
= model.coef_ # notice the sparsity
beta_hat beta_hat
array([ 0. , 0.17423133, 0. , 0. , 0. ,
0.1050172 , 0. , 0. , 0.07963389, 0.06285448,
0.12726386, 0.16911053, 0. , 0. , -0.0988662 ,
0.12489593, -0. , -0. , 0. ])
We can confirm that the predictions are correlated relatively well with the truth.
[,1] [,2]
0 -0.135055108 -0.016427954
1 -0.123971550 0.296625967
2 -0.079637319 0.836969714
3 -0.985163996 -0.193317250
4 0.474540574 0.174597235
5 -1.032823295 -0.797066881
6 -0.966321948 -0.876866640
7 -1.021739737 -0.834469088
8 1.250389625 0.695783793
9 -0.041636232 0.669700610
10 -0.051928424 -0.584211792
11 0.031198260 -0.074863961
12 0.363704996 -0.467194917
13 -0.655982327 -0.325733310
14 0.529958364 1.050398025
15 -0.800068580 -0.945921557
16 -0.888737043 -0.904460693
17 -0.966321948 -0.226783826
18 -0.933071274 -0.353904469
19 0.142033839 0.077451983
20 0.533653622 1.596657954
21 0.507791248 0.442089669
22 0.382176853 -0.371692791
23 0.474540574 0.109540535
24 0.197451628 0.132038662
25 0.807047310 0.762115250
26 -0.944154832 -0.591503678
27 0.169742733 0.152641555
28 -0.522979633 -0.684838193
29 0.696211732 0.387347327
30 -0.988489063 -0.751894525
31 -1.038365074 -0.602162962
32 -0.788985022 -0.622521292
33 -0.511896075 -0.420695837
34 -0.711400117 -0.434539235
35 -0.639356991 -0.659088703
36 0.618626827 0.270983238
37 0.751629521 0.703274377
38 -1.032823295 -0.695245805
39 1.472060782 0.357228560
40 0.308287206 0.559825936
41 -0.268057802 -0.571795489
42 -0.434311170 -0.603384097
43 -0.264362544 -0.417730892
44 1.804567518 0.697128086
45 -0.988489063 -0.402694446
46 -0.578397422 0.058097894
47 -0.678149443 -0.857771950
48 -0.689233001 -0.837235722
49 0.917882889 1.278327372
50 -1.021739737 -0.857251436
51 -0.955238390 -0.150018032
52 -0.478645402 0.398717839
53 0.696211732 0.555832278
54 -0.002052414 1.732530758
55 0.880936957 1.379264455
56 0.696211732 0.012256994
57 -0.722483675 -0.442549371
58 -0.467561844 -0.245293390
59 -0.578397422 -0.618869284
60 -0.190472897 0.434753072
61 3.190012251 1.197509140
62 3.023758883 0.948470654
63 0.142033839 0.143880088
64 1.121082189 1.156646047
65 -0.944154832 -0.394585753
66 -0.611648096 0.053795965
67 -0.135055108 -0.329084300
68 -0.231482061 -0.445028879
69 1.516395013 0.258387027
70 -1.032823295 -0.574875196
71 -0.866569927 -0.264577547
72 0.130950281 -0.266703984
73 2.938326819 1.547107247
74 -0.522979633 -0.220655904
75 -0.101804434 -0.250246536
76 4.265117363 1.615934897
77 -0.356726265 -0.601579785
78 0.474540574 0.591150888
79 1.416642993 0.411463136
80 -1.032823295 -0.769640826
81 2.137074254 1.007630993
82 -0.334559149 -0.211791319
83 3.080442414 1.329957774
84 -0.711400117 0.380273450
85 0.807047310 0.189085139
86 -0.844402811 -0.477826058
87 0.363704996 0.534639550
88 -0.002052414 0.020202881
89 -0.384435160 -0.358246690
90 0.437594643 0.953867405
91 -0.744650790 0.911074305
92 -0.301308476 -0.003155397
93 -0.301308476 -0.202226671
94 0.446831680 0.222218624
95 -0.079637319 -0.215628600
96 0.142033839 -0.005313397
97 0.280578312 0.252848694
98 0.917882889 0.409401995
99 0.474540574 -0.062874121
100 -0.528521412 -0.513259342
101 -0.467561844 0.391593900
102 -0.994030842 -0.587932764
103 -0.800068580 -0.897992535
104 -0.988489063 -0.750433701
105 1.555187466 0.571637575
106 -0.234807129 0.231221459
107 -0.966321948 -0.374118589
108 -0.822235695 -0.125199889
109 -0.633815212 0.226944129
110 1.693731939 0.546180593
111 0.526263106 0.887313809
112 1.047190325 0.467034114
113 -0.578397422 -0.459353681
114 0.529958364 0.272610745
115 0.696211732 0.215312541
116 -0.378893381 -0.028228342
117 -0.977405505 -0.979336753
118 -0.944154832 -0.653596946
119 -0.966321948 -0.230260452
120 -0.572855644 0.201813769
121 -1.010656179 -0.559679803
122 0.142033839 0.179625898
123 -0.744650790 -0.177481921
124 0.268386398 0.429125602
125 -1.021739737 -0.899774381
126 4.159823563 1.754158746
127 -0.633815212 -0.332556571
128 -0.844402811 -0.348831459
129 0.230702301 0.107292550
130 -0.522979633 -0.350526850
131 -0.944154832 -0.777912253
132 0.640793942 0.398105679
133 -0.755734348 -0.392992938
134 -0.190472897 -0.477858654
135 0.208535186 0.007914549
136 -0.996247554 -0.888539537
137 1.693731939 0.346166488
138 1.028718467 0.759605161
139 2.802087725 1.719124925
140 1.715899055 0.752489936
141 0.446831680 -0.210149244
142 0.197451628 0.126914469
143 -0.910904158 -0.458586116
144 1.124775230 0.413923478
145 0.419122785 0.110214550
146 -0.522979633 -0.326767407
147 -0.378893381 0.226564077
148 -1.021739737 -0.821647108
149 1.435114850 0.865461273
150 -0.739109011 -0.092179571
151 -0.689233001 0.061523440
152 -0.024219529 0.410178473
153 -0.600564538 -0.735703112
154 0.557667258 0.588199714
155 0.585376153 0.513901165
156 0.114324944 0.525620107
157 -0.866569927 -0.498133172
158 -0.256974244 0.720892601
159 -1.021739737 -0.753816535
160 0.086616049 0.278980341
161 0.541041922 -0.118742768
162 -0.988489063 -0.654893015
163 -0.855486369 -0.592506556
164 0.363704996 0.347867774
165 0.031198260 -0.056522669
166 0.252869417 -0.181642888
167 -1.037256718 -0.957042699
168 -0.966321948 -0.402379721
169 0.297203649 -0.422854470
170 -0.800068580 -0.109056085
171 -0.884303619 -0.712654874
172 3.527690574 -1.050275058
173 0.751629521 -0.037630346
174 -0.921987716 -0.657357598
175 -0.877653485 -0.517336462
176 -0.722483675 0.105041226
177 0.585376153 0.020201816
178 -0.655982327 -0.637189563
179 -0.412144055 -0.191954382
180 -0.800068580 -0.464498208
181 -0.744650790 -0.078477727
182 3.112427345 0.515793318
183 0.363704996 -0.113110903
184 0.474540574 0.086570512
185 -0.190472897 0.127541161
186 -0.806718714 -0.320633525
187 1.605063476 0.279290388
188 0.474540574 1.448551043
189 -0.766817906 -0.515106282
190 0.097699607 0.334307227
191 -0.899820600 -0.599055488
192 -0.190472897 -0.137286161
193 -0.522979633 -0.417226652
194 -0.633815212 -0.258565805
195 1.139554046 0.829743869
196 -0.711400117 -0.057499484
197 -0.301308476 0.218664917
198 0.053365376 0.044742258
199 2.513915221 0.913930624
200 -0.107346213 1.714745102
201 -0.245890687 -0.423200359
202 -0.079637319 0.299072479
203 -0.633815212 -0.594283080
204 -0.301308476 -0.060252928
205 -0.190472897 -0.150910103
206 0.474540574 0.074660922
207 -1.032823295 -0.646880202
208 0.751629521 -0.242195973
209 -0.766817906 -0.863081047
210 -0.764601195 -0.314220044
211 0.452373459 0.487669956
212 -0.633815212 -0.392492992
213 -0.877653485 -0.671898927
214 -0.971863727 -0.636731706
215 0.452373459 0.376023696
216 -0.877653485 -0.326483685
217 -0.430615912 -0.546594799
218 1.028718467 1.186968530
219 -0.966321948 0.266004024
220 -0.988489063 -0.811629230
221 -0.744650790 0.452791831
222 -0.888737043 -0.537291545
223 -0.844402811 -0.165805103
224 -0.135055108 -0.143870798
225 2.026238675 1.523274774
226 -0.855486369 -0.236037416
227 -0.955238390 -0.451023699
228 -0.412144055 -0.267907427
229 -0.988489063 0.368251838
230 -0.013135971 0.469596706
231 -0.430615912 -0.261287936
232 0.895715773 0.217359619
233 -0.412144055 0.200064477
234 -0.463866586 -0.311167733
235 -0.633815212 -0.537374618
236 0.452373459 0.292336841
237 -0.245890687 -0.647867634
238 0.862465100 0.420804149
239 -0.777901464 -0.571548148
240 0.851381542 -0.962313080
241 -0.552535049 -0.648154649
242 -0.644898770 -0.678178645
243 -0.667065885 -0.317054653
244 1.361225203 0.532519723
245 -0.833319253 -0.150303963
246 -0.245890687 -0.193212904
247 0.807047310 -0.434045213
248 -0.079637319 0.548961007
249 -0.572855644 -0.403167896
250 0.474540574 0.110636811
251 -0.833319253 0.015171595
252 1.693731939 0.995800185
253 -0.024219529 -0.187536583
254 0.031198260 -0.125064409
255 2.358745411 0.758562714
256 -0.921987716 -0.188908725
257 -0.822235695 0.311761865
258 0.363704996 0.148495209
259 0.751629521 0.817735927
260 -0.334559149 -0.349916932
261 0.940050005 1.001580950
262 1.028718467 0.379816267
<- data.frame(py$X, y = py$y, y_hat = py$y_hat)
baseball ggplot(baseball) +
geom_point(aes(y, y_hat)) +
geom_abline(slope = 1, col = "red") +
labs(x = "True Salary", y = "Predicted Salary")
Notice that we can use the same logic to do either regression or classification. For regression, we associate each “leaf” at the bottom of the tree with a continuous prediction. For classification, we associate leaves with probabilities for different classes. It turns out that we can train these models using squared error and cross-entropy losses as before, though the details are beyond the scope of these notes.
If we split the same variable deeper, it creates more steps,
What if we had two variables? Depending on the order, of the splits, we create different axis-aligned partitions,
Q: What would be the diagram if I had switched the order of the splits (traffic before rain)?
A very common variation on tree-based models computes a large ensemble of trees and then combines their curves in some way. How exactly they are combined is beyond the scope of these notes, but this is what random forests and gradient boosted decision trees are doing in the background.
We can implement these models in sklearn
using RandomForestRegressor
/
RandomForestClassifier
, and GradientBoostingRegressor
/
GradientBoostingClassifier
. Let’s just see an example of a boosting classifier
using the penguins dataset. The fitting / prediction code is very similar to what we used for the sparse regression.
from sklearn.ensemble import GradientBoostingClassifier
= GradientBoostingClassifier()
model = penguins[["bill_length_mm", "bill_depth_mm"]], penguins["species"]
X, y model.fit(X, y)
GradientBoostingClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
GradientBoostingClassifier()
"y_hat"] = model.predict(X) penguins[
We use the same visualization code to check predictions against the truth. The boosting classifier makes no mistakes on the training data.
ggplot(py$penguins) +
geom_point(aes(bill_length_mm, bill_depth_mm, col = species)) +
scale_color_manual(values = c("#3DD9BC", "#6DA671", "#F285D5")) +
labs(x = "Bill Length", y = "Bill Depth") +
facet_wrap(~ y_hat)
Strengths | Weaknesses | |
---|---|---|
Linear / Logistic Regression | * Often easy to interpret * No tuning parameters * Very fast to train |
* Unstable when many features to pick from * Can only fit linear curves / boundaries (though, see featurization notes) |
Sparse Linear / Logistic Regression | * Often easy to interpret * Stable even when many features to pick from * Very fast to train |
* Can only fit linear curves / boundaries |
Tree-based Classification / Regression | * Can fit nonlinear functions of inputs | * Can be slow to train * Somewhat harder to interpret |
The answers are,