Python使用graphviz绘制模块间数据流

graphviz官方参考链接:
http://www.graphviz.org/documentation/
https://graphviz.readthedocs.io/en/stable/index.html


文章目录

    • 需求描述
    • 环境配置
    • 实现思路
    • 代码实现

需求描述

根据各模块之间的传参关系绘制出数据流,如下图所示:
在这里插入图片描述
并且生成对应的graphviz代码:

digraph my_graph {
        Input [fillcolor=gray70 shape=box style=filled]
        Output [fillcolor=gray70 shape=box style=filled]
        NodeA
        NodeB
        NodeC
        Input -> NodeA [label=0]
        Input -> NodeA [label=1]
        NodeA -> NodeB [label=0]
        NodeA -> NodeC [label=1]
        NodeB -> Output [label=0]
        NodeC -> Output [label=0]
}

环境配置

  • 安装Python中需要使用的graphviz包:
pip install graphviz
  • 安装graphviz工具(可选,如果不安装无法直接使用Python的graphviz包导出图片),例如ubuntu系统安装指令如下,其他系统可参考官方文档https://www.graphviz.org/download/:
sudo apt install graphviz
  • VSCODE安装Graphviz Interactive Preview插件(可选,如果使用vscode开发建议安装此插件,通过此插件可以直接可视化graphviz代码,并保存图片)
    在这里插入图片描述

实现思路

实现一个Node基类,所有的模块实现都继承自该基类。再实现一个Message基类,模块之间传递的数据都继承自该基类。然后在数据传递过程中记录流经的每个模块的名称以及数据的传递方向即可绘制出想要的数据流。


代码实现

下面给出了一个简易的实现方式:

import os
from graphviz import Digraph


__graph_dict__ = {}


class Message:
    def __init__(self, node_name: str, idx: int):
        self.node_name = node_name
        self.idx = idx


class EdgeInfo:
    def __init__(self, start_node_name: str, end_node_name: str, label: str) -> None:
        self.start_node_name = start_node_name
        self.end_node_name = end_node_name
        self.label = label

    def __str__(self):
        return f'{self.start_node_name} -> {self.end_node_name} [label="{self.label}"];'


class Node:
    input_num: int
    output_num: int
    node_name: str

    def __call__(self, *args):
        global __graph_dict__
        assert len(args) == self.input_num

        if self.node_name not in __graph_dict__:
            __graph_dict__[self.node_name] = []

        for input_ in args:
            __graph_dict__[input_.node_name].append(EdgeInfo(input_.node_name,
                                                             self.node_name,
                                                             str(input_.idx)))

        res = tuple(Message(self.node_name, i) for i in range(self.output_num))
        if self.output_num == 1:
            return res[0]
        return res


def export_graphviz(graph, num_input: int, save_path: str):
    base_name = os.path.basename(save_path)
    name, _ = base_name.split(".")
    global __graph_dict__
    __graph_dict__.clear()
    __graph_dict__.update({"Input": [], "Output": []})

    # infer and collect flow info
    input_args = tuple(Message("Input", i) for i in range(num_input))
    outputs = graph(*input_args)
    for ouput_ in outputs:
        if ouput_.node_name not in __graph_dict__:
            __graph_dict__[ouput_.node_name] = []

        __graph_dict__[ouput_.node_name].append(EdgeInfo(ouput_.node_name,
                                                         "Output",
                                                         str(ouput_.idx)))

    # create graph code
    digraph = Digraph(name=name, format="jpg")

    # add nodes
    keys = list(__graph_dict__.keys())
    for k in keys:
        if k in ["Input", "Output"]:
            digraph.node(k, **{"shape": "box", "style": "filled", "fillcolor": "gray70"})
        else:
            digraph.node(k)

    # add edges
    for k in keys:
        for edge_info in __graph_dict__[k]:
            digraph.edge(edge_info.start_node_name,
                         edge_info.end_node_name,
                         edge_info.label)

    # print digraph code
    print(digraph.source)

    # export gv and jpg file
    try:
        digraph.render(directory=os.path.dirname(save_path))
    except Exception as e:
        print(f"export digraph failed, {e}")


class NodeA(Node):
    def __init__(self):
        self.input_num = 2
        self.output_num = 2
        self.node_name = "NodeA"


class NodeB(Node):
    def __init__(self):
        self.input_num = 1
        self.output_num = 1
        self.node_name = "NodeB"


class NodeC(Node):
    def __init__(self):
        self.input_num = 1
        self.output_num = 1
        self.node_name = "NodeC"


class Graph:
    def __init__(self):
        self.node_a = NodeA()
        self.node_b = NodeB()
        self.node_c = NodeC()

    def __call__(self, x0, x1):
        y0, y1 = self.node_a(x0, x1)
        z0 = self.node_b(y0)
        z1 = self.node_c(y1)

        return z0, z1


if __name__ == "__main__":
    graph = Graph()
    export_graphviz(graph, num_input=2, save_path="./my_graph.gv")

执行上述代码后会生成my_graph.gv以及my_graph.gv.jpg两个文件(如果没有安装graphviz工具是不会生成的),其中my_graph.gv是graphviz的代码形式,my_graph.gv.jpg是可视化后的结果。