Source code for neuralib.sqlp.cli

from __future__ import annotations

import abc
import functools
import re
import sqlite3
import sys
from pathlib import Path
from typing import Optional, Literal, TYPE_CHECKING

import polars as pl

from argclz import AbstractParser, argument

if TYPE_CHECKING:
    from .connection import Connection

__all__ = ['Database', 'transaction']


[docs] def transaction(): """ A decorator that decorate a function as a transaction session that open a connection and set to the global variable. >>> @transaction ... def example(self): ... self.do_something() equalivent to >>> def example(self): ... with self.open_connection(): ... self.do_something() """ def _decorator(f): @functools.wraps(f) def _transaction(self: Database, *args, **kwargs): with self.open_connection(): return f(self, *args, **kwargs) return _transaction return _decorator
[docs] class Database(metaclass=abc.ABCMeta): """ A class for the common usage of using sqlite. """ sqlp_debug_mode = False @property @abc.abstractmethod def database_file(self) -> Optional[Path]: """sqlite database filepath""" pass @property @abc.abstractmethod def database_tables(self) -> list[type]: """supporting tables""" pass
[docs] def open_connection(self) -> Connection: """ open a connection to the database. """ if (database_file := self.database_file) is None: database_file = ':memory:' else: database_file.parent.mkdir(exist_ok=True) from .connection import Connection ret = Connection(database_file, debug=self.sqlp_debug_mode) cls = type(self) if getattr(cls, '__first_connect_init', True): from .stat_start import create_table with ret: for table in self.database_tables: create_table(table).submit() setattr(cls, '__first_connect_init', False) return ret
class CliDatabase(Database, AbstractParser): """ Wrap Database to support commandline interface. """ sqlp_debug_mode: bool = argument('--debug') DB_FILE: str = argument('-d', '--database', metavar='FILE', default=None) DB_STAT: list[str] = argument(metavar='STATS', nargs='*', action='extend') list_table: str = argument('--table', metavar='NAME', default=None, const='', nargs='?') from_file: bool = argument('-f', '--file') action: Literal['import', 'export'] = argument('--action', default=None) pretty: bool = argument('-p', '--pretty', help='as polars dataframe') print_all: bool = argument('-a', '--all', help='print all dataframe if --pretty') USAGE = """\ %(prog)s -d FILE --table [NAME] %(prog)s -d FILE STAT ... %(prog)s -d FILE --file SCRIPT ... %(prog)s -d FILE --action=(import|export) --table NAME FILE """ def __init__(self, database: Optional[Database] = None): Database.__init__(self) self.database = database @property def database_file(self) -> Optional[Path]: if self.database is not None: return self.database.database_file if self.DB_FILE is None: return None return Path(self.DB_FILE) @property def database_tables(self) -> list[type]: if self.database is not None: return self.database.database_tables return [] def open_connection(self) -> Connection: if self.database is not None: self.database.sqlp_debug_mode = self.sqlp_debug_mode return self.database.open_connection() return super().open_connection() def run(self): if self.action is not None: return self.run_action() if self.list_table is not None: return self.run_list_table() if self.from_file: return self.run_script() return self.run_statement() def run_list_table(self): with self.open_connection() as connection: if len(self.list_table) == 0: print(connection.list_table()) else: print(connection.table_schema(self.list_table)) def run_action(self): if self.action == 'export': if self.list_table is None or len(self.list_table) == 0: raise ValueError('missing --table') if len(self.DB_STAT) == 0: out = None elif len(self.DB_STAT) == 1: out = self.DB_STAT[0] else: raise RuntimeError(f'too many arguments : {self.DB_STAT[1:]}') with self.open_connection() as connection: ret = connection.export_dataframe(self.list_table).write_csv(out) if out is None: print(ret) elif self.action == 'import': if self.list_table is None or len(self.list_table) == 0: raise ValueError('missing --table') if len(self.DB_STAT) == 0: data = pl.read_csv(sys.stdin) elif len(self.DB_STAT) == 1: data = pl.read_csv(self.DB_STAT[0]) else: raise RuntimeError(f'too many arguments : {self.DB_STAT[1:]}') with self.open_connection() as connection: connection.import_dataframe(self.list_table, data) else: raise ValueError(f'unknown action={self.action}') def run_script(self): args = list(self.DB_STAT[1:]) with self.open_connection() as connection: stat = [] with Path(self.DB_STAT[0]).open() as file: for line in file: stat.append(line) if line.partition('--')[0].rstrip().endswith(';'): result = self._execute_script(connection, ''.join(stat), args) stat = [] self._print_result(connection, result) if len(stat): result = self._execute_script(connection, ''.join(stat), args) self._print_result(connection, result) if len(args): raise RuntimeError(f'remaining {args=}') def _execute_script(self, connection, script: str, args: list[str]) -> sqlite3.Cursor: try: result = connection.execute(script, parameter=args) args.clear() return result except RuntimeError as e: if not isinstance(ex := e.__cause__, sqlite3.ProgrammingError): raise # sqlite3.ProgrammingError: Incorrect number of bindings supplied. The current statement uses 5, and there are 6 supplied. message: str = ex.args[0] m = re.search(r'The current statement uses (\d+), and there are (\d+) supplied.', message) if m is None: raise need = int(m.group(1)) total = int(m.group(2)) if total != len(args): raise take = args[:need] result = connection.execute(script, parameter=take) del args[:need] return result def run_statement(self): with self.open_connection() as connection: result = connection.execute(' '.join(self.DB_STAT)) self._print_result(connection, result) def _print_result(self, connection, cursor: sqlite3.Cursor): if self.pretty: from neuralib.sqlp.stat import Cursor df = Cursor(connection, cursor).fetch_polars() if self.print_all: from neuralib.util.verbose import printdf printdf(df) else: print(df) else: header = cursor.description if header is not None: print('--', tuple([it[0] for it in header])) for data in cursor: print(data) if __name__ == '__main__': CliDatabase().main()