Source code for neuralib.sqlp.dot

import subprocess
import sys
from pathlib import Path
from typing import overload, IO, Any, Union

from .cli import Database
from .table import table_class, table_name, Table

__all__ = ['generate_dot']


@overload
def generate_dot(db: Database, *,
                 graph: dict[str, Any] = None,
                 node: dict[str, Any] = None,
                 edge: dict[str, Any] = None) -> str:
    pass


@overload
def generate_dot(db: Database, file: Union[str, Path], *,
                 graph: dict[str, Any] = None,
                 node: dict[str, Any] = None,
                 edge: dict[str, Any] = None) -> None:
    pass


@overload
def generate_dot(db: Database, file: IO, *,
                 graph: dict[str, Any] = None,
                 node: dict[str, Any] = None,
                 edge: dict[str, Any] = None) -> None:
    pass


[docs] def generate_dot(db: Database, file=None, *, graph: dict[str, Any] = None, node: dict[str, Any] = None, edge: dict[str, Any] = None): """ :param db: :param file: io, or an output file path which support '.puml' file and '.png' file (required ``graphivz``). :param graph: graph attributes :param node: node attributes :param edge: edge attributes :return: :raise FileNotFoundError: Check if graphivz already installed and add to system PATH """ if file is None: from io import StringIO buf = StringIO() _generate_dot(db, buf, graph=graph, node=node, edge=edge) return buf.getvalue() elif isinstance(file, (str, Path)): file = Path(file) if file.suffix in ('.dot',): with file.open() as out: _generate_dot(db, out, graph=graph, node=node, edge=edge) elif file.suffix in ('.png', '.svg', '.ps', '.pdf', '.jpg', '.gif', '.json'): _generate_dot_png(generate_dot(db, graph=graph, node=node, edge=edge), file) else: raise RuntimeError(f'unsupported filetype : {file.suffix}') else: _generate_dot(db, file)
def _generate_dot_png(dot: str, file: Path) -> int: p = subprocess.Popen(['dot', '-T', file.suffix[1:], '-o', str(file)], stdin=subprocess.PIPE) p.communicate(dot.encode()) return p.wait() def _generate_dot(db: Database, file=sys.stdout, *, graph: dict[str, Any] = None, node: dict[str, Any] = None, edge: dict[str, Any] = None): name = db.database_file if name is None: name = ':memory:' print('digraph', f'"{name}"', '{', file=file) _graph = dict() if graph is not None: _graph.update(graph) if len(_graph): print('graph [', file=file) for k, v in _graph.items(): print(k, '=', v, file=file) print('];', file=file) _node = dict(shape='record', rankdir='LR') if node is not None: _node.update(node) if len(_node): print('node [', file=file) for k, v in _node.items(): print(k, '=', v, file=file) print('];', file=file) _edge = dict() if edge is not None: _edge.update(edge) if len(_edge): print('edge [', file=file) for k, v in _edge.items(): print(k, '=', v, file=file) print('];', file=file) _generate_dot_table(db, file) print('}', file=file) def _generate_dot_table(db: Database, file=sys.stdout): for table in db.database_tables: table: Table = table_class(table) print(f'"{table.table_name}"', '[', file=file) ff = [] for field in table.table_fields: at = [f'<{field.name}>'] if field.is_primary: at.append('#') elif field.is_unique: at.append('!') at.append(field.name) at.append(':') at.append(field.sql_type.__name__) if not field.not_null: at.append('?') ff.append(''.join(at)) print('label=', '"', '<0>', table.table_name, '|{', '|'.join(ff), '}"', sep='', file=file) print('];', file=file) for table in db.database_tables: table: Table = table_class(table) for foreign in table.table_foreign_fields: if foreign.fields == foreign.foreign_fields: print(table_name(foreign.table), ':0', '->', table_name(foreign.foreign_table), ':0', file=file) elif len(foreign.fields) == 1: print(table_name(foreign.table), ':', foreign.fields[0], '->', table_name(foreign.foreign_table), ':', foreign.foreign_fields[0], file=file) else: k1 = ', '.join(foreign.fields) k2 = ', '.join(foreign.foreign_fields) print(table_name(foreign.table), ':0', '->', table_name(foreign.foreign_table), ':0', '[', file=file) print('label=', f'"({k1}) to ({k2})"', file=file) print('];', file=file)