From b0ccc4eba8c8c303482d0ee33c63f4551f961337 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 13 May 2025 12:13:02 -0500 Subject: [PATCH 1/5] Fixing bugs in BCF predict and sampler reload functionality --- R/bcf.R | 35 ++-- tools/debug/bart_continue_sampler_debug.R | 84 +++++++++ tools/debug/bcf_401k_data_debug.R | 210 ++++++++++++++++++++++ tools/debug/bcf_continue_sampler_debug.R | 92 ++++++++++ 4 files changed, 406 insertions(+), 15 deletions(-) create mode 100644 tools/debug/bart_continue_sampler_debug.R create mode 100644 tools/debug/bcf_401k_data_debug.R create mode 100644 tools/debug/bcf_continue_sampler_debug.R diff --git a/R/bcf.R b/R/bcf.R index b9842c5d..bcaab160 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1104,6 +1104,24 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id global_model_config$update_global_error_variance(current_sigma2) } } else if (has_prev_model) { + if (adaptive_coding) { + if (!is.null(previous_b_1_samples)) { + current_b_1 <- previous_b_1_samples[previous_model_warmstart_sample_num] + } + if (!is.null(previous_b_0_samples)) { + current_b_0 <- previous_b_0_samples[previous_model_warmstart_sample_num] + } + tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (has_test) { + tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) + } + y_before <- outcome_train$get_data() + var_y_before <- var(y_before) + cat("Var before = ", var_y_before, "\n") resetActiveForest(active_forest_mu, previous_forest_samples_mu, previous_model_warmstart_sample_num - 1) resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) resetActiveForest(active_forest_tau, previous_forest_samples_tau, previous_model_warmstart_sample_num - 1) @@ -1122,21 +1140,6 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) } - if (adaptive_coding) { - if (!is.null(previous_b_1_samples)) { - current_b_1 <- previous_b_1_samples[previous_model_warmstart_sample_num] - } - if (!is.null(previous_b_0_samples)) { - current_b_0 <- previous_b_0_samples[previous_model_warmstart_sample_num] - } - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (has_test) { - tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) - } if (has_rfx) { if (is.null(previous_rfx_samples)) { warning("`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started") @@ -1618,6 +1621,8 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU # Add propensities to covariate set if necessary if (object$model_params$propensity_covariate != "none") { X_combined <- cbind(X, propensity) + } else { + X_combined <- X } # Create prediction datasets diff --git a/tools/debug/bart_continue_sampler_debug.R b/tools/debug/bart_continue_sampler_debug.R new file mode 100644 index 00000000..b921f979 --- /dev/null +++ b/tools/debug/bart_continue_sampler_debug.R @@ -0,0 +1,84 @@ +# Load libraries +library(stochtree) + +# Sampler settings +num_chains <- 1 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 20 +num_trees <- 100 + +# Generate the data +n <- 500 +p_x <- 10 +snr <- 2 +X <- matrix(runif(n*p_x), ncol = p_x) +f_XW <- sin(4*pi*X[,1]) + cos(4*pi*X[,2]) + sin(4*pi*X[,3]) +cos(4*pi*X[,4]) +noise_sd <- sd(f_XW) / snr +y <- f_XW + rnorm(n, 0, 1)*noise_sd + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +y_test <- y[test_inds] +y_train <- y[train_inds] +f_XW_test <- f_XW[test_inds] +f_XW_train <- f_XW[train_inds] + +# Run the GFR algorithm +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(num_trees = num_trees, alpha = 0.95, + beta = 2.0, max_depth = -1, + min_samples_leaf = 1, + sample_sigma2_leaf = F, + sigma2_leaf_init = 1.0/num_trees) +xbart_model <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = num_gfr, num_burnin = 0, num_mcmc = 0, + general_params = general_params, + mean_forest_params = mean_forest_params +) + +# Inspect results +plot(rowMeans(xbart_model$y_hat_test), y_test); abline(0,1) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(xbart_model$y_hat_test) - y_test)^2)), "\n")) +cat(paste0("Interval coverage = ", mean((apply(xbart_model$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(xbart_model$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n")) +plot(xbart_model$sigma2_global_samples) +xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model) + +# Run the BART MCMC sampler, initialized from the XBART sampler +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(num_trees = num_trees, alpha = 0.95, + beta = 2.0, max_depth = -1, + min_samples_leaf = 1, + sample_sigma2_leaf = F, + sigma2_leaf_init = 1.0/num_trees) +bart_model <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, mean_forest_params = mean_forest_params, + previous_model_json = xbart_model_string, + previous_model_warmstart_sample_num = num_gfr +) + +# Inspect the results +plot(rowMeans(bart_model$y_hat_test), y_test); abline(0,1) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(bart_model$y_hat_test) - y_test)^2)), "\n")) +cat(paste0("Interval coverage = ", mean((apply(bart_model$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(bart_model$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n")) +plot(bart_model$sigma2_global_samples) + +# Compare to a single chain of MCMC samples initialized at root +bart_model_root <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, + general_params = general_params, mean_forest_params = mean_forest_params +) +plot(rowMeans(bart_model_root$y_hat_test), y_test); abline(0,1) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(bart_model_root$y_hat_test) - y_test)^2)), "\n")) +cat(paste0("Interval coverage = ", mean((apply(bart_model_root$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(bart_model_root$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n")) +plot(bart_model_root$sigma2_global_samples) diff --git a/tools/debug/bcf_401k_data_debug.R b/tools/debug/bcf_401k_data_debug.R new file mode 100644 index 00000000..29cd478e --- /dev/null +++ b/tools/debug/bcf_401k_data_debug.R @@ -0,0 +1,210 @@ +################################################################################ +## Investigation of GFR vs MCMC fit issues on the 401k dataset +################################################################################ + +# Load libraries and set seed +library(stochtree) +library(DoubleML) +library(BART) +library(tidyverse) +# seed = 102 +# set.seed(seed) + +# Load 401k data +dat = DoubleML::fetch_401k(return_type = "data.frame") +dat_orig = dat + +# Trim outliers +dat = dat %>% filter(abs(inc)% dplyr::select(-c(e401, net_tfa)) + +# Convert to df and define categorical data types +xdf = data.frame(x) +xdf_st = xdf %>% + mutate(age=factor(age, ordered=TRUE), + inc = factor(inc, ordered=TRUE), + educ = factor(educ, ordered=TRUE), + fsize = factor(fsize, ordered=TRUE), + marr=factor(marr, ordered=TRUE), + twoearn=factor(twoearn, ordered=TRUE), + db=factor(db, ordered=TRUE), + pira=factor(pira, ordered=TRUE), + hown=factor(hown, ordered=TRUE)) + +# Isolate treatment and outcome +z = dat %>% dplyr::select(e401) %>% as.matrix() +y = dat %>% dplyr::select(net_tfa) %>% as.matrix() + +# Define a "jittered" version of the original (integer-valued) x columns +# in which all categories are "upper-jittered" with uniform [0, eps] noise +# except for the largest category which is "lower-jittered" with [-eps, 0] noise +x_jitter = x +for (j in 1:ncol(x)) { + min_diff <- min(diff(sort(x[,j]))[diff(sort(x[,j])) > 0]) + jitter_param <- min_diff / 3.0 + has_max_category <- x[,j] == max(x[,j]) + x_jitter[has_max_category,j] <- x[has_max_category,j] + runif(sum(has_max_category), -jitter_param, 0.0) + x_jitter[!has_max_category,j] <- x[!has_max_category,j] + runif(sum(!has_max_category), 0.0, jitter_param) +} +# Visualize jitters +# for (j in 1:ncol(x)) { +# plot(x[,j], x_jitter[,j], ylab = "jittered", xlab = "original") +# unique_xs <- unique(x[,j]) +# for (i in unique_xs) { +# abline(h = unique_xs[i], col = "red", lty = 3) +# } +# } + +# Fit a p(z = 1 | x) model for propensity features +ps_fit = pbart(x.train = xdf, + y.train = z, ntree = 200, numcut=1000, ndpost = 100, + usequants = TRUE, k = 2.0, nskip = 100, keepevery=1) +g = colMeans(pnorm(ps_fit$yhat.train)) +psf = pnorm(ps_fit$yhat.train) + +# Test-train split +n <- nrow(x) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +xdf_st_test <- xdf_st[test_inds,] +xdf_st_train <- xdf_st[train_inds,] +x_test <- x[test_inds,] +x_train <- x[train_inds,] +x_jitter_test <- x_jitter[test_inds,] +x_jitter_train <- x_jitter[train_inds,] +pi_test <- g[test_inds] +pi_train <- g[train_inds] +z_test <- z[test_inds,] +z_train <- z[train_inds,] +y_test <- y[test_inds,] +y_train <- y[train_inds,] +y_train_scale <- scale(y_train) +y_train_sd <- attr(y_train_scale, "scaled:scale") +y_train_mean <- attr(y_train_scale, "scaled:center") +y_test_scale <- (y_test - y_train_mean) / y_train_sd +var(y_train_scale) +var(y_test_scale) + +# Fit BCF with GFR algorithm on the jittered covariates +# and save model to JSON +num_gfr <- 1000 +general_params <- list( + adaptive_coding = FALSE, propensity_covariate = "none", + keep_every = 1, verbose = TRUE, keep_gfr = TRUE +) +bcf_model_gfr <- stochtree::bcf( + X_train = xdf_st_train, Z_train = c(z_train), + y_train = c(y_train_scale), propensity_train = pi_train, + X_test = xdf_st_test, Z_test = c(z_test), + propensity_test = pi_test, num_gfr = num_gfr, num_burnin = 0, + num_mcmc = 0, general_params = general_params +) +fit_json_gfr = saveBCFModelToJsonString(bcf_model_gfr) + +# Run MCMC chain from the last GFR sample, setting covariate +# equal to an interpolation between the original x and x_jitter +# (alpha = 0 is 100% x_jitter and alpha = 1 is 100% x) +# alpha <- 1.0 +# x_jitter_new_train <- (alpha) * x_train + (1-alpha) * x_jitter_train +# x_jitter_new_test <- (alpha) * x_test + (1-alpha) * x_jitter_test +x_jitter_new_train <- xdf_st_train +x_jitter_new_test <- xdf_st_test +num_mcmc <- 10000 +bcf_model_mcmc <- stochtree::bcf( + X_train = x_jitter_new_train, Z_train = c(z_train), + y_train = c(y_train_scale), propensity_train = pi_train, + X_test = x_jitter_new_test, Z_test = c(z_test), + propensity_test = pi_test, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, + previous_model_json = fit_json_gfr, + previous_model_warmstart_sample_num = num_gfr, + general_params = general_params +) + +# Inspect the "in-sample sigma" via the traceplot +# of the global error variance parameter +combined_sigma <- c(bcf_model_gfr$sigma2_global_samples, + bcf_model_mcmc$sigma2_global_samples) +plot(combined_sigma, ylab = "sigma2", xlab = "sample num", + main = "Global error var traceplot") + +# Inspect the "out-of-sample sigma" by compute the MSE +# of the yhat on the test set +yhat_combined_train <- cbind( + bcf_model_gfr$y_hat_train, + bcf_model_mcmc$y_hat_train +) +yhat_combined_test <- cbind( + bcf_model_gfr$y_hat_test, + bcf_model_mcmc$y_hat_test +) +num_samples <- ncol(yhat_combined_train) +train_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + train_mses[i] <- mean((yhat_combined_train[,i] - y_train_scale)^2) +} +test_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + test_mses[i] <- mean((yhat_combined_test[,i] - y_test_scale)^2) +} +max_y <- max(c(max(train_mses, test_mses))) +min_y <- min(c(min(train_mses, test_mses))) +plot(test_mses, ylab = "outcome MSE", xlab = "sample num", + main = "Outcome MSE Traceplot", ylim = c(min_y, max_y)) +points(train_mses, col = "blue") +legend("right", legend = c("Out-of-Sample", "In-Sample"), + col = c("black", "blue"), pch = c(1,1)) + +# Run some one-off pred vs actual plots +plot(yhat_combined[,11000], y_test_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_mcmc$y_hat_train[,10000], y_train_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_mcmc$y_hat_test[,10000], y_test_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_gfr$y_hat_train[,1000], y_train_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_gfr$y_hat_test[,1000], y_test_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_gfr$y_hat_train[,10], y_train_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_gfr$y_hat_test[,10], y_test_scale); abline(0,1,col="red",lty=3) + +# Run MCMC chain from root +num_mcmc <- 10000 +bcf_model_mcmc_root <- stochtree::bcf( + X_train = xdf_st_train, Z_train = c(z_train), + y_train = c(y_train_scale), propensity_train = pi_train, + X_test = xdf_st_test, Z_test = c(z_test), + propensity_test = pi_test, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, + general_params = general_params +) + +# Inspect the "in-sample sigma" via the traceplot +# of the global error variance parameter +sigma_trace <- bcf_model_mcmc_root$sigma2_global_samples +plot(sigma_trace, ylab = "sigma2", xlab = "sample num", + main = "Global error var traceplot") + +# Inspect the "out-of-sample sigma" by compute the MSE +# of the yhat on the test set +yhat_combined_train <- cbind( + bcf_model_mcmc_root$y_hat_train +) +yhat_combined_test <- cbind( + bcf_model_mcmc_root$y_hat_test +) +num_samples <- ncol(yhat_combined_train) +train_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + train_mses[i] <- mean((yhat_combined_train[,i] - y_train_scale)^2) +} +test_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + test_mses[i] <- mean((yhat_combined_test[,i] - y_test_scale)^2) +} +max_y <- max(c(max(train_mses, test_mses))) +min_y <- min(c(min(train_mses, test_mses))) +plot(test_mses, ylab = "outcome MSE", xlab = "sample num", + main = "Test set outcome MSEs", ylim = c(min_y, max_y)) +points(train_mses, col = "blue") diff --git a/tools/debug/bcf_continue_sampler_debug.R b/tools/debug/bcf_continue_sampler_debug.R new file mode 100644 index 00000000..1c46db5c --- /dev/null +++ b/tools/debug/bcf_continue_sampler_debug.R @@ -0,0 +1,92 @@ +# Load libraries +library(stochtree) + +# Sampler settings +num_chains <- 1 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 20 +num_trees <- 100 + +# Generate the data +n <- 500 +p <- 5 +snr <- 2 +X <- matrix(runif(n*p), ncol = p) +mu_x <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +pi_x <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) +) +tau_x <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) +) +Z <- rbinom(n, 1, pi_x) +f_XZ <- mu_x + tau_x*Z +noise_sd <- sd(f_XZ) / snr +y <- f_XZ + rnorm(n, 0, 1)*noise_sd + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] + +# Run the GFR algorithm +general_params <- list(sample_sigma2_global = T) +xbcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = num_gfr, num_burnin = 0, + num_mcmc = 0, general_params = general_params) + +# Inspect results +plot(rowMeans(xbcf_model$y_hat_test), y_test); abline(0,1) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(xbcf_model$y_hat_test) - y_test)^2)), "\n")) +plot(xbcf_model$sigma2_global_samples) +xbcf_model_string <- stochtree::saveBCFModelToJsonString(xbcf_model) + +# Run the BCF MCMC sampler, initialized from the XBART sampler +general_params <- list(sample_sigma2_global = T) +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = num_burnin, + num_mcmc = num_mcmc, general_params = general_params, + previous_model_json = xbcf_model_string, + previous_model_warmstart_sample_num = num_gfr) + +# Inspect the results +plot(rowMeans(bcf_model$y_hat_test), y_test); abline(0,1) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(bcf_model$y_hat_test) - y_test)^2)), "\n")) +plot(bcf_model$sigma2_global_samples) + +# Compare to a single chain of MCMC samples initialized at root +bcf_model_root <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = num_burnin, + num_mcmc = num_mcmc, general_params = general_params) +plot(rowMeans(bcf_model_root$y_hat_test), y_test); abline(0,1) +plot(bcf_model_root$sigma2_global_samples) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(bcf_model_root$y_hat_test) - y_test)^2)), "\n")) From 9dd22d42ca03e5da3c55893294ac132b023d5c5a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 13 May 2025 12:13:59 -0500 Subject: [PATCH 2/5] Removed debug statements from BCF code --- R/bcf.R | 3 --- 1 file changed, 3 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index bcaab160..cb441d36 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1119,9 +1119,6 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } - y_before <- outcome_train$get_data() - var_y_before <- var(y_before) - cat("Var before = ", var_y_before, "\n") resetActiveForest(active_forest_mu, previous_forest_samples_mu, previous_model_warmstart_sample_num - 1) resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) resetActiveForest(active_forest_tau, previous_forest_samples_tau, previous_model_warmstart_sample_num - 1) From d5406449c020b123c71e4b7d70a7e1536d3ab8ee Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 23:33:16 -0600 Subject: [PATCH 3/5] Updated BCF data checks on y_train and Z_train / Z_test --- R/bcf.R | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index 5a80d5ec..6f7105bb 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -786,6 +786,19 @@ bcf <- function( rfx_basis_test <- as.matrix(rfx_basis_test) } + # Convert y_train to a vector + if (is.matrix(y_train)) { + if (ncol(y_train) > 1) { + stop("y_train must either be a numeric vector of a one-column matrix") + } else { + y_train <- as.numeric(y_train) + } + } else { + if (!is.numeric(y_train)) { + stop("y_train must either be a numeric vector of a one-column matrix") + } + } + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- FALSE has_rfx_test <- FALSE @@ -808,17 +821,6 @@ bcf <- function( } } - # Check that outcome and treatment are numeric - if (!is.numeric(y_train)) { - stop("y_train must be numeric") - } - if (!is.numeric(Z_train)) { - stop("Z_train must be numeric") - } - if (!is.null(Z_test)) { - if (!is.numeric(Z_test)) stop("Z_test must be numeric") - } - # Data consistency checks if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { stop("X_train and X_test must have the same number of columns") From e9d28765ae53609b0c6099f704da5caecd4f70c0 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 23:34:50 -0600 Subject: [PATCH 4/5] Make sure debug scripts run (and reformat code) --- tools/debug/bart_continue_sampler_debug.R | 137 ++++++++---- tools/debug/bcf_401k_data_debug.R | 246 ++++++++++++++-------- tools/debug/bcf_continue_sampler_debug.R | 127 +++++++---- 3 files changed, 339 insertions(+), 171 deletions(-) diff --git a/tools/debug/bart_continue_sampler_debug.R b/tools/debug/bart_continue_sampler_debug.R index b921f979..e34b06fb 100644 --- a/tools/debug/bart_continue_sampler_debug.R +++ b/tools/debug/bart_continue_sampler_debug.R @@ -12,19 +12,22 @@ num_trees <- 100 n <- 500 p_x <- 10 snr <- 2 -X <- matrix(runif(n*p_x), ncol = p_x) -f_XW <- sin(4*pi*X[,1]) + cos(4*pi*X[,2]) + sin(4*pi*X[,3]) +cos(4*pi*X[,4]) +X <- matrix(runif(n * p_x), ncol = p_x) +f_XW <- sin(4 * pi * X[, 1]) + + cos(4 * pi * X[, 2]) + + sin(4 * pi * X[, 3]) + + cos(4 * pi * X[, 4]) noise_sd <- sd(f_XW) / snr -y <- f_XW + rnorm(n, 0, 1)*noise_sd +y <- f_XW + rnorm(n, 0, 1) * noise_sd # Split data into test and train sets test_set_pct <- 0.2 -n_test <- round(test_set_pct*n) +n_test <- round(test_set_pct * n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] -X_test <- as.data.frame(X[test_inds,]) -X_train <- as.data.frame(X[train_inds,]) +X_test <- as.data.frame(X[test_inds, ]) +X_train <- as.data.frame(X[train_inds, ]) y_test <- y[test_inds] y_train <- y[train_inds] f_XW_test <- f_XW[test_inds] @@ -32,53 +35,113 @@ f_XW_train <- f_XW[train_inds] # Run the GFR algorithm general_params <- list(sample_sigma2_global = T) -mean_forest_params <- list(num_trees = num_trees, alpha = 0.95, - beta = 2.0, max_depth = -1, - min_samples_leaf = 1, - sample_sigma2_leaf = F, - sigma2_leaf_init = 1.0/num_trees) +mean_forest_params <- list( + num_trees = num_trees, + alpha = 0.95, + beta = 2.0, + max_depth = -1, + min_samples_leaf = 1, + sample_sigma2_leaf = F, + sigma2_leaf_init = 1.0 / num_trees +) xbart_model <- stochtree::bart( - X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = num_gfr, num_burnin = 0, num_mcmc = 0, - general_params = general_params, - mean_forest_params = mean_forest_params + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = 0, + general_params = general_params, + mean_forest_params = mean_forest_params ) # Inspect results -plot(rowMeans(xbart_model$y_hat_test), y_test); abline(0,1) -cat(paste0("RMSE = ", sqrt(mean((rowMeans(xbart_model$y_hat_test) - y_test)^2)), "\n")) -cat(paste0("Interval coverage = ", mean((apply(xbart_model$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(xbart_model$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n")) +plot(rowMeans(xbart_model$y_hat_test), y_test) +abline(0, 1) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(xbart_model$y_hat_test) - y_test)^2)), + "\n" +)) +cat(paste0( + "Interval coverage = ", + mean( + (apply(xbart_model$y_hat_test, 1, quantile, probs = 0.025) <= f_XW_test) & + (apply(xbart_model$y_hat_test, 1, quantile, probs = 0.975) >= f_XW_test) + ), + "\n" +)) plot(xbart_model$sigma2_global_samples) xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model) # Run the BART MCMC sampler, initialized from the XBART sampler general_params <- list(sample_sigma2_global = T) -mean_forest_params <- list(num_trees = num_trees, alpha = 0.95, - beta = 2.0, max_depth = -1, - min_samples_leaf = 1, - sample_sigma2_leaf = F, - sigma2_leaf_init = 1.0/num_trees) +mean_forest_params <- list( + num_trees = num_trees, + alpha = 0.95, + beta = 2.0, + max_depth = -1, + min_samples_leaf = 1, + sample_sigma2_leaf = F, + sigma2_leaf_init = 1.0 / num_trees +) bart_model <- stochtree::bart( - X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, - general_params = general_params, mean_forest_params = mean_forest_params, - previous_model_json = xbart_model_string, - previous_model_warmstart_sample_num = num_gfr + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + mean_forest_params = mean_forest_params, + previous_model_json = xbart_model_string, + previous_model_warmstart_sample_num = num_gfr ) # Inspect the results -plot(rowMeans(bart_model$y_hat_test), y_test); abline(0,1) -cat(paste0("RMSE = ", sqrt(mean((rowMeans(bart_model$y_hat_test) - y_test)^2)), "\n")) -cat(paste0("Interval coverage = ", mean((apply(bart_model$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(bart_model$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n")) +plot(rowMeans(bart_model$y_hat_test), y_test) +abline(0, 1) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(bart_model$y_hat_test) - y_test)^2)), + "\n" +)) +cat(paste0( + "Interval coverage = ", + mean( + (apply(bart_model$y_hat_test, 1, quantile, probs = 0.025) <= f_XW_test) & + (apply(bart_model$y_hat_test, 1, quantile, probs = 0.975) >= f_XW_test) + ), + "\n" +)) plot(bart_model$sigma2_global_samples) # Compare to a single chain of MCMC samples initialized at root bart_model_root <- stochtree::bart( - X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, - general_params = general_params, mean_forest_params = mean_forest_params + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_params, + mean_forest_params = mean_forest_params ) -plot(rowMeans(bart_model_root$y_hat_test), y_test); abline(0,1) -cat(paste0("RMSE = ", sqrt(mean((rowMeans(bart_model_root$y_hat_test) - y_test)^2)), "\n")) -cat(paste0("Interval coverage = ", mean((apply(bart_model_root$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(bart_model_root$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n")) +plot(rowMeans(bart_model_root$y_hat_test), y_test) +abline(0, 1) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(bart_model_root$y_hat_test) - y_test)^2)), + "\n" +)) +cat(paste0( + "Interval coverage = ", + mean( + (apply(bart_model_root$y_hat_test, 1, quantile, probs = 0.025) <= + f_XW_test) & + (apply(bart_model_root$y_hat_test, 1, quantile, probs = 0.975) >= + f_XW_test) + ), + "\n" +)) plot(bart_model_root$sigma2_global_samples) diff --git a/tools/debug/bcf_401k_data_debug.R b/tools/debug/bcf_401k_data_debug.R index 29cd478e..a9abdada 100644 --- a/tools/debug/bcf_401k_data_debug.R +++ b/tools/debug/bcf_401k_data_debug.R @@ -5,7 +5,6 @@ # Load libraries and set seed library(stochtree) library(DoubleML) -library(BART) library(tidyverse) # seed = 102 # set.seed(seed) @@ -15,23 +14,25 @@ dat = DoubleML::fetch_401k(return_type = "data.frame") dat_orig = dat # Trim outliers -dat = dat %>% filter(abs(inc)% filter(abs(inc) < quantile(abs(inc), 0.9)) # Isolate covariates and convert to df x = dat %>% dplyr::select(-c(e401, net_tfa)) # Convert to df and define categorical data types xdf = data.frame(x) -xdf_st = xdf %>% - mutate(age=factor(age, ordered=TRUE), - inc = factor(inc, ordered=TRUE), - educ = factor(educ, ordered=TRUE), - fsize = factor(fsize, ordered=TRUE), - marr=factor(marr, ordered=TRUE), - twoearn=factor(twoearn, ordered=TRUE), - db=factor(db, ordered=TRUE), - pira=factor(pira, ordered=TRUE), - hown=factor(hown, ordered=TRUE)) +xdf_st = xdf %>% + mutate( + age = factor(age, ordered = TRUE), + inc = factor(inc, ordered = TRUE), + educ = factor(educ, ordered = TRUE), + fsize = factor(fsize, ordered = TRUE), + marr = factor(marr, ordered = TRUE), + twoearn = factor(twoearn, ordered = TRUE), + db = factor(db, ordered = TRUE), + pira = factor(pira, ordered = TRUE), + hown = factor(hown, ordered = TRUE) + ) # Isolate treatment and outcome z = dat %>% dplyr::select(e401) %>% as.matrix() @@ -42,11 +43,13 @@ y = dat %>% dplyr::select(net_tfa) %>% as.matrix() # except for the largest category which is "lower-jittered" with [-eps, 0] noise x_jitter = x for (j in 1:ncol(x)) { - min_diff <- min(diff(sort(x[,j]))[diff(sort(x[,j])) > 0]) - jitter_param <- min_diff / 3.0 - has_max_category <- x[,j] == max(x[,j]) - x_jitter[has_max_category,j] <- x[has_max_category,j] + runif(sum(has_max_category), -jitter_param, 0.0) - x_jitter[!has_max_category,j] <- x[!has_max_category,j] + runif(sum(!has_max_category), 0.0, jitter_param) + min_diff <- min(diff(sort(x[, j]))[diff(sort(x[, j])) > 0]) + jitter_param <- min_diff / 3.0 + has_max_category <- x[, j] == max(x[, j]) + x_jitter[has_max_category, j] <- x[has_max_category, j] + + runif(sum(has_max_category), -jitter_param, 0.0) + x_jitter[!has_max_category, j] <- x[!has_max_category, j] + + runif(sum(!has_max_category), 0.0, jitter_param) } # Visualize jitters # for (j in 1:ncol(x)) { @@ -58,55 +61,77 @@ for (j in 1:ncol(x)) { # } # Fit a p(z = 1 | x) model for propensity features -ps_fit = pbart(x.train = xdf, - y.train = z, ntree = 200, numcut=1000, ndpost = 100, - usequants = TRUE, k = 2.0, nskip = 100, keepevery=1) -g = colMeans(pnorm(ps_fit$yhat.train)) -psf = pnorm(ps_fit$yhat.train) +general_params <- list( + probit_outcome_model = TRUE, + sample_sigma2_global = FALSE +) +mean_forest_params <- list( + num_trees = 200 +) +propensity_model <- bart( + X_train = xdf, + y_train = z, + general_params = general_params, + mean_forest_params = mean_forest_params +) +propensity = predict( + propensity_model, + X = xdf, + type = "mean", + terms = "y_hat", + scale = "probability" +) # Test-train split n <- nrow(x) test_set_pct <- 0.2 -n_test <- round(test_set_pct*n) +n_test <- round(test_set_pct * n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] -xdf_st_test <- xdf_st[test_inds,] -xdf_st_train <- xdf_st[train_inds,] -x_test <- x[test_inds,] -x_train <- x[train_inds,] -x_jitter_test <- x_jitter[test_inds,] -x_jitter_train <- x_jitter[train_inds,] -pi_test <- g[test_inds] -pi_train <- g[train_inds] -z_test <- z[test_inds,] -z_train <- z[train_inds,] -y_test <- y[test_inds,] -y_train <- y[train_inds,] +xdf_st_test <- xdf_st[test_inds, ] +xdf_st_train <- xdf_st[train_inds, ] +x_test <- x[test_inds, ] +x_train <- x[train_inds, ] +x_jitter_test <- x_jitter[test_inds, ] +x_jitter_train <- x_jitter[train_inds, ] +pi_test <- propensity[test_inds] +pi_train <- propensity[train_inds] +z_test <- z[test_inds, ] +z_train <- z[train_inds, ] +y_test <- y[test_inds, ] +y_train <- y[train_inds, ] y_train_scale <- scale(y_train) y_train_sd <- attr(y_train_scale, "scaled:scale") y_train_mean <- attr(y_train_scale, "scaled:center") y_test_scale <- (y_test - y_train_mean) / y_train_sd -var(y_train_scale) -var(y_test_scale) # Fit BCF with GFR algorithm on the jittered covariates # and save model to JSON num_gfr <- 1000 general_params <- list( - adaptive_coding = FALSE, propensity_covariate = "none", - keep_every = 1, verbose = TRUE, keep_gfr = TRUE + adaptive_coding = FALSE, + propensity_covariate = "none", + keep_every = 1, + verbose = TRUE, + keep_gfr = TRUE ) bcf_model_gfr <- stochtree::bcf( - X_train = xdf_st_train, Z_train = c(z_train), - y_train = c(y_train_scale), propensity_train = pi_train, - X_test = xdf_st_test, Z_test = c(z_test), - propensity_test = pi_test, num_gfr = num_gfr, num_burnin = 0, - num_mcmc = 0, general_params = general_params + X_train = xdf_st_train, + Z_train = z_train, + y_train = y_train_scale, + propensity_train = pi_train, + X_test = xdf_st_test, + Z_test = z_test, + propensity_test = pi_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = 0, + general_params = general_params ) fit_json_gfr = saveBCFModelToJsonString(bcf_model_gfr) -# Run MCMC chain from the last GFR sample, setting covariate +# Run MCMC chain from the last GFR sample, setting covariate # equal to an interpolation between the original x and x_jitter # (alpha = 0 is 100% x_jitter and alpha = 1 is 100% x) # alpha <- 1.0 @@ -116,95 +141,136 @@ x_jitter_new_train <- xdf_st_train x_jitter_new_test <- xdf_st_test num_mcmc <- 10000 bcf_model_mcmc <- stochtree::bcf( - X_train = x_jitter_new_train, Z_train = c(z_train), - y_train = c(y_train_scale), propensity_train = pi_train, - X_test = x_jitter_new_test, Z_test = c(z_test), - propensity_test = pi_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, - previous_model_json = fit_json_gfr, - previous_model_warmstart_sample_num = num_gfr, - general_params = general_params + X_train = x_jitter_new_train, + Z_train = z_train, + y_train = y_train_scale, + propensity_train = pi_train, + X_test = x_jitter_new_test, + Z_test = z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + previous_model_json = fit_json_gfr, + previous_model_warmstart_sample_num = num_gfr, + general_params = general_params ) # Inspect the "in-sample sigma" via the traceplot # of the global error variance parameter -combined_sigma <- c(bcf_model_gfr$sigma2_global_samples, - bcf_model_mcmc$sigma2_global_samples) -plot(combined_sigma, ylab = "sigma2", xlab = "sample num", - main = "Global error var traceplot") +combined_sigma <- c( + bcf_model_gfr$sigma2_global_samples, + bcf_model_mcmc$sigma2_global_samples +) +plot( + combined_sigma, + ylab = "sigma2", + xlab = "sample num", + main = "Global error var traceplot" +) -# Inspect the "out-of-sample sigma" by compute the MSE +# Inspect the "out-of-sample sigma" by compute the MSE # of the yhat on the test set yhat_combined_train <- cbind( - bcf_model_gfr$y_hat_train, - bcf_model_mcmc$y_hat_train + bcf_model_gfr$y_hat_train, + bcf_model_mcmc$y_hat_train ) yhat_combined_test <- cbind( - bcf_model_gfr$y_hat_test, - bcf_model_mcmc$y_hat_test + bcf_model_gfr$y_hat_test, + bcf_model_mcmc$y_hat_test ) num_samples <- ncol(yhat_combined_train) train_mses <- rep(NA, num_samples) for (i in 1:num_samples) { - train_mses[i] <- mean((yhat_combined_train[,i] - y_train_scale)^2) + train_mses[i] <- mean((yhat_combined_train[, i] - y_train_scale)^2) } test_mses <- rep(NA, num_samples) for (i in 1:num_samples) { - test_mses[i] <- mean((yhat_combined_test[,i] - y_test_scale)^2) + test_mses[i] <- mean((yhat_combined_test[, i] - y_test_scale)^2) } max_y <- max(c(max(train_mses, test_mses))) min_y <- min(c(min(train_mses, test_mses))) -plot(test_mses, ylab = "outcome MSE", xlab = "sample num", - main = "Outcome MSE Traceplot", ylim = c(min_y, max_y)) +plot( + test_mses, + ylab = "outcome MSE", + xlab = "sample num", + main = "Outcome MSE Traceplot", + ylim = c(min_y, max_y) +) points(train_mses, col = "blue") -legend("right", legend = c("Out-of-Sample", "In-Sample"), - col = c("black", "blue"), pch = c(1,1)) +legend( + "right", + legend = c("Out-of-Sample", "In-Sample"), + col = c("black", "blue"), + pch = c(1, 1) +) # Run some one-off pred vs actual plots -plot(yhat_combined[,11000], y_test_scale); abline(0,1,col="red",lty=3) -plot(bcf_model_mcmc$y_hat_train[,10000], y_train_scale); abline(0,1,col="red",lty=3) -plot(bcf_model_mcmc$y_hat_test[,10000], y_test_scale); abline(0,1,col="red",lty=3) -plot(bcf_model_gfr$y_hat_train[,1000], y_train_scale); abline(0,1,col="red",lty=3) -plot(bcf_model_gfr$y_hat_test[,1000], y_test_scale); abline(0,1,col="red",lty=3) -plot(bcf_model_gfr$y_hat_train[,10], y_train_scale); abline(0,1,col="red",lty=3) -plot(bcf_model_gfr$y_hat_test[,10], y_test_scale); abline(0,1,col="red",lty=3) +plot(yhat_combined_test[, 11000], y_test_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_mcmc$y_hat_train[, 10000], y_train_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_mcmc$y_hat_test[, 10000], y_test_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_gfr$y_hat_train[, 1000], y_train_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_gfr$y_hat_test[, 1000], y_test_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_gfr$y_hat_train[, 10], y_train_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_gfr$y_hat_test[, 10], y_test_scale) +abline(0, 1, col = "red", lty = 3) # Run MCMC chain from root num_mcmc <- 10000 bcf_model_mcmc_root <- stochtree::bcf( - X_train = xdf_st_train, Z_train = c(z_train), - y_train = c(y_train_scale), propensity_train = pi_train, - X_test = xdf_st_test, Z_test = c(z_test), - propensity_test = pi_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, - general_params = general_params + X_train = xdf_st_train, + Z_train = z_train, + y_train = y_train_scale, + propensity_train = pi_train, + X_test = xdf_st_test, + Z_test = z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_params ) # Inspect the "in-sample sigma" via the traceplot # of the global error variance parameter sigma_trace <- bcf_model_mcmc_root$sigma2_global_samples -plot(sigma_trace, ylab = "sigma2", xlab = "sample num", - main = "Global error var traceplot") +plot( + sigma_trace, + ylab = "sigma2", + xlab = "sample num", + main = "Global error var traceplot" +) -# Inspect the "out-of-sample sigma" by compute the MSE +# Inspect the "out-of-sample sigma" by compute the MSE # of the yhat on the test set yhat_combined_train <- cbind( - bcf_model_mcmc_root$y_hat_train + bcf_model_mcmc_root$y_hat_train ) yhat_combined_test <- cbind( - bcf_model_mcmc_root$y_hat_test + bcf_model_mcmc_root$y_hat_test ) num_samples <- ncol(yhat_combined_train) train_mses <- rep(NA, num_samples) for (i in 1:num_samples) { - train_mses[i] <- mean((yhat_combined_train[,i] - y_train_scale)^2) + train_mses[i] <- mean((yhat_combined_train[, i] - y_train_scale)^2) } test_mses <- rep(NA, num_samples) for (i in 1:num_samples) { - test_mses[i] <- mean((yhat_combined_test[,i] - y_test_scale)^2) + test_mses[i] <- mean((yhat_combined_test[, i] - y_test_scale)^2) } max_y <- max(c(max(train_mses, test_mses))) min_y <- min(c(min(train_mses, test_mses))) -plot(test_mses, ylab = "outcome MSE", xlab = "sample num", - main = "Test set outcome MSEs", ylim = c(min_y, max_y)) +plot( + test_mses, + ylab = "outcome MSE", + xlab = "sample num", + main = "Test set outcome MSEs", + ylim = c(min_y, max_y) +) points(train_mses, col = "blue") diff --git a/tools/debug/bcf_continue_sampler_debug.R b/tools/debug/bcf_continue_sampler_debug.R index 1c46db5c..151b40b1 100644 --- a/tools/debug/bcf_continue_sampler_debug.R +++ b/tools/debug/bcf_continue_sampler_debug.R @@ -12,38 +12,35 @@ num_trees <- 100 n <- 500 p <- 5 snr <- 2 -X <- matrix(runif(n*p), ncol = p) -mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) -) -pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) -) -tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) -) +X <- matrix(runif(n * p), ncol = p) +mu_x <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) +pi_x <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) +tau_x <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * + (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) Z <- rbinom(n, 1, pi_x) -f_XZ <- mu_x + tau_x*Z +f_XZ <- mu_x + tau_x * Z noise_sd <- sd(f_XZ) / snr -y <- f_XZ + rnorm(n, 0, 1)*noise_sd +y <- f_XZ + rnorm(n, 0, 1) * noise_sd # Split data into test and train sets test_set_pct <- 0.2 -n_test <- round(test_set_pct*n) +n_test <- round(test_set_pct * n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] -X_test <- X[test_inds,] -X_train <- X[train_inds,] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] pi_test <- pi_x[test_inds] pi_train <- pi_x[train_inds] Z_test <- Z[test_inds] @@ -57,36 +54,78 @@ tau_train <- tau_x[train_inds] # Run the GFR algorithm general_params <- list(sample_sigma2_global = T) -xbcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = num_gfr, num_burnin = 0, - num_mcmc = 0, general_params = general_params) +xbcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = 0, + general_params = general_params +) # Inspect results -plot(rowMeans(xbcf_model$y_hat_test), y_test); abline(0,1) -cat(paste0("RMSE = ", sqrt(mean((rowMeans(xbcf_model$y_hat_test) - y_test)^2)), "\n")) +plot(rowMeans(xbcf_model$y_hat_test), y_test) +abline(0, 1) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(xbcf_model$y_hat_test) - y_test)^2)), + "\n" +)) plot(xbcf_model$sigma2_global_samples) xbcf_model_string <- stochtree::saveBCFModelToJsonString(xbcf_model) # Run the BCF MCMC sampler, initialized from the XBART sampler general_params <- list(sample_sigma2_global = T) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = num_burnin, - num_mcmc = num_mcmc, general_params = general_params, - previous_model_json = xbcf_model_string, - previous_model_warmstart_sample_num = num_gfr) +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + previous_model_json = xbcf_model_string, + previous_model_warmstart_sample_num = num_gfr +) # Inspect the results -plot(rowMeans(bcf_model$y_hat_test), y_test); abline(0,1) -cat(paste0("RMSE = ", sqrt(mean((rowMeans(bcf_model$y_hat_test) - y_test)^2)), "\n")) +plot(rowMeans(bcf_model$y_hat_test), y_test) +abline(0, 1) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(bcf_model$y_hat_test) - y_test)^2)), + "\n" +)) plot(bcf_model$sigma2_global_samples) # Compare to a single chain of MCMC samples initialized at root -bcf_model_root <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = num_burnin, - num_mcmc = num_mcmc, general_params = general_params) -plot(rowMeans(bcf_model_root$y_hat_test), y_test); abline(0,1) +bcf_model_root <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params +) +plot(rowMeans(bcf_model_root$y_hat_test), y_test) +abline(0, 1) plot(bcf_model_root$sigma2_global_samples) -cat(paste0("RMSE = ", sqrt(mean((rowMeans(bcf_model_root$y_hat_test) - y_test)^2)), "\n")) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(bcf_model_root$y_hat_test) - y_test)^2)), + "\n" +)) From 9969336b20cd05a241219600ffd983b1b79ee105 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 23:58:01 -0600 Subject: [PATCH 5/5] Updated BCF --- R/bcf.R | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/R/bcf.R b/R/bcf.R index 6f7105bb..75c9604d 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -767,6 +767,40 @@ bcf <- function( X_test <- preprocessPredictionData(X_test, X_train_metadata) } + # Check that outcome, treatment, and propensity are numeric before running + # further checks / transformations + if (!is.numeric(y_train)) { + stop("y_train must be numeric") + } + if (!is.numeric(Z_train)) { + stop("Z_train must be numeric") + } + if (!is.null(Z_test)) { + if (!is.numeric(Z_test)) { + stop("Z_test must be numeric") + } + } + if (!is.null(propensity_train)) { + if (!is.numeric(propensity_train)) { + stop("propensity_train must be numeric") + } + } + if (!is.null(propensity_test)) { + if (!is.numeric(propensity_test)) { + stop("propensity_test must be numeric") + } + } + if (!is.null(rfx_basis_train)) { + if (!is.numeric(rfx_basis_train)) { + stop("rfx_basis_train must be numeric") + } + } + if (!is.null(rfx_basis_test)) { + if (!is.numeric(rfx_basis_test)) { + stop("rfx_basis_test must be numeric") + } + } + # Convert all input data to matrices if not already converted Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train)) Z_train <- matrix(as.numeric(Z_train), ncol = Z_col)