DAG Pattern_base

둘러봐 기술블로그·2023년 9월 11일
0
post-thumbnail

class pattern

This class stores several graph-related methods used in causal discovery.

In causal discovery, DAG pattern is often used rather than DAGs themselves, so this class is named 'pattern' instead of 'graph'. If the pattern does not contain any link or cyclic path, it is the same as DAG.

Usual causal inference package, such as dowhy or cdt, uses networkx package due to its convenient APIs and powerful visualizations. Howerver, picd does NOT use networkx and instead implements its own graph class, because it is difficult to deal with a more abstract DAG pattern rather than a DAG in networkx. In networkx, there are only undirected and directed graphs. Although it is possible to use opposite directed edges <-> instead of a link -, there is a potential risk that the graph becomes cyclic and we should continue to check if the edge is a substitute for a link or not.

from collections import deque
from itertools import combinations, chain
from collections import defaultdict

class pattern:
    def __init__(self, vertex = None, edges = None, links = None):
        self.vertex = set() # ex : {'A', 'B', 'C'}
        self.parent = dict() # 3d dictionary ex: A -> B (weight = 0.5) => {'A' : {'B': {'weight' : 0.5 } } }
        self.child = dict() # 3d dictionary. ex: A -> B (weight = 0.5) => {'B' : {'A': {'weight' : 0.5 } } }
        self.link = dict() # 3d dictionary. ex: A - B (weight = 0.5) => {'A' : {'B': {'weight' : 0.5 } }, 'B' : {'A' : {'weight' : 0.5 } } }
                           # Why use 3d dictionary data structure?  (1) Searching efficiency (2) Additional attributes, such as weight, can be stored on the edge or link.

        self.link_count = 0

        self.add_vertex(vertex)
        self.add_edges(edges)
        self.add_links(links)

        self.d_separation_set = defaultdict(lambda: defaultdict(set))

Parameter

  • vertex-vertex array like {'A', 'B', 'C'}. If it is None, vertex set is automatically filled by vertexs included in edges and links .
  • edges-(directed) edge array like {('A', 'B'), ('B', 'C')} ( = A->B, B->C)
  • links-link array like {('A', 'B'), ('B', 'C')} ( = A-B, B-C)

add & remove Methods

Vertex

add_vertex

def add_vertex(self, vertex) -> None:
    if vertex: 
        for v in vertex:
            if v not in self.vertex:
                self.vertex.add(v)
                self.parent[v] = dict()
                self.child[v] = dict()
                self.link[v] = dict()

pattern.add_vertex = add_vertex

remove_vertex

def remove_vertex(self, vertex) -> None:
    for v in vertex:
        self.vertex.remove(v)

        for p in self.parent[v].keys():
            self.child[p].pop(v, None) # If v is not key of child[p], return None
        
        for c in self.child[v].keys():
            self.parent[c].pop(v, None)
        
        for l in self.link[v].keys():
            self.remove_links((l,v))
        
        self.parent.pop(v, None)
        self.child.pop(v, None)
        self.link.pop(v, None)

pattern.remove_vertex = remove_vertex

Edge

add_edge

def add_edge(self, v1,  v2, **attribute)->None:
    self.add_vertex([v1, v2])

    self.parent[v2][v1] = attribute
    self.child[v1][v2] = attribute

pattern.add_edge = add_edge

add_edges

Add edges by arr-like parameter edges

def add_edges(self, edges) -> None:
    if edges:
        for e in edges:
            # ex) {'v1' : 'A', 'v2' : 'B', 'weight' : 0.5}
            if isinstance(e, dict):
                v1 = e['v1']; v2 = e['v2']
                del e['v1']; del e['v2']
                self.add_edge(v1, v2, **e)
            # ex) ('A', 'B')
            elif len(e) == 2:
                self.add_edge(*e)

            # ex) ('A', 'B', 0.5, 0.1)
            else:
                arg = dict()
                i = 0
                for attr in e[2:]:
                    arg[f'A{i}'] = attr
                    i += 1
                self.add_edge(e[0], e[1], **arg)

pattern.add_edges = add_edges

remove_edge

def remove_edges(self, edges) -> None:
    for e in edges:
        pa, ch = e
        del self.parent[ch][pa]
        del self.child[pa][ch]

        exist1 = self.parent[ch].pop(pa, None)
        exist2 = self.child[pa].pop(ch, None)

        if exist1 is not None and exist2 is not None: continue
        elif not (exist1 is None and exist2 is None):
            print(f'remove_edges : the edge between {v1} and {v2} is not matched with self.parent and self.child!')

pattern.remove_edges = remove_edges
def add_link(self, v1, v2, **attribute) -> None:
    self.add_vertex([v1, v2])

    self.link[v1][v2] = attribute
    self.link[v2][v1] = attribute
    
    self.link_count += 1

pattern.add_link = add_link
def add_links(self, links) -> None:
    if links:
        for l in links:
            if isinstance(l, dict):
                v1 = l['v1']; v2 = l['v2']
                del l['v1']; del l['v2']
                self.add_link(v1, v2, **l)
            elif len(l) == 2:
                self.add_link(*l)
            else:
                arg = dict()
                i = 0
                for attr in l[2:]:
                    arg[f'A{i}'] = attr
                    i += 1
                self.add_link(l[0], l[1], **arg)

pattern.add_links = add_links
def remove_links(self, links) -> None:
    for l in links:
        v1, v2 = l

        exist1 = self.link[v1].pop(v2, None)
        exist2 = self.link[v2].pop(v1, None)

        if exist1 is not None and exist2 is not None:
            self.link_count -= 1
        elif not (exist1 is None and exist2 is None):
            print(f'remove_links : there are unsymmetric links between {v1} and {v2}!')

