Source code for neuralib.sqlp.stat_start

import datetime
from types import UnionType
from typing import overload, Any, TypeVar, Union, get_origin, get_args

from .annotation import *
from .expr import *
from .expr import use_table_first, SqlRemovePlaceHolder
from .literal import UPDATE_POLICY
from .stat import *
from .table import *

__all__ = [
    'create_table',
    'insert_into',
    'replace_into',
    'select_from',
    'update',
    'delete_from'
]

T = TypeVar('T')
S = TypeVar('S')


@overload
def select_from(table: type[T], *, distinct: bool = False) -> SqlSelectStat[T]:
    """
    >>> select_from(Table) # SELECT * FROM Table
    """
    pass


@overload
def select_from(table: SqlCteExpr, *, distinct: bool = False) -> SqlSelectStat[tuple]:
    pass


@overload
def select_from(*field, distinct: bool = False,
                from_table: Union[str, type, SqlAlias, SqlSelectStat] = None) -> SqlSelectStat[tuple]:
    """
    >>> select_from('a', 'b') # SELECT a, b FROM Table
    """
    pass


[docs] def select_from(*args, distinct: bool = False, from_table: Union[str, type, SqlAlias, SqlSelectStat] = None) -> SqlSelectStat: """ ``SELECT``: https://www.sqlite.org/lang_select.html Select all fields from a table >>> select_from(A).build() # doctest: SKIP SELECT * FROM A >>> select_from(A).fetchall() # doctest: SKIP [A(...), A(...), ...] Select subset of fields from A >>> select_from(A.a, A.b).build() # doctest: SKIP SELECT A.a, A.b FROM A >>> select_from(A.a, A.b).fetchall() # doctest: SKIP [('a', 1), ('b', 2), ...] With a literal value >>> select_from(A.a, 0).build() # doctest: SKIP SELECT A.a, 0 FROM A >>> select_from(A.a, 0).fetchall() # doctest: SKIP [('a', 0), ('b', 0), ...] With SQL functions >>> select_from(A.a, count()).build() # doctest: SKIP SELECT A.a, COUNT(*) FROM A Use table alias >>> a = alias(A, 'a') # doctest: SKIP >>> select_from(a.a).build() # doctest: SKIP SELECT a.a from A a join other tables >>> select_from(A.a, B.b).join(A.c == B.c).build() # doctest: SKIP SELECT A.a, B.b FROM A JOIN B ON A.c = B.c **features supporting** * `SELECT DISTINCT` * `FROM` * `WHERE` * `GROUP BY` * `HAVING` * `WINDOW` * compound-operator: `UNION [ALL]`, `INTERSECT` and `EXCEPT` * `ORDER BY` * `LIMIT [OFFSET]` **features not supporting** * `WITH [RECURSIVE]` * `SELECT ALL` * `VALUES` :param args: :param distinct: :param from_table: :return: """ pre_stat = ['SELECT'] if distinct: pre_stat.append('DISTINCT') if len(args) == 1 and isinstance(table := args[0], type): self = SqlSelectStat(table) self._involved.append(table) self.add(pre_stat) self.add('*') elif len(args) == 1 and isinstance(table := args[0], SqlCteExpr): self = SqlSelectStat(None) self.add(table) self.add(pre_stat) self.add('*') else: table, fields = select_from_fields(*args) if from_table is not None: table = from_table if table is None: raise RuntimeError('need to provide from_table') if isinstance(table, type): self = SqlSelectStat(None) self._involved.append(table) elif isinstance(table, SqlAlias) and isinstance(table._value, type): self = SqlSelectStat(None) self._involved.append(table) elif isinstance(table, SqlCteExpr): self = SqlSelectStat(None) self.add(table) else: self = SqlSelectStat(None) self.add(pre_stat) for i, field in enumerate(fields): if i > 0: self.add(',') if isinstance(field, SqlField): self.add(f'{field.table_name}.{field.name}') elif isinstance(field, SqlAlias) and isinstance(_field := field._value, SqlField): self.add([f'{_field.table_name}.{_field.name}', 'AS', repr(field._name)]) elif isinstance(field, SqlAlias) and isinstance(expr := field._value, SqlExpr): self.add(expr) self.add(['AS', repr(field._name)]) elif isinstance(field, SqlAliasField): self.add(field) elif isinstance(field, SqlFuncOper): self.add(field) elif isinstance(field, SqlLiteral): self.add(field) elif isinstance(field, SqlExpr): self.add(field) else: self.drop() raise TypeError('SELECT ' + repr(field)) with self: self.add('FROM') if isinstance(table, str): self.add(table) elif isinstance(table, type): self.add(table_name(table)) elif isinstance(table, SqlStat): self.add('(') self.add(table) self.add(')') elif isinstance(table, SqlAlias) and isinstance(_table := table._value, type): self.add([table_name(_table), table._name]) elif isinstance(table, SqlAlias) and isinstance(subq := table._value, SqlSubQuery): self.add('(') self.add(subq.stat) self.add([')', 'AS', repr(table._name)]) elif isinstance(table, SqlCteExpr): self.add(table._name) else: raise TypeError('FROM ' + repr(table)) return self
def select_from_fields(*args) -> tuple[Union[type, SqlAlias, None], list[SqlExpr]]: if len(args) == 0: raise RuntimeError('empty field') table = None fields = [] for arg in args: if isinstance(arg, (int, float, bool, str)): fields.append(SqlLiteral(repr(arg))) elif isinstance(arg, Field): if table is None: table = arg.table fields.append(SqlField(arg)) elif isinstance(arg, SqlField): if table is None: table = arg.table fields.append(arg) elif isinstance(arg, SqlAlias) and isinstance(arg._value, SqlField): if table is None: table = arg._value.table fields.append(arg) elif isinstance(arg, SqlAliasField) and isinstance(field_table := arg.table, type): if table is None: table = SqlAlias(field_table, arg.name) fields.append(arg) elif isinstance(arg, SqlAliasField) and isinstance(cte := arg.table, SqlCteExpr): if table is None: table = cte fields.append(arg) elif isinstance(arg, SqlExpr): if table is None: table = use_table_first(arg) fields.append(arg) elif isinstance(arg, property): if (expr_table := getattr(arg.fget, '_sql_owner', None)) is None: raise RuntimeError(f'{arg} not a property from a Table') if table is None: table = expr_table if isinstance(expr := arg.fget(expr_table), SqlExpr): fields.append(expr) else: raise RuntimeError(f'{arg} does not return a sql expression.') else: raise TypeError(repr(arg)) return table, fields @overload def insert_into(table: type[T], *, policy: UPDATE_POLICY = None, named=False) -> SqlInsertStat[T]: pass @overload def insert_into(*field, policy: UPDATE_POLICY = None, named=False) -> SqlInsertStat[T]: pass
[docs] def insert_into(*args, policy: UPDATE_POLICY = None, named=False) -> SqlInsertStat[T]: """ ``INSERT``: https://www.sqlite.org/lang_insert.html insert values >>> insert_into(A, policy='REPLACE').build() # doctest: SKIP INSERT OR REPLACE INTO A VALUES (?) >>> insert_into(A, policy='REPLACE').submit([A(1), A(2)]) # doctest: SKIP insert values with field overwrite >>> insert_into(A, policy='REPLACE').values(a='1').build() # doctest: SKIP INSERT OR REPLACE INTO A VALUES (1) insert values from a table >>> insert_into(A, policy='IGNORE').select_from(B).build() # doctest: SKIP INSERT OR IGNORE INTO A SELECT * FROM B **features supporting** * `INSERT [OR ...]` * `VALUES` * `DEFAULT VALUES` * `SELECT` * upsert clause * returning clause **features not supporting** * `WITH [RECURSIVE]` :param table: :param policy: :param named: :return: """ if len(args) == 1 and isinstance(args[0], type): self = SqlInsertStat((table := args[0]), named=named) else: table, fields = select_from_fields(*args) if table is None: raise RuntimeError() for i, field in enumerate(list(fields)): if isinstance(field, SqlField): if field.table != table: raise RuntimeError(f'field {field.table_name}.{field.name} not belong to {table.__name__}') fields[i] = field.name else: raise TypeError() self = SqlInsertStat(table, fields, named=named) with self: self.add('INSERT') if policy is not None: self.add(['OR', policy.upper()]) self.add(['INTO', table_name(table)]) if self._fields is not None: self.add('(') for i, field in enumerate(self._fields): if i > 0: self.add(',') self.add(field) self.add(')') return self
@overload def replace_into(table: type[T], *, named=False) -> SqlInsertStat[T]: pass @overload def replace_into(*field: Any, named=False) -> SqlInsertStat[T]: pass
[docs] def replace_into(*args, named=False) -> SqlInsertStat[T]: return insert_into(*args, policy='REPLACE', named=named)
[docs] def update(table: type[T], *args: Union[bool, SqlCompareOper], **kwargs) -> SqlUpdateStat[T]: """ ``UPDATE``: https://www.sqlite.org/lang_update.html >>> update(A, A.a==1).where(A.b==2).build() # doctest: SKIP UPDATE A SET A.a = 1 WHERE A.b = 2 **features supporting** * `UPDATE [OR ...]` * `SET COLUMN = EXPR` * `FROM` * `WHERE` * `ON CONFLICT (COLUMNS) SET (COLUMNS) = EXPR` * returning clause **features not supporting** * `WITH [RECURSIVE]` * (qualified table name) `INDEXED BY` * (qualified table name) `NOT INDEXED` :param table: :param args: :param kwargs: :return: """ with SqlUpdateStat(table) as self: self.add(['UPDATE', table_name(table), 'SET']) if len(args): for arg in args: self.add(SqlCompareOper.as_set_expr(arg)) self.add(',') for term, value in kwargs.items(): table_field(table, term) self.add(SqlLiteral(term) == value) self.add(',') self._stat.pop() return self
[docs] def delete_from(table: type[T]) -> SqlDeleteStat[T]: """ ``DELETE``: https://www.sqlite.org/lang_delete.html >>> delete_from(A).where(A.b > 2).build() # doctest: SKIP DELETE FROM A WHERE A.b > 2 **features supporting** * `DELETE FROM` * `WHERE` * `ORDER BY` * `LIMIT [OFFSET]` * returning clause **features not supporting** * `WITH [RECURSIVE]` * (qualified table name) `INDEXED BY` * (qualified table name) `NOT INDEXED` :param table: :return: """ with SqlDeleteStat(table) as self: self.add(['DELETE', 'FROM', table_name(table)]) return self
[docs] def create_table(table: type[T], *, if_not_exists=True) -> SqlStat[T]: """ ``CREATE``: https://www.sqlite.org/lang_createtable.html >>> @named_tuple_table_class # doctest: SKIP ... class A(NamedTuple): ... a: int >>> create_table(A) # doctest: SKIP CREATE TABLE IF NOT EXISTS A (a INT NOT NULL) **features supporting** * `IF NOT EXISTS` * column constraint `NOT NULL` * column constraint `PRIMARY KEY` * column constraint `UNIQUE` * column constraint `CHECK` * column constraint `DEFAULT value` * table constraint `PRIMARY KEY` * table constraint `UNIQUE` * table constraint `CHECK` * table constraint `FOREIGN KEY` **features not supporting** * `CREATE TEMP|TEMPORARY` * `CREATE TEMP` * `AS SELECT` * column constraint `CONSTRAINT` * column constraint `NOT NULL ON CONFLICT` * column constraint `DEFAULT (EXPR)` * column constraint `COLLATE` * column constraint `REFERENCES` * column constraint `[GENERATED ALWAYS] AS` * table constraint `CONSTRAINT` * `WITHOUT ROWID` * `STRICT` :param table: :return: """ with SqlStat(table) as self: self.add(['CREATE', 'TABLE']) if if_not_exists: self.add('IF NOT EXISTS') self.add(table_name(table)) self.add('(') n_primary_key = len(primary_keys := table_primary_fields(table)) for i, field in enumerate(table_fields(table)): if n_primary_key == 1 and field.is_primary: column_def(self, field, field.get_primary()) else: column_def(self, field) self.add(',') if n_primary_key > 1: self.add(['PRIMARY KEY', '(', ' , '.join([it.name for it in primary_keys]), ')']) for it in primary_keys: if (conflict := it.get_primary().conflict) is not None: self.add(['ON CONFLICT', conflict.upper()]) break self.add(',') for unique in table_unique_fields(table): if len(unique.fields) > 1: self.add(['UNIQUE', '(', ' , '.join(unique.fields), ')']) if (conflict := unique.conflict) is not None: self.add(['ON CONFLICT', conflict.upper()]) self.add(',') for foreign_key in table_foreign_fields(table): foreign_constraint(self, foreign_key) self.add(',') if (check := table_check_field(table, None)) is not None: check_constraint(self, check) self.add(',') self._stat.pop() self.add(')') return self
def column_def(self: SqlStat, field: Field, primary: PRIMARY = None): self.add(f"[{field.name}]") # union types handle sql_type = field.sql_type if get_origin(sql_type) in (Union, UnionType): union_args = get_args(sql_type) non_none_types = [arg for arg in union_args if arg is not type(None)] if non_none_types: sql_type = non_none_types[0] if sql_type == Any: pass elif sql_type == int: self.add('INTEGER') elif sql_type == float: self.add('FLOAT') elif sql_type == bool: self.add('BOOLEAN') elif sql_type == bytes: self.add('BLOB') elif sql_type == str: self.add('TEXT') elif sql_type == datetime.time: self.add('DATETIME') elif sql_type == datetime.date: self.add('DATETIME') elif sql_type == datetime.datetime: self.add('DATETIME') else: raise RuntimeError(f'field type {field.sql_type}') if field.not_null: if not field.has_default or field.f_value is not None: self.add('NOT NULL') if primary is not None: self.add(['PRIMARY', 'KEY']) if primary.order is not None: self.add(primary.order.upper()) if primary.conflict is not None: self.add(['ON CONFLICT', primary.conflict.upper()]) if primary.auto_increment: self.add('AUTOINCREMENT') elif (unique := field.get_unique()) is not None: self.add('UNIQUE') if unique.conflict is not None: self.add(['ON CONFLICT', unique.conflict.upper()]) from .table import missing if field.f_value is missing: pass elif field.f_value is None: self.add('DEFAULT NULL') elif field.f_value == CURRENT_DATE: self.add('DEFAULT CURRENT_DATE') elif field.f_value == CURRENT_TIME: self.add('DEFAULT CURRENT_TIME') elif field.f_value == CURRENT_TIMESTAMP: self.add('DEFAULT CURRENT_TIMESTAMP') else: self.add(f'DEFAULT {repr(field.f_value)}') if (check := table_check_field(field.table, field.name)) is not None: check_constraint(self, check) def foreign_constraint(self: SqlStat, foreign: ForeignConstraint): self.add(['FOREIGN KEY']) self.add('(') self.add(' , '.join(foreign.fields)) self.add(')') self.add('REFERENCES') self.add(table_name(foreign.foreign_table)) self.add('(') self.add(' , '.join(foreign.foreign_fields)) self.add(')') if (policy := foreign.on_update) is not None: self.add(['ON UPDATE', policy]) if (policy := foreign.on_delete) is not None: self.add(['ON DELETE', policy]) def check_constraint(self: SqlStat, check: CheckConstraint): self._deparameter = True self.add(['CHECK', '(']) self.add(SqlRemovePlaceHolder(check.expression)) self.add(')') self._deparameter = False