Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 47 additions & 11 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,40 @@ bcf <- function(
X_test <- preprocessPredictionData(X_test, X_train_metadata)
}

# Check that outcome, treatment, and propensity are numeric before running
# further checks / transformations
if (!is.numeric(y_train)) {
stop("y_train must be numeric")
}
if (!is.numeric(Z_train)) {
stop("Z_train must be numeric")
}
if (!is.null(Z_test)) {
if (!is.numeric(Z_test)) {
stop("Z_test must be numeric")
}
}
if (!is.null(propensity_train)) {
if (!is.numeric(propensity_train)) {
stop("propensity_train must be numeric")
}
}
if (!is.null(propensity_test)) {
if (!is.numeric(propensity_test)) {
stop("propensity_test must be numeric")
}
}
if (!is.null(rfx_basis_train)) {
if (!is.numeric(rfx_basis_train)) {
stop("rfx_basis_train must be numeric")
}
}
if (!is.null(rfx_basis_test)) {
if (!is.numeric(rfx_basis_test)) {
stop("rfx_basis_test must be numeric")
}
}

# Convert all input data to matrices if not already converted
Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train))
Z_train <- matrix(as.numeric(Z_train), ncol = Z_col)
Expand All @@ -786,6 +820,19 @@ bcf <- function(
rfx_basis_test <- as.matrix(rfx_basis_test)
}

# Convert y_train to a vector
if (is.matrix(y_train)) {
if (ncol(y_train) > 1) {
stop("y_train must either be a numeric vector of a one-column matrix")
} else {
y_train <- as.numeric(y_train)
}
} else {
if (!is.numeric(y_train)) {
stop("y_train must either be a numeric vector of a one-column matrix")
}
}

# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
has_rfx <- FALSE
has_rfx_test <- FALSE
Expand All @@ -808,17 +855,6 @@ bcf <- function(
}
}

# Check that outcome and treatment are numeric
if (!is.numeric(y_train)) {
stop("y_train must be numeric")
}
if (!is.numeric(Z_train)) {
stop("Z_train must be numeric")
}
if (!is.null(Z_test)) {
if (!is.numeric(Z_test)) stop("Z_test must be numeric")
}

# Data consistency checks
if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) {
stop("X_train and X_test must have the same number of columns")
Expand Down
147 changes: 147 additions & 0 deletions tools/debug/bart_continue_sampler_debug.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Load libraries
library(stochtree)

# Sampler settings
num_chains <- 1
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 20
num_trees <- 100

# Generate the data
n <- 500
p_x <- 10
snr <- 2
X <- matrix(runif(n * p_x), ncol = p_x)
f_XW <- sin(4 * pi * X[, 1]) +
cos(4 * pi * X[, 2]) +
sin(4 * pi * X[, 3]) +
cos(4 * pi * X[, 4])
noise_sd <- sd(f_XW) / snr
y <- f_XW + rnorm(n, 0, 1) * noise_sd

# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct * n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- as.data.frame(X[test_inds, ])
X_train <- as.data.frame(X[train_inds, ])
y_test <- y[test_inds]
y_train <- y[train_inds]
f_XW_test <- f_XW[test_inds]
f_XW_train <- f_XW[train_inds]

# Run the GFR algorithm
general_params <- list(sample_sigma2_global = T)
mean_forest_params <- list(
num_trees = num_trees,
alpha = 0.95,
beta = 2.0,
max_depth = -1,
min_samples_leaf = 1,
sample_sigma2_leaf = F,
sigma2_leaf_init = 1.0 / num_trees
)
xbart_model <- stochtree::bart(
X_train = X_train,
y_train = y_train,
X_test = X_test,
num_gfr = num_gfr,
num_burnin = 0,
num_mcmc = 0,
general_params = general_params,
mean_forest_params = mean_forest_params
)

# Inspect results
plot(rowMeans(xbart_model$y_hat_test), y_test)
abline(0, 1)
cat(paste0(
"RMSE = ",
sqrt(mean((rowMeans(xbart_model$y_hat_test) - y_test)^2)),
"\n"
))
cat(paste0(
"Interval coverage = ",
mean(
(apply(xbart_model$y_hat_test, 1, quantile, probs = 0.025) <= f_XW_test) &
(apply(xbart_model$y_hat_test, 1, quantile, probs = 0.975) >= f_XW_test)
),
"\n"
))
plot(xbart_model$sigma2_global_samples)
xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model)

# Run the BART MCMC sampler, initialized from the XBART sampler
general_params <- list(sample_sigma2_global = T)
mean_forest_params <- list(
num_trees = num_trees,
alpha = 0.95,
beta = 2.0,
max_depth = -1,
min_samples_leaf = 1,
sample_sigma2_leaf = F,
sigma2_leaf_init = 1.0 / num_trees
)
bart_model <- stochtree::bart(
X_train = X_train,
y_train = y_train,
X_test = X_test,
num_gfr = 0,
num_burnin = num_burnin,
num_mcmc = num_mcmc,
general_params = general_params,
mean_forest_params = mean_forest_params,
previous_model_json = xbart_model_string,
previous_model_warmstart_sample_num = num_gfr
)

# Inspect the results
plot(rowMeans(bart_model$y_hat_test), y_test)
abline(0, 1)
cat(paste0(
"RMSE = ",
sqrt(mean((rowMeans(bart_model$y_hat_test) - y_test)^2)),
"\n"
))
cat(paste0(
"Interval coverage = ",
mean(
(apply(bart_model$y_hat_test, 1, quantile, probs = 0.025) <= f_XW_test) &
(apply(bart_model$y_hat_test, 1, quantile, probs = 0.975) >= f_XW_test)
),
"\n"
))
plot(bart_model$sigma2_global_samples)

# Compare to a single chain of MCMC samples initialized at root
bart_model_root <- stochtree::bart(
X_train = X_train,
y_train = y_train,
X_test = X_test,
num_gfr = 0,
num_burnin = 0,
num_mcmc = num_mcmc,
general_params = general_params,
mean_forest_params = mean_forest_params
)
plot(rowMeans(bart_model_root$y_hat_test), y_test)
abline(0, 1)
cat(paste0(
"RMSE = ",
sqrt(mean((rowMeans(bart_model_root$y_hat_test) - y_test)^2)),
"\n"
))
cat(paste0(
"Interval coverage = ",
mean(
(apply(bart_model_root$y_hat_test, 1, quantile, probs = 0.025) <=
f_XW_test) &
(apply(bart_model_root$y_hat_test, 1, quantile, probs = 0.975) >=
f_XW_test)
),
"\n"
))
plot(bart_model_root$sigma2_global_samples)
Loading
Loading