Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ coverage.xml
.vscode
results.xml
venv
.venv
7 changes: 6 additions & 1 deletion duckdb_engine/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Callable, Dict, Optional, Type

import duckdb
import sqlalchemy
from packaging.version import Version
from sqlalchemy import exc
from sqlalchemy.dialects.postgresql.base import PGIdentifierPreparer, PGTypeCompiler
Expand All @@ -26,8 +27,10 @@
(BigInteger, SmallInteger) # pure reexport

duckdb_version = duckdb.__version__
sqlalchemy_version = sqlalchemy.__version__

IS_GT_1 = Version(duckdb_version) > Version("1.0.0")
IS_SQLA_GT_2 = Version(sqlalchemy_version) > Version("2.0.0")


class UInt64(Integer):
Expand Down Expand Up @@ -206,7 +209,6 @@ def __init__(self, fields: Dict[str, TV]):
"timetz": sqltypes.TIME,
"timestamptz": sqltypes.TIMESTAMP,
"float4": sqltypes.FLOAT,
"float8": sqltypes.FLOAT,
"usmallint": USmallInteger,
"uinteger": UInteger,
"ubigint": UBigInteger,
Expand All @@ -219,6 +221,9 @@ def __init__(self, fields: Dict[str, TV]):
}
if IS_GT_1:
ISCHEMA_NAMES["varint"] = VarInt
if IS_SQLA_GT_2:
ISCHEMA_NAMES["float8"] = sqltypes.DOUBLE # type: ignore[attr-defined]
ISCHEMA_NAMES["double"] = sqltypes.DOUBLE # type: ignore[attr-defined]


def register_extension_types() -> None:
Expand Down
20 changes: 20 additions & 0 deletions duckdb_engine/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,26 @@ def test_double_in_sqla_v2(engine: Engine) -> None:
con.execute(t.select())


def test_double(engine: Engine, session: Session) -> None:
sqlalchemy = importorskip("sqlalchemy", "2.0.0")
base = declarative_base()

class Entry(base):
__tablename__ = "test_double"

id = Column(Integer, primary_key=True, default=0)
value = Column(sqlalchemy.DOUBLE, nullable=False)

base.metadata.create_all(bind=engine)

session.add(Entry(value=42.000001)) # type: ignore[call-arg]
session.commit()

result = session.query(Entry).one()

assert result.value == 42.000001


def test_all_types_reflection(engine: Engine) -> None:
importorskip("sqlalchemy", "1.4.0")
importorskip("duckdb", "0.5.1")
Expand Down