Skip to content

Commit

Permalink
* fix Matrix tests and other issues
Browse files Browse the repository at this point in the history
* bump version
* update news
  • Loading branch information
jeremyrcoyle committed Oct 30, 2023
1 parent 48f41e5 commit 977427d
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 37 deletions.
Binary file added .NEWS.md.swp
Binary file not shown.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: hal9001
Title: The Scalable Highly Adaptive Lasso
Version: 0.4.5
Version: 0.4.6
Authors@R: c(
person("Jeremy", "Coyle", email = "jeremyrcoyle@gmail.com",
role = c("aut", "cre"),
Expand Down Expand Up @@ -68,5 +68,5 @@ LinkingTo:
Rcpp,
RcppEigen
VignetteBuilder: knitr
RoxygenNote: 7.2.0
RoxygenNote: 7.2.3
Roxygen: list(markdown = TRUE)
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# hal9001 0.4.6
* Fixed predict method to address changes required by Matrix 1.6.2

# hal9001 0.4.5
* Added multivariate outcome prediction

Expand Down
2 changes: 1 addition & 1 deletion R/formula_hal9001.R
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ h <- function(..., k = NULL, s = NULL, pf = 1,
})


basis_list_item <- hal9001:::make_basis_list(
basis_list_item <- make_basis_list(
X[, col_index, drop = FALSE],
col_index, rep(s, ncol(X))
)
Expand Down
22 changes: 8 additions & 14 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,13 @@ predict.hal9001 <- function(object,
# generate predictions
if (!family %in% c("cox", "mgaussian")) {
if (ncol(object$coefs) > 1) {
preds <- apply(object$coefs, 2, function(hal_coefs) {
as.vector(Matrix::tcrossprod(
x = pred_x_basis,
y = hal_coefs[-1]
) + hal_coefs[1])
})
preds <- pred_x_basis%*%object$coefs[-1,]+
matrix(object$coefs[1,], nrow=nrow(pred_x_basis),
ncol=ncol(object$coefs), byrow = TRUE)
} else {
preds <- as.vector(Matrix::tcrossprod(
x = pred_x_basis,
y = object$coefs[-1]
y = matrix(object$coefs[-1],nrow=1)
) + object$coefs[1])
}
} else {
Expand All @@ -99,16 +96,13 @@ predict.hal9001 <- function(object,
# Note: there is no intercept in the Cox model (built into the baseline
# hazard and would cancel in the partial likelihood).
if (ncol(object$coefs) > 1) {
preds <- apply(object$coefs, 2, function(hal_coefs) {
as.vector(Matrix::tcrossprod(
x = pred_x_basis,
y = hal_coefs
))
})
preds <- pred_x_basis%*%object$coefs[-1,]+
matrix(object$coefs[1,], nrow=nrow(pred_x_basis),
ncol=ncol(object$coefs), byrow = TRUE)
} else {
preds <- as.vector(Matrix::tcrossprod(
x = pred_x_basis,
y = as.vector(object$coefs)
y = as.matrix(object$coefs,nrow = 1)
))
}
} else if (family == "mgaussian") {
Expand Down
20 changes: 0 additions & 20 deletions man/formula_helpers.Rd

This file was deleted.

14 changes: 14 additions & 0 deletions man/generate_all_rules.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 977427d

Please sign in to comment.