Source code for pasteur.graph.utils

from collections import defaultdict
from typing import TYPE_CHECKING, Sequence
import networkx as nx
from IPython.core.display import display, SVG


[docs] def display_graph(g, prog="dot", graph={}, nodes={}, edges={}): display_pydot(nx.nx_pydot.to_pydot(g), prog, graph, nodes, edges)
[docs] def display_pydot(g, prog="dot", graph={}, nodes={}, edges={}): process_args = lambda args, pref: [f"{pref}{k}={v}" for k, v in args.items()] args = ( process_args(graph, "-G") + process_args(edges, "-E") + process_args(nodes, "-N") ) display(SVG(g.create(format="svg", prog=[prog, *args])))
[docs] def display_induced_graph(g, condensed=True): import pydot if g.is_directed(): graph_type = "digraph" else: graph_type = "graph" strict = nx.number_of_selfloops(g) == 0 and not g.is_multigraph() and not condensed graph_defaults = g.graph.get("graph", {}) o = pydot.Dot(g.name, graph_type=graph_type, strict=strict, **graph_defaults) # Keep how many nodes each attribute has for a cleaner appearance attr_counts = defaultdict(int) attr_vals = defaultdict(set) marked_vals = defaultdict(lambda: False) for node_name, d in g.nodes(data=True): attr_counts[(d["table"], d["order"], d["attr"])] += 1 attr_vals[(d["table"], d["order"], d["attr"])].add(d["value"]) marked_vals[(d["table"], d["order"], d["attr"], d["value"])] |= d.get( "marked", False ) table_subs: dict[tuple, pydot.Graph] = {} attr_subs: dict[tuple, pydot.Graph] = {} val_names: dict[tuple, str] = {} # pick random name for condensed node for node_name, d in g.nodes(data=True): # Set up clusters if (d["table"], d["order"]) not in table_subs: if d["table"] and d["order"] is not None: label = f"{d['table']}[{d['order']}]" elif d["table"]: label = d["table"] else: label = "" name = f"{d['table']}[{d['order']}]" if label: sub = pydot.Cluster(name, label=label) else: sub = pydot.Subgraph(name, label=label) table_subs[(d["table"], d["order"])] = sub o.add_subgraph(sub) if (d["table"], d["order"], d["attr"]) not in attr_subs: name = d["attr"] if attr_counts[(d["table"], d["order"], d["attr"])] <= 1 or ( condensed and len(attr_vals[d["table"], d["order"], d["attr"]]) == 1 ): sub = pydot.Subgraph(name, label="") else: sub = pydot.Cluster(name, label=d["attr"]) attr_subs[(d["table"], d["order"], d["attr"])] = sub table_subs[(d["table"], d["order"])].add_subgraph(sub) # Setup attribute if len(attr_vals[d["table"], d["order"], d["attr"]]) == 1: if attr_counts[(d["table"], d["order"], d["attr"])] > 1 and not condensed: label = "" else: label = d["value"] else: label = d["value"].replace(d["attr"] + "_", "") if not condensed: label += f"[{d['height']}]" new_data = {"label": label} if d.get("marked", False) or ( condensed and marked_vals[d["table"], d["order"], d["attr"], d["value"]] ): new_data["color"] = "green" if ( not condensed or (d["table"], d["order"], d["attr"], d["value"]) not in val_names ): val_names[(d["table"], d["order"], d["attr"], d["value"])] = node_name attr_subs[(d["table"], d["order"], d["attr"])].add_node( pydot.Node(node_name, **new_data) ) for a, b, data in g.edges(data=True): new_data = {} if data.get("immorality", False): new_data["color"] = "red" if data.get("immoral", False): new_data["color"] = "blue" if data.get("triangulated", False): new_data["color"] = "green" if ( g.nodes[a]["table"] != g.nodes[b]["table"] or g.nodes[a]["order"] != g.nodes[b]["order"] ): dst = o elif g.nodes[a]["attr"] != g.nodes[b]["attr"]: dst = table_subs[(g.nodes[a]["table"], g.nodes[a]["order"])] else: dst = attr_subs[ (g.nodes[a]["table"], g.nodes[a]["order"], g.nodes[a]["attr"]) ] if condensed: if ah := g.nodes[a]["height"]: new_data["taillabel"] = str(ah) if bh := g.nodes[b]["height"]: new_data["headlabel"] = str(bh) a = val_names[ ( g.nodes[a]["table"], g.nodes[a]["order"], g.nodes[a]["attr"], g.nodes[a]["value"], ) ] b = val_names[ ( g.nodes[b]["table"], g.nodes[b]["order"], g.nodes[b]["attr"], g.nodes[b]["value"], ) ] if a == b: # Do not add edge to self continue dst.add_edge(pydot.Edge(a, b, **new_data)) display_pydot(o, edges={"labeldistance": 1.5})
[docs] def display_junction_tree( junction: nx.Graph, g: nx.Graph | nx.DiGraph, messages: Sequence[Sequence] | None = None, ): import pydot if messages: message_order = {} for i, generation in enumerate(messages): for message in generation: message_order[message] = i + 1 else: message_order = None o = pydot.Dot(graph_type="graph") for cl in junction.nodes(): label = '<<TABLE CELLBORDER="1" BORDER="0">' for table, order, attr, sel in cl: # @TODO: Integrate table meta into figure if isinstance(sel, int): label += f'<TR><TD ALIGN="LEFT"><B>{attr}</B></TD><TD>{sel}</TD></TR>' elif len(sel) == 1: val, h = next(iter(sel)) label += f'<TR><TD ALIGN="LEFT"><B>{val}</B></TD><TD>{h}</TD></TR>' else: label += f'<TR><TD COLSPAN="2"><B>{attr}</B></TD></TR>' for val, h in sorted(sel): label += f'<TR><TD ALIGN="LEFT">{val}</TD><TD>{h}</TD></TR>' label += "</TABLE>>" o.add_node(pydot.Node(str(cl), label=label, shape="plaintext")) for a, b, d in junction.edges(data=True): new_data = {"label": f"{d['common']} ({d['domain']:,d})"} if message_order: new_data["dir"] = "both" new_data["taillabel"] = f"<<B>{message_order[(b, a)]}</B>>" new_data["headlabel"] = f"<<B>{message_order[(a, b)]}</B>>" o.add_edge(pydot.Edge(str(a), str(b), **new_data)) display_pydot(o, edges={"labeldistance": 1.5})