Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions Stream.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,39 @@ Basically, you need to call `pushData` for each input. Then call `run()` to let

That's why in the above code, we immediately copy the `ndarray` returned by `getCurrentBuffer` with `[:]`.

## Serialize stream states and resuming from stream states

You can get the stream states from the context object via `serializeStates` method after `run()` is called.

```python
stream = kr.StreamContext(executor, modu, num_stock)
stream.pushData(...)
...
stream.run()
states: bytes = stream.serializeStates()
```

The method will return a `bytes` object representing a copy of the states of the stream, if no arguments are given to `serializeStates()`. Or you can pass a string as a file path to write to `serializeStates(...)`, to dump the states to a file:

```python
stream.serializeStates("path/to/outout/file") # None is returned
```

To resume from a previous state of the stream, call `kr.StreamContext` with an additional argument to represent the bytes or file path of the dumped states:

```python
stream = kr.StreamContext(executor, modu, num_stock, state_bytes_or_file_path)
```

**Important note** The serialized states are sensitive to the
* KunQuant version
* OS/CPU architecture
* C++ compiler used to compile KunQuant
* number of stocks
* different factor expressions

If a stream is different from another in the aspects of above, they are incompatible. Do **not** feed a stream with stream states generated by another incompatible stream.

## C-API for Streaming mode

The logic is similar to the Python API above. For details, see `tests/capi/test_c.cpp` and `cpp/Kun/CApi.h`.
81 changes: 80 additions & 1 deletion cpp/Kun/CApi.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "CApi.h"
#include "IO.hpp"
#include "Module.hpp"
#include "RunGraph.hpp"
#include <string>
Expand Down Expand Up @@ -81,7 +82,85 @@ KUN_API KunStreamContextHandle kunCreateStream(KunExecutorHandle exec,
size_t num_stocks) {
auto &pexec = *unwrapExecutor(exec);
auto modu = reinterpret_cast<Module *>(m);
return new kun::StreamContext{pexec, modu, num_stocks};
try {
// Create a StreamContext with no initial states
return new kun::StreamContext{pexec, modu, num_stocks};
} catch (...) {
// If there is an error, return nullptr
return nullptr;
}
}


KUN_API KunStatus kunCreateStreamEx(KunExecutorHandle exec,
KunModuleHandle m,
size_t num_stocks,
const KunStreamExtraArgs *extra_args,
KunStreamContextHandle *out_handle) {
auto &pexec = *unwrapExecutor(exec);
auto modu = reinterpret_cast<Module *>(m);
FileInputStream file_stream;
MemoryInputStream memory_stream {nullptr, 0};
InputStreamBase *states = nullptr;
if (extra_args) {
if (extra_args->version != KUN_API_VERSION) {
return KUN_INVALID_ARGUMENT;
}
if (extra_args->init_kind == KUN_INIT_FILE) {
if (!extra_args->init.path) {
return KUN_INVALID_ARGUMENT;
}
file_stream.file.open(extra_args->init.path, std::ios::binary);
states = &file_stream;
} else if (extra_args->init_kind == KUN_INIT_MEMORY) {
if (!extra_args->init.memory.buffer) {
return KUN_INVALID_ARGUMENT;
}
memory_stream.data = extra_args->init.memory.buffer;
memory_stream.size = extra_args->init.memory.size;
states = &memory_stream;
} else if (extra_args->init_kind == KUN_INIT_NONE) {
// do nothing
} else {
return KUN_INVALID_ARGUMENT;
}
}
try {
auto ctx = new kun::StreamContext{pexec, modu, num_stocks, states};
*out_handle = reinterpret_cast<KunStreamContextHandle>(ctx);
} catch (const std::exception &e) {
*out_handle = nullptr;
// If there is an error, return KUN_INIT_ERROR
return KUN_INIT_ERROR;
}
return KUN_SUCCESS;
}

KUN_API KunStatus kunStreamSerializeStates(KunStreamContextHandle context,
KunStateBufferKind dump_kind,
char *path_or_buffer,
size_t *size) {
auto ctx = reinterpret_cast<kun::StreamContext *>(context);
if (dump_kind == KUN_INIT_FILE) {
kun::FileOutputStream stream(path_or_buffer);
if (!ctx->serializeStates(&stream)) {
return KUN_INIT_ERROR;
}
return KUN_SUCCESS;
} else if (dump_kind == KUN_INIT_MEMORY) {
if (!size || (*size !=0 && !path_or_buffer)) {
return KUN_INVALID_ARGUMENT;
}
size_t in_size = *size;
kun::MemoryRefOutputStream stream {path_or_buffer, in_size};
if (!ctx->serializeStates(&stream)) {
return KUN_INVALID_ARGUMENT;
}
*size = stream.pos; // Update size to the actual written size
return stream.pos > in_size ? KUN_INIT_ERROR : KUN_SUCCESS;
} else {
return KUN_INVALID_ARGUMENT;
}
}

