-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R-package] add a tree plotting function #6729
base: master
Are you sure you want to change the base?
Conversation
Added DiagrammeR as suggested in DESCRIPTION Added lgb.plot.tree in _pkgdown.yml Roxygenized.
@microsoft-github-policy-service agree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your interest in LightGBM.
As I mentioned in the discussion on #1222, I'm supportive of trying to add something like this (especially since xgboost has it as well).
But I hope you'll see from the first round of suggestions I left here... significant work remains before I'd support merging this change into the package. If you are willing to work with us on this and go through multiple rounds of reviews and suggestions, we'd be grateful for the help! But if you don't have the time/interest to get this ready for inclusion in the package, please let me know and we'll close this PR and leave #1222 open for someone else to pick up.
R-package/R/lgb.plot.tree.R
Outdated
@@ -0,0 +1,184 @@ | |||
#' @name lgb.plot.tree | |||
#' @title Plot a single LightGBM tree using DiagrammeR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#' @title Plot a single LightGBM tree using DiagrammeR. | |
#' @title Plot a single LightGBM tree. |
Let's simplify this, please.
R-package/R/lgb.plot.tree.R
Outdated
|
||
# function to plot a single LightGBM tree using DiagrammeR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# function to plot a single LightGBM tree using DiagrammeR |
We do not need to repeat in a comment here the same information that's already in the roxygen comments.
R-package/R/lgb.plot.tree.R
Outdated
if (!inherits(model, "lgb.Booster")) { | ||
stop("model: Has to be an object of class lgb.Booster") | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (!inherits(model, "lgb.Booster")) { | |
stop("model: Has to be an object of class lgb.Booster") | |
} | |
if (!.is_Booster(x = model)) { | |
stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster")) | |
} |
Please follow the patterns used elsewhere in the library for this:
LightGBM/R-package/R/lgb.restore_handle.R
Lines 42 to 44 in 83c0ff3
if (!.is_Booster(x = model)) { | |
stop("lgb.restore_handle: model should be an ", sQuote("lgb.Booster")) | |
} |
R-package/R/lgb.plot.tree.R
Outdated
stop("tree: Has to be an integer numeric") | ||
} | ||
# extract data.table model structure | ||
dt <- lgb.model.dt.tree(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dt <- lgb.model.dt.tree(model) | |
modelDT <- lgb.model.dt.tree(model) |
Please don't use the name dt
. That is a function in the {stats}
package (for finding the density of a t-distribution)... try ?dt
to see that.
Shadowing names from the standard library can lead to confusing errors. Please use modelDT
as the name for this data.table
instead.
R-package/R/lgb.plot.tree.R
Outdated
nodes_df = nodes, | ||
edges_df = edges, | ||
attr_theme = NULL | ||
) %>% |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this project, by convention we:
- do not use the
%>%
operator - use comma-first style everywhere
Please update this code and all the other code you're adding to follow that. Keeping all of the code looking the same across the codebase helps us to develop and review changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xgboost
's implementation of similar functionality might be useful as a reference. See https://github.com/dmlc/xgboost/blob/e988b7cf1515b08ad0f949c26beb043ce0b33fe8/R-package/R/xgb.plot.tree.R#L159-L181
@@ -0,0 +1,59 @@ | |||
test_that("lgb.plot.tree works as expected"){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also add tests for the other types of machine learning tasks LightGBM can be used for:
- binary classification
- multiclass classification (where, please note, there are
num_classes
trees produced per iteration) - learning-to-rank
And for the following model situations:
- uses categorical features
These are all cases that could affect the code as written... for example, categorical features have different splitting rules.
R-package/R/lgb.plot.tree.R
Outdated
dt[, Value := 0.0] | ||
dt[, Value := leaf_value] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dt[, Value := 0.0] | |
dt[, Value := leaf_value] | |
dt[, Value := leaf_value] |
I don't understand this... what's the purpose of setting all rows to 0.0
and then immediately overwriting them? It seems to me that the 0.0
could probably be removed.
R-package/R/lgb.plot.tree.R
Outdated
dt[is.na(Value), Value := internal_value] | ||
dt[is.na(Gain), Gain := leaf_value] | ||
dt[is.na(Feature), Feature := "Leaf"] | ||
dt[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count] | ||
dt[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL] | ||
dt[, Node := split_index] | ||
max_node <- max(dt[["Node"]], na.rm = TRUE) | ||
dt[is.na(Node), Node := max_node + leaf_index + 1] | ||
dt[, ID := paste(Tree, Node, sep = "-")] | ||
dt[, c("depth", "leaf_index") := NULL] | ||
dt[, parent := node_parent][is.na(parent), parent := leaf_parent] | ||
dt[, c("node_parent", "leaf_parent", "split_index") := NULL] | ||
dt[, Yes := dt$ID[match(dt$Node, dt$parent)]] | ||
dt <- dt[nrow(dt):1, ] | ||
dt[, No := dt$ID[match(dt$Node, dt$parent)]] | ||
# which way do the NA's go (this path will get a thicker arrow) | ||
# for categorical features, NA gets put into the zero group | ||
dt[default_left == TRUE, Missing := Yes] | ||
dt[default_left == FALSE, Missing := No] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add some comments to make it a bit easier to understand what's happening in this wall of code? It's very difficult to read (at least for me) as currently written).
# trees start from 0 in lgb.model.dt.tree | ||
tree_table <- lgb.model.dt.tree(model) | ||
expect_error({ | ||
lgb.plot.tree(model, 999)TRUE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgb.plot.tree(model, 999)TRUE | |
lgb.plot.tree(model, 999) |
This looks like it was included accidentally?
DiagrammeR in CI. Error messages. Default parameters. Changed tests.
…ree.R) Now tests regressions, binary, multiclass classification and ranks.
Added a warning to functions and shorter stop message to make tests work.
Thanks for the review @jameslamb, helped me a lot! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I left a few more comments for your consideration. To be clear, I still haven't very-thoroughly reviewed so this is not a comprehensive list... these are just quick things I noticed in the few minutes I had to review.
In addition to those... it'd be helpful if you could include some screenshots of what the plots look like, in the description of the PR. That'll really help me understand what the goal is here, without needing to run this code myself.
#' } | ||
#' | ||
#' @export | ||
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) { | |
lgb.plot.tree <- function(model, tree, rules = NULL) { |
I can't think of any situation where it would be ok for model
or tree
to be NULL
, can you?
If not, let's please require callers to provide values explicitly.
# tree must be numeric | ||
if (!inherits(tree, "numeric")) { | ||
stop("lgb.plot.tree: Has to be an integer numeric") | ||
} | ||
# tree must be integer | ||
if (tree %% 1 != 0) { | ||
stop("lgb.plot.tree: Has to be an integer numeric") | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# tree must be numeric | |
if (!inherits(tree, "numeric")) { | |
stop("lgb.plot.tree: Has to be an integer numeric") | |
} | |
# tree must be integer | |
if (tree %% 1 != 0) { | |
stop("lgb.plot.tree: Has to be an integer numeric") | |
} | |
# tree must be numeric | |
tree <- as.integer(tree) | |
if (length(tree) != 1L || tree < 1L) { | |
stop(sprintf("lgb.plot.tree: 'tree' must be a single, positive integer.) | |
} |
Let's combine these, and make it clear what has to be an integer.
warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".") | ||
stop("lgb.plot.tree: Invalid tree number") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".") | |
stop("lgb.plot.tree: Invalid tree number") | |
stop("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".") |
What's the reason for having all of the information in a warning()
and then immediately raising an error after? If there isn't a specific reason, then let's please combine these for simplicity and to make the logs easier for users to understand.
return(invisible(NULL)) | ||
} | ||
|
||
.levels.to.names <- function(x, feature_name, rules) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.levels.to.names <- function(x, feature_name, rules) { | |
.levels_to_names <- function(x, feature_name, rules) { |
Please avoid using .
in any of these private functions' names.
#' @description The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. | ||
#' @param model a \code{lgb.Booster} object. | ||
#' @param tree an integer specifying the tree to plot. This is 1-based, so e.g. a value of '7' means 'the 7th tree' (tree_index=6 in LightGBM's underlying representation). | ||
#' @param rules a list of rules to replace the split values with feature levels. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not totally convinced about this idea... it should be possible to recover the feature names from the model directly.
But before you remove this... can you please expand this doc and add examples and tests showing what this would look like? Right now, it's hard for me to understand what the content of rules
is supposed to be.
data.table::setnames(modelDT | ||
, old = c("tree_index", "split_feature", "threshold", "split_gain") | ||
, new = c("Tree", "Feature", "Split", "Gain")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data.table::setnames(modelDT | |
, old = c("tree_index", "split_feature", "threshold", "split_gain") | |
, new = c("Tree", "Feature", "Split", "Gain")) | |
data.table::setnames( | |
modelDT | |
, old = c("tree_index", "split_feature", "threshold", "split_gain") | |
, new = c("Tree", "Feature", "Split", "Gain") | |
) |
Please, follow the style the rest of the project uses. I suspect that the linting configuration here would have caught this (not sure, as I haven't run it myself and it failed in CI for other unrelated reasons).
From this point forward, before you push a commit please run the R-code linting and fix any issues it reports.
From the root of the repo:
Rscript ./.ci/lint-r-code.R ./R-package
Feature requested in #1222
Added a R function to plot trees.
Basically used the code posted in #1222 by @SpeckledJim2 and followed the given instruction.