Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions runtime/python_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ def serialize_ndarray(obj):
Why: Arrow IPC gives a compact, lossless binary payload that the JS side can decode as a
Table. If JSON fallback is explicitly requested, honor it even when pyarrow is installed so
callers don't unexpectedly need an Arrow decoder on the TypeScript side.

Note: PyArrow's pa.array() only handles 1D arrays. For multi-dimensional arrays, we flatten
before encoding and include shape metadata for reconstruction on the JS side. This maintains
Arrow's binary efficiency while working with the current arrow-js implementation (which
doesn't yet support FixedShapeTensorArray). See: https://github.com/apache/arrow-js/issues/115
"""
if FALLBACK_JSON:
return serialize_ndarray_json(obj)
Expand All @@ -289,7 +294,11 @@ def serialize_ndarray(obj):
'Arrow encoding unavailable for ndarray; install pyarrow or set TYWRAP_CODEC_FALLBACK=json to enable JSON fallback'
) from exc
try:
arr = pa.array(obj)
# Flatten multi-dimensional arrays for Arrow compatibility
# pa.array() only handles 1D arrays; we preserve shape for JS-side reconstruction
original_shape = list(obj.shape) if hasattr(obj, 'shape') else None
flat = obj.flatten() if hasattr(obj, 'ndim') and obj.ndim > 1 else obj
arr = pa.array(flat)
table = pa.Table.from_arrays([arr], names=['value'])
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, table.schema) as writer:
Expand All @@ -301,7 +310,8 @@ def serialize_ndarray(obj):
'codecVersion': CODEC_VERSION,
'encoding': 'arrow',
'b64': b64,
'shape': getattr(obj, 'shape', None),
'shape': original_shape,
'dtype': str(obj.dtype) if hasattr(obj, 'dtype') else None,
}
except Exception as exc:
if FALLBACK_JSON:
Expand Down
71 changes: 70 additions & 1 deletion src/utils/codec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,40 @@ function isPromiseLike(value: unknown): value is PromiseLike<unknown> {
);
}

/**
* Reshape a flat array into a multi-dimensional nested array.
*
* Why: PyArrow's pa.array() only handles 1D arrays, so we flatten multi-dimensional
* arrays before Arrow encoding and reshape after decoding. This maintains Arrow's
* binary efficiency while working with current arrow-js (which doesn't yet support
* FixedShapeTensorArray). See: https://github.com/apache/arrow-js/issues/115
*
* @param flat - Flat array of values
* @param shape - Target shape, e.g., [2, 3] for a 2x3 matrix
* @returns Nested array with the specified shape
*/
function reshapeArray(flat: unknown[], shape: readonly number[]): unknown {
if (shape.length === 0) {
return flat[0];
}
if (shape.length === 1) {
return flat;
}

// shape.length >= 2, so first is always defined
const first = shape[0]!;
const rest = shape.slice(1);
const chunkSize = rest.reduce((a, b) => a * b, 1);
const result: unknown[] = [];

for (let i = 0; i < first; i++) {
const chunk = flat.slice(i * chunkSize, (i + 1) * chunkSize);
result.push(reshapeArray(chunk, rest));
}

return result;
}

// Why: decoding needs to reject incompatible envelopes before we attempt to interpret payloads.
const CODEC_VERSION = 1;

Expand Down Expand Up @@ -326,13 +360,48 @@ function decodeEnvelopeCore<T>(

if (marker === 'ndarray') {
const encoding = (value as { encoding?: unknown }).encoding;
const shapeValue = (value as { shape?: unknown }).shape;
const shape = isNumberArray(shapeValue) ? shapeValue : undefined;

if (encoding === 'arrow') {
const b64 = (value as { b64?: unknown }).b64;
if (typeof b64 !== 'string') {
throw new Error('Invalid ndarray envelope: missing b64');
}
const bytes = fromBase64(b64);
return decodeArrow(bytes);
const decoded = decodeArrow(bytes);

// Reshape if multi-dimensional (Arrow only handles 1D, so we flatten on encode)
if (shape && shape.length > 1) {
if (isPromiseLike(decoded)) {
return decoded.then(data => {
if (Array.isArray(data)) {
return reshapeArray(data, shape);
}
// Arrow table - extract values and reshape
const table = data as ArrowTable & { getChildAt?: (i: number) => { toArray?: () => unknown[] } };
if (typeof table.getChildAt === 'function') {
const column = table.getChildAt(0);
if (column && typeof column.toArray === 'function') {
return reshapeArray(column.toArray(), shape);
}
}
return data;
});
}
if (Array.isArray(decoded)) {
return reshapeArray(decoded, shape);
}
// Arrow table - extract values and reshape
const table = decoded as ArrowTable & { getChildAt?: (i: number) => { toArray?: () => unknown[] } };
if (typeof table.getChildAt === 'function') {
const column = table.getChildAt(0);
if (column && typeof column.toArray === 'function') {
return reshapeArray(column.toArray(), shape);
}
}
}
return decoded;
}
if (encoding === 'json') {
if (!('data' in (value as object))) {
Expand Down
241 changes: 240 additions & 1 deletion test/runtime_codec_scientific.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,13 @@ describeNodeOnly('Scientific Codecs', () => {
if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return;
if (!pythonPath || !hasModule(pythonPath, 'torch')) return;
// Torch tensor serialization requires pyarrow for Arrow encoding of ndarrays
// Multi-dimensional arrays are flattened on encode and reshaped on decode
if (!hasModule(pythonPath, 'pyarrow')) return;

const bridge = new NodeBridge({
scriptPath,
pythonPath,
// Use Arrow encoding (pyarrow required), no fallback
// Use Arrow encoding (pyarrow required) - flatten+reshape handles multi-dim arrays
timeoutMs: bridgeTimeoutMs,
});

Expand Down Expand Up @@ -156,3 +157,241 @@ describeNodeOnly('Scientific Codecs', () => {
scientificTimeoutMs
);
});

/**
* ndarray flatten+reshape tests
* Tests the Arrow encoding path where multi-dimensional arrays are flattened
* on Python side and reshaped on JS side.
*/
describeNodeOnly('ndarray Flatten+Reshape', () => {
it(
'handles 1D arrays (no reshape needed)',
async () => {
const pythonPath = await resolvePythonForTests();
if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return;
if (!pythonPath || !hasModule(pythonPath, 'numpy')) return;
if (!hasModule(pythonPath, 'pyarrow')) return;

const bridge = new NodeBridge({
scriptPath,
pythonPath,
timeoutMs: bridgeTimeoutMs,
});

try {
const result = await bridge.call<number[]>('numpy', 'array', [[1, 2, 3, 4, 5]]);
expect(result).toEqual([1, 2, 3, 4, 5]);
} finally {
await bridge.dispose();
}
},
scientificTimeoutMs
);

it(
'handles 3D arrays',
async () => {
const pythonPath = await resolvePythonForTests();
if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return;
if (!pythonPath || !hasModule(pythonPath, 'numpy')) return;
if (!hasModule(pythonPath, 'pyarrow')) return;

const bridge = new NodeBridge({
scriptPath,
pythonPath,
timeoutMs: bridgeTimeoutMs,
});

try {
// Create a 2x3x4 array via numpy.arange().reshape()
// We'll use builtins.eval to construct this
const result = await bridge.call<number[][][]>(
'builtins',
'eval',
['__import__("numpy").arange(24).reshape(2, 3, 4).tolist()']
);

// Verify shape by checking nested structure
expect(result.length).toBe(2);
expect(result[0].length).toBe(3);
expect(result[0][0].length).toBe(4);
// Verify values
expect(result[0][0]).toEqual([0, 1, 2, 3]);
expect(result[1][2]).toEqual([20, 21, 22, 23]);
} finally {
await bridge.dispose();
}
},
scientificTimeoutMs
);

it(
'handles 3D torch tensors with Arrow encoding',
async () => {
const pythonPath = await resolvePythonForTests();
if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return;
if (!pythonPath || !hasModule(pythonPath, 'torch')) return;
if (!hasModule(pythonPath, 'pyarrow')) return;

const bridge = new NodeBridge({
scriptPath,
pythonPath,
timeoutMs: bridgeTimeoutMs,
});

try {
// Create a 2x3x2 tensor
const result = await bridge.call<{
data: number[][][];
shape?: number[];
dtype?: string;
device?: string;
}>('torch', 'tensor', [
[
[[1, 2], [3, 4], [5, 6]],
[[7, 8], [9, 10], [11, 12]],
],
]);

expect(result.shape).toEqual([2, 3, 2]);
expect(result.device).toBe('cpu');
expect(result.data).toEqual([
[[1, 2], [3, 4], [5, 6]],
[[7, 8], [9, 10], [11, 12]],
]);
} finally {
await bridge.dispose();
}
},
scientificTimeoutMs
);

it(
'handles single-element arrays',
async () => {
const pythonPath = await resolvePythonForTests();
if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return;
if (!pythonPath || !hasModule(pythonPath, 'numpy')) return;
if (!hasModule(pythonPath, 'pyarrow')) return;

const bridge = new NodeBridge({
scriptPath,
pythonPath,
timeoutMs: bridgeTimeoutMs,
});

try {
const result = await bridge.call<number[]>('numpy', 'array', [[42]]);
expect(result).toEqual([42]);
} finally {
await bridge.dispose();
}
},
scientificTimeoutMs
);

it(
'handles single-element multi-dimensional arrays',
async () => {
const pythonPath = await resolvePythonForTests();
if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return;
if (!pythonPath || !hasModule(pythonPath, 'torch')) return;
if (!hasModule(pythonPath, 'pyarrow')) return;

const bridge = new NodeBridge({
scriptPath,
pythonPath,
timeoutMs: bridgeTimeoutMs,
});

try {
// Create a 1x1x1 tensor
const result = await bridge.call<{
data: number[][][];
shape?: number[];
}>('torch', 'tensor', [[[[99]]]]);

expect(result.shape).toEqual([1, 1, 1]);
expect(result.data).toEqual([[[99]]]);
} finally {
await bridge.dispose();
}
},
scientificTimeoutMs
);

it(
'preserves dtype for float arrays',
async () => {
const pythonPath = await resolvePythonForTests();
if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return;
if (!pythonPath || !hasModule(pythonPath, 'torch')) return;
if (!hasModule(pythonPath, 'pyarrow')) return;

const bridge = new NodeBridge({
scriptPath,
pythonPath,
timeoutMs: bridgeTimeoutMs,
});

try {
const result = await bridge.call<{
data: number[][];
shape?: number[];
dtype?: string;
}>('torch', 'tensor', [
[[1.5, 2.5], [3.5, 4.5]],
]);

expect(result.shape).toEqual([2, 2]);
expect(result.data).toEqual([[1.5, 2.5], [3.5, 4.5]]);
// dtype should be float32 or float64
expect(result.dtype).toMatch(/float/);
} finally {
await bridge.dispose();
}
},
scientificTimeoutMs
);

it(
'handles 4D tensors (image-like batches)',
async () => {
const pythonPath = await resolvePythonForTests();
if (!pythonAvailable(pythonPath) || !existsSync(scriptPath)) return;
if (!pythonPath || !hasModule(pythonPath, 'torch')) return;
if (!hasModule(pythonPath, 'pyarrow')) return;

const bridge = new NodeBridge({
scriptPath,
pythonPath,
timeoutMs: bridgeTimeoutMs,
});

try {
// Create a 2x2x2x2 tensor (batch x channels x height x width)
const input = [
[
[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
],
[
[[9, 10], [11, 12]],
[[13, 14], [15, 16]],
],
];

const result = await bridge.call<{
data: number[][][][];
shape?: number[];
}>('torch', 'tensor', [input]);

expect(result.shape).toEqual([2, 2, 2, 2]);
expect(result.data).toEqual(input);
} finally {
await bridge.dispose();
}
},
scientificTimeoutMs
);
});
Loading