diff --git a/.gitignore b/.gitignore index ee84b1fd3..38b401b58 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ coverage.xml .vscode results.xml venv +.venv diff --git a/duckdb_engine/datatypes.py b/duckdb_engine/datatypes.py index 363f6e878..30f22652e 100644 --- a/duckdb_engine/datatypes.py +++ b/duckdb_engine/datatypes.py @@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, Optional, Type import duckdb +import sqlalchemy from packaging.version import Version from sqlalchemy import exc from sqlalchemy.dialects.postgresql.base import PGIdentifierPreparer, PGTypeCompiler @@ -26,8 +27,10 @@ (BigInteger, SmallInteger) # pure reexport duckdb_version = duckdb.__version__ +sqlalchemy_version = sqlalchemy.__version__ IS_GT_1 = Version(duckdb_version) > Version("1.0.0") +IS_SQLA_GT_2 = Version(sqlalchemy_version) > Version("2.0.0") class UInt64(Integer): @@ -206,7 +209,6 @@ def __init__(self, fields: Dict[str, TV]): "timetz": sqltypes.TIME, "timestamptz": sqltypes.TIMESTAMP, "float4": sqltypes.FLOAT, - "float8": sqltypes.FLOAT, "usmallint": USmallInteger, "uinteger": UInteger, "ubigint": UBigInteger, @@ -219,6 +221,9 @@ def __init__(self, fields: Dict[str, TV]): } if IS_GT_1: ISCHEMA_NAMES["varint"] = VarInt +if IS_SQLA_GT_2: + ISCHEMA_NAMES["float8"] = sqltypes.DOUBLE # type: ignore[attr-defined] + ISCHEMA_NAMES["double"] = sqltypes.DOUBLE # type: ignore[attr-defined] def register_extension_types() -> None: diff --git a/duckdb_engine/tests/test_datatypes.py b/duckdb_engine/tests/test_datatypes.py index 6149fcfe9..af27596ad 100644 --- a/duckdb_engine/tests/test_datatypes.py +++ b/duckdb_engine/tests/test_datatypes.py @@ -167,6 +167,26 @@ def test_double_in_sqla_v2(engine: Engine) -> None: con.execute(t.select()) +def test_double(engine: Engine, session: Session) -> None: + sqlalchemy = importorskip("sqlalchemy", "2.0.0") + base = declarative_base() + + class Entry(base): + __tablename__ = "test_double" + + id = Column(Integer, primary_key=True, default=0) + value = Column(sqlalchemy.DOUBLE, nullable=False) + + base.metadata.create_all(bind=engine) + + session.add(Entry(value=42.000001)) # type: ignore[call-arg] + session.commit() + + result = session.query(Entry).one() + + assert result.value == 42.000001 + + def test_all_types_reflection(engine: Engine) -> None: importorskip("sqlalchemy", "1.4.0") importorskip("duckdb", "0.5.1")