複数のモデルを管理する
purrrとbroomの使い方をマスターするために。以下の記事をトレースする。 内容自体は引用元記事の方がちゃんとしているのでそちらを読んでもらいたい。本記事は、読むにあたって理解が薄い人(=自分)用の補足メモを書きながらのトレースとなる。そのため、本筋外の解説が多く、ごちゃごちゃしていることに留意してもらいたい。
purrrとbroomで複数の回帰モデルを効率的に管理する - Dropout
引用記事にもあるが、複数のモデルを作成する際に個々でオブジェクトを作っていくとデータなどに変更があった際に全てに忘れずに適用しなければいけなくなったり、どのオブジェクトがどの係数になっているかという係数用オブジェクトを各モデル毎に更にオブジェクトで作っていくと更にややこしいことになる。また、モデルの比較も面倒。
そのため、各モデルを1つの計算処理でおこなえるようにしつつ、結果を1つのオブジェクトで一元管理ができるならば非常に便利。
これを実現するために、purrr::map()とbroom::tidy(), broom::glance()を用いる。
library(tidyverse) library(tidymodels) library(patchwork) # 可視化用 theme_scatter = theme_minimal() + theme(panel.grid.minor = element_blank(), axis.text = element_text(color = "black")) theme_minimal2 = theme_scatter + theme(panel.grid.major.x = element_blank()) df = diamonds
データの構造を見る
目的変数を正規分布にする
今回、予測モデルとしてダイヤモンドのpriceを、carat, clarity(透明度ランク)を用いてOLSで推定する。
その際、(本論のモデル管理とは別だが)、OLSを用いるためpriceの確率分布を(できる限り)正規分布にするためにpriceに対数を取って推定する必要がある。
# 正規分布にする df %>% ggplot(aes(log(price))) + geom_histogram(fill = "#56B4E9", color = "white") + theme_minimal2
log(price)とcaratは非線形関係なので、caratにも対数をとり線形にする必要がある。
# 線形回帰のために線形関係にする # カラットvs値段 # カラットが上がるほど値段があがる df %>% sample_frac(0.1) %>% #データが多いので減らす ggplot(aes(log(carat), log(price))) + geom_point(color = "#0072B2", alpha = 0.5) + theme_scatter
直感的には透明度が上がると値段が上がりそうだが、データをみると逆の関係になっている。
これは、透明度が上がるほどカラット数が小さくなる...おそらく、カラット数が大きく透明度が高いダイヤは作りづらい(もしくは需要が少ないからあまり作らない)ということが反映されている。
そのため、カラット数が交絡となり、値段と透明度に負の相関が現れる結果となっている。
# 透明度vs値段 # => 透明度が上がるほど値段がさがる g1 = df %>% ggplot(aes(clarity, log(price))) + geom_boxplot(fill = "#56B4E9") + theme_minimal2 # 透明度vsカラット # =>透明度が上がるほどカラットが下がる g2 = df %>% ggplot(aes(clarity, log(carat))) + geom_boxplot(fill = "#56B4E9") + theme_minimal2 g1 | g2
モデリング
使用するモデルの作成
上記の結果、
- 価格、透明度は対数を取る
- 交絡を考慮し、モデルにはカラット数を入れる
このことを踏まえ、検証できるように以下の3モデルを作成する。
- 1.価格の対数を、透明度のみで説明するモデル
- 2.価格の対数を、透明度とカラット数で説明するモデル
- 3.価格の対数を、透明度とカラット数の対数で説明するモデル
# 回帰用前処理 df_input = df %>% mutate_if(is.ordered, factor, ordered = FALSE)# factor(ordered = FALSE)に変換 formulas = c(log(price) ~ clarity, log(price) ~ clarity + carat, log(price) ~ clarity + log(carat)) %>% enframe("model_no", "formula") #vactor to DF # model_no formula # <int> <list> # 1 1 <S3: formula> # 2 2 <S3: formula> # 3 3 <S3: formula>
enframe()
enframe(x, name = "name", value = "value")
はvector/listをtibbleに変換する関数。上記コードでは、パイプでxを渡し、引数1,2はそれぞれmodel_no, formulaを指定している。これは列名に対応している(デフォルトだと、name, valueになる)。
以下はhelpのサンプルコード
enframe(1:3) # => # # A tibble: 3 x 2 # name value # <int> <int> # 1 1 1 # 2 2 2 # 3 3 3 enframe(c(a = 5, b = 7)) # => # # A tibble: 2 x 2 # name value # <chr> <dbl> # 1 a 5 # 2 b 7 enframe(list(one = 1, two = 2:3, three = 4:6)) # => # # A tibble: 3 x 2 # name value # <chr> <list> # 1 one <dbl [1]> # 2 two <int [2]> # 3 three <int [3]>
回帰モデルの結果の整理
先程作成したformulasをmap()
を用いてデータにfitさせる。
# モデルとその評価を管理しやすい形にする df_result = formulas %>% mutate(model = map(formula, lm, data = df_input), # lm(data = df_input, 各formulas$formula) tidied = map(model, tidy), # 回帰モデルの係数 to tidy tibble glanced = map(model, glance)) # 回帰モデルの評価 # => # # A tibble: 3 x 5 # model_no formula model tidied glanced # <int> <list> <list> <list> <list> # 1 1 <S3: formula> <S3: lm> <tibble [8 × 5]> <tibble [1 × 11]> # 2 2 <S3: formula> <S3: lm> <tibble [9 × 5]> <tibble [1 × 11]> # 3 3 <S3: formula> <S3: lm> <tibble [9 × 5]> <tibble [1 × 11]>
map()
map(.x, .f, ...)
関数を用いると、xのそれぞれのデータに対して関数/formula fを適応させる。また、f以降の引数はfの引数として扱われる。
今回の処理をmap関数を用いないで、formulasの1行目(log(price) ~ clarity)だけに適応させると以下になる。
lm(data = df_input, log(price) ~ clarity) # => # Call: # lm(formula = log(price) ~ clarity, data = df_input) # # Coefficients: # (Intercept) claritySI2 claritySI1 clarityVS2 clarityVS1 clarityVVS2 clarityVVS1 # 8.0276 0.1392 -0.1797 -0.2646 -0.3029 -0.4968 -0.7047 # clarityIF # -0.6225
そのため、model列はformulasの各行に入っているformulaに対して適応させた結果が入っている。
tidy()
tidy(x, ...)
はxのオブジェクトをtidyな形のtibbleに変換する。今回の場合は、model列の内容(係数)をxとして変換している。
今回のtidied列に入る結果もmap関数を用いない場合以下のようになる。
lm(data = df_input, log(price) ~ clarity) %>% #model列 tidy() # modelの結果の係数をtidyなtibble # A tibble: 8 x 5 # term estimate std.error statistic p.value # <chr> <dbl> <dbl> <dbl> <dbl> # 1 (Intercept) 8.03 0.0363 221. 0. # 2 claritySI2 0.139 0.0377 3.69 2.28e- 4 # 3 claritySI1 -0.180 0.0373 -4.81 1.49e- 6 # 4 clarityVS2 -0.265 0.0374 -7.08 1.50e-12 # 5 clarityVS1 -0.303 0.0379 -7.99 1.41e-15 # 6 clarityVVS2 -0.497 0.0389 -12.8 2.44e-37 # 7 clarityVVS1 -0.705 0.0398 -17.7 7.20e-70 # 8 clarityIF -0.622 0.0432 -14.4 5.03e-47
glance
glance(x, ...)
はxのmodelオブジェクトのsummaryを1行に変換する。今回の場合は、model列の内容をxとして変換している。
今回のglanced列に入る結果もmap関数を用いない場合以下のようになる。
lm(data = df_input, log(price) ~ clarity) %>% #model列 glance() # modelの結果のモデル評価(summaryの一部)を1行で # => # # A tibble: 1 x 11 # r.squared adj.r.squared sigma statistic p.value df logLik AIC BIC deviance # <dbl> <dbl> <dbl> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <dbl> # 1 0.0511 0.0510 0.988 415. 0 8 -75907. 1.52e5 1.52e5 52694. # # … with 1 more variable: df.residual <int>
モデルの係数比較
先程の結果df_resultの係数部分(tidied列)の中身を見るために`unnest()'する。
# model結果の係数を下二桁 df_coef = df_result %>% select(model_no, tidied) %>% unnest() %>% mutate_if(is.double, round, digits=2) # => # # A tibble: 26 x 6 # model_no term estimate std.error statistic p.value # <int> <chr> <dbl> <dbl> <dbl> <dbl> # 1 1 (Intercept) 8.03 0.04 221. 0 # 2 1 claritySI2 0.14 0.04 3.69 0 # 3 1 claritySI1 -0.18 0.04 -4.81 0 # 4 1 clarityVS2 -0.26 0.04 -7.08 0 # 5 1 clarityVS1 -0.3 0.04 -7.99 0 # 6 1 clarityVVS2 -0.5 0.04 -12.8 0 # 7 1 clarityVVS1 -0.7 0.04 -17.7 0 # 8 1 clarityIF -0.62 0.04 -14.4 0 # 9 2 (Intercept) 5.36 0.01 372. 0 # 10 2 claritySI2 0.570 0.01 40.1 0 # # … with 16 more rows
このままでは比較がしづらいので、横持ちに変換する。
# 係数を横持ち化 df_coef %>% mutate(term = fct_inorder(term)) %>% # デフォルト(アルファベット順)から出てきた順にする dplyr::select(model_no, term, estimate) %>% spread(model_no, estimate) # => # # A tibble: 10 x 4 # term `1` `2` `3` # <fct> <dbl> <dbl> <dbl> # 1 (Intercept) 8.03 5.36 7.77 # 2 claritySI2 0.14 0.570 0.48 # 3 claritySI1 -0.18 0.72 0.62 # 4 clarityVS2 -0.26 0.82 0.78 # 5 clarityVS1 -0.3 0.86 0.82 # 6 clarityVVS2 -0.5 0.93 0.98 # 7 clarityVVS1 -0.7 0.92 1.03 # 8 clarityIF -0.62 1 1.11 # 9 carat NA 2.08 NA # 10 log(carat) NA NA 1.81
結果を見ると、
- mode1では、カラット数の交絡の影響で、透明度が上がるほど大きな負の係数となる
- model2では、カラット数の交絡の影響を統制したため、透明度が上がるほど大きな正の係数となる。
- model3では、model2とほぼ同様だがcaratに対数を取って統制をとったことで透明度の影響(係数)にやや変化が出ている。
modelの評価
モデルの評価が入っているglanced列に対しても同様に比較する。
# モデルの評価 df_result %>% dplyr::select(model_no, glanced) %>% unnest() %>% mutate_if(is.double, round, digits=2) # => # # A tibble: 3 x 12 # model_no r.squared adj.r.squared sigma statistic p.value df logLik AIC BIC # <int> <dbl> <dbl> <dbl> <dbl> <dbl> <int> <dbl> <dbl> <dbl> # 1 1 0.05 0.05 0.99 415. 0 8 -75907. 1.52e5 1.52e5 # 2 2 0.87 0.87 0.37 43737. 0 9 -23024. 4.61e4 4.62e4 # 3 3 0.97 0.97 0.19 187918. 0 9 13378. -2.67e4 -2.66e4 # # … with 2 more variables: deviance <dbl>, df.residual <int>
このとき、自由度調整済み決定係数adj.r.squaredを見るとmodel3が1番性能が高くなっている。model2と比較すると、OLSを用いているので価格とカラット数を線形関係に変換した方が性能が高くなる、というあ妥当な結果になっている。
まとめ
今回のコードをまとめると以下のようになる。
流れとしては、 - 1.回帰モデルを1オブジェクト(formulas)に作成する - 2.map関数を用いることで、formulasに対しfit, 結果, 評価を整理して1オブジェクトに格納 - 3.見たい部分に対してunnest()で結果を見る
# モデル比較の下準備------- # 回帰用前処理 df_input = df %>% mutate_if(is.ordered, factor, ordered = FALSE)# factor(ordered = FALSE)に変換 # 回帰モデル # 1.価格を透明度のみで説明するモデル # 2.透明度とカラット数で説明するモデル # 3.上と同じだが、カラット数に対数をとったモデル formulas = c(log(price) ~ clarity, log(price) ~ clarity + carat, log(price) ~ clarity + log(carat)) %>% enframe("model_no", "formula") #vactor to DF # モデルとその評価を管理しやすい形にする df_result = formulas %>% mutate(model = map(formula, lm, data = df_input), # lm(data = df_input, 各formulas$formula) tidied = map(model, tidy), # 回帰モデルの係数 to tidy tibble glanced = map(model, glance)) # 回帰モデルの評価 # モデルの係数比較---------- # model結果の係数を下二桁 df_coef = df_result %>% dplyr::select(model_no, tidied) %>% unnest() %>% mutate_if(is.double, round, digits=2) # 係数を横持ち化 df_coef %>% mutate(term = fct_inorder(term)) %>% # デフォルト(アルファベット順)から出てきた順にする dplyr::select(model_no, term, estimate) %>% spread(model_no, estimate) # モデルの評価比較------- # model結果の評価を下二桁 df_result %>% dplyr::select(model_no, glanced) %>% unnest() %>% mutate_if(is.double, round, digits=2)