From e8af8506c5c44b97dc8256c4a71c692ff7261a0c Mon Sep 17 00:00:00 2001 From: mhesselbarth Date: Mon, 8 Jan 2024 08:31:57 +0100 Subject: [PATCH] Fix stan partial --- R/engine_stan.R | 99 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 35 deletions(-) diff --git a/R/engine_stan.R b/R/engine_stan.R index 9dc34e89..642c33a7 100644 --- a/R/engine_stan.R +++ b/R/engine_stan.R @@ -691,7 +691,7 @@ engine_stan <- function(x, }, # Partial effect - partial = function(self, x.var, constant = NULL, variable_length = 100, + partial = function(self, x.var = NULL, constant = NULL, variable_length = 100, values = NULL, newdata = NULL, plot = FALSE, type = "predictor"){ # Get model and intercept if present mod <- self$get_data('fit_best') @@ -699,18 +699,29 @@ engine_stan <- function(x, has_intercept <- attr(stats::terms(model$biodiversity[[1]]$equation), "intercept") if(is.null(type)) type <- self$settings$get("type") assertthat::assert_that(inherits(mod,'stanfit'), - is.character(x.var), + is.character(x.var) || is.null(x.var), is.numeric(variable_length) && variable_length > 1, is.null(newdata) || is.data.frame(newdata), is.null(constant) || is.numeric(constant) ) - # Check that given variable is in x.var - assertthat::assert_that(x.var %in% model$predictors_names) + + # get variable names + variables <- model$predictors_names + + # Match x.var to argument + if(is.null(x.var)) { + x.var <- variables + } else { + x.var <- match.arg(x.var, variables, several.ok = TRUE) + } if(is.null(newdata)){ - # Calculate - rr <- sapply(model$predictors, function(x) range(x, na.rm = TRUE)) |> as.data.frame() + + # get range of data + rr <- sapply(model$predictors[, variables], function(x) range(x, na.rm = TRUE)) |> as.data.frame() + df_partial <- list() + if(!is.null(values)){ variable_length <- length(values) } # Add all others as constant @@ -719,22 +730,16 @@ engine_stan <- function(x, } else { for(n in names(rr)) df_partial[[n]] <- rep( constant, variable_length ) } - if(!is.null(values)){ - df_partial[[x.var]] <- values - } else { - df_partial[[x.var]] <- seq(rr[1,x.var], rr[2,x.var], length.out = variable_length) - } - df_partial <- df_partial |> as.data.frame() + df_partial <- do.call(cbind, df_partial) |> as.data.frame() } else { - df_partial <- newdata |> dplyr::select( - dplyr::any_of(model$predictors_names)) + df_partial <- dplyr::select(newdata, dplyr::any_of(model$predictors_names)) } # For Integrated model, follow poisson fam <- ifelse(length(model$biodiversity)>1, "poisson", model$biodiversity[[1]]$family) # Add intercept if present - if(has_intercept==1) df_partial$Intercept <- 1 + if(has_intercept == 1) df_partial$Intercept <- 1 # If poipo, add w to prediction container bd_poipo <- sapply(model$biodiversity, function(x) x$type) == "poipo" if(any(bd_poipo)){ @@ -747,30 +752,54 @@ engine_stan <- function(x, # FIXME: Taken here as average. To be re-evaluated as use case develops! if(utils::hasName(df_partial,"w")) df_partial$w <- df_partial$w + mean(model$offset,na.rm = TRUE) else df_partial$w <- mean(model$offset,na.rm = TRUE) } - # Simulate from the posterior - pred_part <- posterior_predict_stanfit(obj = mod, - form = to_formula(paste0("observed ~ ", - ifelse(has_intercept==1, "Intercept +", ""), - paste(model$predictors_names,collapse = " + "))), - newdata = df_partial, - offset = df_partial$w, - family = fam, - mode = type # Linear predictor - ) - # Also attach the partial variable - pred_part <- cbind("partial_effect" = df_partial[[x.var]], pred_part) + + # create list to store results + o <- vector(mode = "list", length = length(x.var)) + names(o) <- x.var + + # loop through x.var + for(v in x.var) { + + df_temp <- df_partial + + if(!is.null(values)){ + df_temp[,v ] <- values + } else { + df_temp[, v] <- seq(rr[1, v], rr[2, v], length.out = variable_length) + } + + # Simulate from the posterior + pred_part <- posterior_predict_stanfit(obj = mod, + form = to_formula(paste0("observed ~ ", + ifelse(has_intercept==1, "Intercept +", ""), + paste(model$predictors_names,collapse = " + "))), + newdata = df_temp, + offset = df_temp$w, + family = fam, + mode = type) # Linear predictor + + # FIXME: Something wrong here I guess + # Also attach the partial variable + pred_part <- cbind("variable" = v, "partial_effect" = df_temp[, v], + pred_part) + + o[[v]] <- pred_part + + } + + o <- do.call(what = rbind, args = c(o, make.row.names = FALSE)) + if(plot){ - o <- pred_part - pm <- ggplot2::ggplot(data = o, ggplot2::aes(x = partial_effect, y = mean, - ymin = mean-stats::sd, - ymax = mean+stats::sd) ) + + pm <- ggplot2::ggplot(data = o, ggplot2::aes(x = partial_effect)) + ggplot2::theme_classic() + - ggplot2::geom_ribbon(fill = "grey85") + - ggplot2::geom_line() + - ggplot2::labs(x = x.var, y = "Partial effect") + ggplot2::geom_ribbon(ggplot2::aes(ymin = mean - sd, ymax = mean + sd), fill = "grey85") + + ggplot2::geom_line(ggplot2::aes(y = mean)) + + ggplot2::facet_wrap(. ~ variable, scales = "free") + + ggplot2::labs(x = "Variable", y = "Partial effect") + print(pm) } - return(pred_part) # Return the partial data + return(o) # Return the partial data }, # Spatial partial effect plots spartial = function(self, x.var, constant = NULL, newdata = NULL,