Skip to content

Commit

Permalink
All x.var = NULL for partial()
Browse files Browse the repository at this point in the history
  • Loading branch information
mhesselbarth committed Jan 4, 2024
1 parent 081e0fc commit 54e10e6
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 22 deletions.
10 changes: 10 additions & 0 deletions R/engine_bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,19 @@ engine_bart <- function(x,
# Partial effects
partial = function(self, x.var = NULL, constant = NULL, variable_length = 100,
values = NULL, newdata = NULL, plot = FALSE, type = NULL, ...){

model <- self$get_data('fit_best')

assertthat::assert_that(all(x.var %in% attr(model$fit$data@x,'term.labels')) || is.null(x.var),
msg = 'Variable not in predicted model' )

# Match x.var to argument
if(is.null(x.var)){
x.var <- attr(model$fit$data@x,'term.labels')
} else {
x.var <- match.arg(x.var, attr(model$fit$data@x,'term.labels'), several.ok = TRUE)
}

if(is.null(newdata)){
bart_partial_effect(model, x.var = x.var,
transform = self$settings$data$binary,
Expand Down
19 changes: 13 additions & 6 deletions R/engine_gdb.R
Original file line number Diff line number Diff line change
Expand Up @@ -557,16 +557,23 @@ engine_gdb <- function(x,
return(temp)
},
# Partial effect
partial = function(self, x.var, constant = NULL, variable_length = 100, values = NULL,
partial = function(self, x.var = NULL , constant = NULL, variable_length = 100, values = NULL,
newdata = NULL, plot = FALSE, type = NULL){
# Assert that variable(s) are in fitted model
assertthat::assert_that( is.character(x.var),inherits(self$get_data('fit_best'), 'mboost'),
is.numeric(variable_length),
is.null(newdata) || is.data.frame(newdata),
all(is.character(x.var)))
assertthat::assert_that(is.character(x.var) || is.null(x.var),
inherits(self$get_data('fit_best'), 'mboost'),
is.numeric(variable_length),
is.null(newdata) || is.data.frame(newdata))

# Unlike the effects function, build specific predictor for target variable(s) only
variables <- mboost::extract(self$get_data('fit_best'),'variable.names')
assertthat::assert_that( all( x.var %in% variables), msg = 'x.var variable not found in model!' )

# Match x.var to argument
if(is.null(x.var)) {
x.var <- variables
} else {
assertthat::assert_that(all( x.var %in% variables), msg = 'x.var variable not found in model!' )
}

settings <- self$settings
if(is.null(type)) type <- settings$get('type')
Expand Down
3 changes: 2 additions & 1 deletion R/engine_glm.R
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ engine_glm <- function(x,
settings$set("type", type)

# Get data
df <- model$biodiversity[[length( model$biodiversity )]]$predictors
df <- model$biodiversity[[length(model$biodiversity)]]$predictors
df <- subset(df, select = attr(mod$terms, "term.labels"))

# Match x.var to argument
if(is.null(x.var)){
Expand Down
1 change: 1 addition & 0 deletions R/engine_glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ engine_glmnet <- function(x,

# Get data
df <- model$biodiversity[[length( model$biodiversity )]]$predictors
df <- subset(df, select = all.vars(mod$terms))

# Match x.var to argument
if(is.null(x.var)){
Expand Down
10 changes: 6 additions & 4 deletions R/engine_inlabru.R
Original file line number Diff line number Diff line change
Expand Up @@ -1049,33 +1049,35 @@ engine_inlabru <- function(x,
return(out)
},
# Partial response
partial = function(self, x.var, constant = NULL, variable_length = 100,
partial = function(self, x.var = NULL, constant = NULL, variable_length = 100,
values = NULL, newdata = NULL, plot = TRUE, type = "response"){
# We use inlabru's functionalities to sample from the posterior
# a given variable. A prediction is made over a generated fitted data.frame
# Check that provided model exists and variable exist in model
mod <- self$get_data('fit_best')
model <- self$model
df <- model$biodiversity[[1]]$predictors
df <- subset(df, select = mod$names.fixed[mod$names.fixed != "Intercept"])

assertthat::assert_that(inherits(mod,'bru'),
'model' %in% names(self),
is.character(x.var),
is.character(x.var) || is.null(x.var),
is.numeric(variable_length), variable_length >=1,
is.null(constant) || is.numeric(constant),
is.null(newdata) || is.data.frame(newdata),
is.null(values) || is.numeric(values)
)

# Match variable name
# MH: This can be an empty list which !is.null?
if(!is.null(mod$summary.random)) vn <- names(mod$summary.random) else vn <- ""

if(x.var == ""){
if(is.null(x.var)) {
x.var <- colnames(df)
} else {
x.var <- match.arg(x.var, c( mod$names.fixed, vn), several.ok = TRUE)
}


if(is.null(newdata)){
# Calculate range of predictors
if(any(model$predictors_types$type=="factor")){
Expand Down
11 changes: 5 additions & 6 deletions R/partial.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,20 @@ NULL
#' @name partial
methods::setGeneric(
"partial",
signature = methods::signature("mod","x.var"),
function(mod, x.var, constant = NULL, variable_length = 100, values = NULL, newdata = NULL, plot = FALSE, type = "response", ...) standardGeneric("partial"))
signature = methods::signature("mod"),
function(mod, x.var = NULL, constant = NULL, variable_length = 100, values = NULL, newdata = NULL, plot = FALSE, type = "response", ...) standardGeneric("partial"))

#' @name partial
#' @rdname partial
#' @usage
#' \S4method{partial}{ANY,character,ANY,numeric,ANY,ANY,logical,character}(mod,x.var,constant,variable_length,values,newdata,plot,type,...)
methods::setMethod(
"partial",
methods::signature(mod = "ANY", x.var = "character"),
function(mod, x.var, constant = NULL, variable_length = 100,
methods::signature(mod = "ANY"),
function(mod, x.var = NULL, constant = NULL, variable_length = 100,
values = NULL, newdata = NULL, plot = FALSE, type = "response",...) {
assertthat::assert_that(!missing(x.var),msg = 'Specify a variable name in the model!')
assertthat::assert_that(inherits(mod, "DistributionModel"),
is.character(x.var),
is.character(x.var) || is.null(x.var),
is.null(constant) || is.numeric(constant),
is.numeric(variable_length),
is.null(newdata) || is.data.frame(newdata),
Expand Down
9 changes: 5 additions & 4 deletions R/utils-bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ varimp.bart <- function(model){
#' @aliases bart_partial_effect
#' @keywords utils
#' @noRd
bart_partial_effect <- function (model, x.var = NULL, equal = FALSE,
smooth = 1, transform = TRUE, values = NULL,
variable_length = 100,plot = TRUE) {
bart_partial_effect <- function(model, x.var, equal = FALSE,
smooth = 1, transform = TRUE, values = NULL,
variable_length = 100,plot = TRUE) {

assertthat::assert_that(
inherits(model,'bart'),
Expand All @@ -128,13 +128,14 @@ bart_partial_effect <- function (model, x.var = NULL, equal = FALSE,
if (inherits(model,"bart")) {
fitobj <- model$fit
}

# If no x.vars are specified, use all
if(!is.null(values)){
raw <- list()
raw[[x.var]] <- values
raw <- raw |> as.data.frame()
} else {
if (is.null(x.var)) raw <- fitobj$data@x else raw <- fitobj$data@x[, x.var]
raw <- fitobj$data@x[, x.var]
}

# Define binning in equal area width or not
Expand Down
2 changes: 1 addition & 1 deletion man/partial.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 54e10e6

Please sign in to comment.