Skip to content

Commit

Permalink
Make sure partial is consistent across methods and works for several …
Browse files Browse the repository at this point in the history
…x.var
  • Loading branch information
mhesselbarth committed Dec 21, 2023
1 parent 1d073d1 commit 8364d4f
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 137 deletions.
4 changes: 2 additions & 2 deletions R/bdproto-distributionmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ DistributionModel <- bdproto(
inlabru::gg(self$get_data(x)$summary.fixed, bar = TRUE)
} else if(inherits(self, 'BART-Model')){
message('Calculating partial dependence plots')
self$partial(self$get_data(x), x.vars = what, ...)
self$partial(self$get_data(x), x.var = what, ...)
} else if(inherits(self, 'BREG-Model')){
obj <- self$get_data(x)
if(what == "fixed") what <- "coefficients"
Expand All @@ -367,7 +367,7 @@ DistributionModel <- bdproto(
xgboost::xgb.plot.multi.trees(obj)
}
} else {
self$partial(self$get_data(x), x.vars = NULL)
self$partial(self$get_data(x), x.var = NULL)
}
},
# Get equation
Expand Down
10 changes: 5 additions & 5 deletions R/engine_bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -537,18 +537,18 @@ engine_bart <- function(x,
partial = function(self, x.var = NULL, constant = NULL, variable_length = 100,
values = NULL, newdata = NULL, plot = FALSE, type = NULL, ...){
model <- self$get_data('fit_best')
assertthat::assert_that(x.var %in% attr(model$fit$data@x,'term.labels') || is.null(x.var),
assertthat::assert_that(all(x.var %in% attr(model$fit$data@x,'term.labels')) || is.null(x.var),
msg = 'Variable not in predicted model' )
if(is.null(newdata)){
bart_partial_effect(model, x.vars = x.var,
bart_partial_effect(model, x.var = x.var,
transform = self$settings$data$binary,
variable_length = variable_length,
values = values,
equal = TRUE,
plot = plot )
} else {
# Set the values to newdata
bart_partial_effect(model, x.vars = x.var,
bart_partial_effect(model, x.var = x.var,
transform = self$settings$data$binary,
values = newdata[[x.var]],
plot = plot)
Expand All @@ -563,10 +563,10 @@ engine_bart <- function(x,
predictors <- model$predictors_object$get_data()
} else {
predictors <- newdata
assertthat::assert_that(x.var %in% colnames(predictors),
assertthat::assert_that(all(x.var %in% colnames(predictors)),
msg = 'Variable not in provided data!')
}
assertthat::assert_that(x.var %in% attr(fit$fit$data@x,'term.labels'),
assertthat::assert_that(all(x.var %in% attr(fit$fit$data@x,'term.labels')),
msg = 'Variable not in predicted model' )

if( model$biodiversity[[1]]$family != 'binomial' && transform) warning('Check whether transform should not be set to False!')
Expand Down
90 changes: 54 additions & 36 deletions R/engine_breg.R
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ engine_breg <- function(x,
if(is.null(x.var)){
x.var <- colnames(df)
} else {
x.var <- match.arg(x.var, colnames(df), several.ok = FALSE)
x.var <- match.arg(x.var, colnames(df), several.ok = TRUE)
}

if(is.null(newdata)){
Expand All @@ -569,67 +569,85 @@ engine_breg <- function(x,
} else {
rr <- sapply(df, function(x) range(x, na.rm = TRUE)) |> as.data.frame()
}
assertthat::assert_that(nrow(rr)>1, ncol(rr)>=1)

df_partial <- list()
if(!is.null(values)){ assertthat::assert_that(length(values) >= 1) }

# Add all others as constant
if(is.null(constant)){
for(n in names(rr)) df_partial[[n]] <- rep( mean(df[[n]], na.rm = TRUE), variable_length )
} else {
for(n in names(rr)) df_partial[[n]] <- rep( constant, variable_length )
}
} else {
df_partial <- newdata |> dplyr::select(dplyr::any_of(names(df)))
}

# create list to store results
o <- vector(mode = "list", length = length(x.var))
names(o) <- x.var

# loop through x.var
for(v in x.var) {

df2 <- df_partial

if(!is.null(values)){
df_partial[[x.var]] <- values
df2[[v]] <- values
} else {
df_partial[[x.var]] <- seq(rr[1,x.var], rr[2,x.var], length.out = variable_length)
df2[[v]] <- seq(rr[1, v], rr[2, v], length.out = variable_length)
}
df2 <- as.data.frame(df2)

df_partial <- df_partial |> as.data.frame()
if(any(model$predictors_types$type=="factor")){
lvl <- levels(model$predictors[[model$predictors_types$predictors[model$predictors_types$type=="factor"]]])
df_partial[model$predictors_types$predictors[model$predictors_types$type=="factor"]] <-
factor(lvl[1], levels = lvl)
df2[model$predictors_types$predictors[model$predictors_types$type=="factor"]] <- factor(lvl[1], levels = lvl)
}
} else {
df_partial <- newdata |> dplyr::select(dplyr::any_of(names(df)))
}

# For Integrated model, take the last one
fam <- model$biodiversity[[length(model$biodiversity)]]$family
# For Integrated model, take the last one
fam <- model$biodiversity[[length(model$biodiversity)]]$family

pred_breg <- predict_boom(
obj = mod,
newdata = df_partial,
w = unique(w)[2], # The second entry of unique contains the non-observed variables
fam = fam,
params = settings$data # Use the settings as list
) # Also attach the partial variable
pred_breg <- predict_boom(
obj = mod,
newdata = df2,
w = unique(w)[2], # The second entry of unique contains the non-observed variables
fam = fam,
params = settings$data # Use the settings as list
) # Also attach the partial variable

# Summarize the partial effect
pred_part <- cbind(
matrixStats::rowMeans2(pred_breg, na.rm = TRUE),
matrixStats::rowSds(pred_breg, na.rm = TRUE),
matrixStats::rowQuantiles(pred_breg, probs = c(.05,.5,.95), na.rm = TRUE),
apply(pred_breg, 1, mode)
) |> as.data.frame()
names(pred_part) <- c("mean", "sd", "q05", "q50", "q95", "mode")
pred_part$cv <- pred_part$sd / pred_part$mean
# And attach the variable
pred_part <- cbind("partial_effect" = df_partial[[x.var]], pred_part)
# Summarize the partial effect
pred_part <- cbind(
matrixStats::rowMeans2(pred_breg, na.rm = TRUE),
matrixStats::rowSds(pred_breg, na.rm = TRUE),
matrixStats::rowQuantiles(pred_breg, probs = c(.05,.5,.95), na.rm = TRUE),
apply(pred_breg, 1, mode)
) |> as.data.frame()

names(pred_part) <- c("mean", "sd", "q05", "q50", "q95", "mode")
pred_part$cv <- pred_part$sd / pred_part$mean
pred_part <- cbind("partial_effect" = df2[[v]], pred_part)

# Add variable name for consistency
pred_part <- cbind("variable" = v, pred_part)

o[[v]] <- pred_part

}

o <- do.call(what = rbind, args = c(o, make.row.names = FALSE))

if(plot){
# Make a plot
g <- ggplot2::ggplot(data = pred_part,
ggplot2::aes(x = partial_effect, y = q50, ymin = q05, ymax = q95)) +
g <- ggplot2::ggplot(data = o, ggplot2::aes(x = partial_effect)) +
ggplot2::theme_classic(base_size = 18) +
ggplot2::geom_ribbon(fill = 'grey90') +
ggplot2::geom_line() +
ggplot2::labs(x = paste0("partial of ",x.var), y = expression(hat(y)))
ggplot2::geom_ribbon(aes(ymin = q05, ymax = q95), fill = "grey90") +
ggplot2::geom_line(aes(y = mean)) +
ggplot2::facet_wrap(. ~ variable, scales = "free") +
ggplot2::labs(x = "", y = "Partial effect")
print(g)
}
# Return the data
return(pred_part)
return(o)
},
# Spatial partial dependence plot
spartial = function(self, x.var, constant = NULL, newdata = NULL,
Expand Down
11 changes: 6 additions & 5 deletions R/engine_gdb.R
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,12 @@ engine_gdb <- function(x,

# If plot, make plot, otherwise
if(plot){
par.ori <- graphics::par(no.readonly = TRUE)
graphics::par(mfrow = c(1,2))
mboost::plot.mboost(self$get_data('fit_best'), which = x.var, newdata = dummy)
if(utils::hasName(par.ori, "pin")) par.ori$pin <- NULL
graphics::par(par.ori)
g <- ggplot2::ggplot(data = out, ggplot2::aes(x = partial_effect)) +
ggplot2::theme_classic() +
ggplot2::geom_line(aes(y = mean)) +
ggplot2::facet_wrap(. ~ variable, scales = "free") +
ggplot2::labs(x = "", y = "Partial effect")
print(g)
}
return(out)
},
Expand Down
14 changes: 7 additions & 7 deletions R/engine_glm.R
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ engine_glm <- function(x,
if(is.null(x.var)){
x.var <- colnames(df)
} else {
x.var <- match.arg(x.var, names(df), several.ok = FALSE)
x.var <- match.arg(x.var, names(df), several.ok = TRUE)
}

# Calculate range of predictors
Expand Down Expand Up @@ -512,19 +512,19 @@ engine_glm <- function(x,
}
p1 <- p1[,c(v, "yhat")]
names(p1) <- c("partial_effect", "mean")
p1$variable <- v
p1 <- cbind(variable = v, p1)
pp <- rbind(pp, p1)
rm(p1)
if(length(x.var) > 1) pb$tick()
}

if(plot){
# Make a plot
g <- ggplot2::ggplot(data = pp, ggplot2::aes(x = partial_effect, y = mean)) +
ggplot2::theme_classic(base_size = 18) +
ggplot2::geom_line() +
ggplot2::labs(x = "", y = expression(hat(y))) +
ggplot2::facet_wrap(~variable,scales = 'free')
g <- ggplot2::ggplot(data = pp, ggplot2::aes(x = partial_effect)) +
ggplot2::theme_classic() +
ggplot2::geom_line(aes(y = mean)) +
ggplot2::facet_wrap(. ~ variable, scales = "free") +
ggplot2::labs(x = "", y = "Partial effect")
print(g)
}
return(pp)
Expand Down
14 changes: 7 additions & 7 deletions R/engine_glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ engine_glmnet <- function(x,
if(is.null(x.var)){
x.var <- colnames(df)
} else {
x.var <- match.arg(x.var, names(df), several.ok = FALSE)
x.var <- match.arg(x.var, names(df), several.ok = TRUE)
}

# Calculate range of predictors
Expand Down Expand Up @@ -634,19 +634,19 @@ engine_glmnet <- function(x,
}
p1 <- p1[,c(v, "yhat")]
names(p1) <- c("partial_effect", "mean")
p1$variable <- v
p1 <- cbind(variable = v, p1)
pp <- rbind(pp, p1)
rm(p1)
if(length(x.var) > 1) pb$tick()
}

if(plot){
# Make a plot
g <- ggplot2::ggplot(data = pp, ggplot2::aes(x = partial_effect, y = mean)) +
ggplot2::theme_classic(base_size = 18) +
ggplot2::geom_line() +
ggplot2::labs(x = "", y = expression(hat(y))) +
ggplot2::facet_wrap(~variable,scales = 'free')
g <- ggplot2::ggplot(data = pp, ggplot2::aes(x = partial_effect)) +
ggplot2::theme_classic() +
ggplot2::geom_line(aes(y = mean)) +
ggplot2::facet_wrap(. ~ variable, scales = "free") +
ggplot2::labs(x = "", y = "Partial effect")
print(g)
}
return(pp)
Expand Down
94 changes: 58 additions & 36 deletions R/engine_inlabru.R
Original file line number Diff line number Diff line change
Expand Up @@ -1068,10 +1068,16 @@ engine_inlabru <- function(x,

# Match variable name
if(!is.null(mod$summary.random)) vn <- names(mod$summary.random) else vn <- ""
x.var <- match.arg(x.var, c( mod$names.fixed, vn), several.ok = FALSE)

if(x.var == ""){
x.var <- colnames(df)
} else {
x.var <- match.arg(x.var, c( mod$names.fixed, vn), several.ok = TRUE)
}


if(is.null(newdata)){
# Make a prediction via inlabru
# Calculate range of predictors
if(any(model$predictors_types$type=="factor")){
rr <- sapply(df[model$predictors_types$predictors[model$predictors_types$type=="numeric"]],
function(x) range(x, na.rm = TRUE)) |> as.data.frame()
Expand All @@ -1089,55 +1095,71 @@ engine_inlabru <- function(x,
} else {
for(n in names(rr)) df_partial[[n]] <- rep( constant, variable_length )
}
} else {
df_partial <- newdata |> dplyr::select(dplyr::any_of(names(df)))
}

# create list to store results
o <- vector(mode = "list", length = length(x.var))
names(o) <- x.var

# loop through x.var
for(v in x.var) {

df2 <- df_partial

if(!is.null(values)){
df_partial[[x.var]] <- values
df2[[v]] <- values
} else {
df_partial[[x.var]] <- seq(rr[1,x.var], rr[2,x.var], length.out = variable_length)
df2[[v]] <- seq(rr[1, v], rr[2, v], length.out = variable_length)
}
df_partial <- df_partial |> as.data.frame()
df2 <- as.data.frame(df2)

if(any(model$predictors_types$type=="factor")){
lvl <- levels(model$predictors[[model$predictors_types$predictors[model$predictors_types$type=="factor"]]])
df_partial[model$predictors_types$predictors[model$predictors_types$type=="factor"]] <-
factor(lvl[1], levels = lvl)
df2[model$predictors_types$predictors[model$predictors_types$type=="factor"]] <- factor(lvl[1], levels = lvl)
}
} else {
df_partial <- newdata |> dplyr::select(dplyr::any_of(names(df)))
}

## plot the unique effect of the covariate
fun <- ifelse(length(model$biodiversity) == 1 && model$biodiversity[[1]]$type == 'poipa', "logistic", "exp")
pred_cov <- predict(object = mod,
newdata = df_partial,
formula = stats::as.formula( paste("~ ",fun,"(", paste(mod$names.fixed,collapse = " + ") ,")") ),
n.samples = 100,
probs = c(0.05,0.5,0.95)
)
pred_cov$cv <- pred_cov$sd / pred_cov$mean
## plot the unique effect of the covariate
fun <- ifelse(length(model$biodiversity) == 1 && model$biodiversity[[1]]$type == 'poipa', "logistic", "exp")
pred_cov <- predict(object = mod,
newdata = df2,
formula = stats::as.formula( paste("~ ",fun,"(", paste(mod$names.fixed,collapse = " + ") ,")") ),
n.samples = 100,
probs = c(0.05,0.5,0.95)
)

pred_cov$cv <- pred_cov$sd / pred_cov$mean
names(pred_cov)[grep(v, names(pred_cov))] <- "partial_effect"

if(utils::packageVersion("inlabru") <= '2.5.2'){
# Older version where probs are ignored
pred_cov <- subset(pred_cov, select = c("partial_effect", "mean", "sd", "median", "q0.025", "q0.975", "cv"))
names(pred_cov) <- c("partial_effect", "mean", "sd", "median", "lower", "upper", "cv")
} else {
pred_cov <- subset(pred_cov, select = c("partial_effect", "mean", "sd", "q0.05", "q0.5", "q0.95", "cv"))
names(pred_cov) <- c("partial_effect", "mean", "sd", "lower", "median", "upper", "cv")
}

pred_cov <- cbind(variable = v, pred_cov)

o[[v]] <- pred_cov

o <- pred_cov
names(o)[grep(x.var, names(o))] <- "partial_effect"
if(utils::packageVersion("inlabru") <= '2.5.2'){
# Older version where probs are ignored
o <- subset(o, select = c("partial_effect", "mean", "sd", "median", "q0.025", "q0.975", "cv"))
names(o) <- c("partial_effect", "mean", "sd", "median", "lower", "upper", "cv")
} else {
o <- subset(o, select = c("partial_effect", "mean", "sd", "q0.05", "q0.5", "q0.95", "cv"))
names(o) <- c("partial_effect", "mean", "sd", "lower", "median", "upper", "cv")
}

o <- do.call(what = rbind, args = c(o, make.row.names = FALSE))

# Do plot and return result
if(plot){
pm <- ggplot2::ggplot(data = o, ggplot2::aes(x = partial_effect, y = median,
ymin = lower,
ymax = upper) ) +
g <- ggplot2::ggplot(data = o, ggplot2::aes(x = partial_effect)) +
ggplot2::theme_classic() +
ggplot2::geom_ribbon(fill = "grey90") +
ggplot2::geom_line() +
ggplot2::labs(x = x.var, y = "Partial effect")
print(pm)
ggplot2::geom_ribbon(aes(ymin = lower, ymax = upper), fill = "grey90") +
ggplot2::geom_line(aes(y = mean)) +
ggplot2::facet_wrap(. ~ variable, scales = "free") +
ggplot2::labs(x = "", y = "Partial effect")
print(g)
}
return(o |> as.data.frame() )
return(o)
},
# (S)partial effect
spartial = function(self, x.var, constant = NULL, newdata = NULL,
Expand Down
Loading

0 comments on commit 8364d4f

Please sign in to comment.