| Title: | Reinforcement Learning Trees |
|---|---|
| Description: | Random forest with a variety of additional features for regression, classification, and survival analysis. Features include parallel computing with OpenMP, reproducibility with random seeds, variance and confidence band estimations using U-statistics, embedded model for selecting splitting variables and constructing linear combination splits, permutation and distribution-based variable importance, observation and variable weights, and subject tracking across trees. |
| Authors: | Ruoqing Zhu [aut, cre, cph]
|
| Maintainer: | Ruoqing Zhu <[email protected]> |
| License: | GPL (>= 3) |
| Version: | 6.0.2 |
| Built: | 2026-06-01 17:03:51 UTC |
| Source: | https://github.com/teazrq/rlt |
Calculate c-index for survival data
cindex(y, censor, pred)cindex(y, censor, pred)
y |
survival time |
censor |
The censoring indicator if survival model is used |
pred |
the predicted value for each subject |
c-index
set.seed(42) n <- 100 x <- matrix(rnorm(n * 5), ncol = 5) y <- rexp(n, rate = exp(rowSums(x[, 1:2]))) censor <- rbinom(n, 1, 0.7) fit <- RLT(x, y, censor = censor, model = "survival", ntrees = 100) # Use cumulative hazard at last timepoint as risk score risk <- fit$Prediction[, ncol(fit$Prediction)] cindex(y, censor, risk)set.seed(42) n <- 100 x <- matrix(rnorm(n * 5), ncol = 5) y <- rexp(n, rate = exp(rowSums(x[, 1:2]))) censor <- rbinom(n, 1, 0.7) fit <- RLT(x, y, censor = censor, model = "survival", ntrees = 100) # Use cumulative hazard at last timepoint as risk score risk <- fit$Prediction[, ncol(fit$Prediction)] cindex(y, censor, risk)
Get random forest induced kernel weight matrix of testing samples
or between any two sets of data. This is an experimental feature.
Use at your own risk.
forest.kernel( object, X1 = NULL, X2 = NULL, vs.train = FALSE, verbose = FALSE, ... )forest.kernel( object, X1 = NULL, X2 = NULL, vs.train = FALSE, verbose = FALSE, ... )
object |
A fitted RLT object. |
X1 |
The dataset for prediction. This calculates an |
X2 |
The dataset for reference/training.
If |
vs.train |
To calculate the kernel weights with respect to the training data.
This is slightly different than supplying the training data to |
verbose |
Whether fitting should be printed. |
... |
... Additional arguments. |
A kernel matrix that contains kernel weights for each observation in X1 with respect to X1
set.seed(42) x <- matrix(rnorm(200 * 5), ncol = 5) y <- rowSums(x[, 1:2]) + rnorm(200) fit <- RLT(x, y, ntrees = 100) K <- forest.kernel(fit, X1 = x[1:5, ]) print(K$Kernel[1:3, 1:3])set.seed(42) x <- matrix(rnorm(200 * 5), ncol = 5) y <- rowSums(x[, 1:2]) + rnorm(200) fit <- RLT(x, y, ntrees = 100) K <- forest.kernel(fit, X1 = x[1:5, ]) print(K$Kernel[1:3, 1:3])
Print a single fitted tree from a forest object
get.one.tree(x, tree = 1, ...)get.one.tree(x, tree = 1, ...)
x |
A fitted RLT object |
tree |
the tree number, starting from 1 to |
... |
... |
A data.frame with columns: Node (depth, BFS), NodeType (Split=1, Leaf=-1), SplitVar, SplitValue, LeftNode, RightNode, N (sample count). Model-specific columns include YAvg (regression), Prob (classification), or Hazard/SurvProb (survival).
set.seed(42) x <- matrix(rnorm(300 * 5), ncol = 5) y <- rowSums(x[, 1:2]) + rnorm(300) fit <- RLT(x, y, ntrees = 50) get.one.tree(fit, tree = 1)set.seed(42) x <- matrix(rnorm(300 * 5), ncol = 5) y <- rowSums(x[, 1:2]) + rnorm(300) fit <- RLT(x, y, ntrees = 50) get.one.tree(fit, tree = 1)
Calculate the survival function (two-sided) confidence band from a RLT survival prediction.
get.surv.band( x, i = 0, alpha = 0.05, approach = "naive", nsim = 5000, k_rank = 10, k_mode = c("fixed", "proportion"), k_prop = 0.99, ... )get.surv.band( x, i = 0, alpha = 0.05, approach = "naive", nsim = 5000, k_rank = 10, k_mode = c("fixed", "proportion"), k_prop = 0.99, ... )
x |
A RLT prediction object. Must be from a forest with var.mode != "none". |
i |
Observation number in the prediction. Default to calculate all (i = 0). |
alpha |
alpha level for interval (alpha/2, 1 - alpha/2). |
approach |
Confidence band approach:
|
nsim |
Number of simulations for estimating the Monte Carlo critical value. |
k_rank |
Rank truncation K used for the smooth low-rank covariance AND GAM basis size. |
k_mode |
Rank selection mode: "fixed" (use k_rank) or "proportion" (auto-select by eigenvalue cumulative ratio). |
k_prop |
Proportion threshold (0,1] for cumulative eigenvalue ratio when k_mode = "proportion". |
... |
Further arguments (currently not used). |
An object of class c("RLT", "band", "surv") with components: lower (lower bound), upper (upper bound), and timepoints (evaluation grid). If i = 0, a list of such objects for all observations.
set.seed(42) n <- 100 x <- matrix(rnorm(n * 5), ncol = 5) y <- rexp(n, rate = exp(rowSums(x[, 1:2]))) censor <- rbinom(n, 1, 0.7) fit <- RLT(x, y, censor = censor, model = "survival", ntrees = 200, var.mode = "matched") pred <- predict(fit, testx = x[1:3, ], var.est = TRUE) band <- get.surv.band(pred, i = 1, alpha = 0.05)set.seed(42) n <- 100 x <- matrix(rnorm(n * 5), ncol = 5) y <- rexp(n, rate = exp(rowSums(x[, 1:2]))) censor <- rbinom(n, 1, 0.7) fit <- RLT(x, y, censor = censor, model = "survival", ntrees = 200, var.mode = "matched") pred <- predict(fit, testx = x[1:3, ], var.est = TRUE) band <- get.surv.band(pred, i = 1, alpha = 0.05)
Extract variable importance from a fitted RLT model.
When variance estimation was enabled via var.mode,
standard deviations, Z-scores, and significance codes are
also reported. Negative variance estimates yield NA
for SD, Z, and significance.
importance(object, ...)importance(object, ...)
object |
A fitted |
... |
Additional arguments (unused). |
A data.frame with columns:
Variable: variable name
VI: variable importance
SD: standard deviation of VI (NA if not estimated or negative variance)
Z: Z-score (VI / SD, NA if SD is NA)
Sig: significance code ("" if not estimated or negative variance)
Significance codes: *** |Z| >= 2.58, ** |Z| >= 1.96, * |Z| >= 1.64.
## Not run: fit <- RLT(x, y, model = "classification", importance = TRUE, var.mode = TRUE) importance(fit) ## End(Not run)## Not run: fit <- RLT(x, y, model = "classification", importance = TRUE, var.mode = TRUE) importance(fit) ## End(Not run)
Predict the outcome (regression, classification or survival) using a fitted RLT object
## S3 method for class 'RLT' predict( object, testx = NULL, var.est = FALSE, var.mode = NULL, keep.all = FALSE, ncores = 1, verbose = 0, band.grid.size = 0, ... )## S3 method for class 'RLT' predict( object, testx = NULL, var.est = FALSE, var.mode = NULL, keep.all = FALSE, ncores = 1, verbose = 0, band.grid.size = 0, ... )
object |
A fitted RLT object |
testx |
The testing samples, must have the same structure as the training samples |
var.est |
Whether to estimate the variance of each testing data.
The original forest must be fitted with |
var.mode |
Variance estimation mode for prediction. Can be |
keep.all |
whether to keep the prediction from all trees. Warning: this can occupy a large storage space, especially in survival model |
ncores |
number of cores |
verbose |
print additional information |
band.grid.size |
An integer specifying the number of time points for confidence band calculation. Default is 0, which uses all unique failure time points. If a positive integer is provided, a subset of time points will be selected using quantiles, skipping the earliest 5% of time points to improve stability. |
... |
... |
A RLT prediction object, constructed as a list consisting
Prediction |
Prediction |
Variance |
if |
For Survival Forests
hazard |
predicted hazard functions |
CumHazard |
predicted cumulative hazard function |
Survival |
predicted survival function |
Allhazard |
if |
AllCHF |
if |
Cov |
if |
Var |
if |
timepoints |
ordered observed failure times from the training data |
MarginalVar |
if |
MarginalVarSmooth |
if |
CVproj |
if |
CVprojSmooth |
if |
set.seed(42) x <- matrix(rnorm(300 * 5), ncol = 5) y <- rowSums(x[, 1:2]) + rnorm(300) fit <- RLT(x, y, ntrees = 100) pred <- predict(fit, testx = x[1:5, ]) print(pred$Prediction)set.seed(42) x <- matrix(rnorm(300 * 5), ncol = 5) y <- rowSums(x[, 1:2]) + rnorm(300) fit <- RLT(x, y, ntrees = 100) pred <- predict(fit, testx = x[1:5, ]) print(pred$Prediction)
Print method for importance.RLT objects.
## S3 method for class 'importance.RLT' print(x, digits = 4, ...)## S3 method for class 'importance.RLT' print(x, digits = 4, ...)
x |
An |
digits |
Number of digits for formatting. Default: 4. |
... |
Additional arguments (unused). |
Print a RLT object
## S3 method for class 'RLT' print(x, ...)## S3 method for class 'RLT' print(x, ...)
x |
A fitted RLT object |
... |
... |
Fit models for regression, classification and survival
analysis using reinforced splitting rules. The model
fits regular random forest models by default unless the
parameter reinforcement is set to "TRUE". Using
reinforcement = TRUE activates embedded model for
splitting variable selection and allows linear combination
split. To specify parameters of embedded models, see
definition of param.control for details.
RLT( x, y, censor = NULL, model = NULL, ntrees = if (reinforcement) 100 else 500, mtry = max(1, as.integer(ncol(x)/2)), nmin = 5, alpha = 0, nsplit = 0, resample.replace = TRUE, resample.prob = if (resample.replace) 1 else 0.8, resample.preset = NULL, obs.w = NULL, var.prob = NULL, importance = FALSE, reinforcement = FALSE, linear.comb = 1, linear.comb.method = "default", split.rule = "default", var.mode = "none", param.control = list(), ncores = 0, verbose = 0, seed = NULL, ... )RLT( x, y, censor = NULL, model = NULL, ntrees = if (reinforcement) 100 else 500, mtry = max(1, as.integer(ncol(x)/2)), nmin = 5, alpha = 0, nsplit = 0, resample.replace = TRUE, resample.prob = if (resample.replace) 1 else 0.8, resample.preset = NULL, obs.w = NULL, var.prob = NULL, importance = FALSE, reinforcement = FALSE, linear.comb = 1, linear.comb.method = "default", split.rule = "default", var.mode = "none", param.control = list(), ncores = 0, verbose = 0, seed = NULL, ... )
x |
A |
y |
Response variable. a |
censor |
Censoring indicator if survival model is used. |
model |
The model type: |
ntrees |
Number of trees, |
mtry |
Number of randomly selected variables used at each internal node. Default: max(1, floor(p/2)). |
nmin |
Terminal node size. Splitting will stop when the internal
node size is less equal to |
alpha |
Minimum proportion of samples (of the parent node) enforced in each child node. Default: 0 (no constraint). Clamped to the range 0 to 0.5. |
nsplit |
Number of random cutting points to compare for each variable at an internal node. Default: 0 (use all unique values, i.e., best split). When nsplit > 0, random cutting points are generated. |
resample.replace |
Whether the in-bag samples are obtained with replacement. |
resample.prob |
Proportion of in-bag samples. |
resample.preset |
A pre-specified matrix for in-bag data indicator/count
matrix. It must be an |
obs.w |
Observation weights. The weights will be used for calculating
the splitting scores, such as a weighted variance reduction
or weighted gini index. But they will not be used for
sampling observations. In that case, one can pre-specify
|
var.prob |
Variable probabilities for split variable selection. A
numeric vector of length |
importance |
Whether to calculate variable importance measures. When
set to |
reinforcement |
Should reinforcement splitting rule be used. Default
is |
linear.comb |
Number of variables to combine in each linear combination split.
Default is 1 (standard axis-aligned splits). See also
|
linear.comb.method |
Method for constructing linear combinations:
|
split.rule |
Splitting criterion. Default |
var.mode |
Variance estimation mode. Default is |
param.control |
A list of additional parameters. This can be used to
specify other features in a random forest or set embedded
model parameters for reinforcement splitting rules.
Using
See \code{linear.comb} and \code{linear.comb.method} under
\code{param.control} documentation above.
\code{split.rule} specifies the splitting criterion for each model type.
\itemize{
\item \strong{Regression}: \code{"var"} (variance reduction, default and only option)
\item \strong{Classification}: \code{"gini"} (Gini index, default and only option)
\item \strong{Survival}: \code{"logrank"} (default), \code{"suplogrank"}, \code{"coxgrad"}
}
Internally mapped to integers: var=1, gini=1, logrank=1, suplogrank=2, coxgrad=3.
\code{resample.track} indicates whether to keep track
of which observations are used in each tree. This is
required for variance estimation (via \code{var.mode}).
\code{var.mode} specifies the variance estimation method
to prepare during model fitting. Currently available methods:
\itemize{
\item \code{"none"} (default): No variance estimation.
\item \code{"matched"}: Uses matched-sample U-statistic
decomposition (Xu, Zhu & Shao, 2023) for prediction
variance and variable importance variance. Also used for
confidence band in survival models (Formentini, Liang & Zhu, 2023).
}
Specifying \code{var.mode = TRUE} is equivalent to
\code{var.mode = "matched"}.
When \code{var.mode} is not \code{"none"}, the following
parameters are automatically adjusted if not already set:
\itemize{
\item \code{resample.preset} is constructed automatically
\item \code{resample.replace} is set to \code{FALSE}
\item \code{resample.prob} is set to 0.5
\item \code{resample.track} is set to \code{TRUE}
\item \code{importance} is set to \code{"distribute"}
}
It is recommended to use a very large \code{ntrees},
e.g, 10000 or larger. For \code{resample.prob} greater
than 0.5, one should consider the bootstrap
approach in Xu, Zhu & Shao (2023).
\\code{time.grid.size} specifies the number of unique
time points used for survival estimation. By default
(0), all observed failure times are used. Setting a
smaller number (e.g., 50) can speed up computation
for large datasets. The time points are selected at
evenly spaced quantiles of the observed failure times,
always including the minimum and maximum failure times.
|
ncores |
Number of CPU logical cores. Default is 0 (using all available cores). |
verbose |
Whether info should be printed. |
seed |
Random seed number to replicate a previously fitted forest.
Internally, the |
... |
Additional arguments. |
A RLT fitted object, constructed as a list consisting
FittedForest |
Fitted tree structures |
VarImp |
Variable importance measures, if |
Prediction |
Out-of-bag prediction |
Error |
Out-of-bag prediction error, adaptive to the model type |
ObsTrack |
Provided if |
For classification forests, these items are further provided or will replace the regression version
NClass |
The number of classes |
Prob |
Out-of-bag predicted probability |
For survival forests, these items are further provided or will replace the regression version
timepoints |
ordered observed failure times |
NFail |
The number of observed failure times |
Prediction |
Out-of-bag prediction of hazard function |
Zhu, R., Zeng, D., & Kosorok, M. R. (2015) "Reinforcement Learning Trees." Journal of the American Statistical Association. 110(512), 1770-1784.
Xu, T., Zhu, R., & Shao, X. (2023) "On Variance Estimation of Random Forests with Infinite-Order U-statistics." arXiv preprint arXiv:2202.09008.
Formentini, S. E., Wei L., & Zhu, R. (2022) "Confidence Band Estimation for Survival Random Forests." arXiv preprint arXiv:2204.12038.
set.seed(42) x <- matrix(rnorm(300 * 5), ncol = 5) y <- rowSums(x[, 1:2]) + rnorm(300) fit <- RLT(x, y, ntrees = 100) print(fit)set.seed(42) x <- matrix(rnorm(300 * 5), ncol = 5) y <- rowSums(x[, 1:2]) + rnorm(300) fit <- RLT(x, y, ntrees = 100) print(fit)