diff --git a/sphinx_plotly_directive/utils.py b/sphinx_plotly_directive/utils.py index 9d6d8ec..ec7daeb 100644 --- a/sphinx_plotly_directive/utils.py +++ b/sphinx_plotly_directive/utils.py @@ -30,9 +30,14 @@ def save_plotly_figure(fig, path): >>> path = tempfile.NamedTemporaryFile(suffix=".html").name >>> save_plotly_figure(fig, path) """ - fig_html = plotly.offline.plot(fig, output_type="div", include_plotlyjs="cdn", auto_open=False) - with open(path, "w") as f: - f.write(fig_html) + ext = path.split(".")[-1] + if ext in ["htm", "html"]: + fig_html = plotly.offline.plot(fig, output_type="div", + include_plotlyjs="cdn", auto_open=False) + with open(path, "w") as f: + f.write(fig_html) + else: + fig.write_image(path) def assign_last_line_into_variable(code, variable_name): diff --git a/tests/test_utils.py b/tests/test_utils.py index d25a54f..f9c2b92 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -19,6 +19,14 @@ def tset_save_plotly_figure(tmpdir): save_plotly_figure(fig, out_path) assert os.path.exists(out_path) + out_path = os.path.join(tmpdir.strpath, "fig.png") + save_plotly_figure(fig, out_path) + assert os.path.exists(out_path) + + out_path = os.path.join(tmpdir.strpath, "fig.pdf") + save_plotly_figure(fig, out_path) + assert os.path.exists(out_path) + def test_assign_last_line_into_variable(): code = """