From 54e10e661437ca1a97e51763737eb9226da03e7c Mon Sep 17 00:00:00 2001 From: mhesselbarth Date: Thu, 4 Jan 2024 14:11:09 +0100 Subject: [PATCH] All x.var = NULL for partial() --- R/engine_bart.R | 10 ++++++++++ R/engine_gdb.R | 19 +++++++++++++------ R/engine_glm.R | 3 ++- R/engine_glmnet.R | 1 + R/engine_inlabru.R | 10 ++++++---- R/partial.R | 11 +++++------ R/utils-bart.R | 9 +++++---- man/partial.Rd | 2 +- 8 files changed, 43 insertions(+), 22 deletions(-) diff --git a/R/engine_bart.R b/R/engine_bart.R index 270ef025..33741203 100644 --- a/R/engine_bart.R +++ b/R/engine_bart.R @@ -536,9 +536,19 @@ engine_bart <- function(x, # Partial effects partial = function(self, x.var = NULL, constant = NULL, variable_length = 100, values = NULL, newdata = NULL, plot = FALSE, type = NULL, ...){ + model <- self$get_data('fit_best') + assertthat::assert_that(all(x.var %in% attr(model$fit$data@x,'term.labels')) || is.null(x.var), msg = 'Variable not in predicted model' ) + + # Match x.var to argument + if(is.null(x.var)){ + x.var <- attr(model$fit$data@x,'term.labels') + } else { + x.var <- match.arg(x.var, attr(model$fit$data@x,'term.labels'), several.ok = TRUE) + } + if(is.null(newdata)){ bart_partial_effect(model, x.var = x.var, transform = self$settings$data$binary, diff --git a/R/engine_gdb.R b/R/engine_gdb.R index 778fedd2..f53b23d3 100644 --- a/R/engine_gdb.R +++ b/R/engine_gdb.R @@ -557,16 +557,23 @@ engine_gdb <- function(x, return(temp) }, # Partial effect - partial = function(self, x.var, constant = NULL, variable_length = 100, values = NULL, + partial = function(self, x.var = NULL , constant = NULL, variable_length = 100, values = NULL, newdata = NULL, plot = FALSE, type = NULL){ # Assert that variable(s) are in fitted model - assertthat::assert_that( is.character(x.var),inherits(self$get_data('fit_best'), 'mboost'), - is.numeric(variable_length), - is.null(newdata) || is.data.frame(newdata), - all(is.character(x.var))) + assertthat::assert_that(is.character(x.var) || is.null(x.var), + inherits(self$get_data('fit_best'), 'mboost'), + is.numeric(variable_length), + is.null(newdata) || is.data.frame(newdata)) + # Unlike the effects function, build specific predictor for target variable(s) only variables <- mboost::extract(self$get_data('fit_best'),'variable.names') - assertthat::assert_that( all( x.var %in% variables), msg = 'x.var variable not found in model!' ) + + # Match x.var to argument + if(is.null(x.var)) { + x.var <- variables + } else { + assertthat::assert_that(all( x.var %in% variables), msg = 'x.var variable not found in model!' ) + } settings <- self$settings if(is.null(type)) type <- settings$get('type') diff --git a/R/engine_glm.R b/R/engine_glm.R index dae78e9c..dc54054b 100644 --- a/R/engine_glm.R +++ b/R/engine_glm.R @@ -435,7 +435,8 @@ engine_glm <- function(x, settings$set("type", type) # Get data - df <- model$biodiversity[[length( model$biodiversity )]]$predictors + df <- model$biodiversity[[length(model$biodiversity)]]$predictors + df <- subset(df, select = attr(mod$terms, "term.labels")) # Match x.var to argument if(is.null(x.var)){ diff --git a/R/engine_glmnet.R b/R/engine_glmnet.R index 7e46bdd1..2d4aa48b 100644 --- a/R/engine_glmnet.R +++ b/R/engine_glmnet.R @@ -556,6 +556,7 @@ engine_glmnet <- function(x, # Get data df <- model$biodiversity[[length( model$biodiversity )]]$predictors + df <- subset(df, select = all.vars(mod$terms)) # Match x.var to argument if(is.null(x.var)){ diff --git a/R/engine_inlabru.R b/R/engine_inlabru.R index 5d9e1f39..0e09eb87 100644 --- a/R/engine_inlabru.R +++ b/R/engine_inlabru.R @@ -1049,7 +1049,7 @@ engine_inlabru <- function(x, return(out) }, # Partial response - partial = function(self, x.var, constant = NULL, variable_length = 100, + partial = function(self, x.var = NULL, constant = NULL, variable_length = 100, values = NULL, newdata = NULL, plot = TRUE, type = "response"){ # We use inlabru's functionalities to sample from the posterior # a given variable. A prediction is made over a generated fitted data.frame @@ -1057,9 +1057,11 @@ engine_inlabru <- function(x, mod <- self$get_data('fit_best') model <- self$model df <- model$biodiversity[[1]]$predictors + df <- subset(df, select = mod$names.fixed[mod$names.fixed != "Intercept"]) + assertthat::assert_that(inherits(mod,'bru'), 'model' %in% names(self), - is.character(x.var), + is.character(x.var) || is.null(x.var), is.numeric(variable_length), variable_length >=1, is.null(constant) || is.numeric(constant), is.null(newdata) || is.data.frame(newdata), @@ -1067,15 +1069,15 @@ engine_inlabru <- function(x, ) # Match variable name + # MH: This can be an empty list which !is.null? if(!is.null(mod$summary.random)) vn <- names(mod$summary.random) else vn <- "" - if(x.var == ""){ + if(is.null(x.var)) { x.var <- colnames(df) } else { x.var <- match.arg(x.var, c( mod$names.fixed, vn), several.ok = TRUE) } - if(is.null(newdata)){ # Calculate range of predictors if(any(model$predictors_types$type=="factor")){ diff --git a/R/partial.R b/R/partial.R index 1049f4d6..ff414621 100644 --- a/R/partial.R +++ b/R/partial.R @@ -38,8 +38,8 @@ NULL #' @name partial methods::setGeneric( "partial", - signature = methods::signature("mod","x.var"), - function(mod, x.var, constant = NULL, variable_length = 100, values = NULL, newdata = NULL, plot = FALSE, type = "response", ...) standardGeneric("partial")) + signature = methods::signature("mod"), + function(mod, x.var = NULL, constant = NULL, variable_length = 100, values = NULL, newdata = NULL, plot = FALSE, type = "response", ...) standardGeneric("partial")) #' @name partial #' @rdname partial @@ -47,12 +47,11 @@ methods::setGeneric( #' \S4method{partial}{ANY,character,ANY,numeric,ANY,ANY,logical,character}(mod,x.var,constant,variable_length,values,newdata,plot,type,...) methods::setMethod( "partial", - methods::signature(mod = "ANY", x.var = "character"), - function(mod, x.var, constant = NULL, variable_length = 100, + methods::signature(mod = "ANY"), + function(mod, x.var = NULL, constant = NULL, variable_length = 100, values = NULL, newdata = NULL, plot = FALSE, type = "response",...) { - assertthat::assert_that(!missing(x.var),msg = 'Specify a variable name in the model!') assertthat::assert_that(inherits(mod, "DistributionModel"), - is.character(x.var), + is.character(x.var) || is.null(x.var), is.null(constant) || is.numeric(constant), is.numeric(variable_length), is.null(newdata) || is.data.frame(newdata), diff --git a/R/utils-bart.R b/R/utils-bart.R index c87d90c6..4d527cd8 100644 --- a/R/utils-bart.R +++ b/R/utils-bart.R @@ -110,9 +110,9 @@ varimp.bart <- function(model){ #' @aliases bart_partial_effect #' @keywords utils #' @noRd -bart_partial_effect <- function (model, x.var = NULL, equal = FALSE, - smooth = 1, transform = TRUE, values = NULL, - variable_length = 100,plot = TRUE) { +bart_partial_effect <- function(model, x.var, equal = FALSE, + smooth = 1, transform = TRUE, values = NULL, + variable_length = 100,plot = TRUE) { assertthat::assert_that( inherits(model,'bart'), @@ -128,13 +128,14 @@ bart_partial_effect <- function (model, x.var = NULL, equal = FALSE, if (inherits(model,"bart")) { fitobj <- model$fit } + # If no x.vars are specified, use all if(!is.null(values)){ raw <- list() raw[[x.var]] <- values raw <- raw |> as.data.frame() } else { - if (is.null(x.var)) raw <- fitobj$data@x else raw <- fitobj$data@x[, x.var] + raw <- fitobj$data@x[, x.var] } # Define binning in equal area width or not diff --git a/man/partial.Rd b/man/partial.Rd index 6cdea74b..0f48f6fc 100644 --- a/man/partial.Rd +++ b/man/partial.Rd @@ -7,7 +7,7 @@ \usage{ partial( mod, - x.var, + x.var = NULL, constant = NULL, variable_length = 100, values = NULL,