-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpartialPlotBagging.R
98 lines (98 loc) · 3.34 KB
/
partialPlotBagging.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
partialPlotBagging <- function (x, pred.data, x.var, which.class, w, plot = TRUE, add = FALSE,
n.pt = min(length(unique(pred.data[, xname])), 51), rug = TRUE,
xlab = deparse(substitute(x.var)), ylab = "", main = paste("Partial Dependence on",
deparse(substitute(x.var))), ...)
{
classBAG <- class(x)=="bagging"
if (is.null(x$trees))
stop("The object must contain the trees.\n")
x.var <- substitute(x.var)
xname <- if (is.character(x.var))
x.var
else {
if (is.name(x.var))
deparse(x.var)
else {
eval(x.var)
}
}
xv <- pred.data[, xname]
n <- nrow(pred.data)
if (missing(w))
w <- rep(1, n)
if (classBAG) {
if (missing(which.class)) {
focus <- 1
}
else {
focus <- charmatch(which.class, colnames(x$votes))
if (is.na(focus))
stop(which.class, "is not one of the class labels.")
}
}
if (is.factor(xv) && !is.ordered(xv)) {
x.pt <- levels(xv)
y.pt <- numeric(length(x.pt))
for (i in seq(along = x.pt)) {
x.data <- pred.data
x.data[, xname] <- factor(rep(x.pt[i], n), levels = x.pt)
if (classBAG) {
pr <- predict(x, x.data)$prob
colnames(pr) <- colnames(x$votes)
y.pt[i] <- weighted.mean(log(ifelse(pr[, focus] >
0, pr[, focus], .Machine$double.eps)) - rowMeans(log(ifelse(pr >
0, pr, .Machine$double.eps))), w, na.rm = TRUE)
}
else y.pt[i] <- weighted.mean(predict(x, x.data),
w, na.rm = TRUE)
}
if (add) {
points(1:length(x.pt), y.pt, type = "h", lwd = 2,
...)
}
else {
if (plot)
barplot(y.pt, width = rep(1, length(y.pt)), col = "blue",
xlab = xlab, ylab = ylab, main = main, names.arg = x.pt,
...)
}
}
else {
if (is.ordered(xv))
xv <- as.numeric(xv)
x.pt <- seq(min(xv), max(xv), length = n.pt)
y.pt <- numeric(length(x.pt))
for (i in seq(along = x.pt)) {
x.data <- pred.data
x.data[, xname] <- rep(x.pt[i], n)
if (classBAG) {
pr <- predict(x, x.data)$prob
colnames(pr) <- colnames(x$votes)
y.pt[i] <- weighted.mean(log(ifelse(pr[, focus] ==
0, .Machine$double.eps, pr[, focus])) - rowMeans(log(ifelse(pr ==
0, .Machine$double.eps, pr))), w, na.rm = TRUE)
}
else {
y.pt[i] <- weighted.mean(predict(x, x.data),
w, na.rm = TRUE)
}
}
if (add) {
lines(x.pt, y.pt, ...)
}
else {
if (plot)
plot(x.pt, y.pt, type = "l", xlab = xlab, ylab = ylab,
main = main, ...)
}
if (rug && plot) {
if (n.pt > 10) {
rug(quantile(xv, seq(0.1, 0.9, by = 0.1)), side = 1)
}
else {
rug(unique(xv, side = 1))
}
}
}
invisible(list(x = x.pt, y = y.pt))
}