diff --git a/Stream.md b/Stream.md index 149fec7..904080d 100644 --- a/Stream.md +++ b/Stream.md @@ -74,6 +74,39 @@ Basically, you need to call `pushData` for each input. Then call `run()` to let That's why in the above code, we immediately copy the `ndarray` returned by `getCurrentBuffer` with `[:]`. +## Serialize stream states and resuming from stream states + +You can get the stream states from the context object via `serializeStates` method after `run()` is called. + +```python +stream = kr.StreamContext(executor, modu, num_stock) +stream.pushData(...) +... +stream.run() +states: bytes = stream.serializeStates() +``` + +The method will return a `bytes` object representing a copy of the states of the stream, if no arguments are given to `serializeStates()`. Or you can pass a string as a file path to write to `serializeStates(...)`, to dump the states to a file: + +```python +stream.serializeStates("path/to/outout/file") # None is returned +``` + +To resume from a previous state of the stream, call `kr.StreamContext` with an additional argument to represent the bytes or file path of the dumped states: + +```python +stream = kr.StreamContext(executor, modu, num_stock, state_bytes_or_file_path) +``` + +**Important note** The serialized states are sensitive to the + * KunQuant version + * OS/CPU architecture + * C++ compiler used to compile KunQuant + * number of stocks + * different factor expressions + +If a stream is different from another in the aspects of above, they are incompatible. Do **not** feed a stream with stream states generated by another incompatible stream. + ## C-API for Streaming mode The logic is similar to the Python API above. For details, see `tests/capi/test_c.cpp` and `cpp/Kun/CApi.h`. \ No newline at end of file diff --git a/cpp/Kun/CApi.cpp b/cpp/Kun/CApi.cpp index 41b3e60..653d04c 100644 --- a/cpp/Kun/CApi.cpp +++ b/cpp/Kun/CApi.cpp @@ -1,4 +1,5 @@ #include "CApi.h" +#include "IO.hpp" #include "Module.hpp" #include "RunGraph.hpp" #include @@ -81,7 +82,85 @@ KUN_API KunStreamContextHandle kunCreateStream(KunExecutorHandle exec, size_t num_stocks) { auto &pexec = *unwrapExecutor(exec); auto modu = reinterpret_cast(m); - return new kun::StreamContext{pexec, modu, num_stocks}; + try { + // Create a StreamContext with no initial states + return new kun::StreamContext{pexec, modu, num_stocks}; + } catch (...) { + // If there is an error, return nullptr + return nullptr; + } +} + + +KUN_API KunStatus kunCreateStreamEx(KunExecutorHandle exec, + KunModuleHandle m, + size_t num_stocks, + const KunStreamExtraArgs *extra_args, + KunStreamContextHandle *out_handle) { + auto &pexec = *unwrapExecutor(exec); + auto modu = reinterpret_cast(m); + FileInputStream file_stream; + MemoryInputStream memory_stream {nullptr, 0}; + InputStreamBase *states = nullptr; + if (extra_args) { + if (extra_args->version != KUN_API_VERSION) { + return KUN_INVALID_ARGUMENT; + } + if (extra_args->init_kind == KUN_INIT_FILE) { + if (!extra_args->init.path) { + return KUN_INVALID_ARGUMENT; + } + file_stream.file.open(extra_args->init.path, std::ios::binary); + states = &file_stream; + } else if (extra_args->init_kind == KUN_INIT_MEMORY) { + if (!extra_args->init.memory.buffer) { + return KUN_INVALID_ARGUMENT; + } + memory_stream.data = extra_args->init.memory.buffer; + memory_stream.size = extra_args->init.memory.size; + states = &memory_stream; + } else if (extra_args->init_kind == KUN_INIT_NONE) { + // do nothing + } else { + return KUN_INVALID_ARGUMENT; + } + } + try { + auto ctx = new kun::StreamContext{pexec, modu, num_stocks, states}; + *out_handle = reinterpret_cast(ctx); + } catch (const std::exception &e) { + *out_handle = nullptr; + // If there is an error, return KUN_INIT_ERROR + return KUN_INIT_ERROR; + } + return KUN_SUCCESS; +} + +KUN_API KunStatus kunStreamSerializeStates(KunStreamContextHandle context, + KunStateBufferKind dump_kind, + char *path_or_buffer, + size_t *size) { + auto ctx = reinterpret_cast(context); + if (dump_kind == KUN_INIT_FILE) { + kun::FileOutputStream stream(path_or_buffer); + if (!ctx->serializeStates(&stream)) { + return KUN_INIT_ERROR; + } + return KUN_SUCCESS; + } else if (dump_kind == KUN_INIT_MEMORY) { + if (!size || (*size !=0 && !path_or_buffer)) { + return KUN_INVALID_ARGUMENT; + } + size_t in_size = *size; + kun::MemoryRefOutputStream stream {path_or_buffer, in_size}; + if (!ctx->serializeStates(&stream)) { + return KUN_INVALID_ARGUMENT; + } + *size = stream.pos; // Update size to the actual written size + return stream.pos > in_size ? KUN_INIT_ERROR : KUN_SUCCESS; + } else { + return KUN_INVALID_ARGUMENT; + } } KUN_API size_t kunQueryBufferHandle(KunStreamContextHandle context, diff --git a/cpp/Kun/CApi.h b/cpp/Kun/CApi.h index 8fda456..481b4fe 100644 --- a/cpp/Kun/CApi.h +++ b/cpp/Kun/CApi.h @@ -12,6 +12,33 @@ typedef void *KunStreamContextHandle; extern "C" { #endif +#define KUN_API_VERSION 1 + +typedef enum { + KUN_INIT_NONE = 0, + KUN_INIT_FILE, + KUN_INIT_MEMORY, +} KunStateBufferKind; + +typedef enum { + KUN_SUCCESS = 0, + KUN_INIT_ERROR, + KUN_INVALID_ARGUMENT, +} KunStatus; + +typedef struct { + size_t version; // version of the KunQuant C API, must be set to + // KUN_API_VERSION + KunStateBufferKind init_kind; // KUN_INIT_NONE KUN_INIT_FILE KUN_INIT_MEMORY + union { + const char *path; // path to stream state dump file + struct { + const char *buffer; // buffer for stream state dump + size_t size; // size of the stream state dump file in bytes + } memory; // memory buffer for stream state dump + } init; +} KunStreamExtraArgs; + /** * @brief Create an single thread executor * @@ -121,11 +148,51 @@ KUN_API void kunRunGraph(KunExecutorHandle exec, KunModuleHandle m, * * @param exec the executor * @param m the module - * @param num_stocks The number of stocks. Must be multiple of 8 + * @param num_stocks The number of stocks. */ KUN_API KunStreamContextHandle kunCreateStream(KunExecutorHandle exec, KunModuleHandle m, size_t num_stocks); +/** + * @brief Create the Stream computing context with extra arguments. + * + * @param exec the executor + * @param m the module + * @param num_stocks The number of stocks. + * @param extra_args the extra arguments for stream context, to specify the + * stream state dump file or memory buffer. The `version` field must be set to + * KUN_API_VERSION. extra_args can be null for default behavior + * @param out_handle the output handle to the stream context. + * @return KUN_SUCCESS on success, KUN_INIT_ERROR if the stream context cannot + * be initialized from the given states, KUN_INVALID_ARGUMENT if the extra_args + * is invalid. + */ +KUN_API KunStatus kunCreateStreamEx(KunExecutorHandle exec, KunModuleHandle m, + size_t num_stocks, + const KunStreamExtraArgs *extra_args, + KunStreamContextHandle *out_handle); + +/** + * @brief Serialize the states from the stream. + * + * @param context the stream object + * @param dump_kind the kind of dump, KUN_INIT_FILE or KUN_INIT_MEMORY + * @param path_or_buffer if dump_kind is KUN_INIT_FILE, this is the path to the + * stream state dump file. If dump_kind is KUN_INIT_MEMORY, this is the memory + * buffer to store the stream state dump. nullable if `*size` is 0 and dump_kind + * is KUN_INIT_MEMORY. + * @param size if dump_kind is KUN_INIT_MEMORY, this should point to the size of + * the memory buffer. If the function succeeds or the buffer is too small, the + * size will be set to the size of the stream state dump in bytes. If dump_kind + * is KUN_INIT_FILE, this parameter is unused. + * @return KUN_SUCCESS on success, KUN_INIT_ERROR if the memory buffer is too + * small or there is a file error. If dump_kind is KUN_INIT_MEMORY `size` will + * be overwritten to the size of dumped data. KUN_INVALID_ARGUMENT if dump_kind + * is not KUN_INIT_FILE or KUN_INIT_MEMORY. + */ +KUN_API KunStatus kunStreamSerializeStates(KunStreamContextHandle context, + KunStateBufferKind dump_kind, + char *path_or_buffer, size_t *size); /** * @brief Query the handle of a named buffer (input or output) diff --git a/cpp/Kun/Context.hpp b/cpp/Kun/Context.hpp index 76271fc..c98ad8d 100644 --- a/cpp/Kun/Context.hpp +++ b/cpp/Kun/Context.hpp @@ -10,7 +10,7 @@ namespace kun { -static const uint64_t VERSION = 0x64100004; +static const uint64_t VERSION = 0x64100005; struct KUN_API RuntimeStage { const Stage *stage; Context *ctx; diff --git a/cpp/Kun/IO.cpp b/cpp/Kun/IO.cpp new file mode 100644 index 0000000..f4666ef --- /dev/null +++ b/cpp/Kun/IO.cpp @@ -0,0 +1,47 @@ +#include "IO.hpp" +#include + +namespace kun { +bool MemoryInputStream::read(void *buf, size_t len) { + if (pos + len > size) { + return false; + } + std::memcpy(buf, data + pos, len); + pos += len; + return true; +} + +bool MemoryOutputStream::write(const void *buf, size_t len) { + const char *cbuf = reinterpret_cast(buf); + buffer.insert(buffer.end(), cbuf, cbuf + len); + return true; +} + +bool MemoryRefOutputStream::write(const void *b, size_t len) { + if (pos + len <= size) { + std::memcpy(buf + pos, b, len); + } + pos += len; + return true; +} + +FileInputStream::FileInputStream(const std::string &filename) + : file(filename, std::ios::binary) {} + +bool FileInputStream::read(void *buf, size_t len) { + if (!file.read(reinterpret_cast(buf), len)) { + return false; + } + return true; +} + +FileOutputStream::FileOutputStream(const std::string &filename) + : file(filename, std::ios::binary) {} + +bool FileOutputStream::write(const void *buf, size_t len) { + if (!file.write(reinterpret_cast(buf), len)) { + return false; + } + return true; +} +} // namespace kun \ No newline at end of file diff --git a/cpp/Kun/IO.hpp b/cpp/Kun/IO.hpp new file mode 100644 index 0000000..75de6e1 --- /dev/null +++ b/cpp/Kun/IO.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include "StateBuffer.hpp" +#include +#include + +namespace kun { +struct KUN_API MemoryInputStream final : public InputStreamBase { + const char *data; + size_t size; + size_t pos; + + MemoryInputStream(const char *data, size_t size) + : data(data), size(size), pos(0) {} + + bool read(void *buf, size_t len) override; +}; +struct KUN_API MemoryOutputStream final : public OutputStreamBase { + std::vector buffer; + + bool write(const void *buf, size_t len) override; + + const char *getData() const { return buffer.data(); } + size_t getSize() const { return buffer.size(); } +}; + +struct MemoryRefOutputStream final : public OutputStreamBase { + size_t pos; + char* buf; + size_t size; + + MemoryRefOutputStream(char* buf, size_t len) : pos(0), buf(buf), size(len) {} + + bool write(const void *buf, size_t len) override; +}; + + +struct KUN_API FileInputStream final : public InputStreamBase { + std::ifstream file; + + FileInputStream(const std::string &filename); + FileInputStream() = default; + bool read(void *buf, size_t len) override; +}; +struct KUN_API FileOutputStream final : public OutputStreamBase { + std::ofstream file; + FileOutputStream(const std::string &filename); + bool write(const void *buf, size_t len) override; +}; +} // namespace kun diff --git a/cpp/Kun/Ops.hpp b/cpp/Kun/Ops.hpp index 4923384..457a2da 100644 --- a/cpp/Kun/Ops.hpp +++ b/cpp/Kun/Ops.hpp @@ -3,6 +3,7 @@ #include "Base.hpp" #include "Math.hpp" #include "StreamBuffer.hpp" +#include "StateBuffer.hpp" #include #include #include @@ -11,6 +12,18 @@ namespace kun { namespace ops { + +template +struct Serializer { + static bool serialize(StateBuffer *obj, OutputStreamBase *stream) { + return stream->write(obj->buf, sizeof(T) * obj->num_objs); + } + + static bool deserialize(StateBuffer *obj, InputStreamBase *stream) { + return stream->read(obj->buf, sizeof(T) * obj->num_objs); + } +}; + template struct DataSource { constexpr static bool containsWindow = vcontainsWindow; @@ -706,5 +719,23 @@ inline DecayVec_t SetInfOrNanToValue(T aa, T2 v) { return sc_select(mask, v, a); } + +template +StateBufferPtr makeStateBuffer(size_t num_stocks, size_t simd_len) { + return StateBufferPtr(StateBuffer::make( + divideAndCeil(num_stocks, simd_len), sizeof(T), + [](StateBuffer *obj) { + for (size_t i = 0; i < obj->num_objs; i++) { + new (&obj->get(i)) T(); + } + }, + [](StateBuffer *obj) { + for (size_t i = 0; i < obj->num_objs; i++) { + obj->get(i).~T(); + } + }, Serializer::serialize, + Serializer::deserialize)); +} + } // namespace ops } // namespace kun \ No newline at end of file diff --git a/cpp/Kun/Ops/Quantile.hpp b/cpp/Kun/Ops/Quantile.hpp index 2511126..e0bc3f7 100644 --- a/cpp/Kun/Ops/Quantile.hpp +++ b/cpp/Kun/Ops/Quantile.hpp @@ -48,6 +48,35 @@ struct SkipListState : SkipListStateImpl { }; } // namespace +template +struct Serializer; + +template +struct Serializer> { + static bool serialize(StateBuffer *obj, OutputStreamBase *stream) { + for(size_t i = 0; i < obj->num_objs; i++) { + auto &state = obj->get>(i); + if (!serializeSkipList(state.skipList, state.lastInsertRank, simdLen, + expectedwindow, stream)) { + return false; + } + } + return true; + } + + static bool deserialize(StateBuffer *obj, InputStreamBase *stream) { + for(size_t i = 0; i < obj->num_objs; i++) { + auto &state = obj->get>(i); + new (&state) SkipListState{}; + if (!deserializeSkipList(state.skipList, state.lastInsertRank, simdLen, + expectedwindow, stream)) { + return false; + } + } + return true; + } +}; + // https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/window/aggregations.pyx template diff --git a/cpp/Kun/RunGraph.hpp b/cpp/Kun/RunGraph.hpp index 3e1ab14..ddb6775 100644 --- a/cpp/Kun/RunGraph.hpp +++ b/cpp/Kun/RunGraph.hpp @@ -20,9 +20,7 @@ KUN_API void corrWith(std::shared_ptr exec, MemoryLayout layout, bool size_t length); struct AlignedPtr { void* ptr; -#if CHECKED_PTR size_t size; -#endif char* get() const noexcept { return (char*)ptr; } @@ -39,7 +37,7 @@ struct KUN_API StreamContext { Context ctx; const Module *m; StreamContext(std::shared_ptr exec, const Module *m, - size_t num_stocks); + size_t num_stocks, InputStreamBase* states = nullptr); // query the buffer handle of a named buffer size_t queryBufferHandle(const char *name) const; // get the current readable position of the named buffer. The returned @@ -54,6 +52,7 @@ struct KUN_API StreamContext { StreamContext(const StreamContext&) = delete; StreamContext& operator=(const StreamContext&) = delete; ~StreamContext(); + bool serializeStates(OutputStreamBase* stream); }; } // namespace kun \ No newline at end of file diff --git a/cpp/Kun/Runtime.cpp b/cpp/Kun/Runtime.cpp index 38c2fa1..7e23860 100644 --- a/cpp/Kun/Runtime.cpp +++ b/cpp/Kun/Runtime.cpp @@ -295,16 +295,12 @@ void runGraph(std::shared_ptr exec, const Module *m, AlignedPtr::AlignedPtr(void *ptr, size_t size) noexcept { this->ptr = ptr; -#if CHECKED_PTR this->size = size; -#endif } AlignedPtr::AlignedPtr(AlignedPtr &&other) noexcept { ptr = other.ptr; other.ptr = nullptr; -#if CHECKED_PTR size = other.size; -#endif } void AlignedPtr::release() noexcept { @@ -321,9 +317,7 @@ AlignedPtr &AlignedPtr::operator=(AlignedPtr &&other) noexcept { release(); ptr = other.ptr; other.ptr = nullptr; -#if CHECKED_PTR size = other.size; -#endif return *this; } @@ -350,8 +344,25 @@ char *StreamBuffer::make(size_t stock_count, size_t window_size, template struct StreamBuffer; template struct StreamBuffer; +template +static void pushBuffer(std::vector &rtlbuffers, + std::vector &buffers, size_t num_stocks, + size_t blocking_len, const BufferInfo &buf, + InputStreamBase *states) { + auto ptr = StreamBuffer::make(num_stocks, buf.window, blocking_len); + size_t buf_size = + StreamBuffer::getBufferSize(num_stocks, buf.window, blocking_len); + buffers.emplace_back(ptr, buf_size); + if (states) { + if (!states->read(ptr, buf_size)) { + throw std::runtime_error("Failed to read initial stream buffer"); + } + } + rtlbuffers.emplace_back((float *)ptr, 1); +} + StreamContext::StreamContext(std::shared_ptr exec, const Module *m, - size_t num_stocks) + size_t num_stocks, InputStreamBase *states) : m{m} { if (m->required_version != VERSION) { throw std::runtime_error("The required version in the module does not " @@ -361,38 +372,37 @@ StreamContext::StreamContext(std::shared_ptr exec, const Module *m, throw std::runtime_error( "Cannot run batch mode module via StreamContext"); } - if (m->init_state_buffers) { - state_buffers = m->init_state_buffers(num_stocks); - for (auto &buf : state_buffers) { - buf->initialize(); - } - } std::vector rtlbuffers; rtlbuffers.reserve(m->num_buffers); buffers.reserve(m->num_buffers); if (m->dtype == Datatype::Float) { for (size_t i = 0; i < m->num_buffers; i++) { - auto &buf = m->buffers[i]; - auto ptr = StreamBuffer::make(num_stocks, buf.window, - m->blocking_len); - buffers.emplace_back( - ptr, StreamBuffer::getBufferSize(num_stocks, buf.window, - m->blocking_len)); - rtlbuffers.emplace_back((float *)ptr, 1); + pushBuffer(rtlbuffers, buffers, num_stocks, m->blocking_len, + m->buffers[i], states); } } else if (m->dtype == Datatype::Double) { for (size_t i = 0; i < m->num_buffers; i++) { - auto &buf = m->buffers[i]; - auto ptr = StreamBuffer::make(num_stocks, buf.window, - m->blocking_len); - buffers.emplace_back( - ptr, StreamBuffer::getBufferSize(num_stocks, buf.window, - m->blocking_len)); - rtlbuffers.emplace_back((float *)ptr, 1); + pushBuffer(rtlbuffers, buffers, num_stocks, m->blocking_len, + m->buffers[i], states); } } else { throw std::runtime_error("Unknown type"); } + if (m->init_state_buffers) { + state_buffers = m->init_state_buffers(num_stocks); + if (states) { + for (auto &buf : state_buffers) { + if (!buf->deserialize(states)) { + throw std::runtime_error( + "Failed to deserialize state buffer"); + } + } + } else { + for (auto &buf : state_buffers) { + buf->initialize(); + } + } + } ctx.buffers = std::move(rtlbuffers); ctx.executor = exec; ctx.buffer_len = num_stocks * 1; @@ -460,8 +470,23 @@ void StreamContext::run() { StreamContext::~StreamContext() = default; +bool StreamContext::serializeStates(OutputStreamBase* stream) { + for(auto& buf: buffers) { + if(!stream->write(buf.get(), buf.size)) { + return false; + } + } + for (auto& ptr: state_buffers) { + if (!ptr->serialize(stream)) { + return false; + } + } + return true; +} + StateBuffer *StateBuffer::make(size_t num_objs, size_t elem_size, - CtorFn_t ctor_fn, DtorFn_t dtor_fn) { + CtorFn_t ctor_fn, DtorFn_t dtor_fn, SerializeFn_t serialize_fn, + DeserializeFn_t deserialize_fn) { auto ret = kunAlignedAlloc(KUN_MALLOC_ALIGNMENT, sizeof(StateBuffer) + num_objs * elem_size); auto buf = (StateBuffer *)ret; @@ -470,6 +495,8 @@ StateBuffer *StateBuffer::make(size_t num_objs, size_t elem_size, buf->initialized = 0; buf->ctor_fn = ctor_fn; buf->dtor_fn = dtor_fn; + buf->serialize_fn = serialize_fn; + buf->deserialize_fn = deserialize_fn; return buf; } diff --git a/cpp/Kun/SkipList.cpp b/cpp/Kun/SkipList.cpp index f216bf3..eaf6cb3 100644 --- a/cpp/Kun/SkipList.cpp +++ b/cpp/Kun/SkipList.cpp @@ -17,6 +17,7 @@ Python recipe (https://rhettinger.wordpress.com/2010/02/06/lost-knowledge/) */ #include "SkipList.hpp" +#include "StateBuffer.hpp" #include #include #include @@ -294,4 +295,71 @@ double SkipList::get(int rank, size_t &index, bool &found) { } int SkipList::size() const { return impl->size; } + +bool SkipList::serialize(OutputStreamBase *stream, int expsize) const { + if (!stream->write(&impl->size, sizeof(impl->size))) { + return false; + } + for(int i=0;isize;i++) { + size_t index; + bool found; + double value = impl->get(i, index, found); + if (!found) { + return false; + } + if (!stream->write(&value, sizeof(value))) { + return false; + } + if (!stream->write(&index, sizeof(index))) { + return false; + } + } + return true; +} + +bool SkipList::deserialize(InputStreamBase *stream, int expsize) { + int size; + if (!stream->read(&size, sizeof(size))) { + return false; + } + if (size < 0 || size > 1024 * 1024 * 4) { + return false; // sanity check + } + for (int i = 0; i < size; i++) { + double value; + size_t index; + if (!stream->read(&value, sizeof(value))) { + return false; + } + if (!stream->read(&index, sizeof(index))) { + return false; + } + impl->insert(value, index); + } + return true; +} + + +bool serializeSkipList(SkipList* skiplist, int* lastInsertRank, size_t size, size_t window, OutputStreamBase* stream) { + if (!stream->write(lastInsertRank, size * sizeof(int))) { + return false; + } + for (size_t i = 0; i < size; i++) { + if (!skiplist[i].serialize(stream, window)) { + return false; + } + } + return true; +} +bool deserializeSkipList(SkipList* skiplist, int* lastInsertRank, size_t size, size_t window, InputStreamBase* stream) { + if (!stream->read(lastInsertRank, size * sizeof(int))) { + return false; + } + for (size_t i = 0; i < size; i++) { + if (!skiplist[i].deserialize(stream, window)) { + return false; + } + } + return true; +} } // namespace kun \ No newline at end of file diff --git a/cpp/Kun/SkipList.hpp b/cpp/Kun/SkipList.hpp index f58da9a..2bc1ac8 100644 --- a/cpp/Kun/SkipList.hpp +++ b/cpp/Kun/SkipList.hpp @@ -6,9 +6,13 @@ namespace kun { namespace detail { - struct SkipListImpl; +struct SkipListImpl; } +struct OutputStreamBase; +struct InputStreamBase; + + struct KUN_API SkipList { std::unique_ptr impl; SkipList(int size); @@ -22,8 +26,16 @@ struct KUN_API SkipList { // remove the first inserted element with the given value bool remove(double value); int minRank(double value); - double get(int rank, size_t& index, bool& found); + double get(int rank, size_t &index, bool &found); int size() const; + bool serialize(OutputStreamBase *stream, int expsize) const; + bool deserialize(InputStreamBase *stream, int expsize); }; -} \ No newline at end of file +KUN_API bool serializeSkipList(SkipList *skiplist, int *lastInsertRank, size_t size, + size_t window, OutputStreamBase *stream); +KUN_API bool deserializeSkipList(/*uninitialized*/ SkipList *skiplist, + int *lastInsertRank, size_t size, size_t window, + InputStreamBase *stream); + +} // namespace kun \ No newline at end of file diff --git a/cpp/Kun/StateBuffer.hpp b/cpp/Kun/StateBuffer.hpp index ca588cd..da461d3 100644 --- a/cpp/Kun/StateBuffer.hpp +++ b/cpp/Kun/StateBuffer.hpp @@ -13,19 +13,34 @@ namespace kun { #define KUN_MALLOC_ALIGNMENT 16 // NEON alignment #endif +struct InputStreamBase { + virtual bool read(void* buf, size_t len) = 0; + virtual ~InputStreamBase() = default; +}; + +struct OutputStreamBase { + virtual bool write(const void* buf, size_t len) = 0; + virtual ~OutputStreamBase() = default; +}; + struct StateBuffer { using DtorFn_t = void (*)(StateBuffer *obj); using CtorFn_t = void (*)(StateBuffer *obj); + using SerializeFn_t = bool (*)(StateBuffer *obj, OutputStreamBase *stream); + using DeserializeFn_t = bool (*)(StateBuffer *obj, InputStreamBase *stream); alignas(KUN_MALLOC_ALIGNMENT) size_t num_objs; uint32_t elem_size; uint32_t initialized; CtorFn_t ctor_fn; DtorFn_t dtor_fn; + SerializeFn_t serialize_fn; + DeserializeFn_t deserialize_fn; alignas(KUN_MALLOC_ALIGNMENT) char buf[0]; KUN_API static StateBuffer *make(size_t num_objs, size_t elem_size, - CtorFn_t ctor_fn, DtorFn_t dtor_fn); + CtorFn_t ctor_fn, DtorFn_t dtor_fn, SerializeFn_t serialize_fn, + DeserializeFn_t deserialize_fn); // for std::unique_ptr struct Deleter { @@ -47,26 +62,20 @@ struct StateBuffer { } initialized = 0; } - + bool serialize(OutputStreamBase *stream) { + return serialize_fn(this, stream); + } + bool deserialize(InputStreamBase *stream) { + if (deserialize_fn(this, stream)) { + initialized = 1; + return true; + } + return false; + } private: StateBuffer() = default; }; using StateBufferPtr = std::unique_ptr; -template -StateBufferPtr makeStateBuffer(size_t num_stocks, size_t simd_len) { - return StateBufferPtr(StateBuffer::make( - divideAndCeil(num_stocks, simd_len), sizeof(T), - [](StateBuffer *obj) { - for (size_t i = 0; i < obj->num_objs; i++) { - new (&obj->get(i)) T(); - } - }, - [](StateBuffer *obj) { - for (size_t i = 0; i < obj->num_objs; i++) { - obj->get(i).~T(); - } - })); -} } // namespace kun \ No newline at end of file diff --git a/cpp/Python/PyBinding.cpp b/cpp/Python/PyBinding.cpp index 5edaa94..04fe641 100644 --- a/cpp/Python/PyBinding.cpp +++ b/cpp/Python/PyBinding.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #ifdef _WIN32 @@ -59,8 +60,8 @@ struct StreamContextWrapper { std::shared_ptr lib; kun::StreamContext ctx; StreamContextWrapper(std::shared_ptr exec, - const ModuleHandle *m, size_t num_stocks) - : lib{m->lib}, ctx{std::move(exec), m->modu, num_stocks} + const ModuleHandle *m, size_t num_stocks, kun::InputStreamBase* states = nullptr) + : lib{m->lib}, ctx{std::move(exec), m->modu, num_stocks, states} {} }; } // namespace @@ -403,6 +404,27 @@ PYBIND11_MODULE(KunRunner, m) { py::class_(m, "StreamContext") .def(py::init, const ModuleHandle *, size_t>()) + .def(py::init([](std::shared_ptr exec, + const ModuleHandle *mod, size_t stocks, + py::object init) { + if (py::isinstance(init)) { + auto filename = py::cast(init); + kun::FileInputStream stream(filename); + return new StreamContextWrapper(std::move(exec), mod, stocks, + &stream); + } else if (py::isinstance(init)) { + py::bytes b = py::cast(init); + char *data; + py::ssize_t size; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(b.ptr(), &data, &size)) + throw std::runtime_error("Failed to get bytes data"); + kun::MemoryInputStream stream{data, (size_t)size}; + return new StreamContextWrapper(std::move(exec), mod, stocks, + &stream); + } + throw std::runtime_error( + "Bad type for init, expecting filename or bytes"); + })) .def("queryBufferHandle", [](StreamContextWrapper &t, const char *name) { return t.ctx.queryBufferHandle(name); @@ -448,5 +470,24 @@ PYBIND11_MODULE(KunRunner, m) { ths.pushData(handle, (const double *)data.data()); } }) + .def( + "serializeStates", + [](StreamContextWrapper &t, py::object fileNameOrNone) -> py::object { + if (py::isinstance(fileNameOrNone)) { + auto filename = py::cast(fileNameOrNone); + kun::FileOutputStream stream(filename); + if (!t.ctx.serializeStates(&stream)) { + throw std::runtime_error("Failed to serialize states"); + } + return py::none(); + } + kun::MemoryOutputStream stream; + if (!t.ctx.serializeStates(&stream)) { + throw std::runtime_error("Failed to serialize states"); + } + py::bytes b(stream.getData(), (py::ssize_t)stream.getSize()); + return b; + }, + py::arg("fileNameOrNone") = py::none()) .def("run", [](StreamContextWrapper &t) { t.ctx.run(); }); } \ No newline at end of file diff --git a/tests/capi/test_c.cpp b/tests/capi/test_c.cpp index 90d6931..22131c6 100644 --- a/tests/capi/test_c.cpp +++ b/tests/capi/test_c.cpp @@ -80,6 +80,19 @@ static int testStream(const char *libpath) { KunStreamContextHandle ctx = kunCreateStream(exec, modu, num_stocks); CHECK(ctx); + // now dump the stream states + // use null buffer and zero size to get the real buffer size + size_t buf_size = 0; + auto status = kunStreamSerializeStates(ctx, KUN_INIT_MEMORY, nullptr, &buf_size); + CHECK(status == KUN_INIT_ERROR); + // second try to allocate buffer and get the real data + size_t buf_size2 = buf_size; + auto states_buffer = new char[buf_size]; + status = kunStreamSerializeStates(ctx, KUN_INIT_MEMORY, states_buffer, &buf_size); + CHECK(status == KUN_SUCCESS); + CHECK(buf_size == buf_size2); + + // run the stream size_t handleClose = kunQueryBufferHandle(ctx, "close"); size_t handleOpen = kunQueryBufferHandle(ctx, "open"); size_t handleHigh = kunQueryBufferHandle(ctx, "high"); @@ -89,25 +102,44 @@ static int testStream(const char *libpath) { size_t handleAlpha101 = kunQueryBufferHandle(ctx, "alpha101"); // don't need to query the handles everytime when calling kunStreamPushData - kunStreamPushData(ctx, handleClose, dataclose); - kunStreamPushData(ctx, handleOpen, dataopen); - kunStreamPushData(ctx, handleHigh, datahigh); - kunStreamPushData(ctx, handleLow, datalow); - kunStreamPushData(ctx, handleVol, datavol); - kunStreamPushData(ctx, handleAmount, dataamount); + auto run_and_check = [&]() { + kunStreamPushData(ctx, handleClose, dataclose); + kunStreamPushData(ctx, handleOpen, dataopen); + kunStreamPushData(ctx, handleHigh, datahigh); + kunStreamPushData(ctx, handleLow, datalow); + kunStreamPushData(ctx, handleVol, datavol); + kunStreamPushData(ctx, handleAmount, dataamount); - kunStreamRun(ctx); - memcpy(alpha101, kunStreamGetCurrentBuffer(ctx, handleAlpha101), - sizeof(float) * num_stocks); + kunStreamRun(ctx); + memcpy(alpha101, kunStreamGetCurrentBuffer(ctx, handleAlpha101), + sizeof(float) * num_stocks); - for (size_t i = 0; i < num_stocks; i++) { - float expected = - (dataclose[i] - dataopen[i]) / (datahigh[i] - datalow[i] + 0.001); - if (std::abs(alpha101[i] - expected) > 1e-5) { - printf("Output error at %zu => %f, %f\n", i, alpha101[i], expected); - return 4; + for (size_t i = 0; i < num_stocks; i++) { + float expected = + (dataclose[i] - dataopen[i]) / (datahigh[i] - datalow[i] + 0.001); + if (std::abs(alpha101[i] - expected) > 1e-5) { + printf("Output error at %zu => %f, %f\n", i, alpha101[i], expected); + exit(4); + } } - } + }; + run_and_check(); + kunDestoryStream(ctx); + + // check restore stream from states + // create a new stream context from the states + ctx = nullptr; + KunStreamExtraArgs extra_args; + extra_args.version = KUN_API_VERSION; + extra_args.init_kind = KUN_INIT_MEMORY; + extra_args.init.memory.buffer = states_buffer; + extra_args.init.memory.size = buf_size; + status = kunCreateStreamEx(exec, modu, num_stocks, &extra_args, &ctx); + CHECK(status == KUN_SUCCESS); + CHECK(ctx); + // run again to check if the states are restored correctly + run_and_check(); + delete[] dataclose; delete[] dataopen; @@ -115,6 +147,7 @@ static int testStream(const char *libpath) { delete[] datalow; delete[] datavol; delete[] dataamount; + delete[] states_buffer; kunDestoryStream(ctx); kunUnloadLibrary(lib); kunDestoryExecutor(exec); diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 022ddb4..24cf58e 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -148,12 +148,12 @@ def test_avg_stddev(lib): #################################### def check_TS(): - return "avg_and_stddev_TS", build_avg_and_stddev(), KunCompilerConfig(input_layout="TS", output_layout="TS", options={"no_fast_stat": 'no_warn'}) + return "avg_and_stddev_TS", build_avg_and_stddev(), KunCompilerConfig(dtype="double", input_layout="TS", output_layout="TS", options={"no_fast_stat": 'no_warn'}) def test_avg_stddev_TS(lib): modu = lib.getModule("avg_and_stddev_TS") assert(modu) - inp = np.random.rand(24, 20).astype("float32") + inp = np.random.rand(24, 20) df = pd.DataFrame(inp.transpose()) expected_mean = df.rolling(10).mean().to_numpy().transpose() expected_stddev = df.rolling(10).std().to_numpy().transpose() diff --git a/tests/test_stream.py b/tests/test_stream.py index 78e19a8..ba115de 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,3 +1,4 @@ +import os import numpy as np import pandas as pd from KunQuant.jit import cfake @@ -6,9 +7,10 @@ from KunQuant.Op import * from KunQuant.ops import * from KunQuant.runner import KunRunner as kr +import tempfile +def make_steam(): -def test_stream(): builder = Builder() with builder: inp1 = Input("a") @@ -18,7 +20,9 @@ def test_stream(): f = Function(builder.ops) lib = cfake.compileit([("stream_test", f, cfake.KunCompilerConfig(dtype="double", input_layout="STREAM", output_layout="STREAM"))], "stream_test", cfake.CppCompilerConfig()) + return lib +def test_stream(lib): executor = kr.createSingleThreadExecutor() stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24) a = np.random.rand(100, 24) @@ -35,6 +39,7 @@ def test_stream(): out[i] = stream.getCurrentBuffer(handle_quantile) ema[i] = stream.getCurrentBuffer(handle_ema) slope[i] = stream.getCurrentBuffer(handle_slope) + df = pd.DataFrame(a) expected_quantile = df.rolling(10).quantile(0.49, interpolation='linear').to_numpy() expected_ema = df.ewm(span=10, adjust=False, ignore_na=True).mean().to_numpy() @@ -43,5 +48,84 @@ def test_stream(): np.testing.assert_allclose(ema, expected_ema, atol=1e-6, rtol=1e-4, equal_nan=True) np.testing.assert_allclose(slope[10:], expected_slope[10:], atol=1e-6, rtol=1e-4, equal_nan=True) -test_stream() + + +def test_stream_ser_deser(lib): + executor = kr.createSingleThreadExecutor() + stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24) + a = np.random.rand(100, 24) + handle_a = stream.queryBufferHandle("a") + handle_quantile = stream.queryBufferHandle("quantile") + handle_ema = stream.queryBufferHandle("ema") + handle_slope = stream.queryBufferHandle("slope") + out = np.empty((100, 24)) + ema = np.empty((100, 24)) + slope = np.empty((100, 24)) + + df = pd.DataFrame(a) + expected_quantile = df.rolling(10).quantile(0.49, interpolation='linear').to_numpy() + expected_ema = df.ewm(span=10, adjust=False, ignore_na=True).mean().to_numpy() + expected_slope = df.rolling(10).apply(lambda x: np.polyfit(np.arange(len(x)), x, 1)[0]).to_numpy() + + for i in range(50): + stream.pushData(handle_a, a[i]) + stream.run() + out[i] = stream.getCurrentBuffer(handle_quantile) + ema[i] = stream.getCurrentBuffer(handle_ema) + slope[i] = stream.getCurrentBuffer(handle_slope) + temppath = os.path.join(tempfile.gettempdir(), "kun_test_stream_state.bin") + print("serializing states to", temppath) + stream.serializeStates(temppath) + states = stream.serializeStates() + with open(temppath, "rb") as f: + assert(f.read() == states) + print("length of serialized states:", len(states)) + + # reload stream with in memory buffer + del stream + stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24, states) + for i in range(50, 100): + stream.pushData(handle_a, a[i]) + stream.run() + out[i] = stream.getCurrentBuffer(handle_quantile) + ema[i] = stream.getCurrentBuffer(handle_ema) + slope[i] = stream.getCurrentBuffer(handle_slope) + np.testing.assert_allclose(ema, expected_ema, atol=1e-6, rtol=1e-4, equal_nan=True) + np.testing.assert_allclose(slope[10:], expected_slope[10:], atol=1e-6, rtol=1e-4, equal_nan=True) + np.testing.assert_allclose(out, expected_quantile, atol=1e-6, rtol=1e-4, equal_nan=True) + + # reload stream with file + del stream + stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24, temppath) + for i in range(50, 100): + stream.pushData(handle_a, a[i]) + stream.run() + out[i] = stream.getCurrentBuffer(handle_quantile) + ema[i] = stream.getCurrentBuffer(handle_ema) + slope[i] = stream.getCurrentBuffer(handle_slope) + + np.testing.assert_allclose(out, expected_quantile, atol=1e-6, rtol=1e-4, equal_nan=True) + np.testing.assert_allclose(ema, expected_ema, atol=1e-6, rtol=1e-4, equal_nan=True) + np.testing.assert_allclose(slope[10:], expected_slope[10:], atol=1e-6, rtol=1e-4, equal_nan=True) + + # failure: bad buffer + bad_states = states[:-10] + try: + stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24, bad_states) + assert False, "Expected failure due to bad buffer" + except RuntimeError as e: + assert str(e) == "Failed to deserialize state buffer", f"Unexpected error message: {str(e)}" + + with open(temppath, "wb") as f: + f.truncate(20) + try: + stream = kr.StreamContext(executor, lib.getModule("stream_test"), 24, temppath) + assert False, "Expected failure due to bad buffer" + except RuntimeError as e: + assert str(e) == "Failed to read initial stream buffer", f"Unexpected error message: {str(e)}" + os.remove(temppath) + +lib = make_steam() +test_stream(lib) +test_stream_ser_deser(lib) print("test_stream passed") \ No newline at end of file