From 96b032b90640a2660d4a366d55233bf405143ecc Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Mon, 26 Aug 2024 18:59:19 +0530 Subject: [PATCH 1/4] initial commit Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/generic.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index 2c336df06..0b4c30461 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -823,3 +823,7 @@ def is_maximal(G, L: Optional[Set] = None, S: Optional[Set] = None): else: continue return True + + +def find_visible_edges(G): + pass \ No newline at end of file From 9431b9dc280b51015c9950ce812c9bb7afc7fc42 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Wed, 2 Oct 2024 14:42:37 +0530 Subject: [PATCH 2/4] Added function skeleton Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/generic.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index 0b4c30461..36bd96605 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -825,5 +825,29 @@ def is_maximal(G, L: Optional[Set] = None, S: Optional[Set] = None): return True -def find_visible_edges(G): - pass \ No newline at end of file +def get_collider_path(G, X, Y): + pass + +def check_visibility(G: PAG, X: str, Y: str): + X_neighbors = G.neighbors(X) + Y_neighbors = G.neighbors(Y) + + only_x_neighbors = X_neighbors - Y_neighbors + + for elem in only_x_neighbors: + if G.has_edge(elem, X, G.directed_edge_name): + return True + + all_nodes = set(G.nodes) + + candidates = all_nodes - Y_neighbors + + for elem in candidates: + collider_path = get_collider_path(G,elem,X) + final_node = collider_path[-2] + if G.has_edge(final_node, X, G.directed_edge_name): + return True + + return False + + \ No newline at end of file From f5975d7e2eea7a4a623403ad514610cc9ce054a3 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Mon, 7 Oct 2024 20:31:21 +0530 Subject: [PATCH 3/4] add function to check visibility of an edge Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/generic.py | 35 +++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index eea6afd76..cb7a13c27 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -855,8 +855,28 @@ def all_vstructures(G: nx.DiGraph, as_edges: bool = False): vstructs.add((p1, node, p2)) # type: ignore return vstructs -def get_collider_path(G, X, Y): - pass +def get_all_collider_paths(G : PAG, X, Y): + + out = [] + + # find all the possible paths from X to Y with only bi-directed edges + + bidirected_edge_graph = G.sub_bidirected_graph + + X_descendants = set(G.sub_directed_graph.neigbors(X)) + + candidate_collider_path_nodes = set(bidirected_edge_graph.nodes).intersection(X_descendants) + + if candidate_collider_path_nodes is None: + return out + + for elem in candidate_collider_path_nodes: + out.extend(nx.all_simple_paths(G, elem, Y)) + + # for path in out: + # path.insert(0,X) + + return out def check_visibility(G: PAG, X: str, Y: str): X_neighbors = G.neighbors(X) @@ -873,10 +893,13 @@ def check_visibility(G: PAG, X: str, Y: str): candidates = all_nodes - Y_neighbors for elem in candidates: - collider_path = get_collider_path(G,elem,X) - final_node = collider_path[-2] - if G.has_edge(final_node, X, G.directed_edge_name): - return True + collider_paths = get_all_collider_paths(G,elem,X) + for path in collider_paths: + for node in path: + if node in G.neighbors(Y): + continue + else: + return True return False From ce989f4e133b8a3f96c1838b5a2d8df7cb17a6e0 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Sun, 17 Nov 2024 16:29:13 +0530 Subject: [PATCH 4/4] Add some tests Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/generic.py | 10 ++++-- pywhy_graphs/algorithms/tests/test_generic.py | 33 +++++++++++++++++-- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index cb7a13c27..d878c4765 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -19,6 +19,7 @@ "dag_to_mag", "is_maximal", "all_vstructures", + "check_visibility" ] @@ -879,17 +880,22 @@ def get_all_collider_paths(G : PAG, X, Y): return out def check_visibility(G: PAG, X: str, Y: str): - X_neighbors = G.neighbors(X) - Y_neighbors = G.neighbors(Y) + + X_neighbors = set(G.neighbors(X)) + Y_neighbors = set(G.neighbors(Y)) only_x_neighbors = X_neighbors - Y_neighbors + for elem in only_x_neighbors: if G.has_edge(elem, X, G.directed_edge_name): return True all_nodes = set(G.nodes) + all_nodes.remove(X) + + candidates = all_nodes - Y_neighbors for elem in candidates: diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index 09218a334..80b70eb6d 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -2,8 +2,8 @@ import pytest import pywhy_graphs -from pywhy_graphs import ADMG -from pywhy_graphs.algorithms import all_vstructures +from pywhy_graphs import ADMG, PAG +from pywhy_graphs.algorithms import all_vstructures, check_visibility def test_convert_to_latent_confounder_errors(): @@ -496,3 +496,32 @@ def test_all_vstructures(): # Assert that the returned values are as expected assert len(v_structs_edges) == 0 assert len(v_structs_tuples) == 0 + + + +def test_check_visibility(): + + # H <-> K <-> Z <-> X <- Y + + pag = PAG() + pag.add_edge("Y", "X", pag.directed_edge_name) + pag.add_edge("Z", "X", pag.bidirected_edge_name) + pag.add_edge("Z", "K", pag.bidirected_edge_name) + pag.add_edge("K", "H", pag.bidirected_edge_name) + + assert True == check_visibility(pag, "X", "Y") + + pag = PAG() + pag.add_edge("Y", "X", pag.directed_edge_name) + pag.add_edge("Z", "X", pag.bidirected_edge_name) + + assert True == check_visibility(pag, "X", "Y") + + pag = PAG() + pag.add_edge("Y", "X", pag.directed_edge_name) + pag.add_edge("Z", "Y", pag.bidirected_edge_name) + pag.add_edge("Z", "K", pag.bidirected_edge_name) + pag.add_edge("K", "H", pag.bidirected_edge_name) + + assert False == check_visibility(pag, "X", "Y") +