rpart, rpart.plot package를 활용한 decision tree plotting 코드이다.
R packages
library(partykit)
library(rpart)
library(rpart.plot)
library(dplyr)
library(ggparty)
Classification
rpart_fit = rpart(Species ~.,
data = iris)
png(file.path(plot_dir, "Test_classification2.png"),
width = 1300, height = 1000, res = 150)
ggparty::ggparty(partykit::as.party(rpart_fit)) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
geom_node_plot(gglist = list(geom_bar(aes(x = Species,
fill = Species)),
scale_fill_manual(values = c("deeppink1",
"lightgray",
"black")),
theme_bw(),
theme(legend.position = "none"))) +
ggtitle("Give me title") +
theme(plot.title = element_text(hjust = 0.5))
dev.off()
Regression
rpart_fit = rpart(UrbanPop ~ .,
data = USArrests)
png(file.path(plot_dir, "Test_regression2.png"),
width = 1300, height = 1000, res = 150)
ggparty::ggparty(partykit::as.party(rpart_fit)) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
geom_node_plot(gglist = list(geom_boxplot(aes(y = UrbanPop),
fill = "darkslategray3",
width = 0.5),
theme_bw(),
theme(axis.title.x = element_blank(),
axis.text.x = element_blank(),
axis.ticks.x = element_blank())
)) +
ggtitle("Give me title") +
theme(plot.title = element_text(hjust = 0.5))
dev.off()