diff --git a/R/bcf.R b/R/bcf.R index 5a80d5ec..75c9604d 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -767,6 +767,40 @@ bcf <- function( X_test <- preprocessPredictionData(X_test, X_train_metadata) } + # Check that outcome, treatment, and propensity are numeric before running + # further checks / transformations + if (!is.numeric(y_train)) { + stop("y_train must be numeric") + } + if (!is.numeric(Z_train)) { + stop("Z_train must be numeric") + } + if (!is.null(Z_test)) { + if (!is.numeric(Z_test)) { + stop("Z_test must be numeric") + } + } + if (!is.null(propensity_train)) { + if (!is.numeric(propensity_train)) { + stop("propensity_train must be numeric") + } + } + if (!is.null(propensity_test)) { + if (!is.numeric(propensity_test)) { + stop("propensity_test must be numeric") + } + } + if (!is.null(rfx_basis_train)) { + if (!is.numeric(rfx_basis_train)) { + stop("rfx_basis_train must be numeric") + } + } + if (!is.null(rfx_basis_test)) { + if (!is.numeric(rfx_basis_test)) { + stop("rfx_basis_test must be numeric") + } + } + # Convert all input data to matrices if not already converted Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train)) Z_train <- matrix(as.numeric(Z_train), ncol = Z_col) @@ -786,6 +820,19 @@ bcf <- function( rfx_basis_test <- as.matrix(rfx_basis_test) } + # Convert y_train to a vector + if (is.matrix(y_train)) { + if (ncol(y_train) > 1) { + stop("y_train must either be a numeric vector of a one-column matrix") + } else { + y_train <- as.numeric(y_train) + } + } else { + if (!is.numeric(y_train)) { + stop("y_train must either be a numeric vector of a one-column matrix") + } + } + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- FALSE has_rfx_test <- FALSE @@ -808,17 +855,6 @@ bcf <- function( } } - # Check that outcome and treatment are numeric - if (!is.numeric(y_train)) { - stop("y_train must be numeric") - } - if (!is.numeric(Z_train)) { - stop("Z_train must be numeric") - } - if (!is.null(Z_test)) { - if (!is.numeric(Z_test)) stop("Z_test must be numeric") - } - # Data consistency checks if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { stop("X_train and X_test must have the same number of columns") diff --git a/tools/debug/bart_continue_sampler_debug.R b/tools/debug/bart_continue_sampler_debug.R new file mode 100644 index 00000000..e34b06fb --- /dev/null +++ b/tools/debug/bart_continue_sampler_debug.R @@ -0,0 +1,147 @@ +# Load libraries +library(stochtree) + +# Sampler settings +num_chains <- 1 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 20 +num_trees <- 100 + +# Generate the data +n <- 500 +p_x <- 10 +snr <- 2 +X <- matrix(runif(n * p_x), ncol = p_x) +f_XW <- sin(4 * pi * X[, 1]) + + cos(4 * pi * X[, 2]) + + sin(4 * pi * X[, 3]) + + cos(4 * pi * X[, 4]) +noise_sd <- sd(f_XW) / snr +y <- f_XW + rnorm(n, 0, 1) * noise_sd + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds, ]) +X_train <- as.data.frame(X[train_inds, ]) +y_test <- y[test_inds] +y_train <- y[train_inds] +f_XW_test <- f_XW[test_inds] +f_XW_train <- f_XW[train_inds] + +# Run the GFR algorithm +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list( + num_trees = num_trees, + alpha = 0.95, + beta = 2.0, + max_depth = -1, + min_samples_leaf = 1, + sample_sigma2_leaf = F, + sigma2_leaf_init = 1.0 / num_trees +) +xbart_model <- stochtree::bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = 0, + general_params = general_params, + mean_forest_params = mean_forest_params +) + +# Inspect results +plot(rowMeans(xbart_model$y_hat_test), y_test) +abline(0, 1) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(xbart_model$y_hat_test) - y_test)^2)), + "\n" +)) +cat(paste0( + "Interval coverage = ", + mean( + (apply(xbart_model$y_hat_test, 1, quantile, probs = 0.025) <= f_XW_test) & + (apply(xbart_model$y_hat_test, 1, quantile, probs = 0.975) >= f_XW_test) + ), + "\n" +)) +plot(xbart_model$sigma2_global_samples) +xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model) + +# Run the BART MCMC sampler, initialized from the XBART sampler +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list( + num_trees = num_trees, + alpha = 0.95, + beta = 2.0, + max_depth = -1, + min_samples_leaf = 1, + sample_sigma2_leaf = F, + sigma2_leaf_init = 1.0 / num_trees +) +bart_model <- stochtree::bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + mean_forest_params = mean_forest_params, + previous_model_json = xbart_model_string, + previous_model_warmstart_sample_num = num_gfr +) + +# Inspect the results +plot(rowMeans(bart_model$y_hat_test), y_test) +abline(0, 1) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(bart_model$y_hat_test) - y_test)^2)), + "\n" +)) +cat(paste0( + "Interval coverage = ", + mean( + (apply(bart_model$y_hat_test, 1, quantile, probs = 0.025) <= f_XW_test) & + (apply(bart_model$y_hat_test, 1, quantile, probs = 0.975) >= f_XW_test) + ), + "\n" +)) +plot(bart_model$sigma2_global_samples) + +# Compare to a single chain of MCMC samples initialized at root +bart_model_root <- stochtree::bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_params, + mean_forest_params = mean_forest_params +) +plot(rowMeans(bart_model_root$y_hat_test), y_test) +abline(0, 1) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(bart_model_root$y_hat_test) - y_test)^2)), + "\n" +)) +cat(paste0( + "Interval coverage = ", + mean( + (apply(bart_model_root$y_hat_test, 1, quantile, probs = 0.025) <= + f_XW_test) & + (apply(bart_model_root$y_hat_test, 1, quantile, probs = 0.975) >= + f_XW_test) + ), + "\n" +)) +plot(bart_model_root$sigma2_global_samples) diff --git a/tools/debug/bcf_401k_data_debug.R b/tools/debug/bcf_401k_data_debug.R new file mode 100644 index 00000000..a9abdada --- /dev/null +++ b/tools/debug/bcf_401k_data_debug.R @@ -0,0 +1,276 @@ +################################################################################ +## Investigation of GFR vs MCMC fit issues on the 401k dataset +################################################################################ + +# Load libraries and set seed +library(stochtree) +library(DoubleML) +library(tidyverse) +# seed = 102 +# set.seed(seed) + +# Load 401k data +dat = DoubleML::fetch_401k(return_type = "data.frame") +dat_orig = dat + +# Trim outliers +dat = dat %>% filter(abs(inc) < quantile(abs(inc), 0.9)) + +# Isolate covariates and convert to df +x = dat %>% dplyr::select(-c(e401, net_tfa)) + +# Convert to df and define categorical data types +xdf = data.frame(x) +xdf_st = xdf %>% + mutate( + age = factor(age, ordered = TRUE), + inc = factor(inc, ordered = TRUE), + educ = factor(educ, ordered = TRUE), + fsize = factor(fsize, ordered = TRUE), + marr = factor(marr, ordered = TRUE), + twoearn = factor(twoearn, ordered = TRUE), + db = factor(db, ordered = TRUE), + pira = factor(pira, ordered = TRUE), + hown = factor(hown, ordered = TRUE) + ) + +# Isolate treatment and outcome +z = dat %>% dplyr::select(e401) %>% as.matrix() +y = dat %>% dplyr::select(net_tfa) %>% as.matrix() + +# Define a "jittered" version of the original (integer-valued) x columns +# in which all categories are "upper-jittered" with uniform [0, eps] noise +# except for the largest category which is "lower-jittered" with [-eps, 0] noise +x_jitter = x +for (j in 1:ncol(x)) { + min_diff <- min(diff(sort(x[, j]))[diff(sort(x[, j])) > 0]) + jitter_param <- min_diff / 3.0 + has_max_category <- x[, j] == max(x[, j]) + x_jitter[has_max_category, j] <- x[has_max_category, j] + + runif(sum(has_max_category), -jitter_param, 0.0) + x_jitter[!has_max_category, j] <- x[!has_max_category, j] + + runif(sum(!has_max_category), 0.0, jitter_param) +} +# Visualize jitters +# for (j in 1:ncol(x)) { +# plot(x[,j], x_jitter[,j], ylab = "jittered", xlab = "original") +# unique_xs <- unique(x[,j]) +# for (i in unique_xs) { +# abline(h = unique_xs[i], col = "red", lty = 3) +# } +# } + +# Fit a p(z = 1 | x) model for propensity features +general_params <- list( + probit_outcome_model = TRUE, + sample_sigma2_global = FALSE +) +mean_forest_params <- list( + num_trees = 200 +) +propensity_model <- bart( + X_train = xdf, + y_train = z, + general_params = general_params, + mean_forest_params = mean_forest_params +) +propensity = predict( + propensity_model, + X = xdf, + type = "mean", + terms = "y_hat", + scale = "probability" +) + +# Test-train split +n <- nrow(x) +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +xdf_st_test <- xdf_st[test_inds, ] +xdf_st_train <- xdf_st[train_inds, ] +x_test <- x[test_inds, ] +x_train <- x[train_inds, ] +x_jitter_test <- x_jitter[test_inds, ] +x_jitter_train <- x_jitter[train_inds, ] +pi_test <- propensity[test_inds] +pi_train <- propensity[train_inds] +z_test <- z[test_inds, ] +z_train <- z[train_inds, ] +y_test <- y[test_inds, ] +y_train <- y[train_inds, ] +y_train_scale <- scale(y_train) +y_train_sd <- attr(y_train_scale, "scaled:scale") +y_train_mean <- attr(y_train_scale, "scaled:center") +y_test_scale <- (y_test - y_train_mean) / y_train_sd + +# Fit BCF with GFR algorithm on the jittered covariates +# and save model to JSON +num_gfr <- 1000 +general_params <- list( + adaptive_coding = FALSE, + propensity_covariate = "none", + keep_every = 1, + verbose = TRUE, + keep_gfr = TRUE +) +bcf_model_gfr <- stochtree::bcf( + X_train = xdf_st_train, + Z_train = z_train, + y_train = y_train_scale, + propensity_train = pi_train, + X_test = xdf_st_test, + Z_test = z_test, + propensity_test = pi_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = 0, + general_params = general_params +) +fit_json_gfr = saveBCFModelToJsonString(bcf_model_gfr) + +# Run MCMC chain from the last GFR sample, setting covariate +# equal to an interpolation between the original x and x_jitter +# (alpha = 0 is 100% x_jitter and alpha = 1 is 100% x) +# alpha <- 1.0 +# x_jitter_new_train <- (alpha) * x_train + (1-alpha) * x_jitter_train +# x_jitter_new_test <- (alpha) * x_test + (1-alpha) * x_jitter_test +x_jitter_new_train <- xdf_st_train +x_jitter_new_test <- xdf_st_test +num_mcmc <- 10000 +bcf_model_mcmc <- stochtree::bcf( + X_train = x_jitter_new_train, + Z_train = z_train, + y_train = y_train_scale, + propensity_train = pi_train, + X_test = x_jitter_new_test, + Z_test = z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + previous_model_json = fit_json_gfr, + previous_model_warmstart_sample_num = num_gfr, + general_params = general_params +) + +# Inspect the "in-sample sigma" via the traceplot +# of the global error variance parameter +combined_sigma <- c( + bcf_model_gfr$sigma2_global_samples, + bcf_model_mcmc$sigma2_global_samples +) +plot( + combined_sigma, + ylab = "sigma2", + xlab = "sample num", + main = "Global error var traceplot" +) + +# Inspect the "out-of-sample sigma" by compute the MSE +# of the yhat on the test set +yhat_combined_train <- cbind( + bcf_model_gfr$y_hat_train, + bcf_model_mcmc$y_hat_train +) +yhat_combined_test <- cbind( + bcf_model_gfr$y_hat_test, + bcf_model_mcmc$y_hat_test +) +num_samples <- ncol(yhat_combined_train) +train_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + train_mses[i] <- mean((yhat_combined_train[, i] - y_train_scale)^2) +} +test_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + test_mses[i] <- mean((yhat_combined_test[, i] - y_test_scale)^2) +} +max_y <- max(c(max(train_mses, test_mses))) +min_y <- min(c(min(train_mses, test_mses))) +plot( + test_mses, + ylab = "outcome MSE", + xlab = "sample num", + main = "Outcome MSE Traceplot", + ylim = c(min_y, max_y) +) +points(train_mses, col = "blue") +legend( + "right", + legend = c("Out-of-Sample", "In-Sample"), + col = c("black", "blue"), + pch = c(1, 1) +) + +# Run some one-off pred vs actual plots +plot(yhat_combined_test[, 11000], y_test_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_mcmc$y_hat_train[, 10000], y_train_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_mcmc$y_hat_test[, 10000], y_test_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_gfr$y_hat_train[, 1000], y_train_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_gfr$y_hat_test[, 1000], y_test_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_gfr$y_hat_train[, 10], y_train_scale) +abline(0, 1, col = "red", lty = 3) +plot(bcf_model_gfr$y_hat_test[, 10], y_test_scale) +abline(0, 1, col = "red", lty = 3) + +# Run MCMC chain from root +num_mcmc <- 10000 +bcf_model_mcmc_root <- stochtree::bcf( + X_train = xdf_st_train, + Z_train = z_train, + y_train = y_train_scale, + propensity_train = pi_train, + X_test = xdf_st_test, + Z_test = z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_params +) + +# Inspect the "in-sample sigma" via the traceplot +# of the global error variance parameter +sigma_trace <- bcf_model_mcmc_root$sigma2_global_samples +plot( + sigma_trace, + ylab = "sigma2", + xlab = "sample num", + main = "Global error var traceplot" +) + +# Inspect the "out-of-sample sigma" by compute the MSE +# of the yhat on the test set +yhat_combined_train <- cbind( + bcf_model_mcmc_root$y_hat_train +) +yhat_combined_test <- cbind( + bcf_model_mcmc_root$y_hat_test +) +num_samples <- ncol(yhat_combined_train) +train_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + train_mses[i] <- mean((yhat_combined_train[, i] - y_train_scale)^2) +} +test_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + test_mses[i] <- mean((yhat_combined_test[, i] - y_test_scale)^2) +} +max_y <- max(c(max(train_mses, test_mses))) +min_y <- min(c(min(train_mses, test_mses))) +plot( + test_mses, + ylab = "outcome MSE", + xlab = "sample num", + main = "Test set outcome MSEs", + ylim = c(min_y, max_y) +) +points(train_mses, col = "blue") diff --git a/tools/debug/bcf_continue_sampler_debug.R b/tools/debug/bcf_continue_sampler_debug.R new file mode 100644 index 00000000..151b40b1 --- /dev/null +++ b/tools/debug/bcf_continue_sampler_debug.R @@ -0,0 +1,131 @@ +# Load libraries +library(stochtree) + +# Sampler settings +num_chains <- 1 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 20 +num_trees <- 100 + +# Generate the data +n <- 500 +p <- 5 +snr <- 2 +X <- matrix(runif(n * p), ncol = p) +mu_x <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) +pi_x <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) +tau_x <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * + (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) +Z <- rbinom(n, 1, pi_x) +f_XZ <- mu_x + tau_x * Z +noise_sd <- sd(f_XZ) / snr +y <- f_XZ + rnorm(n, 0, 1) * noise_sd + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] + +# Run the GFR algorithm +general_params <- list(sample_sigma2_global = T) +xbcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = 0, + general_params = general_params +) + +# Inspect results +plot(rowMeans(xbcf_model$y_hat_test), y_test) +abline(0, 1) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(xbcf_model$y_hat_test) - y_test)^2)), + "\n" +)) +plot(xbcf_model$sigma2_global_samples) +xbcf_model_string <- stochtree::saveBCFModelToJsonString(xbcf_model) + +# Run the BCF MCMC sampler, initialized from the XBART sampler +general_params <- list(sample_sigma2_global = T) +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + previous_model_json = xbcf_model_string, + previous_model_warmstart_sample_num = num_gfr +) + +# Inspect the results +plot(rowMeans(bcf_model$y_hat_test), y_test) +abline(0, 1) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(bcf_model$y_hat_test) - y_test)^2)), + "\n" +)) +plot(bcf_model$sigma2_global_samples) + +# Compare to a single chain of MCMC samples initialized at root +bcf_model_root <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params +) +plot(rowMeans(bcf_model_root$y_hat_test), y_test) +abline(0, 1) +plot(bcf_model_root$sigma2_global_samples) +cat(paste0( + "RMSE = ", + sqrt(mean((rowMeans(bcf_model_root$y_hat_test) - y_test)^2)), + "\n" +))