graphviz官方参考链接:
http://www.graphviz.org/documentation/
https://graphviz.readthedocs.io/en/stable/index.html
文章目录
-
- 需求描述
- 环境配置
- 实现思路
- 代码实现
需求描述
根据各模块之间的传参关系绘制出数据流,如下图所示:
并且生成对应的graphviz代码:
digraph my_graph { Input [fillcolor=gray70 shape=box style=filled] Output [fillcolor=gray70 shape=box style=filled] NodeA NodeB NodeC Input -> NodeA [label=0] Input -> NodeA [label=1] NodeA -> NodeB [label=0] NodeA -> NodeC [label=1] NodeB -> Output [label=0] NodeC -> Output [label=0] }
环境配置
- 安装Python中需要使用的
graphviz 包:
pip install graphviz
- 安装
graphviz 工具(可选,如果不安装无法直接使用Python的graphviz 包导出图片),例如ubuntu系统安装指令如下,其他系统可参考官方文档https://www.graphviz.org/download/:
sudo apt install graphviz
- VSCODE安装
Graphviz Interactive Preview 插件(可选,如果使用vscode开发建议安装此插件,通过此插件可以直接可视化graphviz代码,并保存图片)
实现思路
实现一个Node基类,所有的模块实现都继承自该基类。再实现一个Message基类,模块之间传递的数据都继承自该基类。然后在数据传递过程中记录流经的每个模块的名称以及数据的传递方向即可绘制出想要的数据流。
代码实现
下面给出了一个简易的实现方式:
import os from graphviz import Digraph __graph_dict__ = {} class Message: def __init__(self, node_name: str, idx: int): self.node_name = node_name self.idx = idx class EdgeInfo: def __init__(self, start_node_name: str, end_node_name: str, label: str) -> None: self.start_node_name = start_node_name self.end_node_name = end_node_name self.label = label def __str__(self): return f'{self.start_node_name} -> {self.end_node_name} [label="{self.label}"];' class Node: input_num: int output_num: int node_name: str def __call__(self, *args): global __graph_dict__ assert len(args) == self.input_num if self.node_name not in __graph_dict__: __graph_dict__[self.node_name] = [] for input_ in args: __graph_dict__[input_.node_name].append(EdgeInfo(input_.node_name, self.node_name, str(input_.idx))) res = tuple(Message(self.node_name, i) for i in range(self.output_num)) if self.output_num == 1: return res[0] return res def export_graphviz(graph, num_input: int, save_path: str): base_name = os.path.basename(save_path) name, _ = base_name.split(".") global __graph_dict__ __graph_dict__.clear() __graph_dict__.update({"Input": [], "Output": []}) # infer and collect flow info input_args = tuple(Message("Input", i) for i in range(num_input)) outputs = graph(*input_args) for ouput_ in outputs: if ouput_.node_name not in __graph_dict__: __graph_dict__[ouput_.node_name] = [] __graph_dict__[ouput_.node_name].append(EdgeInfo(ouput_.node_name, "Output", str(ouput_.idx))) # create graph code digraph = Digraph(name=name, format="jpg") # add nodes keys = list(__graph_dict__.keys()) for k in keys: if k in ["Input", "Output"]: digraph.node(k, **{"shape": "box", "style": "filled", "fillcolor": "gray70"}) else: digraph.node(k) # add edges for k in keys: for edge_info in __graph_dict__[k]: digraph.edge(edge_info.start_node_name, edge_info.end_node_name, edge_info.label) # print digraph code print(digraph.source) # export gv and jpg file try: digraph.render(directory=os.path.dirname(save_path)) except Exception as e: print(f"export digraph failed, {e}") class NodeA(Node): def __init__(self): self.input_num = 2 self.output_num = 2 self.node_name = "NodeA" class NodeB(Node): def __init__(self): self.input_num = 1 self.output_num = 1 self.node_name = "NodeB" class NodeC(Node): def __init__(self): self.input_num = 1 self.output_num = 1 self.node_name = "NodeC" class Graph: def __init__(self): self.node_a = NodeA() self.node_b = NodeB() self.node_c = NodeC() def __call__(self, x0, x1): y0, y1 = self.node_a(x0, x1) z0 = self.node_b(y0) z1 = self.node_c(y1) return z0, z1 if __name__ == "__main__": graph = Graph() export_graphviz(graph, num_input=2, save_path="./my_graph.gv")
执行上述代码后会生成