Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,7 @@ def compose_command(
cmd.append(f'init={self.inits[idx]}')
cmd.append('output')
cmd.append(f'file={csv_file}')
cmd.append('save_cmdstan_config=1')
if diagnostic_file:
cmd.append(f'diagnostic_file={diagnostic_file}')
if profile_file:
Expand Down
46 changes: 29 additions & 17 deletions cmdstanpy/stanfit/runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,29 @@ def __init__(
)
self._stdout_files, self._profile_files = [], []
self._csv_files, self._diagnostic_files = [], []
self._config_files = []

# per-process output files
if one_process_per_chain and chains > 1:
self._stdout_files = [
self.gen_file_name(".txt", extra="stdout", id=id)
for id in self._chain_ids
]
self._config_files = [
os.path.join(
self._outdir, f"{self._base_outfile}_{id}_config.json"
)
for id in self._chain_ids
]

if args.save_profile:
self._profile_files = [
self.gen_file_name(".csv", extra="profile", id=id)
for id in self._chain_ids
]
else:
self._stdout_files = [self.gen_file_name(".txt", extra="stdout")]
self._config_files = [self.gen_file_name(".json", extra="config")]
if args.save_profile:
self._profile_files = [
self.gen_file_name(".csv", extra="profile")
Expand All @@ -93,25 +102,21 @@ def __init__(
]

def __repr__(self) -> str:
repr = 'RunSet: chains={}, chain_ids={}, num_processes={}'.format(
self._chains, self._chain_ids, self._num_procs
)
repr = '{}\n cmd (chain 1):\n\t{}'.format(repr, self.cmd(0))
repr = '{}\n retcodes={}'.format(repr, self._retcodes)
repr = f'{repr}\n per-chain output files (showing chain 1 only):'
repr = '{}\n csv_file:\n\t{}'.format(repr, self._csv_files[0])
lines = [
f"RunSet: chains={self._chains}, chain_ids={self._chain_ids}, "
f"num_processes={self._num_procs}",
f" cmd (chain 1):\n\t{self.cmd(0)}",
f" retcodes={self._retcodes}",
" per-chain output files (showing chain 1 only):",
f" csv_file:\n\t{self._csv_files[0] if self._csv_files else ''}",
]
if self._args.save_latent_dynamics:
repr = '{}\n diagnostics_file:\n\t{}'.format(
repr, self._diagnostic_files[0]
)
lines.append(f" diagnostics_file:\n\t{self._diagnostic_files[0]}")
if self._args.save_profile:
repr = '{}\n profile_file:\n\t{}'.format(
repr, self._profile_files[0]
)
repr = '{}\n console_msgs (if any):\n\t{}'.format(
repr, self._stdout_files[0]
)
return repr
lines.append(f" profile_file:\n\t{self._profile_files[0]}")
lines.append(f" console_msgs (if any):\n\t{self._stdout_files[0]}")
lines.append(f" config_files:\n\t{self._config_files[0]}")
return '\n'.join(lines)

@property
def model(self) -> str:
Expand Down Expand Up @@ -196,6 +201,13 @@ def stdout_files(self) -> list[str]:
"""
return self._stdout_files

@property
def config_files(self) -> list[str]:
"""
List of paths to CmdStan config json files.
"""
return self._config_files

def _check_retcodes(self) -> bool:
"""Returns ``True`` when all chains have retcode 0."""
return all(retcode == 0 for retcode in self._retcodes)
Expand Down
12 changes: 12 additions & 0 deletions test/test_cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,3 +808,15 @@ def test_args_pathfinder_bad(arg: str, require_int: bool) -> None:
args = PathfinderArgs(**{arg: 1.1}) # type: ignore
with pytest.raises(ValueError):
args.validate()


def test_save_cmdstan_config() -> None:
sampler_args = SamplerArgs()
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe='',
chain_ids=[1, 2, 3, 4],
method_args=sampler_args,
)
command = cmdstan_args.compose_command(0, csv_file="foo")
assert "save_cmdstan_config=1" in command
9 changes: 9 additions & 0 deletions test/test_runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_check_repr() -> None:
assert 'csv_file' in repr(runset)
assert 'console_msgs' in repr(runset)
assert 'diagnostics_file' not in repr(runset)
assert 'config_file' in repr(runset)


def test_check_retcodes() -> None:
Expand Down Expand Up @@ -106,6 +107,11 @@ def test_output_filenames_one_proc_per_chain() -> None:
stdout_file.endswith(f"_stdout_{id}.txt")
for id, stdout_file in zip(chain_ids, runset.stdout_files)
)
assert len(runset.config_files) == len(chain_ids)
assert all(
config_file.endswith(f"_{id}_config.json")
for id, config_file in zip(chain_ids, runset.config_files)
)

cmdstan_args_other_files = CmdStanArgs(
model_name='bernoulli',
Expand Down Expand Up @@ -153,6 +159,8 @@ def test_output_filenames_threading() -> None:
)
assert len(runset.stdout_files) == 1
assert runset.stdout_files[0].endswith("_stdout.txt")
assert len(runset.config_files) == 1
assert runset.config_files[0].endswith("_config.json")

cmdstan_args_other_files = CmdStanArgs(
model_name='bernoulli',
Expand Down Expand Up @@ -198,6 +206,7 @@ def test_output_filenames_single_chain() -> None:
runset = RunSet(args=cmdstan_args, chains=1, one_process_per_chain=True)
base_file = runset._base_outfile
assert runset.stdout_files[0].endswith(f"{base_file}_stdout.txt")
assert runset.config_files[0].endswith(f"{base_file}_config.json")

cmdstan_args_other_files = CmdStanArgs(
model_name='bernoulli',
Expand Down
24 changes: 24 additions & 0 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2204,3 +2204,27 @@ def test_no_output_draws() -> None:
mcmc = model.sample(data=data, iter_sampling=0, save_warmup=False, chains=2)
draws = mcmc.draws()
assert np.array_equal(draws, np.empty((0, 2, len(mcmc.column_names))))


def test_config_output() -> None:
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
model = CmdStanModel(stan_file=stan)
fit = model.sample(
data=jdata,
chains=2,
seed=12345,
iter_warmup=100,
iter_sampling=200,
)
assert all(os.path.exists(cf) for cf in fit.runset.config_files)

# Config file naming differs when only a single chain is output
fit_one_chain = model.sample(
data=jdata,
chains=1,
seed=12345,
iter_warmup=100,
iter_sampling=200,
)
assert all(os.path.exists(cf) for cf in fit_one_chain.runset.config_files)