KUN_API size_t kunQueryBufferHandle(KunStreamContextHandle context,
Expand Down
69 changes: 68 additions & 1 deletion cpp/Kun/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,33 @@ typedef void *KunStreamContextHandle;
extern "C" {
#endif

#define KUN_API_VERSION 1

typedef enum {
KUN_INIT_NONE = 0,
KUN_INIT_FILE,
KUN_INIT_MEMORY,
} KunStateBufferKind;

typedef enum {
KUN_SUCCESS = 0,
KUN_INIT_ERROR,
KUN_INVALID_ARGUMENT,
} KunStatus;

typedef struct {
size_t version; // version of the KunQuant C API, must be set to
// KUN_API_VERSION
KunStateBufferKind init_kind; // KUN_INIT_NONE KUN_INIT_FILE KUN_INIT_MEMORY
union {
const char *path; // path to stream state dump file
struct {
const char *buffer; // buffer for stream state dump
size_t size; // size of the stream state dump file in bytes
} memory; // memory buffer for stream state dump
} init;
} KunStreamExtraArgs;

/**
* @brief Create an single thread executor
*
Expand Down Expand Up @@ -121,11 +148,51 @@ KUN_API void kunRunGraph(KunExecutorHandle exec, KunModuleHandle m,
*
* @param exec the executor
* @param m the module
* @param num_stocks The number of stocks. Must be multiple of 8
* @param num_stocks The number of stocks.
*/
KUN_API KunStreamContextHandle kunCreateStream(KunExecutorHandle exec,
KunModuleHandle m,
size_t num_stocks);
/**
* @brief Create the Stream computing context with extra arguments.
*
* @param exec the executor
* @param m the module
* @param num_stocks The number of stocks.
* @param extra_args the extra arguments for stream context, to specify the
* stream state dump file or memory buffer. The `version` field must be set to
* KUN_API_VERSION. extra_args can be null for default behavior
* @param out_handle the output handle to the stream context.
* @return KUN_SUCCESS on success, KUN_INIT_ERROR if the stream context cannot
* be initialized from the given states, KUN_INVALID_ARGUMENT if the extra_args
* is invalid.
*/
KUN_API KunStatus kunCreateStreamEx(KunExecutorHandle exec, KunModuleHandle m,
size_t num_stocks,
const KunStreamExtraArgs *extra_args,
KunStreamContextHandle *out_handle);

/**
* @brief Serialize the states from the stream.
*
* @param context the stream object
* @param dump_kind the kind of dump, KUN_INIT_FILE or KUN_INIT_MEMORY
* @param path_or_buffer if dump_kind is KUN_INIT_FILE, this is the path to the
* stream state dump file. If dump_kind is KUN_INIT_MEMORY, this is the memory
* buffer to store the stream state dump. nullable if `*size` is 0 and dump_kind
* is KUN_INIT_MEMORY.
* @param size if dump_kind is KUN_INIT_MEMORY, this should point to the size of
* the memory buffer. If the function succeeds or the buffer is too small, the
* size will be set to the size of the stream state dump in bytes. If dump_kind
* is KUN_INIT_FILE, this parameter is unused.
* @return KUN_SUCCESS on success, KUN_INIT_ERROR if the memory buffer is too
* small or there is a file error. If dump_kind is KUN_INIT_MEMORY `size` will
* be overwritten to the size of dumped data. KUN_INVALID_ARGUMENT if dump_kind
* is not KUN_INIT_FILE or KUN_INIT_MEMORY.
*/
KUN_API KunStatus kunStreamSerializeStates(KunStreamContextHandle context,
KunStateBufferKind dump_kind,
char *path_or_buffer, size_t *size);

/**
* @brief Query the handle of a named buffer (input or output)
Expand Down
2 changes: 1 addition & 1 deletion cpp/Kun/Context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace kun {

static const uint64_t VERSION = 0x64100004;
static const uint64_t VERSION = 0x64100005;
struct KUN_API RuntimeStage {
const Stage *stage;
Context *ctx;
Expand Down
47 changes: 47 additions & 0 deletions cpp/Kun/IO.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "IO.hpp"
#include <cstring>

namespace kun {
bool MemoryInputStream::read(void *buf, size_t len) {
if (pos + len > size) {
return false;
}
std::memcpy(buf, data + pos, len);
pos += len;
return true;
}

bool MemoryOutputStream::write(const void *buf, size_t len) {
const char *cbuf = reinterpret_cast<const char *>(buf);
buffer.insert(buffer.end(), cbuf, cbuf + len);
return true;
}

bool MemoryRefOutputStream::write(const void *b, size_t len) {
if (pos + len <= size) {
std::memcpy(buf + pos, b, len);
}
pos += len;
return true;
}

FileInputStream::FileInputStream(const std::string &filename)
: file(filename, std::ios::binary) {}

bool FileInputStream::read(void *buf, size_t len) {
if (!file.read(reinterpret_cast<char *>(buf), len)) {
return false;
}
return true;
}

FileOutputStream::FileOutputStream(const std::string &filename)
: file(filename, std::ios::binary) {}

bool FileOutputStream::write(const void *buf, size_t len) {
if (!file.write(reinterpret_cast<const char *>(buf), len)) {
return false;
}
return true;
}
} // namespace kun
50 changes: 50 additions & 0 deletions cpp/Kun/IO.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#pragma once

#include "StateBuffer.hpp"
#include <fstream>
#include <vector>

namespace kun {
struct KUN_API MemoryInputStream final : public InputStreamBase {
const char *data;
size_t size;
size_t pos;

MemoryInputStream(const char *data, size_t size)
: data(data), size(size), pos(0) {}

bool read(void *buf, size_t len) override;
};
struct KUN_API MemoryOutputStream final : public OutputStreamBase {
std::vector<char> buffer;

bool write(const void *buf, size_t len) override;

const char *getData() const { return buffer.data(); }
size_t getSize() const { return buffer.size(); }
};

struct MemoryRefOutputStream final : public OutputStreamBase {
size_t pos;
char* buf;
size_t size;

MemoryRefOutputStream(char* buf, size_t len) : pos(0), buf(buf), size(len) {}

bool write(const void *buf, size_t len) override;
};


struct KUN_API FileInputStream final : public InputStreamBase {
std::ifstream file;

FileInputStream(const std::string &filename);
FileInputStream() = default;
bool read(void *buf, size_t len) override;
};
struct KUN_API FileOutputStream final : public OutputStreamBase {
std::ofstream file;
FileOutputStream(const std::string &filename);
bool write(const void *buf, size_t len) override;
};
} // namespace kun
31 changes: 31 additions & 0 deletions cpp/Kun/Ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "Base.hpp"
#include "Math.hpp"
#include "StreamBuffer.hpp"
#include "StateBuffer.hpp"
#include <cmath>
#include <limits>
#include <stdint.h>
Expand All @@ -11,6 +12,18 @@
namespace kun {
namespace ops {


template <typename T>
struct Serializer {
static bool serialize(StateBuffer *obj, OutputStreamBase *stream) {
return stream->write(obj->buf, sizeof(T) * obj->num_objs);
}

static bool deserialize(StateBuffer *obj, InputStreamBase *stream) {
return stream->read(obj->buf, sizeof(T) * obj->num_objs);
}
};

template <bool vcontainsWindow>
struct DataSource {
constexpr static bool containsWindow = vcontainsWindow;
Expand Down Expand Up @@ -706,5 +719,23 @@ inline DecayVec_t<T> SetInfOrNanToValue(T aa, T2 v) {
return sc_select(mask, v, a);
}


template <typename T>
StateBufferPtr makeStateBuffer(size_t num_stocks, size_t simd_len) {
return StateBufferPtr(StateBuffer::make(
divideAndCeil(num_stocks, simd_len), sizeof(T),
[](StateBuffer *obj) {
for (size_t i = 0; i < obj->num_objs; i++) {
new (&obj->get<T>(i)) T();
}
},
[](StateBuffer *obj) {
for (size_t i = 0; i < obj->num_objs; i++) {
obj->get<T>(i).~T();
}
}, Serializer<T>::serialize,
Serializer<T>::deserialize));
}

} // namespace ops
} // namespace kun
Loading