import contextvars
import sqlite3
from pathlib import Path
from typing import Union, TypeVar, Optional, Any
import polars as pl
from .literal import UPDATE_POLICY
from .stat import SqlStat
from .table import table_name, table_field_names
__all__ = ['Connection', 'get_connection_context']
T = TypeVar('T')
S = TypeVar('S')
[docs]
class Connection:
"""
A sqlite3 connection wrapper.
If as a context manager that put itself in a global context-aware variable.
"""
[docs]
def __init__(self, filename: Union[str, Path] = ':memory:', *,
debug: bool = False):
"""
:param filename: sqlite database filepath. use in-memory database by default.
:param debug: print statement when executing.
"""
self._connection = sqlite3.Connection(str(filename))
self._debug = debug
self._context: Optional[contextvars.Token] = None
@property
def connection(self) -> sqlite3.Connection:
return self._connection
def __enter__(self):
self._context = CONNECTION.set(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val is None:
self._connection.commit()
else:
self._connection.rollback()
if (token := self._context) is not None:
CONNECTION.reset(token)
self._context = None
def __del__(self):
self._connection.close()
# ====== #
# tables #
# ====== #
[docs]
def list_table(self) -> list[str]:
"""
list table's name stored in the database.
:return: list of table's name
"""
results = self.execute('SELECT name FROM sqlite_master WHERE type = "table"')
return [it[0] for it in results]
[docs]
def table_schema(self, table: Union[str, type]) -> str:
"""
get the table schema stored in the database.
:param table: table name or type.
:return: table schema
"""
if isinstance(table, str):
name = table
else:
name = table_name(table)
if not isinstance(name, str):
raise TypeError()
results = self.execute('SELECT sql FROM sqlite_master WHERE type = "table" AND name = ?', [name]).fetchone()
if results is None:
raise ValueError(f'table {table} not found')
return results[0]
# ======= #
# execute #
# ======= #
[docs]
def commit(self):
self._connection.commit()
[docs]
def rollback(self):
self._connection.rollback()
[docs]
def execute(self, stat: Union[str, SqlStat], parameter: Union[list[Any], dict[str, Any]] = ()) -> sqlite3.Cursor:
"""
execute a statement.
:param stat: a raw SQL statement or a SqlStat
:param parameter: statement variable's value.
:return: a cursor.
"""
if isinstance(stat, SqlStat):
stat._connection = None
stat, _parameter = stat.build()
if len(parameter) == 0:
parameter = _parameter
if self._debug:
print(repr(stat))
try:
ret = self._connection.execute(stat, parameter)
except sqlite3.OperationalError as e:
raise RuntimeError(stat) from e
except sqlite3.ProgrammingError as e:
raise RuntimeError(stat) from e
except sqlite3.InterfaceError as e:
raise RuntimeError(repr(parameter)) from e
return ret
[docs]
def execute_batch(self, stat: Union[str, SqlStat], parameter: list) -> sqlite3.Cursor:
"""
execute a statement in batch mode.
:param stat: a raw SQL statement or a SqlStat
:param parameter: list of statement variable's value.
:return: a cursor.
"""
if isinstance(stat, SqlStat):
stat._connection = None
stat, _ = stat.build()
if self._debug:
print(repr(stat))
try:
ret = self._connection.executemany(stat, parameter)
except sqlite3.OperationalError as e:
raise RuntimeError(stat) from e
except sqlite3.ProgrammingError as e:
raise RuntimeError(f'{stat=}, ?={parameter}') from e
except sqlite3.InterfaceError as e:
raise RuntimeError(repr(parameter)) from e
return ret
[docs]
def execute_script(self, stat: Union[str, list[str], list[SqlStat]]):
"""
execute SQL script.
:param stat: a raw SQL script, or a list of statements/SqlStat.
:param commit: commit script.
:return: a cursor.
"""
script = []
if isinstance(stat, str):
script.append(stat)
else:
for _stat in stat:
if isinstance(_stat, str):
script.append(_stat)
elif isinstance(_stat, SqlStat):
_stat._connection = None
script.append(_stat.build()[0])
else:
raise TypeError()
script = ';\n'.join(script)
if self._debug:
print(repr(stat))
try:
ret = self._connection.executescript(script)
except sqlite3.OperationalError as e:
raise RuntimeError(script) from e
except sqlite3.ProgrammingError as e:
raise RuntimeError(script) from e
except sqlite3.InterfaceError as e:
raise RuntimeError(script) from e
return ret
# ========= #
# functions #
# ========= #
[docs]
def sqlite_compileoption_get(self, n):
"""
* https://www.sqlite.org/lang_corefunc.html#sqlite_compileoption_get
* https://www.sqlite.org/compile.html#omitfeatures
:param n:
:return:
"""
ret, *_ = self._connection.execute("""\
WITH RET(val) AS (
VALUES (sqlite_compileoption_get(?))
)
SELECT * FROM RET
""", (n,)).fetchone()
return ret
[docs]
def sqlite_compileoption_used(self, n) -> bool:
"""
* https://www.sqlite.org/lang_corefunc.html#sqlite_compileoption_used
* https://www.sqlite.org/compile.html#omitfeatures
:param n:
:return:
"""
ret, *_ = self._connection.execute("""\
WITH RET(val) AS (
VALUES (sqlite_compileoption_used(?))
)
SELECT * FROM RET
""", (n,)).fetchone()
return ret > 0
# ============= #
# import/export #
# ============= #
[docs]
def export_dataframe(self, table: Union[str, type[T]]) -> pl.DataFrame:
"""
export a table into a DataFrame.
:param table: table name or type.
:return: polars DataFrame
"""
if isinstance(table, str):
from .util import get_fields_from_schema
fields = get_fields_from_schema(self.table_schema(table))
else:
fields = table_field_names(table)
table = table_name(table)
result = self.connection.execute(f'SELECT * FROM {table}').fetchall()
return pl.DataFrame(result, schema=fields)
[docs]
def import_dataframe(self, table: Union[str, type[T]], df: pl.DataFrame, *,
policy: UPDATE_POLICY = 'REPLACE'):
"""
Import a table from a DataFrame.
:param table: table name or type.
:param df: polars DataFrame
:param policy: insert policy
"""
if isinstance(table, str):
from .util import get_fields_from_schema
fields = get_fields_from_schema(self.table_schema(table))
else:
fields = table_field_names(table)
table = table_name(table)
stat = f'INSERT OR {policy.upper()} INTO {table} VALUES (' + ','.join(['?'] * len(fields)) + ')'
self.execute_batch(stat, [
tuple([row[f] for f in fields])
for row in df.iter_rows(named=True)
])
[docs]
def export_csv(self, table: Union[str, type[T]], file: Union[str, Path]):
"""
export a table into a csv file.
:param table: table name or type.
:param file: csv filepath.
"""
self.export_dataframe(table).write_csv(file)
[docs]
def import_csv(self, table: Union[str, type[T]], file: Union[str, Path]):
"""
Import a table from a csv file.
:param table: table name or type.
:param file: csv filepath
"""
self.import_dataframe(table, pl.read_csv(file))
CONNECTION = contextvars.ContextVar('CONNECTION', default=None)
[docs]
def get_connection_context() -> Optional[Connection]:
"""
Get a connection under the current context.
"""
return CONNECTION.get()