用 dplyr 拟合几个回归模型

我想用 dplyr 来拟合每个小时(factor 变量)的模型,我得到了一个错误,我不太确定哪里出错了。

df.h <- data.frame(
hour     = factor(rep(1:24, each = 21)),
price    = runif(504, min = -10, max = 125),
wind     = runif(504, min = 0, max = 2500),
temp     = runif(504, min = - 10, max = 25)
)


df.h <- tbl_df(df.h)
df.h <- group_by(df.h, hour)


group_size(df.h) # checks out, 21 obs. for each factor variable


# different attempts:
reg.models <- do(df.h, formula = price ~ wind + temp)


reg.models <- do(df.h, .f = lm(price ~ wind + temp, data = df.h))

我试过各种不同的方法,但是我不能让它起作用。

48346 次浏览

do的文件:

.f: 一个应用于每个部件的函数。提供给.f 的第一个未命名参数是一个数据帧。

所以:

reg.models <- do(df.h,
.f=function(data){
lm(price ~ wind + temp, data=data)
})

也许还有助于节省安装模型的时间:

reg.models <- do(df.h,
.f=function(data){
m <- lm(price ~ wind + temp, data=data)
m$hour <- unique(data$hour)
m
})

我认为你可以用更恰当的方式使用 dplyr,在这里你不需要定义函数为@Fabian 回答。

results<-df.h %.%
group_by(hour) %.%
do(failwith(NULL, lm), formula = price ~ wind + temp)

或者

results<-do(group_by(tbl_df(df.h), hour),
failwith(NULL, lm), formula = price ~ wind + temp)

编辑: 当然,没有 failwith它也能工作

results<-df.h %.%
group_by(hour) %.%
do(lm, formula = price ~ wind + temp)




results<-do(group_by(tbl_df(df.h), hour),
lm, formula = price ~ wind + temp)

在 dplyr 0.4中,你可以:

df.h %>% do(model = lm(price ~ wind + temp, data = .))

最简单的方法做到这一点,大约2015年5月是使用 broombroom包含三个功能,处理分组统计操作返回的复杂对象: tidy(处理分组统计操作的系数向量)、 glance(处理分组统计操作的汇总统计)和 augment(处理分组统计操作的观察水平结果)。

