Skip to content

Commit

Permalink
Fix stan partial
Browse files Browse the repository at this point in the history
  • Loading branch information
mhesselbarth committed Jan 8, 2024
1 parent 434deab commit e8af850
Showing 1 changed file with 64 additions and 35 deletions.
99 changes: 64 additions & 35 deletions R/engine_stan.R
Original file line number Diff line number Diff line change
Expand Up @@ -691,26 +691,37 @@ engine_stan <- function(x,

},
# Partial effect
partial = function(self, x.var, constant = NULL, variable_length = 100,
partial = function(self, x.var = NULL, constant = NULL, variable_length = 100,
values = NULL, newdata = NULL, plot = FALSE, type = "predictor"){
# Get model and intercept if present
mod <- self$get_data('fit_best')
model <- self$model
has_intercept <- attr(stats::terms(model$biodiversity[[1]]$equation), "intercept")
if(is.null(type)) type <- self$settings$get("type")
assertthat::assert_that(inherits(mod,'stanfit'),
is.character(x.var),
is.character(x.var) || is.null(x.var),
is.numeric(variable_length) && variable_length > 1,
is.null(newdata) || is.data.frame(newdata),
is.null(constant) || is.numeric(constant)
)
# Check that given variable is in x.var
assertthat::assert_that(x.var %in% model$predictors_names)

# get variable names
variables <- model$predictors_names

# Match x.var to argument
if(is.null(x.var)) {
x.var <- variables
} else {
x.var <- match.arg(x.var, variables, several.ok = TRUE)
}

if(is.null(newdata)){
# Calculate
rr <- sapply(model$predictors, function(x) range(x, na.rm = TRUE)) |> as.data.frame()

# get range of data
rr <- sapply(model$predictors[, variables], function(x) range(x, na.rm = TRUE)) |> as.data.frame()

df_partial <- list()

if(!is.null(values)){ variable_length <- length(values) }

# Add all others as constant
Expand All @@ -719,22 +730,16 @@ engine_stan <- function(x,
} else {
for(n in names(rr)) df_partial[[n]] <- rep( constant, variable_length )
}
if(!is.null(values)){
df_partial[[x.var]] <- values
} else {
df_partial[[x.var]] <- seq(rr[1,x.var], rr[2,x.var], length.out = variable_length)
}
df_partial <- df_partial |> as.data.frame()
df_partial <- do.call(cbind, df_partial) |> as.data.frame()
} else {
df_partial <- newdata |> dplyr::select(
dplyr::any_of(model$predictors_names))
df_partial <- dplyr::select(newdata, dplyr::any_of(model$predictors_names))
}

# For Integrated model, follow poisson
fam <- ifelse(length(model$biodiversity)>1, "poisson", model$biodiversity[[1]]$family)

# Add intercept if present
if(has_intercept==1) df_partial$Intercept <- 1
if(has_intercept == 1) df_partial$Intercept <- 1
# If poipo, add w to prediction container
bd_poipo <- sapply(model$biodiversity, function(x) x$type) == "poipo"
if(any(bd_poipo)){
Expand All @@ -747,30 +752,54 @@ engine_stan <- function(x,
# FIXME: Taken here as average. To be re-evaluated as use case develops!
if(utils::hasName(df_partial,"w")) df_partial$w <- df_partial$w + mean(model$offset,na.rm = TRUE) else df_partial$w <- mean(model$offset,na.rm = TRUE)
}
# Simulate from the posterior
pred_part <- posterior_predict_stanfit(obj = mod,
form = to_formula(paste0("observed ~ ",
ifelse(has_intercept==1, "Intercept +", ""),
paste(model$predictors_names,collapse = " + "))),
newdata = df_partial,
offset = df_partial$w,
family = fam,
mode = type # Linear predictor
)
# Also attach the partial variable
pred_part <- cbind("partial_effect" = df_partial[[x.var]], pred_part)

# create list to store results
o <- vector(mode = "list", length = length(x.var))
names(o) <- x.var

# loop through x.var
for(v in x.var) {

df_temp <- df_partial

if(!is.null(values)){
df_temp[,v ] <- values
} else {
df_temp[, v] <- seq(rr[1, v], rr[2, v], length.out = variable_length)
}

# Simulate from the posterior
pred_part <- posterior_predict_stanfit(obj = mod,
form = to_formula(paste0("observed ~ ",
ifelse(has_intercept==1, "Intercept +", ""),
paste(model$predictors_names,collapse = " + "))),
newdata = df_temp,
offset = df_temp$w,
family = fam,
mode = type) # Linear predictor

# FIXME: Something wrong here I guess
# Also attach the partial variable
pred_part <- cbind("variable" = v, "partial_effect" = df_temp[, v],
pred_part)

o[[v]] <- pred_part

}

o <- do.call(what = rbind, args = c(o, make.row.names = FALSE))

if(plot){
o <- pred_part
pm <- ggplot2::ggplot(data = o, ggplot2::aes(x = partial_effect, y = mean,
ymin = mean-stats::sd,
ymax = mean+stats::sd) ) +
pm <- ggplot2::ggplot(data = o, ggplot2::aes(x = partial_effect)) +
ggplot2::theme_classic() +
ggplot2::geom_ribbon(fill = "grey85") +
ggplot2::geom_line() +
ggplot2::labs(x = x.var, y = "Partial effect")
ggplot2::geom_ribbon(ggplot2::aes(ymin = mean - sd, ymax = mean + sd), fill = "grey85") +
ggplot2::geom_line(ggplot2::aes(y = mean)) +
ggplot2::facet_wrap(. ~ variable, scales = "free") +
ggplot2::labs(x = "Variable", y = "Partial effect")

print(pm)
}
return(pred_part) # Return the partial data
return(o) # Return the partial data
},
# Spatial partial effect plots
spartial = function(self, x.var, constant = NULL, newdata = NULL,
Expand Down

0 comments on commit e8af850

Please sign in to comment.