-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R-package] add a tree plotting function #6729
base: master
Are you sure you want to change the base?
Changes from all commits
6862821
0a7ea0e
5206b11
757dc84
55aba68
85ff97a
b4b648a
ed62441
2710705
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,204 @@ | ||||||||||||||||||||||||||||
#' @name lgb.plot.tree | ||||||||||||||||||||||||||||
#' @title Plot a single LightGBM tree. | ||||||||||||||||||||||||||||
#' @description The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. | ||||||||||||||||||||||||||||
#' @param model a \code{lgb.Booster} object. | ||||||||||||||||||||||||||||
#' @param tree an integer specifying the tree to plot. This is 1-based, so e.g. a value of '7' means 'the 7th tree' (tree_index=6 in LightGBM's underlying representation). | ||||||||||||||||||||||||||||
#' @param rules a list of rules to replace the split values with feature levels. | ||||||||||||||||||||||||||||
#' | ||||||||||||||||||||||||||||
#' @return | ||||||||||||||||||||||||||||
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot. | ||||||||||||||||||||||||||||
#' | ||||||||||||||||||||||||||||
#' @details | ||||||||||||||||||||||||||||
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. The tree is extracted from the model and displayed as a directed graph. The nodes are labelled with the feature, split value, gain, cover and value. The edges are labelled with the decision type and split value. | ||||||||||||||||||||||||||||
#' | ||||||||||||||||||||||||||||
#' @examples | ||||||||||||||||||||||||||||
#' \donttest{ | ||||||||||||||||||||||||||||
#' # EXAMPLE: use the LightGBM example dataset to build a model with a single tree | ||||||||||||||||||||||||||||
#' data(agaricus.train, package = "lightgbm") | ||||||||||||||||||||||||||||
#' train <- agaricus.train | ||||||||||||||||||||||||||||
#' dtrain <- lgb.Dataset(train$data, label = train$label) | ||||||||||||||||||||||||||||
#' data(agaricus.test, package = "lightgbm") | ||||||||||||||||||||||||||||
#' test <- agaricus.test | ||||||||||||||||||||||||||||
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label) | ||||||||||||||||||||||||||||
#' # define model parameters and build a single tree | ||||||||||||||||||||||||||||
#' params <- list( | ||||||||||||||||||||||||||||
#' objective = "regression", | ||||||||||||||||||||||||||||
#' min_data = 1L, | ||||||||||||||||||||||||||||
#' ) | ||||||||||||||||||||||||||||
#' valids <- list(test = dtest) | ||||||||||||||||||||||||||||
#' model <- lgb.train( | ||||||||||||||||||||||||||||
#' params = params, | ||||||||||||||||||||||||||||
#' data = dtrain, | ||||||||||||||||||||||||||||
#' nrounds = 1L, | ||||||||||||||||||||||||||||
#' valids = valids, | ||||||||||||||||||||||||||||
#' early_stopping_rounds = 1L | ||||||||||||||||||||||||||||
#' ) | ||||||||||||||||||||||||||||
#' # plot the tree and compare to the tree table | ||||||||||||||||||||||||||||
#' # trees start from 0 in lgb.model.dt.tree | ||||||||||||||||||||||||||||
#' tree_table <- lgb.model.dt.tree(model) | ||||||||||||||||||||||||||||
#' lgb.plot.tree(model, 0) | ||||||||||||||||||||||||||||
#' } | ||||||||||||||||||||||||||||
#' | ||||||||||||||||||||||||||||
#' @export | ||||||||||||||||||||||||||||
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) { | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I can't think of any situation where it would be ok for If not, let's please require callers to provide values explicitly. |
||||||||||||||||||||||||||||
# check model is lgb.Booster | ||||||||||||||||||||||||||||
if (!.is_Booster(x = model)) { | ||||||||||||||||||||||||||||
stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster")) | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
# check DiagrammeR is available | ||||||||||||||||||||||||||||
if (!requireNamespace("DiagrammeR", quietly = TRUE)) { | ||||||||||||||||||||||||||||
stop("lgb.plot.tree: DiagrammeR package is required", | ||||||||||||||||||||||||||||
call. = FALSE | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
# tree must be numeric | ||||||||||||||||||||||||||||
if (!inherits(tree, "numeric")) { | ||||||||||||||||||||||||||||
stop("lgb.plot.tree: Has to be an integer numeric") | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
# tree must be integer | ||||||||||||||||||||||||||||
if (tree %% 1 != 0) { | ||||||||||||||||||||||||||||
stop("lgb.plot.tree: Has to be an integer numeric") | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
Comment on lines
+54
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Let's combine these, and make it clear what has to be an integer. |
||||||||||||||||||||||||||||
# extract data.table model structure | ||||||||||||||||||||||||||||
modelDT <- lgb.model.dt.tree(model) | ||||||||||||||||||||||||||||
# check that tree is less than or equal to the maximum tree index in the model | ||||||||||||||||||||||||||||
if (tree > max(modelDT$tree_index) || tree < 1) { | ||||||||||||||||||||||||||||
warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".") | ||||||||||||||||||||||||||||
stop("lgb.plot.tree: Invalid tree number") | ||||||||||||||||||||||||||||
Comment on lines
+66
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
What's the reason for having all of the information in a |
||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please modify this error message so that it has enough information for someone to quickly debug the issue, like the provided value of Something like this:
|
||||||||||||||||||||||||||||
# filter modelDT to just the rows for the selected tree | ||||||||||||||||||||||||||||
modelDT <- modelDT[tree_index == tree, ] | ||||||||||||||||||||||||||||
# change the column names to shorter more diagram friendly versions | ||||||||||||||||||||||||||||
data.table::setnames(modelDT | ||||||||||||||||||||||||||||
, old = c("tree_index", "split_feature", "threshold", "split_gain") | ||||||||||||||||||||||||||||
, new = c("Tree", "Feature", "Split", "Gain")) | ||||||||||||||||||||||||||||
Comment on lines
+72
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Please, follow the style the rest of the project uses. I suspect that the linting configuration here would have caught this (not sure, as I haven't run it myself and it failed in CI for other unrelated reasons). From this point forward, before you push a commit please run the R-code linting and fix any issues it reports. From the root of the repo: Rscript ./.ci/lint-r-code.R ./R-package |
||||||||||||||||||||||||||||
# assign leaf_value to the Value column in modelDT | ||||||||||||||||||||||||||||
modelDT[, Value := leaf_value] | ||||||||||||||||||||||||||||
# assign new values if NA | ||||||||||||||||||||||||||||
modelDT[is.na(Value), Value := internal_value] | ||||||||||||||||||||||||||||
modelDT[is.na(Gain), Gain := leaf_value] | ||||||||||||||||||||||||||||
modelDT[is.na(Feature), Feature := "Leaf"] | ||||||||||||||||||||||||||||
# assign internal_count to Cover, and if Feature is "Leaf", assign leaf_count to Cover | ||||||||||||||||||||||||||||
modelDT[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count] | ||||||||||||||||||||||||||||
# remove unnecessary columns | ||||||||||||||||||||||||||||
modelDT[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL] | ||||||||||||||||||||||||||||
# assign split_index to Node | ||||||||||||||||||||||||||||
modelDT[, Node := split_index] | ||||||||||||||||||||||||||||
# find the maximum value of Node, if Node is NA, assign max_node + leaf_index + 1 to Node | ||||||||||||||||||||||||||||
max_node <- max(modelDT[["Node"]], na.rm = TRUE) | ||||||||||||||||||||||||||||
modelDT[is.na(Node), Node := max_node + leaf_index + 1] | ||||||||||||||||||||||||||||
# adding ID column | ||||||||||||||||||||||||||||
modelDT[, ID := paste(Tree, Node, sep = "-")] | ||||||||||||||||||||||||||||
# remove unnecessary columns | ||||||||||||||||||||||||||||
modelDT[, c("depth", "leaf_index") := NULL] | ||||||||||||||||||||||||||||
modelDT[, parent := node_parent][is.na(parent), parent := leaf_parent] | ||||||||||||||||||||||||||||
modelDT[, c("node_parent", "leaf_parent", "split_index") := NULL] | ||||||||||||||||||||||||||||
# assign the IDs of the matching parent nodes to Yes and No | ||||||||||||||||||||||||||||
modelDT[, Yes := modelDT$ID[match(modelDT$Node, modelDT$parent)]] | ||||||||||||||||||||||||||||
modelDT <- modelDT[nrow(modelDT):1, ] | ||||||||||||||||||||||||||||
modelDT[, No := modelDT$ID[match(modelDT$Node, modelDT$parent)]] | ||||||||||||||||||||||||||||
# which way do the NA's go (this path will get a thicker arrow) | ||||||||||||||||||||||||||||
# for categorical features, NA gets put into the zero group | ||||||||||||||||||||||||||||
modelDT[default_left == TRUE, Missing := Yes] | ||||||||||||||||||||||||||||
modelDT[default_left == FALSE, Missing := No] | ||||||||||||||||||||||||||||
modelDT[.zero_present(Split), Missing := Yes] | ||||||||||||||||||||||||||||
# create the label text | ||||||||||||||||||||||||||||
modelDT[, label := paste0( | ||||||||||||||||||||||||||||
Feature | ||||||||||||||||||||||||||||
, "\nCover: " | ||||||||||||||||||||||||||||
, Cover | ||||||||||||||||||||||||||||
, ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf" | ||||||||||||||||||||||||||||
, "" | ||||||||||||||||||||||||||||
, round(Gain, 4)) | ||||||||||||||||||||||||||||
, "\nValue: " | ||||||||||||||||||||||||||||
, round(Value, 4) | ||||||||||||||||||||||||||||
)] | ||||||||||||||||||||||||||||
# style the nodes - same format as xgboost | ||||||||||||||||||||||||||||
modelDT[Node == 0, label := paste0("Tree ", Tree, "\n", label)] | ||||||||||||||||||||||||||||
modelDT[, shape := "rectangle"][Feature == "Leaf", shape := "oval"] | ||||||||||||||||||||||||||||
modelDT[, filledcolor := "Beige"][Feature == "Leaf", filledcolor := "Khaki"] | ||||||||||||||||||||||||||||
# in order to draw the first tree on top: | ||||||||||||||||||||||||||||
modelDT <- modelDT[order(-Tree)] | ||||||||||||||||||||||||||||
nodes <- DiagrammeR::create_node_df( | ||||||||||||||||||||||||||||
n = nrow(modelDT) | ||||||||||||||||||||||||||||
, ID = modelDT$ID | ||||||||||||||||||||||||||||
, label = modelDT$label | ||||||||||||||||||||||||||||
, fillcolor = modelDT$filledcolor | ||||||||||||||||||||||||||||
, shape = modelDT$shape | ||||||||||||||||||||||||||||
, data = modelDT$Feature | ||||||||||||||||||||||||||||
, fontcolor = "black" | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
# round the edge labels to 4 s.f. if they are numeric | ||||||||||||||||||||||||||||
# as otherwise get too many decimal places and the diagram looks bad | ||||||||||||||||||||||||||||
# would rather not use suppressWarnings | ||||||||||||||||||||||||||||
numeric_idx <- suppressWarnings(!is.na(as.numeric(modelDT[["Split"]]))) | ||||||||||||||||||||||||||||
modelDT[numeric_idx, Split := round(as.numeric(Split), 4)] | ||||||||||||||||||||||||||||
# replace indices with feature levels if rules supplied | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if (!is.null(rules)) { | ||||||||||||||||||||||||||||
for (f in names(rules)) { | ||||||||||||||||||||||||||||
modelDT[Feature == f & decision_type == "==", Split := .levels.to.names(Split, f, rules)] | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
# replace long split names with a message | ||||||||||||||||||||||||||||
modelDT[nchar(Split) > 500, Split := "Split too long to render"] | ||||||||||||||||||||||||||||
# create the edge labels | ||||||||||||||||||||||||||||
edges <- DiagrammeR::create_edge_df( | ||||||||||||||||||||||||||||
from = match(modelDT[Feature != "Leaf", c(ID)] %>% rep(2), modelDT$ID), | ||||||||||||||||||||||||||||
to = match(modelDT[Feature != "Leaf", c(Yes, No)], modelDT$ID), | ||||||||||||||||||||||||||||
label = modelDT[Feature != "Leaf", paste(decision_type, Split)] %>% | ||||||||||||||||||||||||||||
c(rep("", nrow(modelDT[Feature != "Leaf"]))), | ||||||||||||||||||||||||||||
style = modelDT[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>% | ||||||||||||||||||||||||||||
c(modelDT[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]), | ||||||||||||||||||||||||||||
rel = "leading_to" | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
# create the graph | ||||||||||||||||||||||||||||
graph <- DiagrammeR::create_graph( | ||||||||||||||||||||||||||||
nodes_df = nodes | ||||||||||||||||||||||||||||
, edges_df = edges | ||||||||||||||||||||||||||||
, attr_theme = NULL | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
graph <- DiagrammeR::add_global_graph_attrs( | ||||||||||||||||||||||||||||
graph = graph | ||||||||||||||||||||||||||||
, attr_type = "graph" | ||||||||||||||||||||||||||||
, attr = c("layout", "rankdir") | ||||||||||||||||||||||||||||
, value = c("dot", "LR") | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
graph <- DiagrammeR::add_global_graph_attrs( | ||||||||||||||||||||||||||||
graph = graph | ||||||||||||||||||||||||||||
, attr_type = "node" | ||||||||||||||||||||||||||||
, attr = c("color", "style", "fontname") | ||||||||||||||||||||||||||||
, value = c("DimGray", "filled", "Helvetica") | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
graph <- DiagrammeR::add_global_graph_attrs( | ||||||||||||||||||||||||||||
graph = graph | ||||||||||||||||||||||||||||
, attr_type = "edge" | ||||||||||||||||||||||||||||
, attr = c("color", "arrowsize", "arrowhead", "fontname") | ||||||||||||||||||||||||||||
, value = c("DimGray", "1.5", "vee", "Helvetica") | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
# render the graph | ||||||||||||||||||||||||||||
DiagrammeR::render_graph(graph) | ||||||||||||||||||||||||||||
return(invisible(NULL)) | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
.zero_present <- function(x) { | ||||||||||||||||||||||||||||
sapply(strsplit(as.character(x), "||", fixed = TRUE), function(el) { | ||||||||||||||||||||||||||||
any(el == "0") | ||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||
return(invisible(NULL)) | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
.levels.to.names <- function(x, feature_name, rules) { | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Please avoid using |
||||||||||||||||||||||||||||
lvls <- sort(rules[[feature_name]]) | ||||||||||||||||||||||||||||
result <- strsplit(x, "||", fixed = TRUE) | ||||||||||||||||||||||||||||
result <- lapply(result, as.numeric) | ||||||||||||||||||||||||||||
result <- lapply(result, .levels_to_names) | ||||||||||||||||||||||||||||
result <- lapply(result, paste, collapse = "\n") | ||||||||||||||||||||||||||||
result <- as.character(result) | ||||||||||||||||||||||||||||
return(invisible(NULL)) | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
.levels_to_names <- function(x) { | ||||||||||||||||||||||||||||
names(lvls)[as.numeric(x)] | ||||||||||||||||||||||||||||
return(invisible(NULL)) | ||||||||||||||||||||||||||||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not totally convinced about this idea... it should be possible to recover the feature names from the model directly.
But before you remove this... can you please expand this doc and add examples and tests showing what this would look like? Right now, it's hard for me to understand what the content of
rules
is supposed to be.