pattern.remove_links = remove_links

It is used in pc algorithm

def full_link(self):
    temp_vertex = list(self.vertex)
    for i, v1 in enumerate(temp_vertex[:-1]):
        for v2 in temp_vertex[i + 1 :]:
            self.add_link(v1, v2)

pattern.full_link = full_link

Example

ptn = pattern()
ptn.add_edge('A', 'B', weight = 10, label = 1, any_other_attribute = 23)
ptn.add_links([('A', 'C'), ('C', 'D')])
ptn.vertex
{'A', 'B', 'C', 'D'}
ptn.child
{'A': {'B': {'weight': 10, 'label': 1, 'any_other_attribute': 23}},
 'B': {},
 'C': {},
 'D': {}}
ptn.link
{'A': {'C': {}}, 'B': {}, 'C': {'A': {}, 'D': {}}, 'D': {'C': {}}}

Methods about the relationship of vertices

Adjacency

is_adjacent/adjacent

Return

  • is_adjacent : True if v1 is adjacent to v2, else False
  • adjacent : every vertex which is adjacent to v1
def is_adjacent(self, v1, v2):
    return v1 in self.link[v2].keys() or v1 in self.child[v2].keys() or v1 in self.parent[v2].keys() 

def adjacent(self, v1):
    return {v2 for v2 in (self.vertex - {v1}) if self.is_adjacent(v1, v2)}

pattern.is_adjacent = is_adjacent
pattern.adjacent = adjacent

Example

ptn = pattern(edges = [('A', 'B'), ('A', 'C'), ('C', 'D')])
ptn.is_adjacent('A', 'D')
False
ptn.adjacent('A')
{'B', 'C'}

Ancestor & Descendant

get_ancestor/get_descendant

BFS algorithm is used

def get_ancestor(self, vertex) -> set:
    visited = {v:0 for v in self.vertex}
    visited[vertex] = 1
    result = set()

    queue = deque([vertex])

    while queue:
        v = queue.popleft()
        for v1 in self.parent[v].keys():
            if not visited[v1]:
                visited[v1] = 1
                result.add(v1)
                queue.append(v1)
    
    return result

def get_descendant(self, vertex) -> set:
    visited = {v:0 for v in self.vertex}
    visited[vertex] = 1
    result = set()

    queue = deque([vertex])

    while queue:
        v = queue.popleft()
        for v1 in self.child[v].keys():
            if not visited[v1]:
                visited[v1] = 1
                result.add(v1)
                queue.append(v1)
    
    return result

pattern.get_ancestor = get_ancestor
pattern.get_descendant = get_descendant

Path

get_path

Back Tracking algorithm is used

Return

  • list of every possible directed(or undirected) paths between given two vertexs

Parameter

  • source is start point and target is end point
  • If directed is False, the direction of edges is ignored.
def get_path(self, source, target, directed = True):
    return self.get_path_(source, target, directed=directed)

def get_path_(self, v1, v2, trace = None, initial = True, directed = True):
    if initial:
        self.visited = {v:0 for v in self.vertex}
        self.visited[v1] = 1
        self.result = []
        trace = [v1]
    
    for v in self.child[v1].keys():
        if v == v2:
            self.result.append(trace + [v2])
        else:
            if not self.visited[v]:
                new_trace = trace + [v]
                self.visited[v] = 1
                self.get_path_(v, v2, new_trace, False, directed)
                self.visited[v] = 0
    
    if not directed:
        for v in self.parent[v1].keys():
            if v == v2:
                self.result.append(trace + [v2])
            else:
                if not self.visited[v]:
                    new_trace = trace + [v]
                    self.visited[v] = 1
                    self.get_path_(v, v2, new_trace, False, directed)
                    self.visited[v] = 0

        for v in self.link[v1].keys():
            if v == v2:
                self.result.append(trace + [v2])
            else:
                if not self.visited[v]:
                    new_trace = trace + [v]
                    self.visited[v] = 1
                    self.get_path_(v, v2, new_trace, False, directed)
                    self.visited[v] = 0
    
    return self.result #2d list  

pattern.get_path = get_path
pattern.get_path_ = get_path_

Other Methods

Cyclic Test

is_cyclic/is_cyclic_util

def is_cyclic(self) -> bool:
    # Code Resource : https://www.geeksforgeeks.org/detect-cycle-in-a-graph/
    visited = {v:0 for v in self.vertex}
    recStack =  {v:0 for v in self.vertex}

    for v in self.vertex:
        if not visited[v]:
            if self.is_cyclic_util(v, visited, recStack): return True
    
    return False
    

def is_cyclic_util(self, v, visited, recStack) -> bool:
    visited[v] = 1
    recStack[v] = 1

    for ch in self.child[v].keys():
        if not visited[ch]:
            if self.is_cyclic_util(ch, visited, recStack):
                return True
        elif recStack[ch]:
            return True
    
    recStack[v] = 0
    return False

pattern.is_cyclic = is_cyclic
pattern.is_cyclic_util = is_cyclic_util

Example

ptn = pattern(edges = [
    ('A', 'B'),
    ('B', 'C'),
    ('C', 'A'),
])
ptn.is_cyclic()
True
ptn = pattern(edges = [
    ('A', 'B'),
    ('B', 'C'),
    ('C', 'D'),
])
ptn.is_cyclic()
False

profile
move out to : https://lobster-tech.com?utm_source=velog

0개의 댓글