Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion sigllm/primitives/forecasting/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DEFAULT_PAD_TOKEN = '<pad>'

VALID_NUMBERS = list('0123456789')
VALID_MULTIVARIATE_SYMBOLS = []

DEFAULT_MODEL = 'mistralai/Mistral-7B-Instruct-v0.2'

Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
raw=False,
samples=1,
padding=0,
multivariate_allowed_symbols = [],
):
self.name = name
self.sep = sep
Expand All @@ -62,6 +64,7 @@ def __init__(
self.raw = raw
self.samples = samples
self.padding = padding
self.multivariate_allowed_symbols = multivariate_allowed_symbols

self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_fast=False)

Expand All @@ -85,6 +88,9 @@ def __init__(
token = self.tokenizer.convert_tokens_to_ids(number)
valid_tokens.append(token)

for symbol in self.multivariate_allowed_symbols:
valid_tokens.append(self.tokenizer.convert_tokens_to_ids(symbol))

valid_tokens.append(self.tokenizer.convert_tokens_to_ids(self.sep))
self.invalid_tokens = [
[i] for i in range(len(self.tokenizer) - 1) if i not in valid_tokens
Expand Down Expand Up @@ -116,7 +122,7 @@ def forecast(self, X, **kwargs):
tokenized_input = self.tokenizer([text], return_tensors='pt').to('cuda')

input_length = tokenized_input['input_ids'].shape[1]
average_length = input_length / len(text.split(','))
average_length = input_length / len(text.split(self.sep))
max_tokens = (average_length + self.padding) * self.steps

generate_ids = self.model.generate(
Expand Down
21 changes: 21 additions & 0 deletions sigllm/primitives/formatting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Multivariate formatting methods for time series data."""

from sigllm.primitives.formatting.multivariate_formatting import MultivariateFormattingMethod
from sigllm.primitives.formatting.json_format import JSONFormat
from sigllm.primitives.formatting.univariate_control import UnivariateControl
from sigllm.primitives.formatting.persistence_control import PersistenceControl
from sigllm.primitives.formatting.value_concatenation import ValueConcatenation
from sigllm.primitives.formatting.value_interleave import ValueInterleave
from sigllm.primitives.formatting.digit_interleave import DigitInterleave

__all__ = [
'MultivariateFormattingMethod',
'JSONFormat',
'UnivariateControl',
'PersistenceControl',
'ValueConcatenation',
'ValueInterleave',
'DigitInterleave',
]


72 changes: 72 additions & 0 deletions sigllm/primitives/formatting/digit_interleave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from .multivariate_formatting import MultivariateFormattingMethod
import numpy as np


class DigitInterleave(MultivariateFormattingMethod):
def __init__(self, verbose: bool = False, **kwargs):
super().__init__("digit_interleave", verbose=verbose, **kwargs)


def format_as_string(self, data: np.ndarray, digits_per_timestamp = 3, separator = ",") -> str:
max_digits = max(len(str(abs(int(v)))) for window in data for ts in window for v in ts)
width_used = max(digits_per_timestamp, max_digits)
self.metadata['width_used'] = width_used

def interleave_digits(timestamp):
str_values = [str(int(val)) for val in timestamp]
padded_values = [s.zfill(width_used) for s in str_values]
result_str = ''
for digit_pos in range(width_used):
for padded_val in padded_values:
result_str += padded_val[digit_pos]

return result_str

result = [
separator.join(interleave_digits(timestamp) for timestamp in window) + separator
for window in data
]
return result


def format_as_integer(self, data: list[str], separator = ",", trunc = None, digits_per_timestamp = 3) -> np.ndarray:
width_used = self.metadata['width_used']

def deinterleave_timestamp(interleaved_str):
"""Convert interleaved digits back to original values"""
total_digits = len(interleaved_str)
num_values = total_digits // width_used

values = []
for value_idx in range(num_values):
value_digits = []
for digit_pos in range(width_used):
pos = digit_pos * num_values + value_idx
if pos < total_digits:
value_digits.append(interleaved_str[pos])

if value_digits:
values.append(int(''.join(value_digits)))

return np.array(values)[:trunc] if trunc else np.array(values)

result = np.array([
[
deinterleave_timestamp(timestamp)
for sample in entry
for timestamp in sample.lstrip(separator).rstrip(separator).split(separator)[:trunc]
if timestamp.strip()
]
for entry in data
], dtype=object)
return result



if __name__ == "__main__":
method = DigitInterleave(digits_per_timestamp=3)
method.test_multivariate_formatting_validity(verbose=False)
errs, y_hat, y = method.run_pipeline(return_y_hat=True)
print(errs)
print(y_hat)
print(y)
99 changes: 99 additions & 0 deletions sigllm/primitives/formatting/json_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from .multivariate_formatting import MultivariateFormattingMethod
import numpy as np
import re

class JSONFormat(MultivariateFormattingMethod):
def __init__(self, verbose: bool = False, **kwargs):
super().__init__("json_format", verbose=verbose, **kwargs)

def format_as_string(self, data: np.ndarray, separator = ",") -> str:
def window_to_json(data):
rows = []
for row in data:
parts = [f"d{i}:{val}" for i, val in enumerate(row)]
rows.append(",".join(parts))
return ",".join(rows)

out = [window_to_json(window) for window in data]
return out

def format_as_integer(self, data, trunc=None, steps_ahead=None):
"""
Parse model output and extract d0 values for specified steps ahead.

Args:
data: Model output containing tokens like "d0:1,d1:2,d0:3,d1:4..."
trunc: Legacy parameter for truncation (used when steps_ahead is None)
steps_ahead: List of step indices to extract (e.g., [1,3,5,10])
If None, uses legacy behavior with trunc parameter.

Returns:
If steps_ahead is None: np.array of shape (batch, samples) with truncated flat values
If steps_ahead is provided: dict mapping step -> np.array of d0 values at that step
"""
if steps_ahead is None:
return self._format_as_integer_legacy(data, trunc)

results_by_step = {step: [] for step in steps_ahead}

for window in data:
step_samples = {step: [] for step in steps_ahead}
for sample in window:
d0_values = self._extract_d0_values(sample)
for step in steps_ahead:
idx = step - 1
if idx < len(d0_values):
step_samples[step].append(d0_values[idx])
else:
step_samples[step].append(None)
for step in steps_ahead:
results_by_step[step].append(step_samples[step])

for step in steps_ahead:
results_by_step[step] = np.array(results_by_step[step], dtype=object)

return results_by_step

def _extract_d0_values(self, sample):
"""
Extract all d0 values from a sample string in order.
For "d0:1,d1:2,d0:3,d1:4", returns [1, 3].
"""
tokens = re.findall(r'd(\d+):(\d+)', sample)
d0_values = []
for dim_str, val_str in tokens:
if dim_str == "0":
d0_values.append(int(val_str))
return d0_values

def _format_as_integer_legacy(self, data, trunc=None):
"""
Legacy format_as_integer behavior for backward compatibility.
"""
batch_rows = []
for window in data:
samples = []
for sample in window:
tokens = re.findall(r'd\d+:\d+', sample)
flat, current = [], []
for token in tokens:
key, val = token.split(":")
if key == "d0" and current:
flat.extend(current)
current = []
current.append(int(val))
if current:
flat.extend(current)
if trunc:
flat = flat[:trunc]
samples.append(flat)
batch_rows.append(samples)
return np.array(batch_rows, dtype=object)




if __name__ == "__main__":
method = JSONFormat()
method.test_multivariate_formatting_validity(verbose=False)
method.run_pipeline(multivariate_allowed_symbols=["d", ":", ","])
Loading