Skip to content

Commit 4d34a85

Browse files
committed
more tests
1 parent 5f04fe1 commit 4d34a85

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed

quaddtype/numpy_quaddtype/src/scalar.c

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "dtype.h"
1818
#include "lock.h"
1919
#include "utilities.h"
20+
#include "constants.hpp"
2021

2122

2223
QuadPrecisionObject *
@@ -624,6 +625,119 @@ static PyGetSetDef QuadPrecision_getset[] = {
624625
{NULL} /* Sentinel */
625626
};
626627

628+
/*
629+
* Hash function for QuadPrecision scalars.
630+
*
631+
* This implements the same algorithm as CPython's _Py_HashDouble, adapted for
632+
* quad precision (128-bit) floating point. The algorithm computes a hash based
633+
* on the reduction of the value modulo the prime P = 2**PYHASH_BITS - 1.
634+
*
635+
* Key invariant: hash(x) == hash(y) whenever x and y are numerically equal,
636+
* even if x and y have different types. This ensures that:
637+
* hash(QuadPrecision(1.0)) == hash(1.0) == hash(1)
638+
*
639+
* The algorithm:
640+
* 1. Handle special cases: inf returns PYHASH_INF, nan uses pointer hash
641+
* 2. Extract mantissa m in [0.5, 1.0) and exponent e via frexp(v) = m * 2^e
642+
* 3. Process mantissa 28 bits at a time, accumulating into hash value x
643+
* 4. Adjust for exponent using bit rotation (since 2^PYHASH_BITS ≡ 1 mod P)
644+
* 5. Apply sign and handle the special case of -1 -> -2
645+
*/
646+
647+
#if SIZEOF_VOID_P >= 8
648+
# define PYHASH_BITS 61
649+
#else
650+
# define PYHASH_BITS 31
651+
#endif
652+
#define PYHASH_MODULUS (((Py_uhash_t)1 << PYHASH_BITS) - 1)
653+
#define PYHASH_INF 314159
654+
655+
static Py_hash_t
656+
QuadPrecision_hash(QuadPrecisionObject *self)
657+
{
658+
Sleef_quad value;
659+
int sign = 1;
660+
661+
if (self->backend == BACKEND_SLEEF) {
662+
value = self->value.sleef_value;
663+
}
664+
else {
665+
value = Sleef_cast_from_doubleq1((double)self->value.longdouble_value);
666+
}
667+
668+
// Check for NaN - use pointer hash (each NaN instance gets unique hash)
669+
// This prevents hash table catastrophic pileups from NaN instances
670+
if (Sleef_iunordq1(value, value)) {
671+
return _Py_HashPointer((void *)self);
672+
}
673+
674+
if (Sleef_icmpeqq1(value, QUAD_PRECISION_INF)) {
675+
return PYHASH_INF;
676+
}
677+
if (Sleef_icmpeqq1(value, QUAD_PRECISION_NINF)) {
678+
return -PYHASH_INF;
679+
}
680+
681+
// Handle sign
682+
Sleef_quad zero = Sleef_cast_from_int64q1(0);
683+
if (Sleef_icmpltq1(value, zero)) {
684+
sign = -1;
685+
value = Sleef_negq1(value);
686+
}
687+
688+
// Get mantissa and exponent: value = m * 2^e, where 0.5 <= m < 1.0
689+
int exponent;
690+
Sleef_quad mantissa = Sleef_frexpq1(value, &exponent);
691+
692+
// Process 28 bits at a time (same as CPython's _Py_HashDouble)
693+
// This works well for both binary and hexadecimal floating point
694+
Py_uhash_t x = 0;
695+
// 2^28 = 268435456 - exactly representable in double, so cast is safe
696+
Sleef_quad multiplier = Sleef_cast_from_int64q1(1LL << 28);
697+
698+
// Continue until mantissa becomes zero (all bits processed)
699+
while (Sleef_icmpneq1(mantissa, zero)) {
700+
// Rotate x left by 28 bits within PYHASH_MODULUS
701+
x = ((x << 28) & PYHASH_MODULUS) | (x >> (PYHASH_BITS - 28));
702+
703+
// Scale mantissa by 2^28
704+
mantissa = Sleef_mulq1_u05(mantissa, multiplier);
705+
exponent -= 28;
706+
707+
// Extract integer part
708+
Sleef_quad int_part = Sleef_truncq1(mantissa);
709+
Py_uhash_t y = (Py_uhash_t)Sleef_cast_to_int64q1(int_part);
710+
711+
// Remove integer part from mantissa (keep fractional part)
712+
mantissa = Sleef_subq1_u05(mantissa, int_part);
713+
714+
// Accumulate
715+
x += y;
716+
if (x >= PYHASH_MODULUS) {
717+
x -= PYHASH_MODULUS;
718+
}
719+
}
720+
721+
// Adjust for exponent: reduce e modulo PYHASH_BITS
722+
// For negative exponents: PYHASH_BITS - 1 - ((-1 - e) % PYHASH_BITS)
723+
int e = exponent >= 0
724+
? exponent % PYHASH_BITS
725+
: PYHASH_BITS - 1 - ((-1 - exponent) % PYHASH_BITS);
726+
727+
// Rotate x left by e bits
728+
x = ((x << e) & PYHASH_MODULUS) | (x >> (PYHASH_BITS - e));
729+
730+
// Apply sign
731+
x = x * sign;
732+
733+
// -1 is reserved for errors, so use -2 instead
734+
if (x == (Py_uhash_t)-1) {
735+
x = (Py_uhash_t)-2;
736+
}
737+
738+
return (Py_hash_t)x;
739+
}
740+
627741
PyTypeObject QuadPrecision_Type = {
628742
PyVarObject_HEAD_INIT(NULL, 0).tp_name = "numpy_quaddtype.QuadPrecision",
629743
.tp_basicsize = sizeof(QuadPrecisionObject),
@@ -632,6 +746,7 @@ PyTypeObject QuadPrecision_Type = {
632746
.tp_dealloc = (destructor)QuadPrecision_dealloc,
633747
.tp_repr = (reprfunc)QuadPrecision_repr_dragon4,
634748
.tp_str = (reprfunc)QuadPrecision_str_dragon4,
749+
.tp_hash = (hashfunc)QuadPrecision_hash,
635750
.tp_as_number = &quad_as_scalar,
636751
.tp_as_buffer = &QuadPrecision_as_buffer,
637752
.tp_richcompare = (richcmpfunc)quad_richcompare,

quaddtype/tests/test_quaddtype.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5229,3 +5229,121 @@ def test_add_regression_zero_plus_small(self):
52295229

52305230
assert result_yx == result_xy, f"0 + x = {result_yx}, but x + 0 = {result_xy}"
52315231
assert result_yx == x, f"0 + x = {result_yx}, expected {x}"
5232+
5233+
5234+
class TestQuadPrecisionHash:
5235+
"""Test suite for QuadPrecision hash function.
5236+
5237+
The hash implementation follows CPython's _Py_HashDouble algorithm to ensure
5238+
the invariant: hash(x) == hash(y) when x and y are numerically equal,
5239+
even across different types.
5240+
"""
5241+
5242+
@pytest.mark.parametrize("value", [
5243+
# Values that are exactly representable in binary floating point
5244+
"0.0", "1.0", "-1.0", "2.0", "-2.0",
5245+
"0.5", "0.25", "1.5", "-0.5",
5246+
"100.0", "-100.0",
5247+
# Powers of 2 are exactly representable
5248+
"0.125", "0.0625", "4.0", "8.0",
5249+
])
5250+
def test_hash_matches_float(self, value):
5251+
"""Test that hash(QuadPrecision) == hash(float) for exactly representable values.
5252+
5253+
Note: Only values that are exactly representable in both float64 and float128
5254+
should match. Values like 0.1, 0.3 will have different hashes because they
5255+
have different binary representations at different precisions.
5256+
"""
5257+
quad_val = QuadPrecision(value)
5258+
float_val = float(value)
5259+
assert hash(quad_val) == hash(float_val)
5260+
5261+
@pytest.mark.parametrize("value", [0.1, 0.3, 0.7, 1.1, 2.3])
5262+
def test_hash_matches_float_from_float(self, value):
5263+
"""Test that QuadPrecision created from float has same hash as that float.
5264+
5265+
When creating QuadPrecision from a Python float, the value is converted
5266+
from the float's double precision representation, so they should be
5267+
numerically equal and have the same hash.
5268+
"""
5269+
quad_val = QuadPrecision(value) # Created from float, not string
5270+
assert hash(quad_val) == hash(value)
5271+
5272+
@pytest.mark.parametrize("value", [0, 1, -1, 2, -2, 100, -100, 1000, -1000])
5273+
def test_hash_matches_int(self, value):
5274+
"""Test that hash(QuadPrecision) == hash(int) for integer values."""
5275+
quad_val = QuadPrecision(value)
5276+
assert hash(quad_val) == hash(value)
5277+
5278+
def test_hash_matches_large_int(self):
5279+
"""Test that hash(QuadPrecision) == hash(int) for large integers."""
5280+
big_int = 10**20
5281+
quad_val = QuadPrecision(str(big_int))
5282+
assert hash(quad_val) == hash(big_int)
5283+
5284+
def test_hash_infinity(self):
5285+
"""Test that infinity hash matches Python's float infinity hash."""
5286+
assert hash(QuadPrecision("inf")) == hash(float("inf"))
5287+
assert hash(QuadPrecision("-inf")) == hash(float("-inf"))
5288+
# Standard PyHASH_INF values
5289+
assert hash(QuadPrecision("inf")) == 314159
5290+
assert hash(QuadPrecision("-inf")) == -314159
5291+
5292+
def test_hash_nan_unique(self):
5293+
"""Test that each NaN instance gets a unique hash (pointer-based)."""
5294+
nan1 = QuadPrecision("nan")
5295+
nan2 = QuadPrecision("nan")
5296+
# NaN instances should have different hashes (based on object identity)
5297+
assert hash(nan1) != hash(nan2)
5298+
5299+
def test_hash_nan_same_instance(self):
5300+
"""Test that the same NaN instance has consistent hash."""
5301+
nan = QuadPrecision("nan")
5302+
assert hash(nan) == hash(nan)
5303+
5304+
def test_hash_negative_one(self):
5305+
"""Test that hash(-1) returns -2 (Python's hash convention)."""
5306+
# In Python, hash(-1) returns -2 because -1 is reserved for errors
5307+
assert hash(QuadPrecision(-1.0)) == -2
5308+
assert hash(QuadPrecision("-1.0")) == -2
5309+
5310+
def test_hash_set_membership(self):
5311+
"""Test that QuadPrecision values work correctly in sets."""
5312+
vals = [QuadPrecision(1.0), QuadPrecision(2.0), QuadPrecision(1.0)]
5313+
unique_set = set(vals)
5314+
assert len(unique_set) == 2
5315+
5316+
def test_hash_set_cross_type(self):
5317+
"""Test that QuadPrecision and float with same value are in same set bucket."""
5318+
s = {QuadPrecision(1.0)}
5319+
s.add(1.0)
5320+
assert len(s) == 1
5321+
5322+
def test_hash_dict_key(self):
5323+
"""Test that QuadPrecision values work as dict keys."""
5324+
d = {QuadPrecision(1.0): "one", QuadPrecision(2.0): "two"}
5325+
assert d[QuadPrecision(1.0)] == "one"
5326+
assert d[QuadPrecision(2.0)] == "two"
5327+
5328+
def test_hash_dict_cross_type_lookup(self):
5329+
"""Test that dict lookup works with float keys when hash matches."""
5330+
d = {QuadPrecision(1.0): "one"}
5331+
# Float lookup should work if hash and eq both work
5332+
assert d.get(1.0) == "one"
5333+
5334+
@pytest.mark.parametrize("value", [
5335+
"1e-100", "-1e-100",
5336+
"1e100", "-1e100",
5337+
"1e-300", "-1e-300",
5338+
])
5339+
def test_hash_extreme_values(self, value):
5340+
"""Test hash works for extreme values without errors."""
5341+
quad_val = QuadPrecision(value)
5342+
h = hash(quad_val)
5343+
assert isinstance(h, int)
5344+
5345+
@pytest.mark.parametrize("backend", ["sleef", "longdouble"])
5346+
def test_hash_backends(self, backend):
5347+
"""Test hash works for both backends."""
5348+
quad_val = QuadPrecision(1.5, backend=backend)
5349+
assert hash(quad_val) == hash(1.5)

0 commit comments

Comments
 (0)