下面是一个演示,它可以将不同组别的线性回归结果提取到整齐的字符串中。

  1. tidy :

    library(dplyr)
    library(broom)
    
    
    df.h = data.frame(
    hour     = factor(rep(1:24, each = 21)),
    price    = runif(504, min = -10, max = 125),
    wind     = runif(504, min = 0, max = 2500),
    temp     = runif(504, min = - 10, max = 25)
    )
    
    
    dfHour = df.h %>% group_by(hour) %>%
    do(fitHour = lm(price ~ wind + temp, data = .))
    
    
    # get the coefficients by group in a tidy data_frame
    dfHourCoef = tidy(dfHour, fitHour)
    dfHourCoef
    

    也就是说,

        Source: local data frame [72 x 6]
    Groups: hour
    
    
    hour        term     estimate   std.error  statistic     p.value
    1     1 (Intercept) 53.336069324 21.33190104  2.5002961 0.022294293
    2     1        wind -0.008475175  0.01338668 -0.6331053 0.534626575
    3     1        temp  1.180019541  0.79178607  1.4903262 0.153453756
    4     2 (Intercept) 77.737788772 23.52048754  3.3051096 0.003936651
    5     2        wind -0.008437212  0.01432521 -0.5889765 0.563196358
    6     2        temp -0.731265113  1.00109489 -0.7304653 0.474506855
    7     3 (Intercept) 38.292039924 17.55361626  2.1814331 0.042655670
    8     3        wind  0.005422492  0.01407478  0.3852630 0.704557388
    9     3        temp  0.426765270  0.83672863  0.5100402 0.616220435
    10    4 (Intercept) 30.603119492 21.05059583  1.4537888 0.163219027
    ..  ...         ...          ...         ...        ...         ...
    
  2. augment:

     # get the predictions by group in a tidy data_frame
    dfHourPred = augment(dfHour, fitHour)
    dfHourPred
    

    也就是说,

    Source: local data frame [504 x 11]
    Groups: hour
    
    
    hour       price      wind      temp  .fitted  .se.fit     .resid       .hat   .sigma      .cooksd .std.resid
    1     1  83.8414055   67.3780 -6.199231 45.44982 22.42649  38.391590 0.27955950 42.24400 0.1470891067  1.0663820
    2     1   0.3061628 2073.7540 15.134085 53.61916 14.10041 -53.312993 0.11051343 41.43590 0.0735584714 -1.3327207
    3     1  80.3790032  520.5949 24.711938 78.08451 20.03558   2.294497 0.22312869 43.64059 0.0003606305  0.0613746
    4     1 121.9023855 1618.0864 12.382588 54.23420 10.31293  67.668187 0.05911743 40.23212 0.0566557575  1.6447224
    5     1  -0.4039594 1542.8150 -5.544927 33.71732 14.53349 -34.121278 0.11740628 42.74697 0.0325125137 -0.8562896
    6     1  29.8269832  396.6951  6.134694 57.21307 16.04995 -27.386085 0.14318542 43.05124 0.0271028701 -0.6975290
    7     1  -7.1865483 2009.9552 -5.657871 29.62495 16.93769 -36.811497 0.15946292 42.54487 0.0566686969 -0.9466312
    8     1  -7.8548693 2447.7092 22.043029 58.60251 19.94686 -66.457379 0.22115706 39.63999 0.2983443034 -1.7753911
    9     1  94.8736726 1525.3144 24.484066 69.30044 15.93352  25.573234 0.14111563 43.12898 0.0231796755  0.6505701
    10    1  54.4643001 2473.2234 -7.656520 23.34022 21.83043  31.124076 0.26489650 42.74790 0.0879837510  0.8558507
    ..  ...         ...       ...       ...      ...      ...        ...        ...      ...          ...        ...
    
  3. glance:

    # get the summary statistics by group in a tidy data_frame
    dfHourSumm = glance(dfHour, fitHour)
    dfHourSumm
    

    也就是说,

    Source: local data frame [24 x 12]
    Groups: hour
    
    
    hour  r.squared adj.r.squared    sigma statistic    p.value df    logLik      AIC      BIC deviance df.residual
    1     1 0.12364561    0.02627290 42.41546 1.2698179 0.30487225  3 -106.8769 221.7538 225.9319 32383.29          18
    2     2 0.03506944   -0.07214506 36.79189 0.3270961 0.72521125  3 -103.8900 215.7799 219.9580 24365.58          18
    3     3 0.02805424   -0.07993974 39.33621 0.2597760 0.77406651  3 -105.2942 218.5884 222.7665 27852.07          18
    4     4 0.17640603    0.08489559 41.37115 1.9277147 0.17434859  3 -106.3534 220.7068 224.8849 30808.30          18
    5     5 0.12575453    0.02861615 42.27865 1.2945915 0.29833246  3 -106.8091 221.6181 225.7962 32174.72          18
    6     6 0.08114417   -0.02095092 35.80062 0.7947901 0.46690268  3 -103.3164 214.6328 218.8109 23070.31          18
    7     7 0.21339168    0.12599076 32.77309 2.4415266 0.11529934  3 -101.4609 210.9218 215.0999 19333.36          18
    8     8 0.21655629    0.12950699 40.92788 2.4877430 0.11119114  3 -106.1272 220.2543 224.4324 30151.65          18
    9     9 0.23388711    0.14876346 35.48431 2.7476160 0.09091487  3 -103.1300 214.2601 218.4381 22664.45          18
    10   10 0.18326177    0.09251307 40.77241 2.0194425 0.16171339  3 -106.0472 220.0945 224.2726 29923.01          18
    ..  ...        ...           ...      ...       ...        ... ..       ...      ...      ...      ...         ...
    

截至2020年中期 (并更新到适合 dplyr1.0 + 截至2022-04年) ,broom0将失败。为了规避 broomdpylr似乎相互作用的新方法,可以使用以下 broom::tidybroom::augmentbroom::glance的组合。我们只需要将它们与 nest_by()summarize()结合使用(之前在 do()中使用,后来在 unnest()中使用)。

library(dplyr)
library(broom)
library(tidyr)


set.seed(42)
df.h = data.frame(
hour     = factor(rep(1:24, each = 21)),
price    = runif(504, min = -10, max = 125),
wind     = runif(504, min = 0, max = 2500),
temp     = runif(504, min = - 10, max = 25)
)


