From 9b1cc771e51af1bae389386736be3d976f0e782d Mon Sep 17 00:00:00 2001 From: Menooker Date: Wed, 17 Dec 2025 08:25:39 +0800 Subject: [PATCH 01/11] stream state --- KunQuant/Driver.py | 18 +++++++-- KunQuant/Op.py | 15 ++++++++ KunQuant/ops/MiscOp.py | 7 +++- KunQuant/passes/CodegenCpp.py | 11 ++++-- cpp/Kun/Context.hpp | 3 ++ cpp/Kun/MathUtil.hpp | 10 +++++ cpp/Kun/Module.hpp | 7 ++-- cpp/Kun/Ops.hpp | 1 + cpp/Kun/RunGraph.hpp | 2 + cpp/Kun/Runtime.cpp | 41 +++++++++++++++----- cpp/Kun/StateBuffer.hpp | 72 +++++++++++++++++++++++++++++++++++ cpp/Kun/StreamBuffer.hpp | 6 +-- cpp/Python/PyBinding.cpp | 20 ++++++---- tests/cpp/TestRuntime.cpp | 2 +- tests/test_stream.py | 47 +++++++++++++++++++++++ tests/tests.sh | 2 + tests/tests_arm.sh | 2 + 17 files changed, 233 insertions(+), 33 deletions(-) create mode 100644 cpp/Kun/MathUtil.hpp create mode 100644 cpp/Kun/StateBuffer.hpp create mode 100644 tests/test_stream.py diff --git a/KunQuant/Driver.py b/KunQuant/Driver.py index 5c37e83..d0c1966 100644 --- a/KunQuant/Driver.py +++ b/KunQuant/Driver.py @@ -10,7 +10,7 @@ # get the cpu architecture of the machine from KunQuant.jit.env import cpu_arch as _cpu_arch -required_version = "0x64100003" +required_version = "kun::VERSION" @dataclass class KunCompilerConfig: partition_factor : int = 3 @@ -232,7 +232,8 @@ def push_source(is_simple=False): def query_temp_buf_id(tempname: str, window: int) -> int: input_windows[tempname] = window return insert_name_str(tempname, "TEMP").idx - src, decl = codegen_cpp(module_name, func, input_name_to_idx, ins, outs, options, stream_mode, query_temp_buf_id, input_windows, generated_cross_sectional_func, dtype, blocking_len, not allow_unaligned, is_single_source) + stream_state_buffer_init = [] + src, decl = codegen_cpp(module_name, func, input_name_to_idx, ins, outs, options, stream_mode, query_temp_buf_id, input_windows, stream_state_buffer_init, generated_cross_sectional_func, dtype, blocking_len, not allow_unaligned, is_single_source) impl_src.append(src) decl_src.append(decl) newparti = _Partition(func.name, len(partitions), pins, pouts) @@ -309,6 +310,16 @@ def query_temp_buf_id(tempname: str, window: int) -> int: {parti_dep_src2} }} ''') + + if len(stream_state_buffer_init) > 0: + impl_src.append(f'''static std::vector __init_state_buffers(size_t stock_count) {{ + std::vector buffers; + buffers.reserve({len(stream_state_buffer_init)}); + {"\n ".join(stream_state_buffer_init)} + return buffers; +}} +''') + dty = dtype[0].upper() + dtype[1:] impl_src.append(f'''KUN_EXPORT Module {module_name}{{ {required_version}, @@ -320,7 +331,8 @@ def query_temp_buf_id(tempname: str, window: int) -> int: MemoryLayout::{output_layout}, {blocking_len}, Datatype::{dty}, - {"0" if allow_unaligned else "1"} + {"0" if allow_unaligned else "1"}, + {"nullptr" if len(stream_state_buffer_init) == 0 else "__init_state_buffers"} }};''') push_source() if not is_single_source: diff --git a/KunQuant/Op.py b/KunQuant/Op.py index 1c81503..bd96014 100644 --- a/KunQuant/Op.py +++ b/KunQuant/Op.py @@ -474,6 +474,21 @@ def generate_init_code(self, idx: str, elem_type: str, simd_lanes: int, inputs: inputs: the input variables of the op ''' return f"{self.get_func_or_class_full_name(elem_type, simd_lanes)} {self.get_state_variable_name_prefix()}{idx};" + def generate_init_code_stream(self, local_idx: str, buffer_idx: str, elem_type: str, simd_lanes: int, inputs: List[str], aligned: bool) -> Tuple[str, str]: + ''' + generate the code for the initialization of the state variable + local_idx: the output variable name index + buffer_idx: the buffer index in ctx->state_buffers + elem_type: the element type of the state variable + simd_lanes: SIMD lanes + inputs: the input variables of the op + + Returns: + ["the code for the initialization of the state variable in the compute function", "initalizer code for the state buffer"] + ''' + typename = self.get_func_or_class_full_name(elem_type, simd_lanes) + return f"{typename}& {self.get_state_variable_name_prefix()}{local_idx} = __ctx->state_buffers[{buffer_idx}]->get<{typename}>(__stock_idx);",\ + f"buffers.emplace_back(makeStateBuffer<{typename}>(stock_count, {simd_lanes}));" class GloablStatefulOpTrait(StatefulOpTrait): diff --git a/KunQuant/ops/MiscOp.py b/KunQuant/ops/MiscOp.py index de98aa0..218d487 100644 --- a/KunQuant/ops/MiscOp.py +++ b/KunQuant/ops/MiscOp.py @@ -1,6 +1,6 @@ import KunQuant from KunQuant.Op import AcceptSingleValueInputTrait, Input, OpBase, WindowedTrait, SinkOpTrait, CrossSectionalOp, GlobalStatefulProducerTrait, GloablStatefulOpTrait, StateConsumerTrait, UnaryElementwiseOp, BinaryElementwiseOp -from typing import List, Union +from typing import List, Tuple, Union class BackRef(OpBase, WindowedTrait): ''' @@ -100,6 +100,11 @@ def get_single_value_input_id(self) -> int: def get_state_variable_name_prefix(self) -> str: return "ema_" + def generate_init_code_stream(self, local_idx: str, buffer_idx: str, elem_type: str, simd_lanes: int, inputs: List[str], aligned: bool) -> Tuple[str, str]: + if len(self.inputs) == 2: + raise RuntimeError("EMA with init_val is not supported in stream mode") + return super().generate_init_code_stream(local_idx, buffer_idx, elem_type, simd_lanes, inputs, aligned) + def generate_init_code(self, idx: str, elem_type: str, simd_lanes: int, inputs: List[str], aligned: bool) -> str: initv = "NAN" if len(self.inputs) == 2: diff --git a/KunQuant/passes/CodegenCpp.py b/KunQuant/passes/CodegenCpp.py index 05bf184..83c4904 100644 --- a/KunQuant/passes/CodegenCpp.py +++ b/KunQuant/passes/CodegenCpp.py @@ -92,7 +92,7 @@ def _generate_cross_sectional_func_name(op: GenericCrossSectionalOp, inputs: Lis name.append(layout) return f"{op.__class__.__name__}_{'_'.join(name)}" -def codegen_cpp(prefix: str, f: Function, input_name_to_idx: Dict[str, int], inputs: List[Tuple[Input, bool]], outputs: List[Tuple[Output, bool]], options: dict, stream_mode: bool, query_temp_buffer_id, stream_window_size: Dict[str, int], generated_cross_sectional_func: Set[str], elem_type: str, simd_lanes: int, aligned: bool, static: bool) -> Tuple[str, str]: +def codegen_cpp(prefix: str, f: Function, input_name_to_idx: Dict[str, int], inputs: List[Tuple[Input, bool]], outputs: List[Tuple[Output, bool]], options: dict, stream_mode: bool, query_temp_buffer_id, stream_window_size: Dict[str, int], stream_state_buffer_init: List[str], generated_cross_sectional_func: Set[str], elem_type: str, simd_lanes: int, aligned: bool, static: bool) -> Tuple[str, str]: if len(f.ops) == 3 and isinstance(f.ops[1], SimpleCrossSectionalOp): return "", f'''static auto stage_{prefix}__{f.name} = {f.ops[1].__class__.__name__}Stocks, Mapper{f.ops[2].attrs["layout"]}<{elem_type}, {simd_lanes}>>;''' @@ -281,14 +281,19 @@ def codegen_cpp(prefix: str, f: Function, input_name_to_idx: Dict[str, int], inp funcname = "SkipListArgMin" scope.scope.append(_CppSingleLine(scope, f'auto v{idx} = {funcname}<{elem_type}, {simd_lanes}>(v{inp[0]}, i);')) elif isinstance(op, GloablStatefulOpTrait): - if stream_mode: raise RuntimeError(f"Stream Mode does not support {op.__class__.__name__}") assert(op.get_parent() is None) args = {} if isinstance(op, WindowedTrait): buf_name = _get_buffer_name(op.inputs[0], inp[0]) args["buf_name"] = buf_name vargs = [f"v{inpv}" for inpv in inp] - toplevel.scope.insert(-1, _CppSingleLine(toplevel, op.generate_init_code(idx, elem_type, simd_lanes, vargs, aligned))) + if stream_mode: + cur_idx = len(stream_state_buffer_init) + var_init_code, init_buffer_code = op.generate_init_code_stream(idx, cur_idx, elem_type, simd_lanes, vargs, aligned) + toplevel.scope.insert(-1, _CppSingleLine(toplevel, var_init_code)) + stream_state_buffer_init.append(init_buffer_code) + else: + toplevel.scope.insert(-1, _CppSingleLine(toplevel, op.generate_init_code(idx, elem_type, simd_lanes, vargs, aligned))) scope.scope.append(_CppSingleLine(scope, op.generate_step_code(idx, "i", vargs, **args))) elif isinstance(op, Select): scope.scope.append(_CppSingleLine(scope, f"auto v{idx} = Select(v{inp[0]}, v{inp[1]}, v{inp[2]});")) diff --git a/cpp/Kun/Context.hpp b/cpp/Kun/Context.hpp index a33ed37..76271fc 100644 --- a/cpp/Kun/Context.hpp +++ b/cpp/Kun/Context.hpp @@ -2,6 +2,7 @@ #include "Stage.hpp" #include "StreamBuffer.hpp" +#include "StateBuffer.hpp" #include #include #include @@ -9,6 +10,7 @@ namespace kun { +static const uint64_t VERSION = 0x64100004; struct KUN_API RuntimeStage { const Stage *stage; Context *ctx; @@ -127,6 +129,7 @@ struct Context { size_t simd_len; Datatype dtype; bool is_stream; + StateBufferPtr* state_buffers; }; KUN_API std::shared_ptr createSingleThreadExecutor(); diff --git a/cpp/Kun/MathUtil.hpp b/cpp/Kun/MathUtil.hpp new file mode 100644 index 0000000..71211a5 --- /dev/null +++ b/cpp/Kun/MathUtil.hpp @@ -0,0 +1,10 @@ +#pragma once +#include + +namespace kun { +namespace { +size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; } +size_t roundUp(size_t x, size_t y) { return divideAndCeil(x, y) * y; } + +} // namespace +} // namespace kun \ No newline at end of file diff --git a/cpp/Kun/Module.hpp b/cpp/Kun/Module.hpp index 4e36c90..da65d4a 100644 --- a/cpp/Kun/Module.hpp +++ b/cpp/Kun/Module.hpp @@ -1,8 +1,9 @@ #pragma once #include "Stage.hpp" -#include +#include "StateBuffer.hpp" #include +#include namespace kun { @@ -12,7 +13,6 @@ enum class MemoryLayout { STREAM, }; - struct Module { size_t required_version; size_t num_stages; @@ -24,11 +24,12 @@ struct Module { size_t blocking_len; Datatype dtype; size_t aligned; + std::vector (*init_state_buffers)(size_t num_stocks); }; struct Library { void *handle; - std::function dtor; + std::function dtor; KUN_API const Module *getModule(const char *name); KUN_API static std::shared_ptr load(const char *filename); Library(const Library &) = delete; diff --git a/cpp/Kun/Ops.hpp b/cpp/Kun/Ops.hpp index eac8615..4923384 100644 --- a/cpp/Kun/Ops.hpp +++ b/cpp/Kun/Ops.hpp @@ -348,6 +348,7 @@ struct ExpMovingAvg { using simd_int_t = kun_simd::vec::int_t, stride>; simd_t v; + ExpMovingAvg() : v{NAN} {} ExpMovingAvg(const simd_t &init) : v{init} {} static constexpr T weight_latest = T(2.0) / (window + 1); simd_t step(simd_t cur, size_t index) { diff --git a/cpp/Kun/RunGraph.hpp b/cpp/Kun/RunGraph.hpp index 966fafd..3e1ab14 100644 --- a/cpp/Kun/RunGraph.hpp +++ b/cpp/Kun/RunGraph.hpp @@ -2,6 +2,7 @@ #include "Context.hpp" #include "Module.hpp" +#include "StateBuffer.hpp" #include #include #include @@ -34,6 +35,7 @@ struct AlignedPtr { struct KUN_API StreamContext { std::vector buffers; + std::vector state_buffers; Context ctx; const Module *m; StreamContext(std::shared_ptr exec, const Module *m, diff --git a/cpp/Kun/Runtime.cpp b/cpp/Kun/Runtime.cpp index 0dc784b..38c2fa1 100644 --- a/cpp/Kun/Runtime.cpp +++ b/cpp/Kun/Runtime.cpp @@ -18,12 +18,6 @@ #define kunAlignedFree(x) free(x) #endif -#ifdef __AVX__ -#define MALLOC_ALIGNMENT 64 // AVX-512 alignment -#else -#define MALLOC_ALIGNMENT 16 // NEON alignment -#endif - #if CHECKED_PTR #include #include @@ -76,11 +70,10 @@ void checkedDealloc(void *ptr, size_t sz) { #endif namespace kun { -static const uint64_t VERSION = 0x64100003; void Buffer::alloc(size_t count, size_t use_count, size_t elem_size) { if (!ptr) { - ptr = (float *)kunAlignedAlloc(MALLOC_ALIGNMENT, count * elem_size); + ptr = (float *)kunAlignedAlloc(KUN_MALLOC_ALIGNMENT, count * elem_size); refcount = (int)use_count; #if CHECKED_PTR size = count * elem_size; @@ -230,7 +223,7 @@ void corrWith(std::shared_ptr exec, MemoryLayout layout, length, 8, Datatype::Float, - false}; + false, nullptr}; std::vector &stages = ctx.stages; stages.reserve(buffers.size()); for (size_t i = 0; i < buffers.size(); i++) { @@ -284,7 +277,7 @@ void runGraph(std::shared_ptr exec, const Module *m, length, m->blocking_len, m->dtype, - false}; + false, nullptr}; std::vector &stages = ctx.stages; stages.reserve(m->num_stages); for (size_t i = 0; i < m->num_stages; i++) { @@ -368,6 +361,12 @@ StreamContext::StreamContext(std::shared_ptr exec, const Module *m, throw std::runtime_error( "Cannot run batch mode module via StreamContext"); } + if (m->init_state_buffers) { + state_buffers = m->init_state_buffers(num_stocks); + for (auto &buf : state_buffers) { + buf->initialize(); + } + } std::vector rtlbuffers; rtlbuffers.reserve(m->num_buffers); buffers.reserve(m->num_buffers); @@ -404,6 +403,7 @@ StreamContext::StreamContext(std::shared_ptr exec, const Module *m, ctx.dtype = m->dtype; ctx.is_stream = true; ctx.simd_len = m->blocking_len; + ctx.state_buffers = state_buffers.data(); } size_t StreamContext::queryBufferHandle(const char *name) const { @@ -458,4 +458,25 @@ void StreamContext::run() { } StreamContext::~StreamContext() = default; + + +StateBuffer *StateBuffer::make(size_t num_objs, size_t elem_size, + CtorFn_t ctor_fn, DtorFn_t dtor_fn) { + auto ret = kunAlignedAlloc(KUN_MALLOC_ALIGNMENT, + sizeof(StateBuffer) + num_objs * elem_size); + auto buf = (StateBuffer *)ret; + buf->num_objs = num_objs; + buf->elem_size = elem_size; + buf->initialized = 0; + buf->ctor_fn = ctor_fn; + buf->dtor_fn = dtor_fn; + return buf; +} + +void StateBuffer::Deleter::operator()(StateBuffer *buf) { + if (buf->initialized) { + buf->dtor_fn(buf); + } + kunAlignedFree(buf); +} } // namespace kun diff --git a/cpp/Kun/StateBuffer.hpp b/cpp/Kun/StateBuffer.hpp new file mode 100644 index 0000000..ca588cd --- /dev/null +++ b/cpp/Kun/StateBuffer.hpp @@ -0,0 +1,72 @@ +#pragma once + +#include "Base.hpp" +#include "MathUtil.hpp" +#include +#include + +namespace kun { + +#ifdef __AVX__ +#define KUN_MALLOC_ALIGNMENT 64 // AVX-512 alignment +#else +#define KUN_MALLOC_ALIGNMENT 16 // NEON alignment +#endif + +struct StateBuffer { + using DtorFn_t = void (*)(StateBuffer *obj); + using CtorFn_t = void (*)(StateBuffer *obj); + + alignas(KUN_MALLOC_ALIGNMENT) size_t num_objs; + uint32_t elem_size; + uint32_t initialized; + CtorFn_t ctor_fn; + DtorFn_t dtor_fn; + alignas(KUN_MALLOC_ALIGNMENT) char buf[0]; + + KUN_API static StateBuffer *make(size_t num_objs, size_t elem_size, + CtorFn_t ctor_fn, DtorFn_t dtor_fn); + + // for std::unique_ptr + struct Deleter { + KUN_API void operator()(StateBuffer *buf); + }; + + template + T &get(size_t idx) { + return *reinterpret_cast(buf + idx * sizeof(T)); + } + + void initialize() { + initialized = 1; + ctor_fn(this); + } + void destroy() { + if (initialized) { + dtor_fn(this); + } + initialized = 0; + } + + private: + StateBuffer() = default; +}; + +using StateBufferPtr = std::unique_ptr; + +template +StateBufferPtr makeStateBuffer(size_t num_stocks, size_t simd_len) { + return StateBufferPtr(StateBuffer::make( + divideAndCeil(num_stocks, simd_len), sizeof(T), + [](StateBuffer *obj) { + for (size_t i = 0; i < obj->num_objs; i++) { + new (&obj->get(i)) T(); + } + }, + [](StateBuffer *obj) { + for (size_t i = 0; i < obj->num_objs; i++) { + obj->get(i).~T(); + } + })); +} +} // namespace kun \ No newline at end of file diff --git a/cpp/Kun/StreamBuffer.hpp b/cpp/Kun/StreamBuffer.hpp index 57e5e7b..0c67529 100644 --- a/cpp/Kun/StreamBuffer.hpp +++ b/cpp/Kun/StreamBuffer.hpp @@ -1,14 +1,10 @@ #pragma once #include "Base.hpp" +#include "MathUtil.hpp" #include #include namespace kun { -namespace { -size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; } -size_t roundUp(size_t x, size_t y) { return divideAndCeil(x, y) * y; } - -} // namespace template struct StreamBuffer { // [#stock_count of float data] diff --git a/cpp/Python/PyBinding.cpp b/cpp/Python/PyBinding.cpp index 5616771..5edaa94 100644 --- a/cpp/Python/PyBinding.cpp +++ b/cpp/Python/PyBinding.cpp @@ -55,12 +55,13 @@ struct ModuleHandle { const std::shared_ptr &lib) : modu{modu}, lib{lib} {} }; -struct StreamContextWrapper : kun::StreamContext { +struct StreamContextWrapper { std::shared_ptr lib; + kun::StreamContext ctx; StreamContextWrapper(std::shared_ptr exec, const ModuleHandle *m, size_t num_stocks) - : kun::StreamContext{std::move(exec), m->modu, num_stocks}, - lib{m->lib} {} + : lib{m->lib}, ctx{std::move(exec), m->modu, num_stocks} + {} }; } // namespace @@ -402,9 +403,13 @@ PYBIND11_MODULE(KunRunner, m) { py::class_(m, "StreamContext") .def(py::init, const ModuleHandle *, size_t>()) - .def("queryBufferHandle", &StreamContextWrapper::queryBufferHandle) + .def("queryBufferHandle", + [](StreamContextWrapper &t, const char *name) { + return t.ctx.queryBufferHandle(name); + }) .def("getCurrentBuffer", - [](StreamContextWrapper &ths, size_t handle) -> py::buffer { + [](StreamContextWrapper &t, size_t handle) -> py::buffer { + auto &ths = t.ctx; if (ths.m->dtype == kun::Datatype::Double) { auto buf = ths.getCurrentBufferPtrDouble(handle); return py::array_t{ @@ -416,7 +421,8 @@ PYBIND11_MODULE(KunRunner, m) { }) .def( "pushData", - [](StreamContextWrapper &ths, size_t handle, py::array data) { + [](StreamContextWrapper &t, size_t handle, py::array data) { + auto &ths = t.ctx; py::ssize_t ndim; if (ths.m->dtype == kun::Datatype::Float) { if (!py::isinstance>( @@ -442,5 +448,5 @@ PYBIND11_MODULE(KunRunner, m) { ths.pushData(handle, (const double *)data.data()); } }) - .def("run", &StreamContextWrapper::run); + .def("run", [](StreamContextWrapper &t) { t.ctx.run(); }); } \ No newline at end of file diff --git a/tests/cpp/TestRuntime.cpp b/tests/cpp/TestRuntime.cpp index 4d9b570..206793c 100644 --- a/tests/cpp/TestRuntime.cpp +++ b/tests/cpp/TestRuntime.cpp @@ -87,7 +87,7 @@ Stage *stage2_dep[] = {&stages[2]}; } // namespace KUN_EXPORT Module testRuntimeModule{ - 0x64100003, + kun::VERSION, arraySize(stages), stages, arraySize(buffers), diff --git a/tests/test_stream.py b/tests/test_stream.py new file mode 100644 index 0000000..78e19a8 --- /dev/null +++ b/tests/test_stream.py @@ -0,0 +1,47 @@ +import numpy as np +import pandas as pd +from KunQuant.jit import cfake +from KunQuant.Op import Input, Output, Builder +from KunQuant.Stage import Function +from KunQuant.Op import * +from KunQuant.ops import * +from KunQuant.runner import KunRunner as kr + + +def test_stream(): + builder = Builder() + with builder: + inp1 = Input("a") + Output(WindowedQuantile(inp1, 10, 0.49), "quantile") + Output(ExpMovingAvg(inp1, 10), "ema") + Output(WindowedLinearRegressionSlope(inp1, 10), "slope") + f = Function(builder.ops) + lib = cfake.compileit([("stream_test", f, cfake.KunCompilerConfig(dtype="double", input_layout="STREAM", output_layout="STREAM"))], + "stream_test", cfake.CppCompilerConfig()) + + executor = kr.createSingleThreadExecutor() + stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24) + a = np.random.rand(100, 24) + handle_a = stream.queryBufferHandle("a") + handle_quantile = stream.queryBufferHandle("quantile") + handle_ema = stream.queryBufferHandle("ema") + handle_slope = stream.queryBufferHandle("slope") + out = np.empty((100, 24)) + ema = np.empty((100, 24)) + slope = np.empty((100, 24)) + for i in range(100): + stream.pushData(handle_a, a[i]) + stream.run() + out[i] = stream.getCurrentBuffer(handle_quantile) + ema[i] = stream.getCurrentBuffer(handle_ema) + slope[i] = stream.getCurrentBuffer(handle_slope) + df = pd.DataFrame(a) + expected_quantile = df.rolling(10).quantile(0.49, interpolation='linear').to_numpy() + expected_ema = df.ewm(span=10, adjust=False, ignore_na=True).mean().to_numpy() + expected_slope = df.rolling(10).apply(lambda x: np.polyfit(np.arange(len(x)), x, 1)[0]).to_numpy() + np.testing.assert_allclose(out, expected_quantile, atol=1e-6, rtol=1e-4, equal_nan=True) + np.testing.assert_allclose(ema, expected_ema, atol=1e-6, rtol=1e-4, equal_nan=True) + np.testing.assert_allclose(slope[10:], expected_slope[10:], atol=1e-6, rtol=1e-4, equal_nan=True) + +test_stream() +print("test_stream passed") \ No newline at end of file diff --git a/tests/tests.sh b/tests/tests.sh index 5c80b8d..85902b8 100644 --- a/tests/tests.sh +++ b/tests/tests.sh @@ -4,6 +4,8 @@ python tests/test.py python tests/test2.py echo "KunQuant runtime tests" python tests/test_runtime.py +echo "KunQuant stream tests" +python tests/test_stream.py echo "KunQuant runtime tests (AVX)" KUN_TEST_NO_AVX2=1 python tests/test_runtime.py echo "KunQuant alpha101 tests" diff --git a/tests/tests_arm.sh b/tests/tests_arm.sh index 8d61eba..58427a2 100644 --- a/tests/tests_arm.sh +++ b/tests/tests_arm.sh @@ -17,6 +17,8 @@ python tests/test.py python tests/test2.py echo "KunQuant runtime tests" python tests/test_runtime.py +echo "KunQuant stream tests" +python tests/test_stream.py echo "KunQuant alpha101 tests" python tests/test_alpha101.py arm echo "KunQuant alpha158 tests" From 9f8fc9d2d3a6d9f0ab72039b65b5f7bdec49123b Mon Sep 17 00:00:00 2001 From: Menooker Date: Wed, 17 Dec 2025 20:22:27 +0800 Subject: [PATCH 02/11] fix py3.9 --- KunQuant/Driver.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/KunQuant/Driver.py b/KunQuant/Driver.py index d0c1966..b9ec17c 100644 --- a/KunQuant/Driver.py +++ b/KunQuant/Driver.py @@ -312,10 +312,11 @@ def query_temp_buf_id(tempname: str, window: int) -> int: ''') if len(stream_state_buffer_init) > 0: + stream_state_str = "\n ".join(stream_state_buffer_init) impl_src.append(f'''static std::vector __init_state_buffers(size_t stock_count) {{ std::vector buffers; buffers.reserve({len(stream_state_buffer_init)}); - {"\n ".join(stream_state_buffer_init)} + {stream_state_str} return buffers; }} ''') From eaf8762bd3b22dc0befd458c4110fb019d4e4d46 Mon Sep 17 00:00:00 2001 From: Menooker Date: Wed, 17 Dec 2025 20:26:46 +0800 Subject: [PATCH 03/11] add windows stream test --- .github/workflows/ccpp.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ccpp.yml b/.github/workflows/ccpp.yml index 86fe071..8131ea1 100644 --- a/.github/workflows/ccpp.yml +++ b/.github/workflows/ccpp.yml @@ -106,6 +106,7 @@ jobs: python tests/test.py python tests/test2.py python tests/test_runtime.py + python tests/test_stream.py - name: Alpha158 test working-directory: ./ run: | From a8f46a7b4b260ccd9b243c25f49e2d25b17a6ccb Mon Sep 17 00:00:00 2001 From: Menooker Date: Wed, 17 Dec 2025 22:35:08 +0800 Subject: [PATCH 04/11] stream ser deser --- cpp/Kun/IO.cpp | 39 ++++++++++++++++++++++++ cpp/Kun/IO.hpp | 38 +++++++++++++++++++++++ cpp/Kun/Ops.hpp | 31 +++++++++++++++++++ cpp/Kun/Ops/Quantile.hpp | 27 +++++++++++++++++ cpp/Kun/RunGraph.hpp | 3 +- cpp/Kun/Runtime.cpp | 29 +++++++++++++++--- cpp/Kun/SkipList.cpp | 65 ++++++++++++++++++++++++++++++++++++++++ cpp/Kun/SkipList.hpp | 18 +++++++++-- cpp/Kun/StateBuffer.hpp | 43 +++++++++++++++----------- cpp/Python/PyBinding.cpp | 45 ++++++++++++++++++++++++++-- tests/test_stream.py | 37 +++++++++++++++++++++-- 11 files changed, 346 insertions(+), 29 deletions(-) create mode 100644 cpp/Kun/IO.cpp create mode 100644 cpp/Kun/IO.hpp diff --git a/cpp/Kun/IO.cpp b/cpp/Kun/IO.cpp new file mode 100644 index 0000000..e9d679b --- /dev/null +++ b/cpp/Kun/IO.cpp @@ -0,0 +1,39 @@ +#include "IO.hpp" +#include + +namespace kun { +bool MemoryInputStream::read(void *buf, size_t len) { + if (pos + len > size) { + return false; + } + std::memcpy(buf, data + pos, len); + pos += len; + return true; +} + +bool MemoryOutputStream::write(const void *buf, size_t len) { + const char *cbuf = reinterpret_cast(buf); + buffer.insert(buffer.end(), cbuf, cbuf + len); + return true; +} + +FileInputStream::FileInputStream(const std::string &filename) + : file(filename, std::ios::binary) {} + +bool FileInputStream::read(void *buf, size_t len) { + if (!file.read(reinterpret_cast(buf), len)) { + return false; + } + return true; +} + +FileOutputStream::FileOutputStream(const std::string &filename) + : file(filename, std::ios::binary) {} + +bool FileOutputStream::write(const void *buf, size_t len) { + if (!file.write(reinterpret_cast(buf), len)) { + return false; + } + return true; +} +} // namespace kun \ No newline at end of file diff --git a/cpp/Kun/IO.hpp b/cpp/Kun/IO.hpp new file mode 100644 index 0000000..5187338 --- /dev/null +++ b/cpp/Kun/IO.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include "StateBuffer.hpp" +#include +#include + +namespace kun { +struct KUN_API MemoryInputStream final : public InputStreamBase { + const char *data; + size_t size; + size_t pos; + + MemoryInputStream(const char *data, size_t size) + : data(data), size(size), pos(0) {} + + bool read(void *buf, size_t len) override; +}; +struct KUN_API MemoryOutputStream final : public OutputStreamBase { + std::vector buffer; + + bool write(const void *buf, size_t len) override; + + const char *getData() const { return buffer.data(); } + size_t getSize() const { return buffer.size(); } +}; +struct KUN_API FileInputStream final : public InputStreamBase { + std::ifstream file; + + FileInputStream(const std::string &filename); + + bool read(void *buf, size_t len) override; +}; +struct KUN_API FileOutputStream final : public OutputStreamBase { + std::ofstream file; + FileOutputStream(const std::string &filename); + bool write(const void *buf, size_t len) override; +}; +} // namespace kun diff --git a/cpp/Kun/Ops.hpp b/cpp/Kun/Ops.hpp index 4923384..457a2da 100644 --- a/cpp/Kun/Ops.hpp +++ b/cpp/Kun/Ops.hpp @@ -3,6 +3,7 @@ #include "Base.hpp" #include "Math.hpp" #include "StreamBuffer.hpp" +#include "StateBuffer.hpp" #include #include #include @@ -11,6 +12,18 @@ namespace kun { namespace ops { + +template +struct Serializer { + static bool serialize(StateBuffer *obj, OutputStreamBase *stream) { + return stream->write(obj->buf, sizeof(T) * obj->num_objs); + } + + static bool deserialize(StateBuffer *obj, InputStreamBase *stream) { + return stream->read(obj->buf, sizeof(T) * obj->num_objs); + } +}; + template struct DataSource { constexpr static bool containsWindow = vcontainsWindow; @@ -706,5 +719,23 @@ inline DecayVec_t SetInfOrNanToValue(T aa, T2 v) { return sc_select(mask, v, a); } + +template +StateBufferPtr makeStateBuffer(size_t num_stocks, size_t simd_len) { + return StateBufferPtr(StateBuffer::make( + divideAndCeil(num_stocks, simd_len), sizeof(T), + [](StateBuffer *obj) { + for (size_t i = 0; i < obj->num_objs; i++) { + new (&obj->get(i)) T(); + } + }, + [](StateBuffer *obj) { + for (size_t i = 0; i < obj->num_objs; i++) { + obj->get(i).~T(); + } + }, Serializer::serialize, + Serializer::deserialize)); +} + } // namespace ops } // namespace kun \ No newline at end of file diff --git a/cpp/Kun/Ops/Quantile.hpp b/cpp/Kun/Ops/Quantile.hpp index 2511126..72edf42 100644 --- a/cpp/Kun/Ops/Quantile.hpp +++ b/cpp/Kun/Ops/Quantile.hpp @@ -48,6 +48,33 @@ struct SkipListState : SkipListStateImpl { }; } // namespace + +template +struct Serializer> { + static bool serialize(StateBuffer *obj, OutputStreamBase *stream) { + for(size_t i = 0; i < obj->num_objs; i++) { + auto &state = obj->get>(i); + if (!serializeSkipList(state.skipList, state.lastInsertRank, simdLen, + expectedwindow, stream)) { + return false; + } + } + return true; + } + + static bool deserialize(StateBuffer *obj, InputStreamBase *stream) { + for(size_t i = 0; i < obj->num_objs; i++) { + auto &state = obj->get>(i); + new (&state) SkipListState{}; + if (!deserializeSkipList(state.skipList, state.lastInsertRank, simdLen, + expectedwindow, stream)) { + return false; + } + } + return true; + } +}; + // https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/window/aggregations.pyx template diff --git a/cpp/Kun/RunGraph.hpp b/cpp/Kun/RunGraph.hpp index 3e1ab14..52e8804 100644 --- a/cpp/Kun/RunGraph.hpp +++ b/cpp/Kun/RunGraph.hpp @@ -39,7 +39,7 @@ struct KUN_API StreamContext { Context ctx; const Module *m; StreamContext(std::shared_ptr exec, const Module *m, - size_t num_stocks); + size_t num_stocks, InputStreamBase* states = nullptr); // query the buffer handle of a named buffer size_t queryBufferHandle(const char *name) const; // get the current readable position of the named buffer. The returned @@ -54,6 +54,7 @@ struct KUN_API StreamContext { StreamContext(const StreamContext&) = delete; StreamContext& operator=(const StreamContext&) = delete; ~StreamContext(); + bool serializeStates(OutputStreamBase* stream); }; } // namespace kun \ No newline at end of file diff --git a/cpp/Kun/Runtime.cpp b/cpp/Kun/Runtime.cpp index 38c2fa1..255eb59 100644 --- a/cpp/Kun/Runtime.cpp +++ b/cpp/Kun/Runtime.cpp @@ -351,7 +351,7 @@ template struct StreamBuffer; template struct StreamBuffer; StreamContext::StreamContext(std::shared_ptr exec, const Module *m, - size_t num_stocks) + size_t num_stocks, InputStreamBase* states) : m{m} { if (m->required_version != VERSION) { throw std::runtime_error("The required version in the module does not " @@ -363,8 +363,17 @@ StreamContext::StreamContext(std::shared_ptr exec, const Module *m, } if (m->init_state_buffers) { state_buffers = m->init_state_buffers(num_stocks); - for (auto &buf : state_buffers) { - buf->initialize(); + if (states) { + for (auto &buf : state_buffers) { + if (!buf->deserialize(states)) { + throw std::runtime_error( + "Failed to deserialize state buffer"); + } + } + } else { + for (auto &buf : state_buffers) { + buf->initialize(); + } } } std::vector rtlbuffers; @@ -460,8 +469,18 @@ void StreamContext::run() { StreamContext::~StreamContext() = default; +bool StreamContext::serializeStates(OutputStreamBase* stream) { + for (auto& ptr: state_buffers) { + if (!ptr->serialize(stream)) { + return false; + } + } + return true; +} + StateBuffer *StateBuffer::make(size_t num_objs, size_t elem_size, - CtorFn_t ctor_fn, DtorFn_t dtor_fn) { + CtorFn_t ctor_fn, DtorFn_t dtor_fn, SerializeFn_t serialize_fn, + DeserializeFn_t deserialize_fn) { auto ret = kunAlignedAlloc(KUN_MALLOC_ALIGNMENT, sizeof(StateBuffer) + num_objs * elem_size); auto buf = (StateBuffer *)ret; @@ -470,6 +489,8 @@ StateBuffer *StateBuffer::make(size_t num_objs, size_t elem_size, buf->initialized = 0; buf->ctor_fn = ctor_fn; buf->dtor_fn = dtor_fn; + buf->serialize_fn = serialize_fn; + buf->deserialize_fn = deserialize_fn; return buf; } diff --git a/cpp/Kun/SkipList.cpp b/cpp/Kun/SkipList.cpp index f216bf3..cb352b9 100644 --- a/cpp/Kun/SkipList.cpp +++ b/cpp/Kun/SkipList.cpp @@ -17,6 +17,7 @@ Python recipe (https://rhettinger.wordpress.com/2010/02/06/lost-knowledge/) */ #include "SkipList.hpp" +#include "StateBuffer.hpp" #include #include #include @@ -294,4 +295,68 @@ double SkipList::get(int rank, size_t &index, bool &found) { } int SkipList::size() const { return impl->size; } + +bool SkipList::serialize(OutputStreamBase *stream, int expsize) const { + if (!stream->write(&impl->size, sizeof(impl->size))) { + return false; + } + for(int i=0;isize;i++) { + size_t index; + bool found; + double value = impl->get(i, index, found); + if (!found) { + return false; + } + if (!stream->write(&value, sizeof(value))) { + return false; + } + if (!stream->write(&index, sizeof(index))) { + return false; + } + } + return true; +} + +bool SkipList::deserialize(InputStreamBase *stream, int expsize) { + size_t size; + if (!stream->read(&size, sizeof(size))) { + return false; + } + for (size_t i = 0; i < size; i++) { + double value; + size_t index; + if (!stream->read(&value, sizeof(value))) { + return false; + } + if (!stream->read(&index, sizeof(index))) { + return false; + } + impl->insert(value, index); + } + return true; +} + + +bool serializeSkipList(SkipList* skiplist, int* lastInsertRank, size_t size, size_t window, OutputStreamBase* stream) { + if (!stream->write(lastInsertRank, size * sizeof(int))) { + return false; + } + for (size_t i = 0; i < size; i++) { + if (!skiplist->serialize(stream, window)) { + return false; + } + } + return true; +} +bool deserializeSkipList(SkipList* skiplist, int* lastInsertRank, size_t size, size_t window, InputStreamBase* stream) { + if (!stream->read(lastInsertRank, size * sizeof(int))) { + return false; + } + for (size_t i = 0; i < size; i++) { + if (!skiplist[i].deserialize(stream, window)) { + return false; + } + } + return true; +} } // namespace kun \ No newline at end of file diff --git a/cpp/Kun/SkipList.hpp b/cpp/Kun/SkipList.hpp index f58da9a..2bc1ac8 100644 --- a/cpp/Kun/SkipList.hpp +++ b/cpp/Kun/SkipList.hpp @@ -6,9 +6,13 @@ namespace kun { namespace detail { - struct SkipListImpl; +struct SkipListImpl; } +struct OutputStreamBase; +struct InputStreamBase; + + struct KUN_API SkipList { std::unique_ptr impl; SkipList(int size); @@ -22,8 +26,16 @@ struct KUN_API SkipList { // remove the first inserted element with the given value bool remove(double value); int minRank(double value); - double get(int rank, size_t& index, bool& found); + double get(int rank, size_t &index, bool &found); int size() const; + bool serialize(OutputStreamBase *stream, int expsize) const; + bool deserialize(InputStreamBase *stream, int expsize); }; -} \ No newline at end of file +KUN_API bool serializeSkipList(SkipList *skiplist, int *lastInsertRank, size_t size, + size_t window, OutputStreamBase *stream); +KUN_API bool deserializeSkipList(/*uninitialized*/ SkipList *skiplist, + int *lastInsertRank, size_t size, size_t window, + InputStreamBase *stream); + +} // namespace kun \ No newline at end of file diff --git a/cpp/Kun/StateBuffer.hpp b/cpp/Kun/StateBuffer.hpp index ca588cd..da461d3 100644 --- a/cpp/Kun/StateBuffer.hpp +++ b/cpp/Kun/StateBuffer.hpp @@ -13,19 +13,34 @@ namespace kun { #define KUN_MALLOC_ALIGNMENT 16 // NEON alignment #endif +struct InputStreamBase { + virtual bool read(void* buf, size_t len) = 0; + virtual ~InputStreamBase() = default; +}; + +struct OutputStreamBase { + virtual bool write(const void* buf, size_t len) = 0; + virtual ~OutputStreamBase() = default; +}; + struct StateBuffer { using DtorFn_t = void (*)(StateBuffer *obj); using CtorFn_t = void (*)(StateBuffer *obj); + using SerializeFn_t = bool (*)(StateBuffer *obj, OutputStreamBase *stream); + using DeserializeFn_t = bool (*)(StateBuffer *obj, InputStreamBase *stream); alignas(KUN_MALLOC_ALIGNMENT) size_t num_objs; uint32_t elem_size; uint32_t initialized; CtorFn_t ctor_fn; DtorFn_t dtor_fn; + SerializeFn_t serialize_fn; + DeserializeFn_t deserialize_fn; alignas(KUN_MALLOC_ALIGNMENT) char buf[0]; KUN_API static StateBuffer *make(size_t num_objs, size_t elem_size, - CtorFn_t ctor_fn, DtorFn_t dtor_fn); + CtorFn_t ctor_fn, DtorFn_t dtor_fn, SerializeFn_t serialize_fn, + DeserializeFn_t deserialize_fn); // for std::unique_ptr struct Deleter { @@ -47,26 +62,20 @@ struct StateBuffer { } initialized = 0; } - + bool serialize(OutputStreamBase *stream) { + return serialize_fn(this, stream); + } + bool deserialize(InputStreamBase *stream) { + if (deserialize_fn(this, stream)) { + initialized = 1; + return true; + } + return false; + } private: StateBuffer() = default; }; using StateBufferPtr = std::unique_ptr; -template -StateBufferPtr makeStateBuffer(size_t num_stocks, size_t simd_len) { - return StateBufferPtr(StateBuffer::make( - divideAndCeil(num_stocks, simd_len), sizeof(T), - [](StateBuffer *obj) { - for (size_t i = 0; i < obj->num_objs; i++) { - new (&obj->get(i)) T(); - } - }, - [](StateBuffer *obj) { - for (size_t i = 0; i < obj->num_objs; i++) { - obj->get(i).~T(); - } - })); -} } // namespace kun \ No newline at end of file diff --git a/cpp/Python/PyBinding.cpp b/cpp/Python/PyBinding.cpp index 5edaa94..fb7d6e3 100644 --- a/cpp/Python/PyBinding.cpp +++ b/cpp/Python/PyBinding.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #ifdef _WIN32 @@ -59,8 +60,8 @@ struct StreamContextWrapper { std::shared_ptr lib; kun::StreamContext ctx; StreamContextWrapper(std::shared_ptr exec, - const ModuleHandle *m, size_t num_stocks) - : lib{m->lib}, ctx{std::move(exec), m->modu, num_stocks} + const ModuleHandle *m, size_t num_stocks, kun::InputStreamBase* states = nullptr) + : lib{m->lib}, ctx{std::move(exec), m->modu, num_stocks, states} {} }; } // namespace @@ -403,6 +404,27 @@ PYBIND11_MODULE(KunRunner, m) { py::class_(m, "StreamContext") .def(py::init, const ModuleHandle *, size_t>()) + .def(py::init([](std::shared_ptr exec, + const ModuleHandle *mod, size_t stocks, + py::object init) { + if (py::isinstance(init)) { + auto filename = py::cast(init); + kun::FileInputStream stream(filename); + return new StreamContextWrapper(std::move(exec), mod, stocks, + &stream); + } else if (py::isinstance(init)) { + py::bytes b = py::cast(init); + char *data; + py::ssize_t size; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(b.ptr(), &data, &size)) + throw std::runtime_error("Failed to get bytes data"); + kun::MemoryInputStream stream{data, (size_t)size}; + return new StreamContextWrapper(std::move(exec), mod, stocks, + &stream); + } + throw std::runtime_error( + "Bad type for init, expecting filename or bytes"); + })) .def("queryBufferHandle", [](StreamContextWrapper &t, const char *name) { return t.ctx.queryBufferHandle(name); @@ -448,5 +470,24 @@ PYBIND11_MODULE(KunRunner, m) { ths.pushData(handle, (const double *)data.data()); } }) + .def( + "serializeStates", + [](StreamContextWrapper &t, py::object fileNameOrNone) -> py::object { + if (py::isinstance(fileNameOrNone)) { + auto filename = py::cast(fileNameOrNone); + kun::FileOutputStream stream(filename); + if (!t.ctx.serializeStates(&stream)) { + throw std::runtime_error("Failed to serialize states"); + } + return py::none(); + } + kun::MemoryOutputStream stream; + if (!t.ctx.serializeStates(&stream)) { + throw std::runtime_error("Failed to serialize states"); + } + py::bytes b(stream.getData(), (py::ssize_t)stream.getSize()); + return b; + }, + py::arg("fileNameOrNone") = py::none()) .def("run", [](StreamContextWrapper &t) { t.ctx.run(); }); } \ No newline at end of file diff --git a/tests/test_stream.py b/tests/test_stream.py index 78e19a8..4d3439e 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,3 +1,4 @@ +import os import numpy as np import pandas as pd from KunQuant.jit import cfake @@ -6,9 +7,10 @@ from KunQuant.Op import * from KunQuant.ops import * from KunQuant.runner import KunRunner as kr +from tempfile import NamedTemporaryFile +def make_steam(): -def test_stream(): builder = Builder() with builder: inp1 = Input("a") @@ -18,7 +20,9 @@ def test_stream(): f = Function(builder.ops) lib = cfake.compileit([("stream_test", f, cfake.KunCompilerConfig(dtype="double", input_layout="STREAM", output_layout="STREAM"))], "stream_test", cfake.CppCompilerConfig()) + return lib +def test_stream(lib): executor = kr.createSingleThreadExecutor() stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24) a = np.random.rand(100, 24) @@ -35,6 +39,7 @@ def test_stream(): out[i] = stream.getCurrentBuffer(handle_quantile) ema[i] = stream.getCurrentBuffer(handle_ema) slope[i] = stream.getCurrentBuffer(handle_slope) + df = pd.DataFrame(a) expected_quantile = df.rolling(10).quantile(0.49, interpolation='linear').to_numpy() expected_ema = df.ewm(span=10, adjust=False, ignore_na=True).mean().to_numpy() @@ -43,5 +48,33 @@ def test_stream(): np.testing.assert_allclose(ema, expected_ema, atol=1e-6, rtol=1e-4, equal_nan=True) np.testing.assert_allclose(slope[10:], expected_slope[10:], atol=1e-6, rtol=1e-4, equal_nan=True) -test_stream() + + +def test_stream_ser_deser(lib): + executor = kr.createSingleThreadExecutor() + stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24) + a = np.random.rand(100, 24) + handle_a = stream.queryBufferHandle("a") + handle_quantile = stream.queryBufferHandle("quantile") + handle_ema = stream.queryBufferHandle("ema") + handle_slope = stream.queryBufferHandle("slope") + out = np.empty((100, 24)) + ema = np.empty((100, 24)) + slope = np.empty((100, 24)) + + for i in range(50): + stream.pushData(handle_a, a[i]) + stream.run() + out[i] = stream.getCurrentBuffer(handle_quantile) + ema[i] = stream.getCurrentBuffer(handle_ema) + slope[i] = stream.getCurrentBuffer(handle_slope) + stream.serializeStates("kun_test_stream_state.bin") + states = stream.serializeStates() + with open("kun_test_stream_state.bin", "rb") as f: + assert(f.read() == states) + os.remove("kun_test_stream_state.bin") + print("length of serialized states:", len(states)) +lib = make_steam() +test_stream(lib) +test_stream_ser_deser(lib) print("test_stream passed") \ No newline at end of file From 3df6f1cd77abda87e1a3abe0c741fc4c7e499977 Mon Sep 17 00:00:00 2001 From: Menooker Date: Wed, 17 Dec 2025 23:09:54 +0800 Subject: [PATCH 05/11] stash --- cpp/Kun/Runtime.cpp | 48 ++++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/cpp/Kun/Runtime.cpp b/cpp/Kun/Runtime.cpp index 255eb59..e645d8e 100644 --- a/cpp/Kun/Runtime.cpp +++ b/cpp/Kun/Runtime.cpp @@ -361,21 +361,6 @@ StreamContext::StreamContext(std::shared_ptr exec, const Module *m, throw std::runtime_error( "Cannot run batch mode module via StreamContext"); } - if (m->init_state_buffers) { - state_buffers = m->init_state_buffers(num_stocks); - if (states) { - for (auto &buf : state_buffers) { - if (!buf->deserialize(states)) { - throw std::runtime_error( - "Failed to deserialize state buffer"); - } - } - } else { - for (auto &buf : state_buffers) { - buf->initialize(); - } - } - } std::vector rtlbuffers; rtlbuffers.reserve(m->num_buffers); buffers.reserve(m->num_buffers); @@ -384,9 +369,15 @@ StreamContext::StreamContext(std::shared_ptr exec, const Module *m, auto &buf = m->buffers[i]; auto ptr = StreamBuffer::make(num_stocks, buf.window, m->blocking_len); - buffers.emplace_back( - ptr, StreamBuffer::getBufferSize(num_stocks, buf.window, - m->blocking_len)); + size_t buf_size = StreamBuffer::getBufferSize( + num_stocks, buf.window, m->blocking_len); + buffers.emplace_back(ptr, buf_size); + if (m->init_state_buffers) { + if (!states->read(ptr, buf_size)) { + throw std::runtime_error( + "Failed to read initial stream buffer"); + } + } rtlbuffers.emplace_back((float *)ptr, 1); } } else if (m->dtype == Datatype::Double) { @@ -397,11 +388,32 @@ StreamContext::StreamContext(std::shared_ptr exec, const Module *m, buffers.emplace_back( ptr, StreamBuffer::getBufferSize(num_stocks, buf.window, m->blocking_len)); + if (m->init_state_buffers) { + if (!states->read(ptr, buf_size)) { + throw std::runtime_error( + "Failed to read initial stream buffer"); + } + } rtlbuffers.emplace_back((float *)ptr, 1); } } else { throw std::runtime_error("Unknown type"); } + if (m->init_state_buffers) { + state_buffers = m->init_state_buffers(num_stocks); + if (states) { + for (auto &buf : state_buffers) { + if (!buf->deserialize(states)) { + throw std::runtime_error( + "Failed to deserialize state buffer"); + } + } + } else { + for (auto &buf : state_buffers) { + buf->initialize(); + } + } + } ctx.buffers = std::move(rtlbuffers); ctx.executor = exec; ctx.buffer_len = num_stocks * 1; From be17929288d97993f88d87e28323b7d10378e696 Mon Sep 17 00:00:00 2001 From: Menooker Date: Thu, 18 Dec 2025 23:17:28 +0800 Subject: [PATCH 06/11] test and c-api --- Stream.md | 33 ++++++++++++++++ cpp/Kun/CApi.cpp | 81 +++++++++++++++++++++++++++++++++++++++- cpp/Kun/CApi.h | 65 +++++++++++++++++++++++++++++++- cpp/Kun/IO.cpp | 8 ++++ cpp/Kun/IO.hpp | 14 ++++++- cpp/Kun/Ops/Quantile.hpp | 2 + cpp/Kun/RunGraph.hpp | 2 - cpp/Kun/Runtime.cpp | 60 ++++++++++++++--------------- cpp/Kun/SkipList.cpp | 9 +++-- cpp/Python/PyBinding.cpp | 2 +- tests/capi/test_c.cpp | 20 ++++++++++ tests/test_stream.py | 59 +++++++++++++++++++++++++++-- 12 files changed, 309 insertions(+), 46 deletions(-) diff --git a/Stream.md b/Stream.md index 149fec7..904080d 100644 --- a/Stream.md +++ b/Stream.md @@ -74,6 +74,39 @@ Basically, you need to call `pushData` for each input. Then call `run()` to let That's why in the above code, we immediately copy the `ndarray` returned by `getCurrentBuffer` with `[:]`. +## Serialize stream states and resuming from stream states + +You can get the stream states from the context object via `serializeStates` method after `run()` is called. + +```python +stream = kr.StreamContext(executor, modu, num_stock) +stream.pushData(...) +... +stream.run() +states: bytes = stream.serializeStates() +``` + +The method will return a `bytes` object representing a copy of the states of the stream, if no arguments are given to `serializeStates()`. Or you can pass a string as a file path to write to `serializeStates(...)`, to dump the states to a file: + +```python +stream.serializeStates("path/to/outout/file") # None is returned +``` + +To resume from a previous state of the stream, call `kr.StreamContext` with an additional argument to represent the bytes or file path of the dumped states: + +```python +stream = kr.StreamContext(executor, modu, num_stock, state_bytes_or_file_path) +``` + +**Important note** The serialized states are sensitive to the + * KunQuant version + * OS/CPU architecture + * C++ compiler used to compile KunQuant + * number of stocks + * different factor expressions + +If a stream is different from another in the aspects of above, they are incompatible. Do **not** feed a stream with stream states generated by another incompatible stream. + ## C-API for Streaming mode The logic is similar to the Python API above. For details, see `tests/capi/test_c.cpp` and `cpp/Kun/CApi.h`. \ No newline at end of file diff --git a/cpp/Kun/CApi.cpp b/cpp/Kun/CApi.cpp index 41b3e60..36b84b1 100644 --- a/cpp/Kun/CApi.cpp +++ b/cpp/Kun/CApi.cpp @@ -1,4 +1,5 @@ #include "CApi.h" +#include "IO.hpp" #include "Module.hpp" #include "RunGraph.hpp" #include @@ -81,7 +82,85 @@ KUN_API KunStreamContextHandle kunCreateStream(KunExecutorHandle exec, size_t num_stocks) { auto &pexec = *unwrapExecutor(exec); auto modu = reinterpret_cast(m); - return new kun::StreamContext{pexec, modu, num_stocks}; + try { + // Create a StreamContext with no initial states + return new kun::StreamContext{pexec, modu, num_stocks}; + } catch (...) { + // If there is an error, return nullptr + return nullptr; + } +} + + +KUN_API KunStatus kunCreateStreamEx(KunExecutorHandle exec, + KunModuleHandle m, + size_t num_stocks, + const KunStreamExtraArgs *extra_args, + KunStreamContextHandle *out_handle) { + auto &pexec = *unwrapExecutor(exec); + auto modu = reinterpret_cast(m); + FileInputStream file_stream; + MemoryInputStream memory_stream {nullptr, 0}; + InputStreamBase *states = nullptr; + if (extra_args) { + if (extra_args->version != KUN_API_VERSION) { + return KUN_INVALID_ARGUMENT; + } + if (extra_args->init_kind == KUN_INIT_FILE) { + if (!extra_args->init.path) { + return KUN_INVALID_ARGUMENT; + } + file_stream.file.open(extra_args->init.path, std::ios::binary); + states = &file_stream; + } else if (extra_args->init_kind == KUN_INIT_MEMORY) { + if (!extra_args->init.memory.buffer) { + return KUN_INVALID_ARGUMENT; + } + memory_stream.data = extra_args->init.memory.buffer; + memory_stream.size = extra_args->init.memory.size; + states = &memory_stream; + } else if (extra_args->init_kind == KUN_INIT_NONE) { + // do nothing + } else { + return KUN_INVALID_ARGUMENT; + } + } + try { + auto ctx = new kun::StreamContext{pexec, modu, num_stocks, states}; + *out_handle = reinterpret_cast(ctx); + } catch (const std::exception &e) { + *out_handle = nullptr; + // If there is an error, return KUN_INIT_ERROR + return KUN_INIT_ERROR; + } + return KUN_SUCCESS; +} + +KUN_API KunStatus kunStreamSerializeStates(KunStreamContextHandle context, + size_t dump_kind, + char *path_or_buffer, + size_t *size) { + auto ctx = reinterpret_cast(context); + if (dump_kind == KUN_INIT_FILE) { + kun::FileOutputStream stream(path_or_buffer); + if (!ctx->serializeStates(&stream)) { + return KUN_INVALID_ARGUMENT; + } + return KUN_SUCCESS; + } else if (dump_kind == KUN_INIT_MEMORY) { + if (!path_or_buffer || !size) { + return KUN_INVALID_ARGUMENT; + } + size_t in_size = *size; + kun::MemoryRefOutputStream stream {path_or_buffer, in_size}; + if (!ctx->serializeStates(&stream)) { + return KUN_INVALID_ARGUMENT; + } + *size = stream.pos; // Update size to the actual written size + return stream.pos > in_size ? KUN_INIT_ERROR : KUN_SUCCESS; + } else { + return KUN_INVALID_ARGUMENT; + } } KUN_API size_t kunQueryBufferHandle(KunStreamContextHandle context, diff --git a/cpp/Kun/CApi.h b/cpp/Kun/CApi.h index 8fda456..751c7f5 100644 --- a/cpp/Kun/CApi.h +++ b/cpp/Kun/CApi.h @@ -12,6 +12,30 @@ typedef void *KunStreamContextHandle; extern "C" { #endif +#define KUN_API_VERSION 1 + +#define KUN_INIT_NONE 0 +#define KUN_INIT_FILE 1 +#define KUN_INIT_MEMORY 2 + +typedef int KunStatus; + +#define KUN_SUCCESS 0 +#define KUN_INIT_ERROR 1 +#define KUN_INVALID_ARGUMENT 2 + +typedef struct { + size_t version; // version of the KunQuant C API, must be set to KUN_API_VERSION + size_t init_kind; // KUN_INIT_NONE KUN_INIT_FILE KUN_INIT_MEMORY + union { + const char* path; // path to stream state dump file + struct { + const char* buffer; // name of the stream state dump file + size_t size; // size of the stream state dump file + } memory; // memory buffer for stream state dump + } init; +} KunStreamExtraArgs; + /** * @brief Create an single thread executor * @@ -121,11 +145,50 @@ KUN_API void kunRunGraph(KunExecutorHandle exec, KunModuleHandle m, * * @param exec the executor * @param m the module - * @param num_stocks The number of stocks. Must be multiple of 8 + * @param num_stocks The number of stocks. */ KUN_API KunStreamContextHandle kunCreateStream(KunExecutorHandle exec, KunModuleHandle m, size_t num_stocks); +/** + * @brief Create the Stream computing context with extra arguments. + * + * @param exec the executor + * @param m the module + * @param num_stocks The number of stocks. + * @param extra_args the extra arguments for stream context, to specify the + * stream state dump file or memory buffer. The `version` field must be set to + * KUN_API_VERSION. extra_args can be null for default behavior + * @param out_handle the output handle to the stream context. + * @return KUN_SUCCESS on success, KUN_INIT_ERROR if the stream context cannot be + * initialized from the given states, KUN_INVALID_ARGUMENT if the extra_args is + * invalid. + */ +KUN_API KunStatus kunCreateStreamEx(KunExecutorHandle exec, + KunModuleHandle m, + size_t num_stocks, + const KunStreamExtraArgs* extra_args, + KunStreamContextHandle *out_handle); + +/** + * @brief Serialize the states from the stream. + * + * @param context the stream object + * @param dump_kind the kind of dump, KUN_INIT_FILE or KUN_INIT_MEMORY + * @param path_or_buffer if dump_kind is KUN_INIT_FILE, this is the path to the + * stream state dump file. If dump_kind is KUN_INIT_MEMORY, this is the memory + * buffer to store the stream state dump. + * @param size if dump_kind is KUN_INIT_MEMORY, this should point to the size of + * the memory buffer. If the function succeeds or the buffer is too small, + * the size will be set to the size of the stream state dump. + * @return KUN_SUCCESS on success, KUN_INIT_ERROR if the memory buffer is too + * small or there is a file error. If dump_kind is KUN_INIT_MEMORY `size` will + * be overwritten to the size of dumped data. KUN_INVALID_ARGUMENT if dump_kind + * is not KUN_INIT_FILE or KUN_INIT_MEMORY. + */ +KUN_API KunStatus kunStreamSerializeStates( + KunStreamContextHandle context, size_t dump_kind, char *path_or_buffer, size_t* size); + /** * @brief Query the handle of a named buffer (input or output) diff --git a/cpp/Kun/IO.cpp b/cpp/Kun/IO.cpp index e9d679b..f4666ef 100644 --- a/cpp/Kun/IO.cpp +++ b/cpp/Kun/IO.cpp @@ -17,6 +17,14 @@ bool MemoryOutputStream::write(const void *buf, size_t len) { return true; } +bool MemoryRefOutputStream::write(const void *b, size_t len) { + if (pos + len <= size) { + std::memcpy(buf + pos, b, len); + } + pos += len; + return true; +} + FileInputStream::FileInputStream(const std::string &filename) : file(filename, std::ios::binary) {} diff --git a/cpp/Kun/IO.hpp b/cpp/Kun/IO.hpp index 5187338..75de6e1 100644 --- a/cpp/Kun/IO.hpp +++ b/cpp/Kun/IO.hpp @@ -23,11 +23,23 @@ struct KUN_API MemoryOutputStream final : public OutputStreamBase { const char *getData() const { return buffer.data(); } size_t getSize() const { return buffer.size(); } }; + +struct MemoryRefOutputStream final : public OutputStreamBase { + size_t pos; + char* buf; + size_t size; + + MemoryRefOutputStream(char* buf, size_t len) : pos(0), buf(buf), size(len) {} + + bool write(const void *buf, size_t len) override; +}; + + struct KUN_API FileInputStream final : public InputStreamBase { std::ifstream file; FileInputStream(const std::string &filename); - + FileInputStream() = default; bool read(void *buf, size_t len) override; }; struct KUN_API FileOutputStream final : public OutputStreamBase { diff --git a/cpp/Kun/Ops/Quantile.hpp b/cpp/Kun/Ops/Quantile.hpp index 72edf42..e0bc3f7 100644 --- a/cpp/Kun/Ops/Quantile.hpp +++ b/cpp/Kun/Ops/Quantile.hpp @@ -48,6 +48,8 @@ struct SkipListState : SkipListStateImpl { }; } // namespace +template +struct Serializer; template struct Serializer> { diff --git a/cpp/Kun/RunGraph.hpp b/cpp/Kun/RunGraph.hpp index 52e8804..ddb6775 100644 --- a/cpp/Kun/RunGraph.hpp +++ b/cpp/Kun/RunGraph.hpp @@ -20,9 +20,7 @@ KUN_API void corrWith(std::shared_ptr exec, MemoryLayout layout, bool size_t length); struct AlignedPtr { void* ptr; -#if CHECKED_PTR size_t size; -#endif char* get() const noexcept { return (char*)ptr; } diff --git a/cpp/Kun/Runtime.cpp b/cpp/Kun/Runtime.cpp index e645d8e..7e23860 100644 --- a/cpp/Kun/Runtime.cpp +++ b/cpp/Kun/Runtime.cpp @@ -295,16 +295,12 @@ void runGraph(std::shared_ptr exec, const Module *m, AlignedPtr::AlignedPtr(void *ptr, size_t size) noexcept { this->ptr = ptr; -#if CHECKED_PTR this->size = size; -#endif } AlignedPtr::AlignedPtr(AlignedPtr &&other) noexcept { ptr = other.ptr; other.ptr = nullptr; -#if CHECKED_PTR size = other.size; -#endif } void AlignedPtr::release() noexcept { @@ -321,9 +317,7 @@ AlignedPtr &AlignedPtr::operator=(AlignedPtr &&other) noexcept { release(); ptr = other.ptr; other.ptr = nullptr; -#if CHECKED_PTR size = other.size; -#endif return *this; } @@ -350,8 +344,25 @@ char *StreamBuffer::make(size_t stock_count, size_t window_size, template struct StreamBuffer; template struct StreamBuffer; +template +static void pushBuffer(std::vector &rtlbuffers, + std::vector &buffers, size_t num_stocks, + size_t blocking_len, const BufferInfo &buf, + InputStreamBase *states) { + auto ptr = StreamBuffer::make(num_stocks, buf.window, blocking_len); + size_t buf_size = + StreamBuffer::getBufferSize(num_stocks, buf.window, blocking_len); + buffers.emplace_back(ptr, buf_size); + if (states) { + if (!states->read(ptr, buf_size)) { + throw std::runtime_error("Failed to read initial stream buffer"); + } + } + rtlbuffers.emplace_back((float *)ptr, 1); +} + StreamContext::StreamContext(std::shared_ptr exec, const Module *m, - size_t num_stocks, InputStreamBase* states) + size_t num_stocks, InputStreamBase *states) : m{m} { if (m->required_version != VERSION) { throw std::runtime_error("The required version in the module does not " @@ -366,35 +377,13 @@ StreamContext::StreamContext(std::shared_ptr exec, const Module *m, buffers.reserve(m->num_buffers); if (m->dtype == Datatype::Float) { for (size_t i = 0; i < m->num_buffers; i++) { - auto &buf = m->buffers[i]; - auto ptr = StreamBuffer::make(num_stocks, buf.window, - m->blocking_len); - size_t buf_size = StreamBuffer::getBufferSize( - num_stocks, buf.window, m->blocking_len); - buffers.emplace_back(ptr, buf_size); - if (m->init_state_buffers) { - if (!states->read(ptr, buf_size)) { - throw std::runtime_error( - "Failed to read initial stream buffer"); - } - } - rtlbuffers.emplace_back((float *)ptr, 1); + pushBuffer(rtlbuffers, buffers, num_stocks, m->blocking_len, + m->buffers[i], states); } } else if (m->dtype == Datatype::Double) { for (size_t i = 0; i < m->num_buffers; i++) { - auto &buf = m->buffers[i]; - auto ptr = StreamBuffer::make(num_stocks, buf.window, - m->blocking_len); - buffers.emplace_back( - ptr, StreamBuffer::getBufferSize(num_stocks, buf.window, - m->blocking_len)); - if (m->init_state_buffers) { - if (!states->read(ptr, buf_size)) { - throw std::runtime_error( - "Failed to read initial stream buffer"); - } - } - rtlbuffers.emplace_back((float *)ptr, 1); + pushBuffer(rtlbuffers, buffers, num_stocks, m->blocking_len, + m->buffers[i], states); } } else { throw std::runtime_error("Unknown type"); @@ -482,6 +471,11 @@ StreamContext::~StreamContext() = default; bool StreamContext::serializeStates(OutputStreamBase* stream) { + for(auto& buf: buffers) { + if(!stream->write(buf.get(), buf.size)) { + return false; + } + } for (auto& ptr: state_buffers) { if (!ptr->serialize(stream)) { return false; diff --git a/cpp/Kun/SkipList.cpp b/cpp/Kun/SkipList.cpp index cb352b9..eaf6cb3 100644 --- a/cpp/Kun/SkipList.cpp +++ b/cpp/Kun/SkipList.cpp @@ -318,11 +318,14 @@ bool SkipList::serialize(OutputStreamBase *stream, int expsize) const { } bool SkipList::deserialize(InputStreamBase *stream, int expsize) { - size_t size; + int size; if (!stream->read(&size, sizeof(size))) { return false; } - for (size_t i = 0; i < size; i++) { + if (size < 0 || size > 1024 * 1024 * 4) { + return false; // sanity check + } + for (int i = 0; i < size; i++) { double value; size_t index; if (!stream->read(&value, sizeof(value))) { @@ -342,7 +345,7 @@ bool serializeSkipList(SkipList* skiplist, int* lastInsertRank, size_t size, siz return false; } for (size_t i = 0; i < size; i++) { - if (!skiplist->serialize(stream, window)) { + if (!skiplist[i].serialize(stream, window)) { return false; } } diff --git a/cpp/Python/PyBinding.cpp b/cpp/Python/PyBinding.cpp index fb7d6e3..04fe641 100644 --- a/cpp/Python/PyBinding.cpp +++ b/cpp/Python/PyBinding.cpp @@ -407,7 +407,7 @@ PYBIND11_MODULE(KunRunner, m) { .def(py::init([](std::shared_ptr exec, const ModuleHandle *mod, size_t stocks, py::object init) { - if (py::isinstance(init)) { + if (py::isinstance(init)) { auto filename = py::cast(init); kun::FileInputStream stream(filename); return new StreamContextWrapper(std::move(exec), mod, stocks, diff --git a/tests/capi/test_c.cpp b/tests/capi/test_c.cpp index 90d6931..7739e22 100644 --- a/tests/capi/test_c.cpp +++ b/tests/capi/test_c.cpp @@ -80,6 +80,13 @@ static int testStream(const char *libpath) { KunStreamContextHandle ctx = kunCreateStream(exec, modu, num_stocks); CHECK(ctx); + size_t buf_size = 0; + auto status = kunStreamSerializeStates(ctx, KUN_INIT_MEMORY, nullptr, &buf_size); + CHECK(status == KUN_INIT_ERROR); + auto states_buffer = new char[buf_size]; + status = kunStreamSerializeStates(ctx, KUN_INIT_MEMORY, states_buffer, &buf_size); + CHECK(status == KUN_SUCCESS); + size_t handleClose = kunQueryBufferHandle(ctx, "close"); size_t handleOpen = kunQueryBufferHandle(ctx, "open"); size_t handleHigh = kunQueryBufferHandle(ctx, "high"); @@ -108,6 +115,18 @@ static int testStream(const char *libpath) { return 4; } } + + kunDestoryStream(ctx); + ctx = nullptr; + KunStreamExtraArgs extra_args; + extra_args.version = KUN_API_VERSION; + extra_args.init_kind = KUN_INIT_MEMORY; + extra_args.init.memory.buffer = states_buffer; + extra_args.init.memory.size = buf_size; + status = kunCreateStreamEx(exec, modu, num_stocks, &extra_args, &ctx); + CHECK(status == KUN_SUCCESS); + CHECK(ctx); + delete[] dataclose; delete[] dataopen; @@ -115,6 +134,7 @@ static int testStream(const char *libpath) { delete[] datalow; delete[] datavol; delete[] dataamount; + delete[] states_buffer; kunDestoryStream(ctx); kunUnloadLibrary(lib); kunDestoryExecutor(exec); diff --git a/tests/test_stream.py b/tests/test_stream.py index 4d3439e..ba115de 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -7,7 +7,7 @@ from KunQuant.Op import * from KunQuant.ops import * from KunQuant.runner import KunRunner as kr -from tempfile import NamedTemporaryFile +import tempfile def make_steam(): @@ -62,18 +62,69 @@ def test_stream_ser_deser(lib): ema = np.empty((100, 24)) slope = np.empty((100, 24)) + df = pd.DataFrame(a) + expected_quantile = df.rolling(10).quantile(0.49, interpolation='linear').to_numpy() + expected_ema = df.ewm(span=10, adjust=False, ignore_na=True).mean().to_numpy() + expected_slope = df.rolling(10).apply(lambda x: np.polyfit(np.arange(len(x)), x, 1)[0]).to_numpy() + for i in range(50): stream.pushData(handle_a, a[i]) stream.run() out[i] = stream.getCurrentBuffer(handle_quantile) ema[i] = stream.getCurrentBuffer(handle_ema) slope[i] = stream.getCurrentBuffer(handle_slope) - stream.serializeStates("kun_test_stream_state.bin") + temppath = os.path.join(tempfile.gettempdir(), "kun_test_stream_state.bin") + print("serializing states to", temppath) + stream.serializeStates(temppath) states = stream.serializeStates() - with open("kun_test_stream_state.bin", "rb") as f: + with open(temppath, "rb") as f: assert(f.read() == states) - os.remove("kun_test_stream_state.bin") print("length of serialized states:", len(states)) + + # reload stream with in memory buffer + del stream + stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24, states) + for i in range(50, 100): + stream.pushData(handle_a, a[i]) + stream.run() + out[i] = stream.getCurrentBuffer(handle_quantile) + ema[i] = stream.getCurrentBuffer(handle_ema) + slope[i] = stream.getCurrentBuffer(handle_slope) + np.testing.assert_allclose(ema, expected_ema, atol=1e-6, rtol=1e-4, equal_nan=True) + np.testing.assert_allclose(slope[10:], expected_slope[10:], atol=1e-6, rtol=1e-4, equal_nan=True) + np.testing.assert_allclose(out, expected_quantile, atol=1e-6, rtol=1e-4, equal_nan=True) + + # reload stream with file + del stream + stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24, temppath) + for i in range(50, 100): + stream.pushData(handle_a, a[i]) + stream.run() + out[i] = stream.getCurrentBuffer(handle_quantile) + ema[i] = stream.getCurrentBuffer(handle_ema) + slope[i] = stream.getCurrentBuffer(handle_slope) + + np.testing.assert_allclose(out, expected_quantile, atol=1e-6, rtol=1e-4, equal_nan=True) + np.testing.assert_allclose(ema, expected_ema, atol=1e-6, rtol=1e-4, equal_nan=True) + np.testing.assert_allclose(slope[10:], expected_slope[10:], atol=1e-6, rtol=1e-4, equal_nan=True) + + # failure: bad buffer + bad_states = states[:-10] + try: + stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24, bad_states) + assert False, "Expected failure due to bad buffer" + except RuntimeError as e: + assert str(e) == "Failed to deserialize state buffer", f"Unexpected error message: {str(e)}" + + with open(temppath, "wb") as f: + f.truncate(20) + try: + stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24, temppath) + assert False, "Expected failure due to bad buffer" + except RuntimeError as e: + assert str(e) == "Failed to read initial stream buffer", f"Unexpected error message: {str(e)}" + os.remove(temppath) + lib = make_steam() test_stream(lib) test_stream_ser_deser(lib) From 9a37992ffad27abe9ef198ed22ce7581aee40011 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 19 Dec 2025 08:01:03 +0800 Subject: [PATCH 07/11] fix --- cpp/Kun/CApi.cpp | 4 ++-- cpp/Kun/CApi.h | 40 +++++++++++++++++++++------------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/cpp/Kun/CApi.cpp b/cpp/Kun/CApi.cpp index 36b84b1..f1b0413 100644 --- a/cpp/Kun/CApi.cpp +++ b/cpp/Kun/CApi.cpp @@ -144,11 +144,11 @@ KUN_API KunStatus kunStreamSerializeStates(KunStreamContextHandle context, if (dump_kind == KUN_INIT_FILE) { kun::FileOutputStream stream(path_or_buffer); if (!ctx->serializeStates(&stream)) { - return KUN_INVALID_ARGUMENT; + return KUN_INIT_ERROR; } return KUN_SUCCESS; } else if (dump_kind == KUN_INIT_MEMORY) { - if (!path_or_buffer || !size) { + if (!size || (*size !=0 && !path_or_buffer)) { return KUN_INVALID_ARGUMENT; } size_t in_size = *size; diff --git a/cpp/Kun/CApi.h b/cpp/Kun/CApi.h index 751c7f5..337e3f1 100644 --- a/cpp/Kun/CApi.h +++ b/cpp/Kun/CApi.h @@ -25,14 +25,15 @@ typedef int KunStatus; #define KUN_INVALID_ARGUMENT 2 typedef struct { - size_t version; // version of the KunQuant C API, must be set to KUN_API_VERSION + size_t version; // version of the KunQuant C API, must be set to + // KUN_API_VERSION size_t init_kind; // KUN_INIT_NONE KUN_INIT_FILE KUN_INIT_MEMORY union { - const char* path; // path to stream state dump file + const char *path; // path to stream state dump file struct { - const char* buffer; // name of the stream state dump file - size_t size; // size of the stream state dump file - } memory; // memory buffer for stream state dump + const char *buffer; // name of the stream state dump file + size_t size; // size of the stream state dump file + } memory; // memory buffer for stream state dump } init; } KunStreamExtraArgs; @@ -160,15 +161,14 @@ KUN_API KunStreamContextHandle kunCreateStream(KunExecutorHandle exec, * stream state dump file or memory buffer. The `version` field must be set to * KUN_API_VERSION. extra_args can be null for default behavior * @param out_handle the output handle to the stream context. - * @return KUN_SUCCESS on success, KUN_INIT_ERROR if the stream context cannot be - * initialized from the given states, KUN_INVALID_ARGUMENT if the extra_args is - * invalid. + * @return KUN_SUCCESS on success, KUN_INIT_ERROR if the stream context cannot + * be initialized from the given states, KUN_INVALID_ARGUMENT if the extra_args + * is invalid. */ -KUN_API KunStatus kunCreateStreamEx(KunExecutorHandle exec, - KunModuleHandle m, - size_t num_stocks, - const KunStreamExtraArgs* extra_args, - KunStreamContextHandle *out_handle); +KUN_API KunStatus kunCreateStreamEx(KunExecutorHandle exec, KunModuleHandle m, + size_t num_stocks, + const KunStreamExtraArgs *extra_args, + KunStreamContextHandle *out_handle); /** * @brief Serialize the states from the stream. @@ -177,18 +177,20 @@ KUN_API KunStatus kunCreateStreamEx(KunExecutorHandle exec, * @param dump_kind the kind of dump, KUN_INIT_FILE or KUN_INIT_MEMORY * @param path_or_buffer if dump_kind is KUN_INIT_FILE, this is the path to the * stream state dump file. If dump_kind is KUN_INIT_MEMORY, this is the memory - * buffer to store the stream state dump. + * buffer to store the stream state dump. nullable if `*size` is 0 and dump_kind + * is KUN_INIT_MEMORY. * @param size if dump_kind is KUN_INIT_MEMORY, this should point to the size of - * the memory buffer. If the function succeeds or the buffer is too small, - * the size will be set to the size of the stream state dump. + * the memory buffer. If the function succeeds or the buffer is too small, the + * size will be set to the size of the stream state dump in bytes. If dump_kind + * is KUN_INIT_FILE, this parameter is unused. * @return KUN_SUCCESS on success, KUN_INIT_ERROR if the memory buffer is too * small or there is a file error. If dump_kind is KUN_INIT_MEMORY `size` will * be overwritten to the size of dumped data. KUN_INVALID_ARGUMENT if dump_kind * is not KUN_INIT_FILE or KUN_INIT_MEMORY. */ -KUN_API KunStatus kunStreamSerializeStates( - KunStreamContextHandle context, size_t dump_kind, char *path_or_buffer, size_t* size); - +KUN_API KunStatus kunStreamSerializeStates(KunStreamContextHandle context, + size_t dump_kind, + char *path_or_buffer, size_t *size); /** * @brief Query the handle of a named buffer (input or output) From 53d3200c37a2adfc7ad1f7857d0ca73b9ffb4a9f Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 19 Dec 2025 08:14:43 +0800 Subject: [PATCH 08/11] add test --- tests/capi/test_c.cpp | 51 +++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/tests/capi/test_c.cpp b/tests/capi/test_c.cpp index 7739e22..22131c6 100644 --- a/tests/capi/test_c.cpp +++ b/tests/capi/test_c.cpp @@ -80,13 +80,19 @@ static int testStream(const char *libpath) { KunStreamContextHandle ctx = kunCreateStream(exec, modu, num_stocks); CHECK(ctx); + // now dump the stream states + // use null buffer and zero size to get the real buffer size size_t buf_size = 0; auto status = kunStreamSerializeStates(ctx, KUN_INIT_MEMORY, nullptr, &buf_size); CHECK(status == KUN_INIT_ERROR); + // second try to allocate buffer and get the real data + size_t buf_size2 = buf_size; auto states_buffer = new char[buf_size]; status = kunStreamSerializeStates(ctx, KUN_INIT_MEMORY, states_buffer, &buf_size); CHECK(status == KUN_SUCCESS); + CHECK(buf_size == buf_size2); + // run the stream size_t handleClose = kunQueryBufferHandle(ctx, "close"); size_t handleOpen = kunQueryBufferHandle(ctx, "open"); size_t handleHigh = kunQueryBufferHandle(ctx, "high"); @@ -96,27 +102,32 @@ static int testStream(const char *libpath) { size_t handleAlpha101 = kunQueryBufferHandle(ctx, "alpha101"); // don't need to query the handles everytime when calling kunStreamPushData - kunStreamPushData(ctx, handleClose, dataclose); - kunStreamPushData(ctx, handleOpen, dataopen); - kunStreamPushData(ctx, handleHigh, datahigh); - kunStreamPushData(ctx, handleLow, datalow); - kunStreamPushData(ctx, handleVol, datavol); - kunStreamPushData(ctx, handleAmount, dataamount); - - kunStreamRun(ctx); - memcpy(alpha101, kunStreamGetCurrentBuffer(ctx, handleAlpha101), - sizeof(float) * num_stocks); - - for (size_t i = 0; i < num_stocks; i++) { - float expected = - (dataclose[i] - dataopen[i]) / (datahigh[i] - datalow[i] + 0.001); - if (std::abs(alpha101[i] - expected) > 1e-5) { - printf("Output error at %zu => %f, %f\n", i, alpha101[i], expected); - return 4; + auto run_and_check = [&]() { + kunStreamPushData(ctx, handleClose, dataclose); + kunStreamPushData(ctx, handleOpen, dataopen); + kunStreamPushData(ctx, handleHigh, datahigh); + kunStreamPushData(ctx, handleLow, datalow); + kunStreamPushData(ctx, handleVol, datavol); + kunStreamPushData(ctx, handleAmount, dataamount); + + kunStreamRun(ctx); + memcpy(alpha101, kunStreamGetCurrentBuffer(ctx, handleAlpha101), + sizeof(float) * num_stocks); + + for (size_t i = 0; i < num_stocks; i++) { + float expected = + (dataclose[i] - dataopen[i]) / (datahigh[i] - datalow[i] + 0.001); + if (std::abs(alpha101[i] - expected) > 1e-5) { + printf("Output error at %zu => %f, %f\n", i, alpha101[i], expected); + exit(4); + } } - } - + }; + run_and_check(); kunDestoryStream(ctx); + + // check restore stream from states + // create a new stream context from the states ctx = nullptr; KunStreamExtraArgs extra_args; extra_args.version = KUN_API_VERSION; @@ -126,6 +137,8 @@ static int testStream(const char *libpath) { status = kunCreateStreamEx(exec, modu, num_stocks, &extra_args, &ctx); CHECK(status == KUN_SUCCESS); CHECK(ctx); + // run again to check if the states are restored correctly + run_and_check(); delete[] dataclose; From 061c0759947f10ad3c7770963e043b4d13cce169 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 19 Dec 2025 08:42:24 +0800 Subject: [PATCH 09/11] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20Context.hpp?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cpp/Kun/Context.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/Kun/Context.hpp b/cpp/Kun/Context.hpp index 76271fc..c98ad8d 100644 --- a/cpp/Kun/Context.hpp +++ b/cpp/Kun/Context.hpp @@ -10,7 +10,7 @@ namespace kun { -static const uint64_t VERSION = 0x64100004; +static const uint64_t VERSION = 0x64100005; struct KUN_API RuntimeStage { const Stage *stage; Context *ctx; From d9cfb72905953f0b336e1f667db27b8314b26d4b Mon Sep 17 00:00:00 2001 From: Menooker Date: Sat, 20 Dec 2025 20:40:44 +0800 Subject: [PATCH 10/11] change to enum --- cpp/Kun/CApi.cpp | 2 +- cpp/Kun/CApi.h | 32 +++++++++++++++++--------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/cpp/Kun/CApi.cpp b/cpp/Kun/CApi.cpp index f1b0413..653d04c 100644 --- a/cpp/Kun/CApi.cpp +++ b/cpp/Kun/CApi.cpp @@ -137,7 +137,7 @@ KUN_API KunStatus kunCreateStreamEx(KunExecutorHandle exec, } KUN_API KunStatus kunStreamSerializeStates(KunStreamContextHandle context, - size_t dump_kind, + KunStateBufferKind dump_kind, char *path_or_buffer, size_t *size) { auto ctx = reinterpret_cast(context); diff --git a/cpp/Kun/CApi.h b/cpp/Kun/CApi.h index 337e3f1..481b4fe 100644 --- a/cpp/Kun/CApi.h +++ b/cpp/Kun/CApi.h @@ -14,25 +14,27 @@ extern "C" { #define KUN_API_VERSION 1 -#define KUN_INIT_NONE 0 -#define KUN_INIT_FILE 1 -#define KUN_INIT_MEMORY 2 - -typedef int KunStatus; - -#define KUN_SUCCESS 0 -#define KUN_INIT_ERROR 1 -#define KUN_INVALID_ARGUMENT 2 +typedef enum { + KUN_INIT_NONE = 0, + KUN_INIT_FILE, + KUN_INIT_MEMORY, +} KunStateBufferKind; + +typedef enum { + KUN_SUCCESS = 0, + KUN_INIT_ERROR, + KUN_INVALID_ARGUMENT, +} KunStatus; typedef struct { - size_t version; // version of the KunQuant C API, must be set to - // KUN_API_VERSION - size_t init_kind; // KUN_INIT_NONE KUN_INIT_FILE KUN_INIT_MEMORY + size_t version; // version of the KunQuant C API, must be set to + // KUN_API_VERSION + KunStateBufferKind init_kind; // KUN_INIT_NONE KUN_INIT_FILE KUN_INIT_MEMORY union { const char *path; // path to stream state dump file struct { - const char *buffer; // name of the stream state dump file - size_t size; // size of the stream state dump file + const char *buffer; // buffer for stream state dump + size_t size; // size of the stream state dump file in bytes } memory; // memory buffer for stream state dump } init; } KunStreamExtraArgs; @@ -189,7 +191,7 @@ KUN_API KunStatus kunCreateStreamEx(KunExecutorHandle exec, KunModuleHandle m, * is not KUN_INIT_FILE or KUN_INIT_MEMORY. */ KUN_API KunStatus kunStreamSerializeStates(KunStreamContextHandle context, - size_t dump_kind, + KunStateBufferKind dump_kind, char *path_or_buffer, size_t *size); /** From e0cd6d029fedc6cb8f1ce63fed0e5a635e7bef46 Mon Sep 17 00:00:00 2001 From: Menooker Date: Sat, 20 Dec 2025 20:45:28 +0800 Subject: [PATCH 11/11] fix test --- tests/test_runtime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 022ddb4..24cf58e 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -148,12 +148,12 @@ def test_avg_stddev(lib): #################################### def check_TS(): - return "avg_and_stddev_TS", build_avg_and_stddev(), KunCompilerConfig(input_layout="TS", output_layout="TS", options={"no_fast_stat": 'no_warn'}) + return "avg_and_stddev_TS", build_avg_and_stddev(), KunCompilerConfig(dtype="double", input_layout="TS", output_layout="TS", options={"no_fast_stat": 'no_warn'}) def test_avg_stddev_TS(lib): modu = lib.getModule("avg_and_stddev_TS") assert(modu) - inp = np.random.rand(24, 20).astype("float32") + inp = np.random.rand(24, 20) df = pd.DataFrame(inp.transpose()) expected_mean = df.rolling(10).mean().to_numpy().transpose() expected_stddev = df.rolling(10).std().to_numpy().transpose()