improve typing

This commit is contained in:
Josh Hawkins 2024-10-06 07:45:04 -05:00
parent cf7b27875e
commit 0b736c881b

View File

@ -1,19 +1,21 @@
import sqlite3
import sqlite_vec import sqlite_vec
from playhouse.sqliteq import SqliteQueueDatabase from playhouse.sqliteq import SqliteQueueDatabase
class SqliteVecQueueDatabase(SqliteQueueDatabase): class SqliteVecQueueDatabase(SqliteQueueDatabase):
def __init__(self, *args, load_vec_extension=False, **kwargs): def __init__(self, *args, load_vec_extension: bool = False, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.load_vec_extension = load_vec_extension self.load_vec_extension: bool = load_vec_extension
def _connect(self, *args, **kwargs): def _connect(self, *args, **kwargs) -> sqlite3.Connection:
conn = super()._connect(*args, **kwargs) conn: sqlite3.Connection = super()._connect(*args, **kwargs)
if self.load_vec_extension: if self.load_vec_extension:
self._load_vec_extension(conn) self._load_vec_extension(conn)
return conn return conn
def _load_vec_extension(self, conn): def _load_vec_extension(self, conn: sqlite3.Connection) -> None:
conn.enable_load_extension(True) conn.enable_load_extension(True)
sqlite_vec.load(conn) sqlite_vec.load(conn)
conn.enable_load_extension(False) conn.enable_load_extension(False)