from __future__ import annotations
import functools
import sqlite3
import warnings
from collections.abc import Iterator
from typing import overload, TYPE_CHECKING, Any, TypeVar, Generic, Optional, Literal, Union, cast, Callable
import polars as pl
from typing_extensions import Self
from .expr import *
from .expr import SqlStatBuilder, SqlRemovePlaceHolder
from .table import *
if TYPE_CHECKING:
from .connection import Connection
__all__ = [
'SqlStat',
'SqlSelectStat',
'SqlInsertStat',
'SqlUpdateStat',
'SqlDeleteStat',
'Cursor',
]
T = TypeVar('T')
S = TypeVar('S')
def catch_error(f=None, *, attr: str = None):
def _catch_error_decorator(f):
@functools.wraps(f)
def _catch_error(self, *args, **kwargs):
if attr is None:
stat: SqlStat = self
else:
stat = getattr(self, attr)
with stat:
return f(self, *args, **kwargs)
return _catch_error
if f is None:
return _catch_error_decorator
else:
return _catch_error_decorator(f)
[docs]
class SqlStat(Generic[T]):
"""Abstract SQL statement."""
[docs]
def __init__(self, table: Optional[type[T]]):
self.table = table
self._stat: list[Any] = []
from .connection import get_connection_context
self._connection: Optional[Connection] = get_connection_context()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val is not None:
self.drop()
[docs]
def build(self) -> tuple[str, list]:
"""build a SQL statement."""
builder = SqlStatBuilder()
builder.add(self)
self._stat = None
return builder.build(), builder.para
def __str__(self) -> str:
table = table_name(self.table) if self.table is not None else '...'
return f'{type(Self).__name__}[{table}]'
def __repr__(self) -> str:
builder = SqlStatBuilder()
builder._deparameter = 'repr'
builder.add(self)
return builder.build()
[docs]
def submit(self) -> Cursor[T]:
"""
build the SQL statement and execute.
:return: a cursor
:raise RuntimeError: current statement does not bind with a connection.
"""
if (connection := self._connection) is None:
raise RuntimeError('Do not in a connection context')
ret = Cursor(connection, connection.execute(self), self.table)
self._connection = None
return ret
[docs]
def drop(self):
self._connection = None
self._stat = None
def __del__(self):
# check connection and auto_commit,
# make SqlStat auto submit itself when nobody is referring to it,
# so users do need to explict call submit() for every statements.
# FIXME error raised during statement constructing should be aware
if self._connection is not None and self._stat is not None and len(self._stat) > 0:
try:
self.submit()
except BaseException as e:
warnings.warn(repr(e))
[docs]
def add(self, stat) -> Self:
"""
Add SQL token.
"""
if self._stat is None:
raise RuntimeError("Statement is closed.")
self._stat.append(stat)
if isinstance(stat, SqlStat):
stat._connection = None
return self
[docs]
class Cursor(Generic[T]):
"""
A SQL cursor wrapper.
It will try to cast to T from the tuple returns.
"""
[docs]
def __init__(self, connection: Connection, cursor: sqlite3.Cursor, table: type[T] = None):
self._connection = connection
self._cursor = cursor
if table is not None:
from .table import table_class
try:
table_cls = table_class(table)
except AttributeError:
pass
else:
cursor.row_factory = lambda _, row: table_cls.table_new(*row)
@property
def headers(self) -> list[str]:
header = self._cursor.description
return [it[0] for it in header] if header is not None else []
def __del__(self):
self._cursor.close()
self._connection = None
[docs]
def fetchall(self) -> list[T]:
"""fetch all results."""
return list(self)
[docs]
def fetchone(self) -> Optional[T]:
"""fetch the first result."""
if (ret := self._cursor.fetchone()) is None:
return None
return ret
def __iter__(self) -> Iterator[T]:
"""iterate the results."""
yield from iter(self._cursor)
[docs]
def fetch_polars(self) -> pl.DataFrame:
return pl.DataFrame(list(self._cursor), schema=self.headers)
class SqlWhereStat:
"""statement with **WHERE** support."""
@catch_error
def where(self, *expr: Union[bool, SqlCompareOper, SqlExpr, None]) -> Self:
"""
``WHERE`` clause: https://www.sqlite.org/lang_select.html#whereclause
>>> select_from(A).where( # doctest: SKIP
... A.a == 1, A.b == 2
... ).build()
SELECT * FROM A
WHERE (A.a = 1) AND (A.b = 2)
:param expr:
:return:
"""
zelf = cast(SqlStat, self)
expr = [it for it in expr if it is not None]
if len(expr):
from .func_stat import and_
zelf.add('WHERE')
zelf.add(and_(*expr))
return self
class SqlLimitStat:
@overload
def limit(self, n: int) -> Self:
pass
@overload
def limit(self, row_count: int, offset: int) -> Self:
pass
@catch_error
def limit(self, *args: int) -> Self:
"""
``LIMIT``: https://www.sqlite.org/lang_select.html#limitoffset
>>> select_from(A).limit(10).build() # doctest: SKIP
SELECT * FROM A LIMIT 10
>>> select_from(A).limit(10, 10).build() # doctest: SKIP
SELECT * FROM A LIMIT 10 OFFSET 10
**NOTE**
LIMIT on UPDATE/DELETE need compile flag ``SQLITE_ENABLE_UPDATE_DELETE_LIMIT``.
>>> assert Connection().sqlite_compileoption_used('SQLITE_ENABLE_UPDATE_DELETE_LIMIT') # doctest: SKIP
"""
zelf = cast(SqlStat, self)
if len(args) == 1:
n, *_ = args
zelf.add(['LIMIT', str(n)])
elif len(args) == 2:
row_count, offset = args
zelf.add(['LIMIT', str(row_count), 'OFFSET', str(offset)])
else:
raise TypeError()
return self
@catch_error
def order_by(self, *by: Union[int, str, SqlExpr, Any]) -> Self:
"""
``ORDER BY``: https://www.sqlite.org/lang_select.html#orderby
>>> select_from(A).order_by(A.a).build() # doctest: SKIP
SELECT * FROM A ORDER BY A.
**possible ordering**
>>> select_from(A).order_by( # doctest: SKIP
... asc(A.a), desc(A.b), nulls_first(A.c), asc(A.d).nulls_last(),
... )
SELECT * FROM A ORDER BY
A.a ASC, A.b DESC, A.c NULLS FIRST, A.d ASC NULLS LAST
**NOTE**
ORDER BY on UPDATE/DELETE need compile flag ``SQLITE_ENABLE_UPDATE_DELETE_LIMIT``.
>>> assert Connection().sqlite_compileoption_used('SQLITE_ENABLE_UPDATE_DELETE_LIMIT') # doctest: SKIP
"""
zelf = cast(SqlStat, self)
zelf.add('ORDER BY')
fields = []
for it in by:
if isinstance(it, int):
fields.append(SqlLiteral(str(it)))
elif isinstance(it, str):
fields.append(SqlLiteral(it))
elif isinstance(it, SqlAlias):
fields.append(SqlLiteral(it._name))
elif isinstance(it, SqlExpr):
fields.append(it)
else:
fields.append(wrap(it))
zelf.add(SqlConcatOper(fields, ','))
return self
[docs]
class SqlSelectStat(SqlStat[T], SqlWhereStat, SqlLimitStat, Generic[T]):
"""**SELECT** statement."""
[docs]
def __init__(self, table: Optional[type[T]]):
super().__init__(table)
self._involved: list[Union[type, SqlAlias[type]]] = []
def __matmul__(self, other: str) -> SqlAlias[SqlSubQuery]:
"""wrap itself as a subquery with an alias name."""
return SqlAlias(SqlSubQuery(self), other)
[docs]
def fetchall(self) -> list[T]:
"""submit and fetch all result."""
return self.submit().fetchall()
[docs]
def fetchone(self) -> Optional[T]:
"""submit and fetch the first result."""
return self.submit().fetchone()
def __iter__(self) -> Iterator[T]:
"""submit and iterate the results"""
return iter(self.submit())
[docs]
def fetch_polars(self) -> pl.DataFrame:
"""submit and fetch all as a polar dataframe."""
return self.submit().fetch_polars()
[docs]
@catch_error
def windows(self, *windows: Union[SqlWindowDef, SqlAlias[SqlWindowDef]], **window_ks: SqlWindowDef) -> Self:
"""
define windows.
"""
if len(windows) == 0 and len(window_ks) == 0:
return self
self.add('WINDOW')
for window in windows:
if isinstance(window, SqlWindowDef):
if window.name is None:
raise RuntimeError('?? AS ' + repr(window))
self.add([window.name, 'AS'])
self.add(window)
elif isinstance(window, SqlAlias) and isinstance(window._value, SqlWindowDef):
self.add([window.name, 'AS'])
self.add(window)
else:
raise TypeError()
self.add(',')
for name, window in window_ks.items():
self.add([name, 'AS'])
self.add(window)
self.add(',')
self._stat.pop()
return self
BY = Literal['left', 'right', 'inner', 'full outer', 'cross']
@overload
def join(self, constraint: Union[Callable, ForeignConstraint], *, by: BY = None) -> SqlSelectStat[tuple]:
pass
@overload
def join(self, table: Union[type[S], SqlAlias[S]],
constraint: Union[Callable, ForeignConstraint], *,
by: BY = None) -> SqlSelectStat[tuple]:
pass
@overload
def join(self, table: Union[type[S], SqlSelectStat[S], SqlAlias[S], SqlCteExpr],
*field: Union[bool, Any],
by: BY = None) -> SqlSelectStat[tuple]:
pass
@overload
def join(self, *field: bool | Any, by: BY = None) -> SqlSelectStat[tuple]:
pass
[docs]
@catch_error
def join(self, *args, by: BY = None) -> SqlSelectStat[tuple]:
"""
``JOIN`` https://www.sqlite.org/lang_select.html#strange_join_names
>>> select_from(A.a, B.b).join(A.a == B.a) # doctest: SKIP
SELECT A.a, B.b FROM A
JOIN B ON A.a = B.a
"""
if by is not None:
self.add(by.upper())
self.add('JOIN')
self.table = None
if len(args) == 2 and isinstance(table := args[0], type) and (callable(constraint := args[1]) or isinstance(constraint, ForeignConstraint)):
if not isinstance(constraint, ForeignConstraint):
if (constraint := table_foreign_field(constraint)) is None:
raise RuntimeError('not a foreign constraint')
self.__join_foreign(constraint, table)
elif len(args) > 0 and isinstance(table := args[0], type):
self.__join(table, *args[1:])
elif len(args) > 0 and isinstance(expr := args[0], SqlCteExpr):
self._stat.insert(0, expr)
self.__join(expr, *args[1:])
elif len(args) > 0 and isinstance(stat := args[0], SqlSelectStat):
self.add(stat)
self.__join(None, *args[1:])
elif len(args) == 2 and isinstance(table := args[0], SqlAlias) and (callable(constraint := args[1]) or isinstance(constraint, ForeignConstraint)):
if not isinstance(constraint, ForeignConstraint):
if (constraint := table_foreign_field(constraint)) is None:
raise RuntimeError('not a foreign constraint')
self.__join_foreign(constraint, table)
elif len(args) > 0 and isinstance(table := args[0], SqlAlias) and isinstance(table._value, type) and isinstance(table._name, str):
self.__join(table, *args[1:])
elif len(args) > 0 and isinstance(table := args[0], SqlAlias) and isinstance(expr := table._value, SqlCteExpr) and isinstance(table._name, str):
self._stat.insert(0, expr)
self.__join(table, *args[1:])
elif len(args) > 0 and isinstance(table := args[0], SqlAlias) \
and isinstance(subq := table._value, SqlSubQuery) \
and isinstance(name := table._name, str) \
and isinstance(stat := subq.stat, SqlSelectStat):
self.add(stat)
self.add('AS')
self.add(name)
self.__join(None, *args[1:])
elif len(args) == 1 and isinstance(constraint := args[0], ForeignConstraint):
self.__join_foreign(constraint, None)
elif len(args) == 1 and callable(constraint := args[0]):
if (constraint := table_foreign_field(constraint)) is None:
raise RuntimeError('not a foreign constraint')
self.__join_foreign(constraint, None)
else:
table = None
for field in args:
if table is None and isinstance(field, SqlExpr):
table = use_table(field, self._involved)
if table is None:
raise RuntimeError('no join table')
self.__join(table, *args)
return self
def __join(self, right: Union[type, SqlAlias, SqlCteExpr, None], *fields):
if right is None:
pass
elif isinstance(right, type):
self.add(table_name(right))
elif isinstance(right, SqlCteExpr):
self.add(right._name)
elif isinstance(right, SqlAlias) and isinstance(table := right._value, type):
self.add(table_name(table))
self.add(right._name)
elif isinstance(right, SqlAlias) and isinstance(expr := right._value, SqlCteExpr):
self.add(expr._select)
self.add(right._name)
else:
raise TypeError(f'JOIN {right}')
if len(fields):
if any([isinstance(it, SqlCompareOper) for it in fields]):
self.__join_on(*fields)
else:
self.__join_use(*fields)
self._involved.append(right)
def __join_use(self, *fields):
self.add(['USING', '('])
for field in fields:
if isinstance(field, str):
self.add(repr(field))
if isinstance(field, Field):
self.add(field.name)
if isinstance(field, SqlField):
self.add(field.name)
else:
raise TypeError('USING ' + repr(field))
self.add(',')
self._stat.pop()
self.add(')')
def __join_on(self, *exprs: bool | SqlExpr):
self.add('ON')
from .func_stat import and_
self.add(and_(*exprs))
def __join_foreign(self, constraint: ForeignConstraint, right: type | SqlAlias | None):
this = that = None
for _this in self._involved:
if isinstance(_this, type) and _this == constraint.table:
this = constraint.table_name
if right is None:
right = constraint.foreign_table
break
elif isinstance(_this, type) and _this == constraint.foreign_table:
that = constraint.table_name
if right is None:
right = constraint.table
break
elif isinstance(_this, SqlAlias) and isinstance(table := _this._value, type) and table == constraint.table:
this = _this._name
if right is None:
right = constraint.foreign_table
break
elif isinstance(_this, SqlAlias) and isinstance(table := _this._value, type) and table == constraint.foreign_table:
that = _this._name
if right is None:
right = constraint.table
break
else:
raise RuntimeError('improper foreign constraint')
if isinstance(right, type):
self.add(table_name(right))
if this is None and right == constraint.table:
this = _this.__name__
if that is None and right == constraint.foreign_table:
that = _this.__name__
elif isinstance(right, SqlAlias) and isinstance(table := right._value, type):
self.add(table_name(table))
self.add(right._name)
if this is None and right == constraint.table:
this = right._name
if that is None and right == constraint.foreign_table:
that = right._name
else:
raise TypeError(f'JOIN {right}')
assert this is not None and that is not None
if constraint.fields == constraint.foreign_fields:
self.add(['USING', '('])
for field in constraint.fields:
self.add(field)
self.add(',')
self._stat.pop()
self.add(')')
else:
self.add(['ON', '('])
for af, bf in zip(constraint.fields, constraint.foreign_fields):
self.add([f'{this}.{af} = {that}.{bf}'])
self.add('AND')
self._stat.pop()
self.add(')')
self._involved.append(right)
[docs]
@catch_error
def group_by(self, *by) -> Self:
"""
``GROUP BY`` https://www.sqlite.org/lang_select.html#resultset
"""
if len(by) == 0:
raise RuntimeError()
self.add('GROUP BY')
for field in by:
if isinstance(field, (int, float, bool)):
self.add(str(field))
elif isinstance(field, str):
self.add(repr(field))
elif isinstance(field, Field):
self.add(f'{field.table_name}.{field.name}')
elif isinstance(field, SqlField):
self.add(f'{field.table_name}.{field.name}')
elif isinstance(field, SqlAlias):
self.add(field._name)
elif isinstance(field, SqlExpr):
self.add(field)
else:
raise TypeError('GROUP BY ' + repr(field))
self.add(',')
self._stat.pop()
return self
[docs]
@catch_error
def having(self, *exprs: Union[bool, SqlExpr]) -> Self:
"""
``HAVING`` https://www.sqlite.org/lang_select.html#resultset
"""
if len(exprs) == 0:
return self
from .func_stat import and_
self.add('HAVING')
self.add(and_(*exprs))
return self
[docs]
@catch_error
def intersect(self, stat: SqlStat) -> Self:
"""
``INTERSECT`` https://www.sqlite.org/lang_select.html#compound_select_statements
"""
self.add('INTERSECT')
self._stat.extend(stat._stat)
stat._connection = None
stat._stat = None
return self
def __and__(self, other: SqlStat) -> Self:
"""
``INTERSECT`` https://www.sqlite.org/lang_select.html#compound_select_statements
"""
return self.intersect(other)
[docs]
@catch_error
def union(self, stat: SqlStat, all=False) -> Self:
"""
``UNION`` https://www.sqlite.org/lang_select.html#compound_select_statements
"""
self.add('UNION')
if all:
self.add('ALL')
self._stat.extend(stat._stat)
stat._connection = None
stat._stat = None
return self
def __or__(self, other: SqlStat) -> Self:
"""
``UNION`` https://www.sqlite.org/lang_select.html#compound_select_statements
"""
return self.union(other)
[docs]
@catch_error
def except_(self, stat: SqlStat) -> Self:
"""
``EXCEPT`` https://www.sqlite.org/lang_select.html#compound_select_statements
"""
self.add('EXCEPT')
self._stat.extend(stat._stat)
stat._connection = None
stat._stat = None
return self
def __sub__(self, other: SqlStat) -> Self:
"""
``EXCEPT`` https://www.sqlite.org/lang_select.html#compound_select_statements
"""
return self.except_(other)
class SqlReturnStat:
@catch_error
def returning(self, *expr: Union[str, Any]) -> Self:
zelf = cast(SqlStat, self)
zelf.add('RETURNING')
if len(expr) == 0:
zelf.add('*')
else:
fields = []
for exp in expr:
if isinstance(exp, str):
zelf.add(exp)
fields.append(exp)
elif isinstance(exp, SqlField):
zelf.add(exp.name)
fields.append(exp.name)
elif isinstance(exp, SqlAlias) and isinstance(field := exp._value, SqlField):
zelf.add([field.name, 'AS', exp._name])
fields.append(exp._name)
elif isinstance(exp, SqlAlias) and isinstance(_expr := exp._value, SqlExpr):
zelf.add(SqlRemovePlaceHolder(_expr))
zelf.add(['AS', exp._name])
fields.append(exp._name)
else:
raise TypeError(f'RETURNING ({exp})')
zelf.add(',')
zelf._stat.pop()
if isinstance(self, SqlInsertStat):
self._return_table = None
self._fields = fields
else:
self.table = None
return self
[docs]
class SqlInsertStat(SqlStat[T], SqlReturnStat, Generic[T]):
[docs]
def __init__(self, table: type[T], fields: list[str] = None, *, named: bool = False):
super().__init__(table)
self._fields = fields
self._used_fields: list[str] | None = None
# when
# None: `VALUES` set
# 'DEFAULT': `DEFAULT VALUES` unset
# {field->value}: `DEFAULT VALUES` set
self._values: Union[Literal['DEFAULT'], dict[str, Any], None] = {
name: SqlLiteral('?')
for name in table_field_names(table)
}
self._named = named
self._returning = False
self._return_table = table
@overload
def select_from(self, table: type[T], *, distinct: bool = False) -> SqlSelectStat[T]:
pass
@overload
def select_from(self, *field, distinct: bool = False,
from_table: Union[str, type, SqlAlias, SqlSelectStat] = None) -> SqlSelectStat[tuple]:
pass
[docs]
@catch_error
def select_from(self, *args, distinct: bool = False,
from_table: Union[str, type, SqlAlias, SqlSelectStat] = None) -> SqlSelectStat[T]:
if isinstance(self._values, str):
raise RuntimeError()
self._connection = None
self.add('(')
for i, field in enumerate(self._values):
if i > 0:
self.add(',')
self.add(field)
self.add(')')
from .stat_start import select_from
ret = select_from(*args, distinct=distinct, from_table=from_table)
ret._stat = [*self._stat, *ret._stat]
return ret
[docs]
@catch_error
def values(self, *args, **kwargs: Union[str, SqlExpr]) -> Self:
# TODO support direct set values
if self._values is None or isinstance(self._values, str):
raise RuntimeError()
for expr in args:
if isinstance(expr, SqlCompareOper) and expr.oper == '=' and isinstance(expr.left, SqlField):
self._values[expr.left.field.name] = expr.right
else:
raise TypeError(expr)
for field, expr in kwargs.items():
if field not in self._values:
raise RuntimeError(f'{table_name(self.table)}.{field} not found')
self._values[field] = expr
return self
[docs]
def defaults(self) -> Self:
"""insert default values"""
self._values = 'DEFAULT'
return self
@catch_error
def _set_values(self):
if isinstance(self._values, str):
self.add(['DEFAULT', 'VALUES'])
elif isinstance(self._values, dict):
self._use_fields = []
if self._fields is None:
fields = list(self._values.keys())
else:
fields = self._fields
self.add(['VALUES', '('])
for field in fields:
value = self._values[field]
if isinstance(value, (int, float, bool, str)):
self.add(repr(value))
elif isinstance(value, SqlLiteral):
if value.value == '?':
if self._named:
self.add(f':{field}')
else:
self.add('?')
self._use_fields.append(field)
else:
self.add(value.value)
elif isinstance(value, SqlPlaceHolder):
self.add(repr(value.value))
elif isinstance(value, SqlStat):
self.add(value)
elif isinstance(value, SqlExpr):
self.add(SqlRemovePlaceHolder(value))
else:
raise TypeError(repr(value))
self.add(',')
self._stat.pop()
self.add(')')
self._values = None
[docs]
@catch_error
def on_conflict(self, *conflict, where: Union[bool, SqlCompareOper] = None) -> SqlUpsertStat[T]:
self._set_values()
return SqlUpsertStat(self, *conflict, where=where)
[docs]
@catch_error
def returning(self, *expr: Union[str, SqlExpr]) -> SqlInsertStat[tuple]:
self._set_values()
super().returning(*expr)
self._returning = True
return self
[docs]
def build(self) -> tuple[str, list]:
self._set_values()
return super().build()
[docs]
def submit(self, parameter: list[T] = ()) -> Cursor[T]:
if (connection := self._connection) is None:
raise RuntimeError('Do not in a connection context')
self._set_values()
from .table import table_class
try:
table_cls = table_class(self.table)
except AttributeError:
pass
else:
if self._named:
def mapper(p):
if isinstance(p, self.table):
return {f: v for f, v in zip(self._use_fields, table_cls.table_seq(p, self._fields))}
return dict(p)
else:
def mapper(p):
if isinstance(p, self.table):
return table_cls.table_seq(p, self._use_fields)
return tuple(p)
parameter = list(map(mapper, parameter))
if len(parameter):
if self._returning:
if len(parameter) > 1:
raise RuntimeError('only support return one data')
cur = connection.execute(self, parameter[0])
else:
cur = connection.execute_batch(self, parameter)
else:
cur = connection.execute(self, parameter)
ret = Cursor(connection, cur, self._return_table)
self._connection = None
return ret
class SqlUpsertStat(Generic[T]):
"""
https://www.sqlite.org/syntax/upsert-clause.html
"""
def __init__(self, stat: SqlInsertStat[T], *conflict, where: Union[bool, SqlCompareOper] = None):
self._stat = stat
self._conflict = conflict
self._stat.add(['ON', 'CONFLICT'])
self._do_where = False
if len(conflict):
from .stat_start import select_from_fields
table, fields = select_from_fields(*conflict)
if table is None:
table = stat.table
elif table != stat.table:
raise RuntimeError()
self._stat.add('(')
for i, field in enumerate(fields):
if i > 0:
self._stat.add(',')
if isinstance(field, SqlField):
if field.table != table:
raise RuntimeError(f'field {field.table_name}.{field.name} not belong to {table.__name__}')
self._stat.add(f'{field.name}')
elif isinstance(field, SqlLiteral) and isinstance(field.value, str):
self._stat.add(field.value)
else:
raise RuntimeError(f'ON CONFLICT ({field}:{type(field).__name__})')
if where is not None:
self._stat.add('WHERE')
self._stat.add(SqlRemovePlaceHolder(where))
self._stat.add(')')
@catch_error(attr='_stat')
def do_nothing(self) -> SqlInsertStat[T]:
self._stat.add(['DO', 'NOTHING'])
return self._stat
@catch_error(attr='_stat')
def do_update(self, *args: Union[bool, SqlCompareOper], where: Union[bool, SqlCompareOper] = None) -> SqlInsertStat[T]:
self._stat.add(['DO', 'UPDATE', 'SET'])
self._stat.add(SqlRemovePlaceHolder(SqlVarArgOper(',', [
SqlCompareOper.as_set_expr(it) for it in args
])))
if where is not None:
self._stat.add('WHERE')
self._stat.add(SqlRemovePlaceHolder(where))
return self._stat
[docs]
class SqlUpdateStat(SqlStat[T], SqlWhereStat, SqlLimitStat, SqlReturnStat, Generic[T]):
[docs]
@catch_error
def from_(self, query: Union[SqlStat, SqlAlias[SqlSubQuery]]) -> Self:
self.add('FROM')
if isinstance(query, SqlStat):
self.add(query)
elif isinstance(query, SqlSubQuery):
self.add(query.stat)
elif isinstance(query, SqlAlias) and isinstance(query._value, SqlSubQuery):
self.add(query._value.stat)
self.add(['AS', query._name])
else:
raise TypeError(repr(query))
return self
[docs]
class SqlDeleteStat(SqlStat[T], SqlWhereStat, SqlLimitStat, SqlReturnStat, Generic[T]):
pass