Source code for neuralib.sqlp.table

from __future__ import annotations

import abc
import typing
from typing import NamedTuple, Any, Optional, Generic, TYPE_CHECKING, Callable

from .annotation import *
from .literal import FOREIGN_POLICY, CONFLICT_POLICY

if TYPE_CHECKING:
    from .expr import SqlExpr, SqlField

__all__ = [
    'Field', 'table_name', 'table_field_names', 'table_fields', 'table_field',
    'table_primary_fields',
    'UniqueConstraint', 'table_unique_fields', 'make_unique_constraint',
    'ForeignConstraint', 'table_foreign_fields', 'table_foreign_field', 'make_foreign_constrain',
    'CheckConstraint', 'table_check_fields', 'table_check_field', 'make_check_constraint',

]

T = typing.TypeVar('T')
F = typing.TypeVar('F')
missing = object()


[docs] class Field(NamedTuple): """ A SQL column, a field in a table. """ table: type """associated table""" name: str """column/field name""" raw_type: type """origin field type""" sql_type: type """column/field type. Should be supported by SQL""" f_value: Any = missing """default value.""" not_null: bool = True """Is it not NULL.""" annotated: list[Any] = () @property def table_name(self) -> str: return table_name(self.table) @property def has_default(self) -> bool: if self.f_value is not missing: return True if (primary := self.get_primary()) is not None: return primary.auto_increment return False @property def is_primary(self) -> bool: return self.get_primary() is not None
[docs] def get_primary(self) -> PRIMARY | None: for a in self.annotated: if a == PRIMARY: return PRIMARY() elif isinstance(a, PRIMARY): return a return None
@property def is_unique(self) -> bool: return self.get_unique() is not None
[docs] def get_unique(self) -> UNIQUE | None: for a in self.annotated: if a == UNIQUE: return UNIQUE() elif isinstance(a, UNIQUE): return a return None
[docs] def get_annotation(self, annotation_type: type[T]) -> T | None: for a in self.annotated: if a == annotation_type or isinstance(a, annotation_type): return a return None
@typing.overload def __call__(self, data: type) -> SqlField: pass @typing.overload def __call__(self, data: Any) -> Any: pass def __call__(self, data): return getattr(data, self.name)
class Table(Generic[T], metaclass=abc.ABCMeta): """ SQL table information. """ table_type: type[T] """associated table""" table_name: str """name of the table.""" @abc.abstractmethod def table_seq(self, instance: T, fields: list[str] = None) -> tuple[Any, ...]: """ cast an instance as a tuple as SQL parameters. :param instance: :param fields: :return: """ pass @abc.abstractmethod def table_new(self, *args) -> T: """create an instance.""" pass @property def table_field_names(self) -> list[str]: """list of the name for each field in the table.""" return [it.name for it in self.table_fields] @property @abc.abstractmethod def table_fields(self) -> list[Field]: """list fields in the table.""" pass @property def table_primary_fields(self) -> list[Field]: """list of primary field in the table.""" return [field for field in self.table_fields if field.is_primary] def table_field(self, name: str) -> Field: """ Get field by the name in the table. :param name: :return: :raise RuntimeError: no such field. """ for field in self.table_fields: if field.name == name: return field raise RuntimeError(f'{self.table_name} no such field {name}') @property @abc.abstractmethod def table_unique_fields(self) -> list[UniqueConstraint]: """get a list of the unique constraint in the table.""" pass @property @abc.abstractmethod def table_foreign_fields(self) -> list[ForeignConstraint]: """get a list of the foreign constraint in the table.""" pass @property @abc.abstractmethod def table_check_fields(self) -> dict[Optional[str], CheckConstraint]: """get a dict of the field constraint in the table.""" pass def table_class(cls: type[T]) -> Table[T]: return getattr(cls, '_sql_table')
[docs] def table_name(table: type[T]) -> str: """the name of the table.""" return table_class(table).table_name
[docs] class UniqueConstraint(NamedTuple): name: str """constraint name""" table: type """associated table""" fields: list[str] """associated fields""" conflict: CONFLICT_POLICY | None
[docs] class ForeignConstraint(NamedTuple): """SQL foreign constraint.""" name: str """constraint name""" table: type """associated table""" fields: list[str] """associated fields""" foreign_table: type """referred foreign table""" foreign_fields: list[str] """referred foreign fields""" on_update: FOREIGN_POLICY on_delete: FOREIGN_POLICY @property def table_name(self) -> str: return table_name(self.table) @property def foreign_table_name(self) -> str: return table_name(self.foreign_table)
[docs] class CheckConstraint(NamedTuple): """SQL check constraint""" name: str """constraint name""" table: type """associated table""" field: str | None """associated field's name.""" expression: SqlExpr """checking expression"""
[docs] def table_field_names(table: type[T]) -> list[str]: """list of the name for each field in the table.""" return table_class(table).table_field_names
[docs] def table_fields(table: type[T]) -> list[Field]: """list fields in the table.""" return table_class(table).table_fields
[docs] def table_field(table: type[T], field: str) -> Field: """ Get field by the name in the table. :param table: :param field: :return: :raise RuntimeError: no such field. """ return table_class(table).table_field(field)
[docs] def table_primary_fields(table: type[T]) -> list[Field]: """list of the name for each primary field in the table.""" return table_class(table).table_primary_fields
[docs] def table_unique_fields(table: type[T]) -> list[UniqueConstraint]: """list of the name for each unique field in the table.""" return table_class(table).table_unique_fields
[docs] def table_foreign_fields(table: type[T]) -> list[ForeignConstraint]: """get a list of the foreign constraint in the table.""" return table_class(table).table_foreign_fields
@typing.overload def table_foreign_field(table: Callable) -> ForeignConstraint | None: pass @typing.overload def table_foreign_field(table: type[T], target: type[F]) -> ForeignConstraint | None: pass
[docs] def table_foreign_field(table: type[T] | Callable, target: type[F] = None) -> ForeignConstraint | None: """ get the foreign constraint in the table that refer to the target table. :param table: table or a foreign constraint function/property (the function decorated by @foreign) :param target: refer table :return: foreign constraint. """ if isinstance(table, type) and isinstance(target, type): for constraint in table_foreign_fields(table): if constraint.foreign_table == target: return constraint return None elif target is None: constraint = getattr(table, '_sql_foreign', None) if isinstance(constraint, ForeignConstraint): return constraint return None else: raise TypeError(repr(table))
[docs] def table_check_fields(table: type[T]) -> dict[Optional[str], CheckConstraint]: """get a dict of the field constraint in the table.""" return table_class(table).table_check_fields
[docs] def table_check_field(table: type[T], field: Optional[str]) -> Optional[CheckConstraint]: """get the check constrain of a field in the table.""" return table_class(table).table_check_fields.get(field, None)
[docs] def make_foreign_constrain(table: Table, prop: callable, fields: list, update: FOREIGN_POLICY, delete: FOREIGN_POLICY) -> ForeignConstraint: from .expr import SqlField foreign_fields = [] if len(fields) == 0: raise RuntimeError('empty fields') elif isinstance(fields[0], str): if not all([isinstance(it, str) for it in fields]): raise TypeError() foreign_table = table.table_type foreign_fields.extend(fields) elif len(fields) == 1 and isinstance(fields[0], type): foreign_table = fields[0] foreign_fields = [it.name for it in table_primary_fields(foreign_table)] else: foreign_table = None for field in fields: if isinstance(field, SqlField): if foreign_table is None: foreign_table = field.table elif foreign_table != field.table: raise RuntimeError() foreign_fields.append(field.name) else: raise TypeError() ret = prop(table.table_type) if not isinstance(ret, tuple): ret = [ret] self_fields = [] for field in ret: if isinstance(field, SqlField): if table.table_type != field.table: raise RuntimeError() self_fields.append(field.name) else: raise TypeError() return ForeignConstraint(prop.__name__, table.table_type, self_fields, foreign_table, foreign_fields, update, delete)
[docs] def make_check_constraint(table: Table, prop: callable, field: str | None) -> CheckConstraint: from .expr import wrap ret = wrap(prop(table.table_type)) return CheckConstraint(prop.__name__, table.table_type, field, ret)
[docs] def make_unique_constraint(table: Table, prop: callable, conflict: CONFLICT_POLICY = None) -> UniqueConstraint: from .expr import SqlField ret = prop(table.table_type) if not isinstance(ret, tuple): ret = [ret] fields = [] for field in ret: if isinstance(field, SqlField): if table.table_type != field.table: raise RuntimeError() fields.append(field.name) else: raise TypeError() return UniqueConstraint(prop.__name__, table.table_type, fields, conflict)