You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何在R语言决策树节点中叠加直方图(基于rpart.plot)

Absolutely right—node.fun in the rpart.plot package is perfect for adding histograms directly inside decision tree nodes. Let’s break down exactly how to make this work with concrete examples.

Step 1: Set Up Your Environment and Model

First, load the required packages and train a basic decision tree. We’ll use the mtcars dataset for demonstration, but you can swap in your own data.

library(rpart)
library(rpart.plot)

Train a tree (we’ll predict transmission type am using mpg, wt, and hp):

fit <- rpart(am ~ mpg + wt + hp, data = mtcars)

Step 2: Define a Custom node.fun for Histograms

The node.fun function receives coordinates and metadata for each node, which we’ll use to draw a histogram tailored to that node’s sample data.

node_hist_fun <- function(x, y, w, h, node, digits, varlen, faclen) {
  # Grab the subset of data belonging to this node
  node_samples <- mtcars[fit$where == node, ]
  
  # Reset the plotting coordinate system to fit inside the node box
  par(usr = c(x - w/2, x + w/2, y - h/2, y + h/2))
  
  # Draw the histogram (tweaked for binary response here)
  hist(node_samples$am,
       breaks = c(-0.5, 0.5, 1.5), # Match binary 0/1 values
       col = c("#66c2a5", "#fc8d62"),
       border = "white",
       main = "", axes = FALSE, xlab = "", ylab = "",
       ylim = c(0, nrow(node_samples)) # Scale to node sample count
  )
  
  # Optional: Add sample count above the histogram
  text(x, y + h/2 - 0.05, paste("n =", nrow(node_samples)), cex = 0.8)
}

Step 3: Plot the Tree with Overlaid Histograms

Call rpart.plot and pass in your custom node.fun. We’ll turn off default node text/extra info to make room for the histograms:

rpart.plot(fit,
           type = 0,          # Use empty node boxes for our custom content
           extra = 0,         # Disable default extra statistics
           node.fun = node_hist_fun,
           box.palette = "Blues", # Optional: Node border color
           branch.lty = 1,
           cex = 0.8)

For Continuous Response Variables

If you’re predicting a continuous outcome (like mpg), adjust the histogram to fit continuous data:

# Train a regression tree
fit_reg <- rpart(mpg ~ wt + hp + disp, data = mtcars)

# Updated node function for continuous data
node_hist_fun_reg <- function(x, y, w, h, node, digits, varlen, faclen) {
  node_samples <- mtcars[fit_reg$where == node, ]
  par(usr = c(x - w/2, x + w/2, y - h/2, y + h/2))
  
  hist(node_samples$mpg,
       col = "#8da0cb",
       border = "white",
       main = "", axes = FALSE, xlab = "", ylab = "",
       ylim = c(0, nrow(node_samples))
  )
  
  # Add sample count and mean value
  text(x, y + h/2 - 0.05, paste("n =", nrow(node_samples)), cex = 0.8)
  text(x, y - h/2 + 0.05, paste("mean =", round(mean(node_samples$mpg), 1)), cex = 0.7)
}

# Plot the regression tree
rpart.plot(fit_reg,
           type = 0,
           extra = 0,
           node.fun = node_hist_fun_reg,
           box.palette = "Purples",
           cex = 0.8)

Key Notes

  • fit$where == node filters your original data to only the samples in the current node—this ensures each histogram reflects the node’s specific subset.
  • par(usr) redefines the plotting area to match the node’s dimensions, so the histogram fits perfectly inside the box.
  • You can tweak colors, bin widths, and added text to match your visualization needs.

内容的提问来源于stack exchange,提问作者Ushuaia81

火山引擎 最新活动