This class stores several graph-related methods used in causal discovery.
In causal discovery, DAG pattern is often used rather than DAGs themselves, so this class is named 'pattern' instead of 'graph'. If the pattern does not contain any link or cyclic path, it is the same as DAG.
Usual causal inference package, such as dowhy or cdt, uses networkx package due to its convenient APIs and powerful visualizations. Howerver, picd does NOT use networkx and instead implements its own graph class, because it is difficult to deal with a more abstract DAG pattern rather than a DAG in networkx. In networkx, there are only undirected and directed graphs. Although it is possible to use opposite directed edges <-> instead of a link -, there is a potential risk that the graph becomes cyclic and we should continue to check if the edge is a substitute for a link or not.
from collections import deque
from itertools import combinations, chain
from collections import defaultdict
class pattern:
def __init__(self, vertex = None, edges = None, links = None):
self.vertex = set() # ex : {'A', 'B', 'C'}
self.parent = dict() # 3d dictionary ex: A -> B (weight = 0.5) => {'A' : {'B': {'weight' : 0.5 } } }
self.child = dict() # 3d dictionary. ex: A -> B (weight = 0.5) => {'B' : {'A': {'weight' : 0.5 } } }
self.link = dict() # 3d dictionary. ex: A - B (weight = 0.5) => {'A' : {'B': {'weight' : 0.5 } }, 'B' : {'A' : {'weight' : 0.5 } } }
# Why use 3d dictionary data structure? (1) Searching efficiency (2) Additional attributes, such as weight, can be stored on the edge or link.
self.link_count = 0
self.add_vertex(vertex)
self.add_edges(edges)
self.add_links(links)
self.d_separation_set = defaultdict(lambda: defaultdict(set))
def add_vertex(self, vertex) -> None:
if vertex:
for v in vertex:
if v not in self.vertex:
self.vertex.add(v)
self.parent[v] = dict()
self.child[v] = dict()
self.link[v] = dict()
pattern.add_vertex = add_vertex
def remove_vertex(self, vertex) -> None:
for v in vertex:
self.vertex.remove(v)
for p in self.parent[v].keys():
self.child[p].pop(v, None) # If v is not key of child[p], return None
for c in self.child[v].keys():
self.parent[c].pop(v, None)
for l in self.link[v].keys():
self.remove_links((l,v))
self.parent.pop(v, None)
self.child.pop(v, None)
self.link.pop(v, None)
pattern.remove_vertex = remove_vertex
def add_edge(self, v1, v2, **attribute)->None:
self.add_vertex([v1, v2])
self.parent[v2][v1] = attribute
self.child[v1][v2] = attribute
pattern.add_edge = add_edge
Add edges by arr-like parameter edges
def add_edges(self, edges) -> None:
if edges:
for e in edges:
# ex) {'v1' : 'A', 'v2' : 'B', 'weight' : 0.5}
if isinstance(e, dict):
v1 = e['v1']; v2 = e['v2']
del e['v1']; del e['v2']
self.add_edge(v1, v2, **e)
# ex) ('A', 'B')
elif len(e) == 2:
self.add_edge(*e)
# ex) ('A', 'B', 0.5, 0.1)
else:
arg = dict()
i = 0
for attr in e[2:]:
arg[f'A{i}'] = attr
i += 1
self.add_edge(e[0], e[1], **arg)
pattern.add_edges = add_edges
def remove_edges(self, edges) -> None:
for e in edges:
pa, ch = e
del self.parent[ch][pa]
del self.child[pa][ch]
exist1 = self.parent[ch].pop(pa, None)
exist2 = self.child[pa].pop(ch, None)
if exist1 is not None and exist2 is not None: continue
elif not (exist1 is None and exist2 is None):
print(f'remove_edges : the edge between {v1} and {v2} is not matched with self.parent and self.child!')
pattern.remove_edges = remove_edges
def add_link(self, v1, v2, **attribute) -> None:
self.add_vertex([v1, v2])
self.link[v1][v2] = attribute
self.link[v2][v1] = attribute
self.link_count += 1
pattern.add_link = add_link
def add_links(self, links) -> None:
if links:
for l in links:
if isinstance(l, dict):
v1 = l['v1']; v2 = l['v2']
del l['v1']; del l['v2']
self.add_link(v1, v2, **l)
elif len(l) == 2:
self.add_link(*l)
else:
arg = dict()
i = 0
for attr in l[2:]:
arg[f'A{i}'] = attr
i += 1
self.add_link(l[0], l[1], **arg)
pattern.add_links = add_links
def remove_links(self, links) -> None:
for l in links:
v1, v2 = l
exist1 = self.link[v1].pop(v2, None)
exist2 = self.link[v2].pop(v1, None)
if exist1 is not None and exist2 is not None:
self.link_count -= 1
elif not (exist1 is None and exist2 is None):
print(f'remove_links : there are unsymmetric links between {v1} and {v2}!')
pattern.remove_links = remove_links
It is used in pc algorithm
def full_link(self):
temp_vertex = list(self.vertex)
for i, v1 in enumerate(temp_vertex[:-1]):
for v2 in temp_vertex[i + 1 :]:
self.add_link(v1, v2)
pattern.full_link = full_link
ptn = pattern()
ptn.add_edge('A', 'B', weight = 10, label = 1, any_other_attribute = 23)
ptn.add_links([('A', 'C'), ('C', 'D')])
ptn.vertex
{'A', 'B', 'C', 'D'}
ptn.child
{'A': {'B': {'weight': 10, 'label': 1, 'any_other_attribute': 23}},
'B': {},
'C': {},
'D': {}}
ptn.link
{'A': {'C': {}}, 'B': {}, 'C': {'A': {}, 'D': {}}, 'D': {'C': {}}}
Return
def is_adjacent(self, v1, v2):
return v1 in self.link[v2].keys() or v1 in self.child[v2].keys() or v1 in self.parent[v2].keys()
def adjacent(self, v1):
return {v2 for v2 in (self.vertex - {v1}) if self.is_adjacent(v1, v2)}
pattern.is_adjacent = is_adjacent
pattern.adjacent = adjacent
Example
ptn = pattern(edges = [('A', 'B'), ('A', 'C'), ('C', 'D')])
ptn.is_adjacent('A', 'D')
False
ptn.adjacent('A')
{'B', 'C'}
BFS algorithm is used
def get_ancestor(self, vertex) -> set:
visited = {v:0 for v in self.vertex}
visited[vertex] = 1
result = set()
queue = deque([vertex])
while queue:
v = queue.popleft()
for v1 in self.parent[v].keys():
if not visited[v1]:
visited[v1] = 1
result.add(v1)
queue.append(v1)
return result
def get_descendant(self, vertex) -> set:
visited = {v:0 for v in self.vertex}
visited[vertex] = 1
result = set()
queue = deque([vertex])
while queue:
v = queue.popleft()
for v1 in self.child[v].keys():
if not visited[v1]:
visited[v1] = 1
result.add(v1)
queue.append(v1)
return result
pattern.get_ancestor = get_ancestor
pattern.get_descendant = get_descendant
Back Tracking algorithm is used
Return
Parameter
def get_path(self, source, target, directed = True):
return self.get_path_(source, target, directed=directed)
def get_path_(self, v1, v2, trace = None, initial = True, directed = True):
if initial:
self.visited = {v:0 for v in self.vertex}
self.visited[v1] = 1
self.result = []
trace = [v1]
for v in self.child[v1].keys():
if v == v2:
self.result.append(trace + [v2])
else:
if not self.visited[v]:
new_trace = trace + [v]
self.visited[v] = 1
self.get_path_(v, v2, new_trace, False, directed)
self.visited[v] = 0
if not directed:
for v in self.parent[v1].keys():
if v == v2:
self.result.append(trace + [v2])
else:
if not self.visited[v]:
new_trace = trace + [v]
self.visited[v] = 1
self.get_path_(v, v2, new_trace, False, directed)
self.visited[v] = 0
for v in self.link[v1].keys():
if v == v2:
self.result.append(trace + [v2])
else:
if not self.visited[v]:
new_trace = trace + [v]
self.visited[v] = 1
self.get_path_(v, v2, new_trace, False, directed)
self.visited[v] = 0
return self.result #2d list
pattern.get_path = get_path
pattern.get_path_ = get_path_
def is_cyclic(self) -> bool:
# Code Resource : https://www.geeksforgeeks.org/detect-cycle-in-a-graph/
visited = {v:0 for v in self.vertex}
recStack = {v:0 for v in self.vertex}
for v in self.vertex:
if not visited[v]:
if self.is_cyclic_util(v, visited, recStack): return True
return False
def is_cyclic_util(self, v, visited, recStack) -> bool:
visited[v] = 1
recStack[v] = 1
for ch in self.child[v].keys():
if not visited[ch]:
if self.is_cyclic_util(ch, visited, recStack):
return True
elif recStack[ch]:
return True
recStack[v] = 0
return False
pattern.is_cyclic = is_cyclic
pattern.is_cyclic_util = is_cyclic_util
Example
ptn = pattern(edges = [
('A', 'B'),
('B', 'C'),
('C', 'A'),
])
ptn.is_cyclic()
True
ptn = pattern(edges = [
('A', 'B'),
('B', 'C'),
('C', 'D'),
])
ptn.is_cyclic()
False