Skip to content

Commit

Permalink
Small fix with weighted mean and multi-band rasters
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin-Jung committed Dec 19, 2023
1 parent 56ecfba commit 325562c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
18 changes: 9 additions & 9 deletions R/ensemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ NULL
#' is a fitted [`DistributionModel`] object. Take care not to create an ensemble
#' of models constructed with different link functions, e.g. logistic vs [log].
#' In this case the \code{"normalize"} parameter has to be set.
#' @param ... Provided [`DistributionModel`] objects.
#' @param ... Provided [`DistributionModel`] or [`SpatRaster`] objects.
#' @param method Approach on how the ensemble is to be created. See details for
#' available options (Default: \code{'mean'}).
#' @param weights (*Optional*) weights provided to the ensemble function if
Expand Down Expand Up @@ -133,14 +133,6 @@ methods::setMethod(
# Uncertainty calculation
uncertainty <- match.arg(uncertainty, c('none','sd', 'cv', 'range', 'pca'), several.ok = FALSE)

# Check that weight lengths is equal to provided distribution objects
if(!is.null(weights)) assertthat::assert_that(length(weights) == length(mods))
# If weights vector is numeric, standardize the weights
if(is.numeric(weights)) {
if(any(weights < 0)) weights[weights < 0] <- 0 # Assume those contribute anything
weights <- weights / sum(weights)
}

# For Distribution model ensembles
if( all( sapply(mods, function(z) inherits(z, "DistributionModel")) ) ){
assertthat::assert_that(length(mods)>=2, # Need at least 2 otherwise this does not make sense
Expand Down Expand Up @@ -288,6 +280,14 @@ methods::setMethod(
# If normalize before running an ensemble if parameter set
if(normalize) ras <- predictor_transform(ras, option = "norm")

# Check that weight lengths is equal to provided distribution objects
if(!is.null(weights) && !is.numeric(layer)) assertthat::assert_that(length(weights) == length(mods))
# If weights vector is numeric, standardize the weights
if(is.numeric(weights)) {
if(any(weights < 0)) weights[weights < 0] <- 0 # Assume those contribute anything
weights <- weights / sum(weights)
}

# Now ensemble per layer entry
out <- terra::rast()
for(lyr in layer){
Expand Down
5 changes: 5 additions & 0 deletions tests/testthat/test_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ test_that('Custom functions - Test gridded transformations and ensembles', {
pp <- ensemble(ras, method = "mean", layer = "mean")
)
expect_s4_class(pp, "SpatRaster")
# Also test weighted mean
expect_no_error(
pp <- ensemble(ras, method = "weighted.mean", weights = runif(3, 0.5,1))
)


# Check centroid calculation
expect_s3_class(raster_centroid(r1), "sf")
Expand Down

0 comments on commit 325562c

Please sign in to comment.