df.h %>%
nest_by(hour) %>%
mutate(mod = list(lm(price ~ wind + temp, data = data))) %>%
summarize(tidy(mod))
# # A tibble: 72 × 6
# # Groups:   hour [24]
#    hour  term        estimate std.error statistic   p.value
#    <fct> <chr>          <dbl>     <dbl>     <dbl>     <dbl>
# 1  1     (Intercept) 87.4       15.8        5.55  0.0000289
# 2  1     wind        -0.0129     0.0120    -1.08  0.296
# 3  1     temp         0.588      0.849      0.693 0.497
# 4  2     (Intercept) 92.3       21.6        4.27  0.000466
# 5  2     wind        -0.0227     0.0134    -1.69  0.107
# 6  2     temp        -0.216      0.841     -0.257 0.800
# 7  3     (Intercept) 61.1       18.6        3.29  0.00409
# 8  3     wind         0.00471    0.0128     0.367 0.718
# 9  3     temp         0.425      0.964      0.442 0.664
# 10 4     (Intercept) 31.6       15.3        2.07  0.0529


df.h %>%
nest_by(hour) %>%
mutate(mod = list(lm(price ~ wind + temp, data = data))) %>%
summarize(augment(mod))
# # A tibble: 504 × 10
# # Groups:   hour [24]
#    hour   price  wind   temp .fitted .resid   .hat .sigma  .cooksd .std.resid
#    <fct>  <dbl> <dbl>  <dbl>   <dbl>  <dbl>  <dbl>  <dbl>    <dbl>      <dbl>
#  1 1     113.    288. -1.75     82.7  30.8  0.123    37.8 0.0359       0.877
#  2 1     117.   2234. 18.4      69.5  47.0  0.201    36.4 0.165        1.40
#  3 1      28.6  1438.  4.75     71.7 -43.1  0.0539   37.1 0.0265      -1.18
#  4 1     102.    366.  9.77     88.5  13.7  0.151    38.4 0.00926      0.395
#  5 1      76.6  2257. -4.69     55.6  21.0  0.245    38.2 0.0448       0.644
#  6 1      60.1   633. -3.18     77.4 -17.3  0.0876   38.4 0.00749     -0.484
#  7 1      89.4   376. -4.16     80.1   9.31 0.119    38.5 0.00314      0.264
#  8 1       8.18 1921. 19.2      74.0 -65.9  0.173    34.4 0.261       -1.93
#  9 1      78.7   575. -6.11     76.4   2.26 0.111    38.6 0.000170     0.0640
# 10 1      85.2   763. -0.618    77.2   7.94 0.0679   38.6 0.00117      0.219
# # … with 494 more rows


df.h %>%
nest_by(hour) %>%
mutate(mod = list(lm(price ~ wind + temp, data = data))) %>%
summarize(glance(mod))
# # A tibble: 24 × 13
# # Groups:   hour [24]
#    hour  r.squared adj.r.squared sigma statistic p.value    df logLik   AIC
#    <fct>     <dbl>         <dbl> <dbl>     <dbl>   <dbl> <dbl>  <dbl> <dbl>
#  1 1        0.0679       -0.0357  37.5     0.655   0.531     2  -104.  217.
#  2 2        0.139         0.0431  42.7     1.45    0.261     2  -107.  222.
#  3 3        0.0142       -0.0953  43.1     0.130   0.879     2  -107.  222.
#  4 4        0.0737       -0.0293  36.7     0.716   0.502     2  -104.  216.
#  5 5        0.213         0.126   37.8     2.44    0.115     2  -104.  217.
#  6 6        0.0813       -0.0208  33.5     0.796   0.466     2  -102.  212.
#  7 7        0.0607       -0.0437  40.7     0.582   0.569     2  -106.  220.
#  8 8        0.153         0.0592  36.3     1.63    0.224     2  -104.  215.
#  9 9        0.166         0.0736  36.5     1.79    0.195     2  -104.  216.
# 10 10       0.110         0.0108  40.0     1.11    0.351     2  -106.  219.
# # … with 14 more rows, and 4 more variables: BIC <dbl>, deviance <dbl>,
# #   df.residual <int>, nobs <int>

感谢 Bob Muenchen 的博客的灵感。

我相信有一个比 洛基的回答更简洁的答案,它放弃了已经被取代的/被取代了 do():

library(dplyr)
library(broom)
library(tidyr)


h.lm <- df.h %>%
nest_by(hour) %>%
mutate(fitHour = list(lm(price ~ wind + temp, data = data))) %>%
summarise(tidy_out = list(tidy(fitHour)),
glance_out = list(glance(fitHour)),
augment_out = list(augment(fitHour))) %>%
ungroup()


