Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyrcoyle committed Oct 30, 2023
1 parent 977427d commit ebe87bf
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
Binary file removed .NEWS.md.swp
Binary file not shown.
27 changes: 15 additions & 12 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,7 @@ predict.hal9001 <- function(object,
# hazard and would cancel in the partial likelihood).
# Note: there is no intercept in the Cox model (built into the baseline
# hazard and would cancel in the partial likelihood).
if (ncol(object$coefs) > 1) {
preds <- pred_x_basis%*%object$coefs[-1,]+
matrix(object$coefs[1,], nrow=nrow(pred_x_basis),
ncol=ncol(object$coefs), byrow = TRUE)
} else {
preds <- as.vector(Matrix::tcrossprod(
x = pred_x_basis,
y = as.matrix(object$coefs,nrow = 1)
))
}
preds <- pred_x_basis%*%object$coefs
} else if (family == "mgaussian") {
preds <- stats::predict(
object$lasso_fit, newx = pred_x_basis, s = object$lambda_star
Expand All @@ -129,9 +120,21 @@ predict.hal9001 <- function(object,
preds <- inverse_link_fun(preds)
} else {
if (family == "binomial") {
preds <- stats::plogis(preds)
transform <- stats::plogis
} else if (family %in% c("poisson", "cox")) {
preds <- exp(preds)
transform <- exp
} else if(family%in%c("gaussian","mgaussian")){
transform <- identity
} else{
stop("unsupported family")
}

if(length(ncol(preds))){
# apply along only the first dimension (to handle n-d arrays)
margin <- seq(length(dim(preds)))[-1]
preds <- apply(preds, margin, transform)
} else {
preds <- transform(preds)
}
}

Expand Down
11 changes: 6 additions & 5 deletions tests/testthat/test-reduce_basis_filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ system.time({
mse <- mean(se)
se[c(current_i, new_i)] <- 0
new_i <- which.max(se)
print(sprintf("%f, %f", old_mse, mse))
#print(sprintf("%f, %f", old_mse, mse))
continue <- mse < 1.1 * old_mse
if (mse < old_mse) {
good_i <- unique(c(good_i, new_i))
offset <- preds
old_mse <- mse
coefs <- as.vector(coef(screen_glmnet, s = "lambda.min"))[-1]
# old_basis <- union(old_basis,c(old_basis,b1)[which(coefs!=0)])
print(length(old_basis))
#print(length(old_basis))
old_basis <- unique(c(old_basis, b1))
}

Expand All @@ -72,7 +72,7 @@ system.time({
if (is.na(rate)) {
rate <- -Inf
}
print(rate)
# print(rate)
continue <- (-1 * rate) > 1e-4
continue <- TRUE
continue <- length(current_i) < n
Expand Down Expand Up @@ -108,13 +108,13 @@ b1 <- coef(fit)

fit <- glmnet(
x = x_basis, y = y, family = "gaussian", offset = offset,
intercept = FALSE, maxit = 1, thresh = 1, lambda = 0.03
intercept = FALSE, maxit = 2, thresh = 1, lambda = 0.03
)
b2 <- coef(fit)

fit <- glmnet(
x = x_basis, y = y, family = "gaussian", offset = offset,
intercept = FALSE, maxit = 1, thresh = 1, lambda = 0.03
intercept = FALSE, maxit = 2, thresh = 1, lambda = 0.03
)
b3 <- coef(fit)

Expand Down Expand Up @@ -164,3 +164,4 @@ test_that("Predictions are not too different when reducing basis functions", {
# ensure hal fit with reduce_basis works with new data for prediction
newx <- matrix(rnorm(n * p), n, p)
hal_pred_reduced_newx <- predict(hal_fit_reduced, new_data = newx)

0 comments on commit ebe87bf

Please sign in to comment.