diff --git a/runtime/python_bridge.py b/runtime/python_bridge.py index e82cc56..8e43f7e 100644 --- a/runtime/python_bridge.py +++ b/runtime/python_bridge.py @@ -279,6 +279,11 @@ def serialize_ndarray(obj): Why: Arrow IPC gives a compact, lossless binary payload that the JS side can decode as a Table. If JSON fallback is explicitly requested, honor it even when pyarrow is installed so callers don't unexpectedly need an Arrow decoder on the TypeScript side. + + Note: PyArrow's pa.array() only handles 1D arrays. For multi-dimensional arrays, we flatten + before encoding and include shape metadata for reconstruction on the JS side. This maintains + Arrow's binary efficiency while working with the current arrow-js implementation (which + doesn't yet support FixedShapeTensorArray). See: https://github.com/apache/arrow-js/issues/115 """ if FALLBACK_JSON: return serialize_ndarray_json(obj) @@ -289,7 +294,11 @@ def serialize_ndarray(obj): 'Arrow encoding unavailable for ndarray; install pyarrow or set TYWRAP_CODEC_FALLBACK=json to enable JSON fallback' ) from exc try: - arr = pa.array(obj) + # Flatten multi-dimensional arrays for Arrow compatibility + # pa.array() only handles 1D arrays; we preserve shape for JS-side reconstruction + original_shape = list(obj.shape) if hasattr(obj, 'shape') else None + flat = obj.flatten() if hasattr(obj, 'ndim') and obj.ndim > 1 else obj + arr = pa.array(flat) table = pa.Table.from_arrays([arr], names=['value']) sink = pa.BufferOutputStream() with pa.ipc.new_stream(sink, table.schema) as writer: @@ -301,7 +310,8 @@ def serialize_ndarray(obj): 'codecVersion': CODEC_VERSION, 'encoding': 'arrow', 'b64': b64, - 'shape': getattr(obj, 'shape', None), + 'shape': original_shape, + 'dtype': str(obj.dtype) if hasattr(obj, 'dtype') else None, } except Exception as exc: if FALLBACK_JSON: diff --git a/src/utils/codec.ts b/src/utils/codec.ts index cacd758..8f03bce 100644 --- a/src/utils/codec.ts +++ b/src/utils/codec.ts @@ -243,6 +243,40 @@ function isPromiseLike(value: unknown): value is PromiseLike { ); } +/** + * Reshape a flat array into a multi-dimensional nested array. + * + * Why: PyArrow's pa.array() only handles 1D arrays, so we flatten multi-dimensional + * arrays before Arrow encoding and reshape after decoding. This maintains Arrow's + * binary efficiency while working with current arrow-js (which doesn't yet support + * FixedShapeTensorArray). See: https://github.com/apache/arrow-js/issues/115 + * + * @param flat - Flat array of values + * @param shape - Target shape, e.g., [2, 3] for a 2x3 matrix + * @returns Nested array with the specified shape + */ +function reshapeArray(flat: unknown[], shape: readonly number[]): unknown { + if (shape.length === 0) { + return flat[0]; + } + if (shape.length === 1) { + return flat; + } + + // shape.length >= 2, so first is always defined + const first = shape[0]!; + const rest = shape.slice(1); + const chunkSize = rest.reduce((a, b) => a * b, 1); + const result: unknown[] = []; + + for (let i = 0; i < first; i++) { + const chunk = flat.slice(i * chunkSize, (i + 1) * chunkSize); + result.push(reshapeArray(chunk, rest)); + } + + return result; +} + // Why: decoding needs to reject incompatible envelopes before we attempt to interpret payloads. const CODEC_VERSION = 1; @@ -326,13 +360,48 @@ function decodeEnvelopeCore( if (marker === 'ndarray') { const encoding = (value as { encoding?: unknown }).encoding; + const shapeValue = (value as { shape?: unknown }).shape; + const shape = isNumberArray(shapeValue) ? shapeValue : undefined; + if (encoding === 'arrow') { const b64 = (value as { b64?: unknown }).b64; if (typeof b64 !== 'string') { throw new Error('Invalid ndarray envelope: missing b64'); } const bytes = fromBase64(b64); - return decodeArrow(bytes); + const decoded = decodeArrow(bytes); + + // Reshape if multi-dimensional (Arrow only handles 1D, so we flatten on encode) + if (shape && shape.length > 1) { + if (isPromiseLike(decoded)) { + return decoded.then(data => { + if (Array.isArray(data)) { + return reshapeArray(data, shape); + } + // Arrow table - extract values and reshape + const table = data as ArrowTable & { getChildAt?: (i: number) => { toArray?: () => unknown[] } }; + if (typeof table.getChildAt === 'function') { + const column = table.getChildAt(0); + if (column && typeof column.toArray === 'function') { + return reshapeArray(column.toArray(), shape); + } + } + return data; + }); + } + if (Array.isArray(decoded)) { + return reshapeArray(decoded, shape); + } + // Arrow table - extract values and reshape + const table = decoded as ArrowTable & { getChildAt?: (i: number) => { toArray?: () => unknown[] } }; + if (typeof table.getChildAt === 'function') { + const column = table.getChildAt(0); + if (column && typeof column.toArray === 'function') { + return reshapeArray(column.toArray(), shape); + } + } + } + return decoded; } if (encoding === 'json') { if (!('data' in (value as object))) { diff --git a/test/runtime_codec_scientific.test.ts b/test/runtime_codec_scientific.test.ts index da1471b..f921c1b 100644 --- a/test/runtime_codec_scientific.test.ts +++ b/test/runtime_codec_scientific.test.ts @@ -90,12 +90,13 @@ describeNodeOnly('Scientific Codecs', () => { if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return; if (!pythonPath || !hasModule(pythonPath, 'torch')) return; // Torch tensor serialization requires pyarrow for Arrow encoding of ndarrays + // Multi-dimensional arrays are flattened on encode and reshaped on decode if (!hasModule(pythonPath, 'pyarrow')) return; const bridge = new NodeBridge({ scriptPath, pythonPath, - // Use Arrow encoding (pyarrow required), no fallback + // Use Arrow encoding (pyarrow required) - flatten+reshape handles multi-dim arrays timeoutMs: bridgeTimeoutMs, }); @@ -156,3 +157,241 @@ describeNodeOnly('Scientific Codecs', () => { scientificTimeoutMs ); }); + +/** + * ndarray flatten+reshape tests + * Tests the Arrow encoding path where multi-dimensional arrays are flattened + * on Python side and reshaped on JS side. + */ +describeNodeOnly('ndarray Flatten+Reshape', () => { + it( + 'handles 1D arrays (no reshape needed)', + async () => { + const pythonPath = await resolvePythonForTests(); + if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return; + if (!pythonPath || !hasModule(pythonPath, 'numpy')) return; + if (!hasModule(pythonPath, 'pyarrow')) return; + + const bridge = new NodeBridge({ + scriptPath, + pythonPath, + timeoutMs: bridgeTimeoutMs, + }); + + try { + const result = await bridge.call('numpy', 'array', [[1, 2, 3, 4, 5]]); + expect(result).toEqual([1, 2, 3, 4, 5]); + } finally { + await bridge.dispose(); + } + }, + scientificTimeoutMs + ); + + it( + 'handles 3D arrays', + async () => { + const pythonPath = await resolvePythonForTests(); + if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return; + if (!pythonPath || !hasModule(pythonPath, 'numpy')) return; + if (!hasModule(pythonPath, 'pyarrow')) return; + + const bridge = new NodeBridge({ + scriptPath, + pythonPath, + timeoutMs: bridgeTimeoutMs, + }); + + try { + // Create a 2x3x4 array via numpy.arange().reshape() + // We'll use builtins.eval to construct this + const result = await bridge.call( + 'builtins', + 'eval', + ['__import__("numpy").arange(24).reshape(2, 3, 4).tolist()'] + ); + + // Verify shape by checking nested structure + expect(result.length).toBe(2); + expect(result[0].length).toBe(3); + expect(result[0][0].length).toBe(4); + // Verify values + expect(result[0][0]).toEqual([0, 1, 2, 3]); + expect(result[1][2]).toEqual([20, 21, 22, 23]); + } finally { + await bridge.dispose(); + } + }, + scientificTimeoutMs + ); + + it( + 'handles 3D torch tensors with Arrow encoding', + async () => { + const pythonPath = await resolvePythonForTests(); + if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return; + if (!pythonPath || !hasModule(pythonPath, 'torch')) return; + if (!hasModule(pythonPath, 'pyarrow')) return; + + const bridge = new NodeBridge({ + scriptPath, + pythonPath, + timeoutMs: bridgeTimeoutMs, + }); + + try { + // Create a 2x3x2 tensor + const result = await bridge.call<{ + data: number[][][]; + shape?: number[]; + dtype?: string; + device?: string; + }>('torch', 'tensor', [ + [ + [[1, 2], [3, 4], [5, 6]], + [[7, 8], [9, 10], [11, 12]], + ], + ]); + + expect(result.shape).toEqual([2, 3, 2]); + expect(result.device).toBe('cpu'); + expect(result.data).toEqual([ + [[1, 2], [3, 4], [5, 6]], + [[7, 8], [9, 10], [11, 12]], + ]); + } finally { + await bridge.dispose(); + } + }, + scientificTimeoutMs + ); + + it( + 'handles single-element arrays', + async () => { + const pythonPath = await resolvePythonForTests(); + if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return; + if (!pythonPath || !hasModule(pythonPath, 'numpy')) return; + if (!hasModule(pythonPath, 'pyarrow')) return; + + const bridge = new NodeBridge({ + scriptPath, + pythonPath, + timeoutMs: bridgeTimeoutMs, + }); + + try { + const result = await bridge.call('numpy', 'array', [[42]]); + expect(result).toEqual([42]); + } finally { + await bridge.dispose(); + } + }, + scientificTimeoutMs + ); + + it( + 'handles single-element multi-dimensional arrays', + async () => { + const pythonPath = await resolvePythonForTests(); + if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return; + if (!pythonPath || !hasModule(pythonPath, 'torch')) return; + if (!hasModule(pythonPath, 'pyarrow')) return; + + const bridge = new NodeBridge({ + scriptPath, + pythonPath, + timeoutMs: bridgeTimeoutMs, + }); + + try { + // Create a 1x1x1 tensor + const result = await bridge.call<{ + data: number[][][]; + shape?: number[]; + }>('torch', 'tensor', [[[[99]]]]); + + expect(result.shape).toEqual([1, 1, 1]); + expect(result.data).toEqual([[[99]]]); + } finally { + await bridge.dispose(); + } + }, + scientificTimeoutMs + ); + + it( + 'preserves dtype for float arrays', + async () => { + const pythonPath = await resolvePythonForTests(); + if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return; + if (!pythonPath || !hasModule(pythonPath, 'torch')) return; + if (!hasModule(pythonPath, 'pyarrow')) return; + + const bridge = new NodeBridge({ + scriptPath, + pythonPath, + timeoutMs: bridgeTimeoutMs, + }); + + try { + const result = await bridge.call<{ + data: number[][]; + shape?: number[]; + dtype?: string; + }>('torch', 'tensor', [ + [[1.5, 2.5], [3.5, 4.5]], + ]); + + expect(result.shape).toEqual([2, 2]); + expect(result.data).toEqual([[1.5, 2.5], [3.5, 4.5]]); + // dtype should be float32 or float64 + expect(result.dtype).toMatch(/float/); + } finally { + await bridge.dispose(); + } + }, + scientificTimeoutMs + ); + + it( + 'handles 4D tensors (image-like batches)', + async () => { + const pythonPath = await resolvePythonForTests(); + if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return; + if (!pythonPath || !hasModule(pythonPath, 'torch')) return; + if (!hasModule(pythonPath, 'pyarrow')) return; + + const bridge = new NodeBridge({ + scriptPath, + pythonPath, + timeoutMs: bridgeTimeoutMs, + }); + + try { + // Create a 2x2x2x2 tensor (batch x channels x height x width) + const input = [ + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + ], + [ + [[9, 10], [11, 12]], + [[13, 14], [15, 16]], + ], + ]; + + const result = await bridge.call<{ + data: number[][][][]; + shape?: number[]; + }>('torch', 'tensor', [input]); + + expect(result.shape).toEqual([2, 2, 2, 2]); + expect(result.data).toEqual(input); + } finally { + await bridge.dispose(); + } + }, + scientificTimeoutMs + ); +});