h.lm
# # A tibble: 24 x 4
#    hour  tidy_out         glance_out        augment_out
#    <fct> <list>           <list>            <list>
#  1 1     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  2 2     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  3 3     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  4 4     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  5 5     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  6 6     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  7 7     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  8 8     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  9 9     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
# 10 10    <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
# # … with 14 more rows

类似于他们的回答,为了访问,只需卸载任何需要的组件:

unnest(select(h.lm, hour, tidy_out))
# # A tibble: 72 x 6
#    hour  term        estimate std.error statistic p.value
#    <fct> <chr>          <dbl>     <dbl>     <dbl>   <dbl>
#  1 1     (Intercept) 63.2       20.9        3.02  0.00728
#  2 1     wind        -0.00237    0.0139    -0.171 0.866
#  3 1     temp        -0.266      0.950     -0.280 0.783
#  4 2     (Intercept) 65.1       23.0        2.83  0.0111
#  5 2     wind         0.00691    0.0129     0.535 0.599
#  6 2     temp        -0.448      0.877     -0.510 0.616
#  7 3     (Intercept) 65.2       17.8        3.67  0.00175
#  8 3     wind         0.00515    0.0112     0.458 0.652
#  9 3     temp        -1.87       0.695     -2.69  0.0148
# 10 4     (Intercept) 49.7       17.6        2.83  0.0111
# # … with 62 more rows

从 dplyr 1.0.0开始,group_split为这个操作提供了一个方便的快捷方式:

library(dplyr)
library(broom)
library(purrr)
df.h <- data.frame(
hour     = factor(rep(1:24, each = 21)),
price    = runif(504, min = -10, max = 125),
wind     = runif(504, min = 0, max = 2500),
temp     = runif(504, min = - 10, max = 25)
)


df.g <- group_split(df.h, hour)
map_dfr(df.g, function(x) tidy(lm(price ~ wind + temp, data=x)))
#> # A tibble: 72 x 5
#>    term        estimate std.error statistic p.value
#>    <chr>          <dbl>     <dbl>     <dbl>   <dbl>
#>  1 (Intercept) -10.4      20.3       -0.512 0.615
#>  2 wind          0.0377    0.0117     3.23  0.00467
#>  3 temp          1.34      0.890      1.50  0.150
#>  4 (Intercept)  34.6      18.6        1.86  0.0799
#>  5 wind          0.0214    0.0125     1.71  0.104
#>  6 temp          0.332     0.865      0.384 0.706
#>  7 (Intercept)  42.5      15.3        2.79  0.0122
#>  8 wind          0.0103    0.0116     0.888 0.386
#>  9 temp         -0.542     0.736     -0.736 0.471
#> 10 (Intercept)  64.1      18.8        3.41  0.00312
#> # … with 62 more rows

Reprex 软件包于2021.03-04创建(v1.0.0)

几个修订的潮流后期的 do()运算符被取代,我们可以适应一个模型每组一行代码少。

library("broom")
library("tidyverse")


df.h <- data.frame(
hour     = factor(rep(1:24, each = 21)),
price    = runif(504, min = -10, max = 125),
wind     = runif(504, min = 0, max = 2500),
temp     = runif(504, min = -10, max = 25)
)


df.h %>%
group_by(hour) %>%
group_modify(
# Use `tidy`, `glance` or `augment` to extract different information from the fitted models.
~ tidy(lm(price ~ wind + temp, data = .))
)
#> # A tibble: 72 × 6
#> # Groups:   hour [24]
#>    hour  term        estimate std.error statistic  p.value
#>    <fct> <chr>          <dbl>     <dbl>     <dbl>    <dbl>
#>  1 1     (Intercept) 73.9      16.3         4.52  0.000266
#>  2 1     wind        -0.0256    0.0119     -2.15  0.0456
#>  3 1     temp         1.72      0.861       2.00  0.0604
#>  4 2     (Intercept) 81.5      18.4         4.42  0.000331
#>  5 2     wind        -0.0111    0.00973    -1.14  0.270
#>  6 2     temp        -1.60      0.763      -2.09  0.0506
#>  7 3     (Intercept) 59.9      16.1         3.73  0.00154
#>  8 3     wind         0.00358   0.0102      0.349 0.731
#>  9 3     temp        -1.82      0.664      -2.74  0.0134
#> 10 4     (Intercept) 49.6      18.5         2.69  0.0151
#> # … with 62 more rows

Reprex 软件包于2022.04-20年度创作(v2.0.1)