Skip to content

Commit

Permalink
Adding engine_scampr to the package (#107)
Browse files Browse the repository at this point in the history
* First push of `scampr` engine. Probably still buggy as hell.

* Small helper function for combining `formula` objects

* Small update to correct equations

* Adding projections and partials to `engine_scampr()`

* Updated `engine_scampr` to make use of offsets.

* Small 🐛 fix to make coefficients work
  • Loading branch information
Martin-Jung authored Mar 20, 2024
1 parent 86c40db commit fd8c41b
Show file tree
Hide file tree
Showing 26 changed files with 1,253 additions and 54 deletions.
1 change: 1 addition & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
any::glmnetUtils,
any::pdp,
stan-dev/cmdstanr,
ElliotDovers/scampr,
any::igraph,
any::lwgeom,
any::ncmeta,
Expand Down
3 changes: 3 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ Collate:
'utils-inla.R'
'engine_inla.R'
'engine_inlabru.R'
'engine_scampr.R'
'engine_stan.R'
'engine_xgboost.R'
'ensemble.R'
Expand Down Expand Up @@ -166,3 +167,5 @@ Collate:
'validate.R'
'write_output.R'
'zzz.R'
Remotes:
ElliotDovers/scampr
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ export(alignRasters)
export(as.Id)
export(bivplot)
export(check)
export(combine_formulas)
export(distribution)
export(emptyraster)
export(engine_bart)
Expand All @@ -91,6 +92,7 @@ export(engine_glm)
export(engine_glmnet)
export(engine_inla)
export(engine_inlabru)
export(engine_scampr)
export(engine_stan)
export(engine_xgboost)
export(ensemble)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

#### New features
* Add functions that creates HTML file base on `DistributionModel`.
* Added new engine `engine_scampr()` for model-based integration.

#### Minor improvements and bug fixes
* Small fixes to ensure `boruta` filtering works (again)?
* Small fix to parameter in `train()` #102 @jeffreyhanson
* Small helper function for combining 2 different formula objects `combine_formulas()`
* Small bug fixes dealing with `scenario()` projections and limits, plus unit tests #104
* Fixed bug with adding `predictor_derivate()` to scenario predictors and added unit tests #106

Expand Down
88 changes: 59 additions & 29 deletions R/class-distributionmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@ DistributionModel <- R6::R6Class(
# Check whether threshold has been calculated
has_threshold <- grep('threshold',self$show_rasters(),value = TRUE)[1]

# Get model
obj <- self$get_data('fit_best')

# FIXME: Have engine-specific code moved to engine
if( self$get_name() == 'INLA-Model' || self$get_name() == 'INLABRU-Model'){
if( length( self$fits ) != 0 ){
# Get strongest effects
ms <- subset(tidy_inla_summary(self$get_data('fit_best')),
ms <- subset(tidy_inla_summary(obj),
select = c('variable', 'mean'))
ms <- ms[order(ms$mean,decreasing = TRUE),] # Sort

Expand All @@ -80,9 +83,7 @@ DistributionModel <- R6::R6Class(
} else if( self$get_name() == 'GDB-Model' ) {

# Get Variable importance
vi <- mboost::varimp(
self$get_data('fit_best')
)
vi <- mboost::varimp(obj)
vi <- sort( vi[which(vi>0)],decreasing = TRUE )

message(paste0(
Expand All @@ -98,7 +99,7 @@ DistributionModel <- R6::R6Class(
))
} else if( self$get_name() == 'BART-Model' ) {
# Calculate variable importance from the posterior trees
vi <- varimp.bart(self$get_data('fit_best'))
vi <- varimp.bart(obj)

message(paste0(
'Trained ',self$name,' (',self$show(),')',
Expand All @@ -113,7 +114,7 @@ DistributionModel <- R6::R6Class(
))
} else if( self$get_name() == 'STAN-Model' ) {
# Calculate variable importance from the posterior
vi <- rstan::summary(self$get_data('fit_best'))$summary |> as.data.frame() |>
vi <- rstan::summary(obj)$summary |> as.data.frame() |>
tibble::rownames_to_column(var = "parameter") |> as.data.frame()
# Get beta coefficients only
vi <- vi[grep("beta", vi$parameter,ignore.case = TRUE),]
Expand All @@ -133,7 +134,7 @@ DistributionModel <- R6::R6Class(
'\n \033[31mNegative:\033[39m ', name_atomic(vi$parameter[vi$mean<0])
))
} else if( self$get_name() == 'XGBOOST-Model' ) {
vi <- xgboost::xgb.importance(model = self$get_data('fit_best'),)
vi <- xgboost::xgb.importance(model = obj)

message(paste0(
'Trained ',self$name,' (',self$show(),')',
Expand All @@ -147,7 +148,6 @@ DistributionModel <- R6::R6Class(
"")
))
} else if( self$get_name() == 'BREG-Model' ) {
obj <- self$get_data('fit_best')
# Summarize the beta coefficients from the posterior
ms <- posterior::summarise_draws(obj$beta) |>
subset(select = c('variable', 'mean'))
Expand All @@ -167,8 +167,6 @@ DistributionModel <- R6::R6Class(
"")
))
} else if( self$get_name() == 'GLMNET-Model') {
obj <- self$get_data('fit_best')

# Summarise coefficients within 1 standard deviation
ms <- tidy_glmnet_summary(obj)

Expand All @@ -186,8 +184,6 @@ DistributionModel <- R6::R6Class(
))

} else if( self$get_name() == 'GLM-Model' ) {
obj <- self$get_data('fit_best')

# Summarise coefficients within 1 standard deviation
ms <- tidy_glm_summary(obj)

Expand All @@ -204,6 +200,31 @@ DistributionModel <- R6::R6Class(
"")
))

} else if( self$get_name() == 'SCAMPR-Model' ) {
# Summarise coefficients within 1 standard deviation
ms <- obj$fixed.effects |>
as.data.frame() |>
tibble::rownames_to_column(var = "variable")
# Remove intercept
int <- grep("Intercept",ms$variable,ignore.case = TRUE)
if(length(int)>0) ms <- ms[-int,]

# Rename the estimate and std.error column
ms <- ms |> dplyr::rename(mean = "Estimate", se = "Std. Error")

message(paste0(
'Trained ',self$name,' (',self$show(),')',
'\n \033[1mStrongest summary effects:\033[22m',
'\n \033[34mPositive:\033[39m ', name_atomic(ms$variable[ms$mean>0]),
'\n \033[31mNegative:\033[39m ', name_atomic(ms$variable[ms$mean<0]),
ifelse(has_prediction,
paste0("\n Prediction fitted: ",text_green("yes")),
""),
ifelse(!is.na(has_threshold),
paste0("\n Threshold created: ",text_green("yes")),
"")
))

} else {
message(paste0(
'Trained distribution model (',self$show(),')',
Expand Down Expand Up @@ -335,6 +356,16 @@ DistributionModel <- R6::R6Class(
tidy_glmnet_summary(self$get_data(obj))
} else if( self$get_name() == 'GLM-Model'){
tidy_glm_summary(self$get_data(obj))
} else if( self$get_name() == 'SCAMPR-Model'){
# Summarise coefficients within 1 standard deviation
ms <- self$get_data(obj)$fixed.effects |>
as.data.frame() |>
tibble::rownames_to_column(var = "variable")
# Remove intercept
int <- grep("Intercept",ms$variable,ignore.case = TRUE)
if(length(int)>0) ms <- ms[-int,]
# Rename the estimate and std.error column
ms |> dplyr::rename(mean = "Estimate", se = "Std. Error")
}
},

Expand All @@ -346,44 +377,42 @@ DistributionModel <- R6::R6Class(
#' @return A graphical representation of the coefficents.
effects = function(x = 'fit_best', what = 'fixed', ...){
assertthat::assert_that(is.character(what))
# Get model
obj <- self$get_data(x)
if( self$get_name() == 'GDB-Model'){
# How many effects
n <- length( stats::coef( self$get_data(x) ))
n <- length( stats::coef( obj ))
# Use the base plotting
par.ori <- graphics::par(no.readonly = TRUE)
graphics::par(mfrow = c(ceiling(n/3),3))

mboost:::plot.mboost(x = self$get_data(x),
type = 'b',cex.axis=1.5, cex.lab=1.5)

mboost:::plot.mboost(x = obj, type = 'b',cex.axis=1.5, cex.lab=1.5)
graphics::par(par.ori)#dev.off()
} else if( self$get_name() == 'INLA-Model') {
plot_inla_marginals(self$get_data(x),what = what)
plot_inla_marginals(obj, what = what)
} else if( self$get_name() == 'GLMNET-Model') {
if(what == "fixed"){
ms <- tidy_glm_summary(mod)
ms <- tidy_glm_summary(obj)
graphics::dotchart(ms$mean,
labels = ms$variable,
frame.plot = FALSE,
color = "grey20")
} else{ plot(self$get_data(x)) }
} else{ plot(obj) }
} else if( self$get_name() == 'GLM-Model') {
if(what == "fixed"){
glmnet:::plot.glmnet(self$get_data(x)$glmnet.fit, xvar = "lambda") # Deviance explained
} else{ plot(self$get_data(x)) }
glmnet:::plot.glmnet(obj$glmnet.fit, xvar = "lambda") # Deviance explained
} else{ plot(obj) }
} else if( self$get_name() == 'STAN-Model') {
# Get true beta parameters
ra <- grep("beta", names(self$get_data(x)),value = TRUE) # Get range
rstan::stan_plot(self$get_data(x), pars = ra)
ra <- grep("beta", names(obj),value = TRUE) # Get range
rstan::stan_plot(obj, pars = ra)
} else if( self$get_name() == 'INLABRU-Model') {
# Use inlabru effect plot
ggplot2::ggplot() +
inlabru::gg(self$get_data(x)$summary.fixed, bar = TRUE)
inlabru::gg(obj$summary.fixed, bar = TRUE)
} else if( self$get_name() == 'BART-Model'){
message('Calculating partial dependence plots')
self$partial(self$get_data(x), x.var = what, ...)
self$partial(obj, x.var = what, ...)
} else if( self$get_name() == 'BREG-Model'){
obj <- self$get_data(x)
if(what == "fixed") what <- "coefficients"
what <- match.arg(what, choices = c("coefficients", "scaled.coefficients","residuals",
"size", "fit", "help", "inclusion"), several.ok = FALSE)
Expand All @@ -400,11 +429,12 @@ DistributionModel <- R6::R6Class(
vi <- self$summary(x)
xgboost::xgb.ggplot.importance(vi)
} else {
obj <- self$get_data(x)
xgboost::xgb.plot.multi.trees(obj)
}
} else if( self$get_name() == "SCAMPR-Model"){
dotchart(obj$fixed.effects[,1])
} else {
self$partial(self$get_data(x), x.var = NULL)
self$partial(obj, x.var = NULL)
}
},

Expand Down
Loading

0 comments on commit fd8c41b

Please sign in to comment.