Source code for neuralib.sqlp.util

from __future__ import annotations

import datetime
import operator
import re
from pathlib import Path
from typing import TypeVar, Iterable, overload, TYPE_CHECKING, Any, cast

from neuralib.util.table import rich_table

from .expr import SqlExpr, SqlField

if TYPE_CHECKING:
    from .stat import Cursor

__all__ = [
    'str_to_datetime',
    'datetime_to_str',
    'take',
    'infer_eq',
    'infer_cmp',
    'infer_in',
    'resolve_field_type',
    'cast_to_sql',
    'cast_from_sql',
    'get_fields_from_schema',
    'map_foreign',
    'pull_foreign',
    'rich_sql_table'
]

T = TypeVar('T')
V = TypeVar('V')


[docs] def str_to_datetime(t: str) -> datetime.datetime: return datetime.datetime.strptime(t, '%Y-%m-%d %H:%M:%S')
[docs] def datetime_to_str(t: datetime.datetime) -> str: return t.strftime('%Y-%m-%d %H:%M:%S')
@overload def take(index: int, coll: Cursor | Iterable[tuple[T, ...]]) -> list[T]: pass @overload def take(index: type[V], coll: Cursor | Iterable[tuple[T, ...]]) -> list[V]: pass @overload def take(index: tuple[int, ...], coll: Cursor | Iterable[tuple[T, ...]]) -> list[tuple[T, ...]]: pass @overload def take(index: tuple[V], coll: Cursor | Iterable[T]) -> list[tuple[V]]: pass @overload def take(index: V, coll: Cursor | Iterable[T]) -> list[V]: pass
[docs] def take(index, coll: Cursor | Iterable): """ A help function that compose itemgetter and mapping functions. >>> @named_tuple_table_class ... class A: ... a: int ... b: str >>> take(0, [(0, 'a'), (1, 'b')]) [0, 1] >>> take(A.a, [A(0, 'a'), A(1, 'b')]) [0, 1] :param index: :param coll: :return: """ if isinstance(index, type): return list(map(lambda it: index(*it), coll)) if isinstance(index, int): return list(map(operator.itemgetter(index), coll)) if isinstance(index, tuple) and all([isinstance(it, int) for it in index]): def _index(item): return tuple([item[it] for it in index]) return list(map(_index, coll)) from .table import table_field_names from .stat import Cursor if isinstance(index, SqlField): if isinstance(coll, Cursor): index = coll.headers.index(index.field.name) else: fields = table_field_names(index.field.table) index = fields.index(index.field.name) return list(map(operator.itemgetter(index), coll)) if isinstance(index, tuple): index = cast(tuple[SqlField], index) if isinstance(coll, Cursor): index = tuple([coll.headers.index(it.field.name) for it in index]) else: fields = table_field_names(index[0].field.table) index = tuple([fields.index(it.field.name) for it in index]) return take(index, coll) raise TypeError()
[docs] def infer_eq(x: T, v: T | str, *, prepend: str = '', append: str = '') -> SqlExpr | None: """ A help function to make a SQL ``=`` expression. >>> infer_eq(A.a, 1) # doctest: SKIP A.a = 1 >>> infer_eq(A.a, '!1') # doctest: SKIP A.a != 1 >>> infer_eq(A.a, '1%') # doctest: SKIP A.a LIKE '1%' :param x: :param v: :param prepend: :param append: :return: """ if not isinstance(x, SqlField): raise TypeError() if v is None: return None if isinstance(v, SqlExpr): return v invert = False if isinstance(v, str) and v.startswith('!'): invert = True v = v[1:] if isinstance(v, str) and '%' in v: from .func_stat import like, not_like return not_like(x, v) if invert else like(x, v) if isinstance(v, str) and (prepend == '%' or append == '%'): from .func_stat import like, not_like v = prepend + v + append return not_like(x, v) if invert else like(x, v) return x != v if invert else x == v
[docs] def infer_cmp(x: T, v: T | str | range | slice) -> SqlExpr | None: """ A help function to make a SQL comparison expression. >>> infer_cmp(A.a, range(0, 10)) # doctest: SKIP A.a BETWEEN 0 AND 9 >>> infer_cmp(A.a, slice(0, 10)) # doctest: SKIP A.a BETWEEN 0 AND 10 >>> infer_cmp(A.a, '<10') # doctest: SKIP A.a < 10 >>> infer_cmp(A.a, 10) # doctest: SKIP A.a = 10 :param x: :param v: :return: """ if not isinstance(x, SqlField): raise TypeError() if v is None: return None if isinstance(v, SqlExpr): return v if isinstance(v, (int, float)): return x == v if isinstance(v, (range, slice)): return infer_in(x, v) if '%' in v: from .func_stat import like return like(x, v) if v.startswith('<='): return x <= float(v[2:]) elif v.startswith('>='): return x >= float(v[2:]) elif v.startswith('<'): return x < float(v[1:]) elif v.startswith('>'): return x > float(v[1:]) invert = False if v.startswith('!'): invert = True v = v[1:] return x != v if invert else x == v
[docs] def infer_in(x: T, v: T | str | list[str] | slice | range) -> SqlExpr | None: """ A help function to make a SQL containing expression. >>> infer_in(A.a, '1') # doctest: SKIP A.a == '1' >>> infer_in(A.a, range(0, 10)) # doctest: SKIP A.a BETWEEN 0 AND 9 >>> infer_in(A.a, ['a', 'b']) # doctest: SKIP A.a IN ('a', 'b') :param x: :param v: :return: """ if not isinstance(x, SqlField): raise TypeError() if v is None: return None if isinstance(v, SqlExpr): return v if isinstance(v, (list, tuple)): return x.contains(v) if isinstance(v, (range, slice)): return x.between(v) return infer_eq(x, v)
[docs] def resolve_field_type(f_type: type) -> tuple[type, type, bool]: """ SQL primary types: * bool: BOOLEAN * int: INT * float: FLOAT * str: TEXT * bytes: BLOB * datetime.date: DATETIME * datetime.datetime: DATETIME Python type mapping * `T|None`: `resolve_field_type(T)` null-able * `T|V` : supported not * `Literal`: `str` * `Path`: `str` :param f_type: :return: (raw_type, sql_type, not_null) """ import typing sql_type = f_type o = typing.get_origin(f_type) if o == typing.Annotated: return resolve_field_type(typing.get_args(f_type)[0]) elif o == typing.Union: a = typing.get_args(f_type) if len(a) == 2: try: i = a.index(type(None)) except ValueError as e: raise RuntimeError('Union type is not supported now') from e else: f_type = a[1 - i] _, sql_type, _ = resolve_field_type(f_type) return f_type, sql_type, False elif o == typing.Literal: return str, str, True elif f_type == Path: return f_type, str, True return f_type, sql_type, True
[docs] def cast_to_sql(raw_type: type[T], sql_type: type[V], value: T) -> V: if value is None: return None if sql_type == str: return str(value) if raw_type == Any: return value return value
[docs] def cast_from_sql(raw_type: type[T], sql_type: type[V], value: V) -> T: if value is None: return None if raw_type == Any: return value if sql_type in (int, float, bool): return value if raw_type == datetime.datetime and isinstance(value, str): return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') if raw_type == datetime.date and isinstance(value, str): return datetime.datetime.strptime(value, '%Y-%m-%d').date() if raw_type == sql_type: return value if raw_type == Path: return Path(value) if callable(raw_type): return raw_type(value) return value
[docs] def get_fields_from_schema(schema: str) -> list[str]: schema = schema[schema.index('(') + 1:] try: schema = schema[:schema.rfind(')')] except ValueError: pass schema = re.sub(r'\(.+?\)', '', schema.strip()) found = [] for field in schema.split(','): field = field.strip().split(' ') if field[0] in ('FOREIGN', 'UNIQUE', 'PRIMARY', 'CHECK'): break field = field[0] if len(field): if field.startswith('[') and field.endswith(']'): field = field[1:-1] found.append(field) return found
[docs] def map_foreign(value: T, foreign: type[V]) -> Cursor[V]: """ Let a table ``T`` with a foreign constraint refer to table ``V``, map a ``T`` data to the ``V`` data. :param value: :param foreign: a foreign constraint :return: """ from .table import table_foreign_field from .stat_start import select_from table = type(value) if (constraint := table_foreign_field(table, foreign)) is None: raise RuntimeError(f'not a foreign constraint : {foreign}') # SELECT * FROM V # WHERE AND*([V.field == t.field for field in constraint]) return select_from(constraint.foreign_table).where(*[ getattr(constraint.foreign_table, f) == getattr(value, t) for (t, f) in zip(constraint.fields, constraint.foreign_fields) ]).submit()
[docs] def pull_foreign(target: type[T], foreign: V) -> Cursor[T]: """ Let a table ``T`` with a foreign constraint refer to table ``V``, pull ``T`` data from a ``V`` data. :param target: target table ``T`` :param foreign: a foreign data ``V`` referred to. :return: """ from .table import table_foreign_field from .stat_start import select_from if (constraint := table_foreign_field(target, type(foreign))) is None: raise RuntimeError('not a foreign constraint') # SELECT * FROM T # WHERE AND*([T.field == v.field for field in constraint]) return select_from(constraint.table).where(*[ getattr(constraint.table, t) == getattr(foreign, f) for (t, f) in zip(constraint.fields, constraint.foreign_fields) ]).submit()
[docs] def rich_sql_table(table: type[T], value: list[T]): from .table import table_field_names with rich_table(*table_field_names(table)) as _table: for _value in value: _table(*_value)