如何在R语言决策树节点中叠加直方图(基于rpart.plot)
Absolutely right—node.fun in the rpart.plot package is perfect for adding histograms directly inside decision tree nodes. Let’s break down exactly how to make this work with concrete examples.
Step 1: Set Up Your Environment and Model
First, load the required packages and train a basic decision tree. We’ll use the mtcars dataset for demonstration, but you can swap in your own data.
library(rpart) library(rpart.plot)
Train a tree (we’ll predict transmission type am using mpg, wt, and hp):
fit <- rpart(am ~ mpg + wt + hp, data = mtcars)
Step 2: Define a Custom node.fun for Histograms
The node.fun function receives coordinates and metadata for each node, which we’ll use to draw a histogram tailored to that node’s sample data.
node_hist_fun <- function(x, y, w, h, node, digits, varlen, faclen) { # Grab the subset of data belonging to this node node_samples <- mtcars[fit$where == node, ] # Reset the plotting coordinate system to fit inside the node box par(usr = c(x - w/2, x + w/2, y - h/2, y + h/2)) # Draw the histogram (tweaked for binary response here) hist(node_samples$am, breaks = c(-0.5, 0.5, 1.5), # Match binary 0/1 values col = c("#66c2a5", "#fc8d62"), border = "white", main = "", axes = FALSE, xlab = "", ylab = "", ylim = c(0, nrow(node_samples)) # Scale to node sample count ) # Optional: Add sample count above the histogram text(x, y + h/2 - 0.05, paste("n =", nrow(node_samples)), cex = 0.8) }
Step 3: Plot the Tree with Overlaid Histograms
Call rpart.plot and pass in your custom node.fun. We’ll turn off default node text/extra info to make room for the histograms:
rpart.plot(fit, type = 0, # Use empty node boxes for our custom content extra = 0, # Disable default extra statistics node.fun = node_hist_fun, box.palette = "Blues", # Optional: Node border color branch.lty = 1, cex = 0.8)
For Continuous Response Variables
If you’re predicting a continuous outcome (like mpg), adjust the histogram to fit continuous data:
# Train a regression tree fit_reg <- rpart(mpg ~ wt + hp + disp, data = mtcars) # Updated node function for continuous data node_hist_fun_reg <- function(x, y, w, h, node, digits, varlen, faclen) { node_samples <- mtcars[fit_reg$where == node, ] par(usr = c(x - w/2, x + w/2, y - h/2, y + h/2)) hist(node_samples$mpg, col = "#8da0cb", border = "white", main = "", axes = FALSE, xlab = "", ylab = "", ylim = c(0, nrow(node_samples)) ) # Add sample count and mean value text(x, y + h/2 - 0.05, paste("n =", nrow(node_samples)), cex = 0.8) text(x, y - h/2 + 0.05, paste("mean =", round(mean(node_samples$mpg), 1)), cex = 0.7) } # Plot the regression tree rpart.plot(fit_reg, type = 0, extra = 0, node.fun = node_hist_fun_reg, box.palette = "Purples", cex = 0.8)
Key Notes
fit$where == nodefilters your original data to only the samples in the current node—this ensures each histogram reflects the node’s specific subset.par(usr)redefines the plotting area to match the node’s dimensions, so the histogram fits perfectly inside the box.- You can tweak colors, bin widths, and added text to match your visualization needs.
内容的提问来源于stack exchange,提问作者Ushuaia81




