Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file removed const_block_generic
Binary file not shown.
4 changes: 4 additions & 0 deletions crates/bamltype/tests/ui.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#[test]
#[cfg_attr(
miri,
ignore = "trybuild launches subprocesses and is unsupported under miri"
)]
fn ui_compile_failures() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/ui/*.rs");
Expand Down
1 change: 1 addition & 0 deletions crates/dspy-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ rig-core = { git = "https://github.com/0xPlaygrounds/rig", rev="e7849df" }
enum_dispatch = "0.3.13"
tracing = "0.1.44"
tracing-subscriber = { version = "0.3.22", features = ["env-filter", "fmt"] }
minijinja = { git = "https://github.com/boundaryml/minijinja.git", branch = "main", default-features = false, features = ["builtins", "serde"] }

[package.metadata.cargo-machete]
ignored = ["rig-core"]
Expand Down
230 changes: 214 additions & 16 deletions crates/dspy-rs/src/adapter/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@ use bamltype::jsonish::BamlValueWithFlags;
use bamltype::jsonish::deserializer::coercer::run_user_checks;
use bamltype::jsonish::deserializer::deserialize_flags::DeserializerConditions;
use indexmap::IndexMap;
use minijinja::UndefinedBehavior;
use minijinja::value::{Kwargs, Value as MiniJinjaValue};
use regex::Regex;
use std::sync::LazyLock;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::{LazyLock, Mutex};
use tracing::{debug, trace};

use super::Adapter;
use crate::CallMetadata;
use crate::{
BamlType, BamlValue, ConstraintLevel, ConstraintResult, FieldMeta, Flag, JsonishError, Message,
OutputFormatContent, ParseError, PredictError, Predicted, RenderOptions, Signature, TypeIR,
BamlType, BamlValue, ConstraintLevel, ConstraintResult, FieldMeta, Flag, InputRenderSpec,
JsonishError, Message, OutputFormatContent, ParseError, PredictError, Predicted, RenderOptions,
Signature, TypeIR,
};

/// Builds prompts and parses responses using the `[[ ## field ## ]]` delimiter protocol.
Expand All @@ -33,6 +38,101 @@ pub struct ChatAdapter;
static FIELD_HEADER_PATTERN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^\[\[ ## ([^#]+?) ## \]\]").unwrap());

const INPUT_RENDER_TEMPLATE_NAME: &str = "__input_field__";

#[derive(Clone)]
struct CachedInputRenderTemplate {
env: minijinja::Environment<'static>,
}

static INPUT_RENDER_TEMPLATE_CACHE: LazyLock<
Mutex<HashMap<&'static str, CachedInputRenderTemplate>>,
> = LazyLock::new(|| Mutex::new(HashMap::new()));

fn regex_match(value: String, regex: String) -> bool {
match Regex::new(&regex) {
Ok(re) => re.is_match(&value),
Err(_) => false,
}
}

fn sum_filter(value: Vec<MiniJinjaValue>) -> MiniJinjaValue {
let int_sum: Option<i64> = value
.iter()
.map(|value| <i64>::try_from(value.clone()).ok())
.collect::<Option<Vec<_>>>()
.map(|ints| ints.into_iter().sum());
let float_sum: Option<f64> = value
.into_iter()
.map(|value| <f64>::try_from(value).ok())
.collect::<Option<Vec<_>>>()
.map(|floats| floats.into_iter().sum());
int_sum.map_or(
float_sum.map_or(MiniJinjaValue::from(0), MiniJinjaValue::from),
MiniJinjaValue::from,
)
}

fn truncate_filter(
value: String,
positional_length: Option<usize>,
kwargs: Kwargs,
) -> Result<String, minijinja::Error> {
let kwarg_length: Option<usize> = kwargs.get("length")?;
let length = kwarg_length.or(positional_length).unwrap_or(255);
let killwords: Option<bool> = kwargs.get("killwords")?;
let leeway: Option<usize> = kwargs.get("leeway")?;
let end: Option<String> = kwargs.get("end")?;
kwargs.assert_all_used()?;

let killwords = killwords.unwrap_or(false);
let leeway = leeway.unwrap_or(5);
let end = end.unwrap_or_else(|| "...".to_string());
let value_len = value.chars().count();

if value_len <= length.saturating_add(leeway) {
return Ok(value);
}

let trim_to = length.saturating_sub(end.chars().count());
if trim_to == 0 {
return Ok(end.chars().take(length).collect());
}

let mut truncated: String = value.chars().take(trim_to).collect();
if !killwords {
if let Some(index) = truncated.rfind(char::is_whitespace) {
if index > 0 {
truncated.truncate(index);
}
}
truncated = truncated.trim_end().to_string();
}

Ok(format!("{truncated}{end}"))
}

fn build_input_render_environment() -> minijinja::Environment<'static> {
// Keep this setup aligned with BAML's jinja env defaults, then add contrib filters.
let mut env = minijinja::Environment::new();
env.set_formatter(|output, state, value| {
let value = if value.is_none() {
&MiniJinjaValue::from("null")
} else {
value
};
minijinja::escape_formatter(output, state, value)
});
env.set_debug(true);
env.set_trim_blocks(true);
env.set_lstrip_blocks(true);
env.set_undefined_behavior(UndefinedBehavior::Strict);
env.add_filter("regex_match", regex_match);
env.add_filter("sum", sum_filter);
env.add_filter("truncate", truncate_filter);
env
}

fn render_field_type_schema(
parent_format: &OutputFormatContent,
type_ir: &TypeIR,
Expand Down Expand Up @@ -396,15 +496,19 @@ impl ChatAdapter {
{
let baml_value = input.to_baml_value();
let input_output_format = <I as BamlType>::baml_output_format();
let input_json = build_input_context_value(schema, &baml_value);
let vars = Value::Object(serde_json::Map::new());

let mut result = String::new();
for field_spec in schema.input_fields() {
if let Some(value) = value_for_path_relaxed(&baml_value, field_spec.path()) {
result.push_str(&format!("[[ ## {} ## ]]\n", field_spec.lm_name));
result.push_str(&format_baml_value_for_prompt_typed(
result.push_str(&render_input_field(
field_spec,
value,
&input_json,
input_output_format,
field_spec.format,
&vars,
));
result.push_str("\n\n");
}
Expand Down Expand Up @@ -865,23 +969,117 @@ fn format_baml_value_for_prompt(value: &BamlValue) -> String {
}
}

fn format_baml_value_for_prompt_typed(
fn render_input_field(
field_spec: &crate::FieldSchema,
value: &BamlValue,
input: &Value,
output_format: &OutputFormatContent,
format: Option<&str>,
vars: &Value,
) -> String {
let format = match format {
Some(format) => format,
None => {
if let BamlValue::String(s) = value {
return s.clone();
}
"json"
match field_spec.input_render {
InputRenderSpec::Default => match value {
BamlValue::String(s) => s.clone(),
_ => bamltype::internal_baml_jinja::format_baml_value(value, output_format, "json")
.unwrap_or_else(|_| "<error>".to_string()),
},
InputRenderSpec::Format(format) => {
bamltype::internal_baml_jinja::format_baml_value(value, output_format, format)
.unwrap_or_else(|_| "<error>".to_string())
}
InputRenderSpec::Jinja(template) => {
render_input_field_jinja(template, field_spec, value, input, output_format, vars)
}
}
}

fn build_input_context_value(schema: &crate::SignatureSchema, root: &BamlValue) -> Value {
let mut input_json = baml_value_to_render_json(root);
let Some(root_map) = input_json.as_object_mut() else {
return input_json;
};

// Provide alias lookups for top-level fields so templates can use either
// Rust field names (`input.question`) or prompt aliases (`input.query`).
for field_spec in schema.input_fields() {
if field_spec.rust_name.contains('.') || field_spec.lm_name == field_spec.rust_name {
continue;
}
if field_spec.path().iter().nth(1).is_some() {
continue;
}
if let Some(value) = root_map.get(field_spec.rust_name.as_str()).cloned() {
root_map
.entry(field_spec.lm_name.to_string())
.or_insert(value);
}
}

input_json
}

fn baml_value_to_render_json(value: &BamlValue) -> Value {
serde_json::to_value(value).unwrap_or(Value::Null)
}

fn render_input_field_jinja(
template: &'static str,
field_spec: &crate::FieldSchema,
value: &BamlValue,
input: &Value,
_output_format: &OutputFormatContent,
vars: &Value,
) -> String {
let env = {
let mut cache = INPUT_RENDER_TEMPLATE_CACHE
.lock()
.expect("input render template cache lock poisoned");
cache
.entry(template)
.or_insert_with(|| {
let mut env = build_input_render_environment();
env.add_template(INPUT_RENDER_TEMPLATE_NAME, template)
.unwrap_or_else(|err| {
panic!(
"failed to compile cached input render template for `{}` ({}): {err}",
field_spec.lm_name, field_spec.rust_name
)
});
CachedInputRenderTemplate { env }
})
.env
.clone()
};

bamltype::internal_baml_jinja::format_baml_value(value, output_format, format)
.unwrap_or_else(|_| "<error>".to_string())
let compiled = env
.get_template(INPUT_RENDER_TEMPLATE_NAME)
.unwrap_or_else(|err| {
panic!(
"failed to fetch cached input render template for `{}` ({}): {err}",
field_spec.lm_name, field_spec.rust_name
)
});

let this = baml_value_to_render_json(value);
let field = json!({
"name": field_spec.lm_name,
"rust_name": field_spec.rust_name,
"type": field_spec.type_ir.diagnostic_repr().to_string(),
});
let context = json!({
"this": this,
"input": input,
"field": field,
"vars": vars,
});

compiled
.render(minijinja::Value::from_serialize(context))
.unwrap_or_else(|err| {
panic!(
"failed to render input field `{}` (rust `{}`) with #[render(jinja = ...)] template `{}`: {err}",
field_spec.lm_name, field_spec.rust_name, template
)
})
}

fn collect_flags_recursive(value: &BamlValueWithFlags, flags: &mut Vec<Flag>) {
Expand Down
2 changes: 1 addition & 1 deletion crates/dspy-rs/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub use lm::*;
pub use module::*;
pub use module_ext::*;
pub use predicted::{CallMetadata, ConstraintResult, FieldMeta, Predicted};
pub use schema::{FieldMetadataSpec, FieldPath, FieldSchema, SignatureSchema};
pub use schema::{FieldMetadataSpec, FieldPath, FieldSchema, InputRenderSpec, SignatureSchema};
pub use settings::*;
pub use signature::*;
pub use specials::*;
26 changes: 19 additions & 7 deletions crates/dspy-rs/src/core/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,17 @@ impl FieldPath {
/// Static metadata for a single signature field, emitted by `#[derive(Signature)]`.
///
/// Carries the Rust field name, optional LM-facing alias, constraint specs, and
/// format hints. Fed into [`SignatureSchema`] construction alongside Facet shape data.
/// input render hints. Fed into [`SignatureSchema`] construction alongside Facet shape data.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InputRenderSpec {
/// Default behavior: strings are raw, non-strings are rendered as JSON.
Default,
/// Explicit format hint (`#[format("json" | "yaml" | "toon")]`).
Format(&'static str),
/// Custom Jinja template (`#[render(jinja = "...")]`).
Jinja(&'static str),
}

#[derive(Debug, Clone, Copy)]
pub struct FieldMetadataSpec {
/// The Rust field name as written in the signature struct.
Expand All @@ -57,8 +67,8 @@ pub struct FieldMetadataSpec {
pub alias: Option<&'static str>,
/// Constraint specs from `#[check(...)]` and `#[assert(...)]` attributes.
pub constraints: &'static [ConstraintSpec],
/// Optional format hint (e.g. `#[format = "json"]`).
pub format: Option<&'static str>,
/// Input rendering policy for this field.
pub input_render: InputRenderSpec,
}

/// Complete schema for a single field in a signature, combining Facet shape data with metadata.
Expand All @@ -81,8 +91,8 @@ pub struct FieldSchema {
pub path: FieldPath,
/// Constraints declared on this field.
pub constraints: &'static [ConstraintSpec],
/// Optional format hint.
pub format: Option<&'static str>,
/// Input rendering policy.
pub input_render: InputRenderSpec,
}

impl FieldSchema {
Expand Down Expand Up @@ -329,7 +339,9 @@ fn emit_field(
let lm_name = inherited
.and_then(|meta| meta.alias)
.unwrap_or_else(|| field.effective_name());
let format = inherited.and_then(|meta| meta.format);
let input_render = inherited
.map(|meta| meta.input_render)
.unwrap_or(InputRenderSpec::Default);

out.push(FieldSchema {
lm_name,
Expand All @@ -339,7 +351,7 @@ fn emit_field(
shape: field.shape(),
path,
constraints,
format,
input_render,
});

Ok(())
Expand Down
4 changes: 2 additions & 2 deletions crates/dspy-rs/src/core/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ pub trait Signature: Send + Sync + 'static {
/// The Facet shape of the output struct.
fn output_shape() -> &'static Shape;

/// Per-field metadata for input fields (aliases, constraints, format hints).
/// Per-field metadata for input fields (aliases, constraints, input render hints).
fn input_field_metadata() -> &'static [FieldMetadataSpec];
/// Per-field metadata for output fields (aliases, constraints, format hints).
/// Per-field metadata for output fields (aliases, constraints, input render hints).
fn output_field_metadata() -> &'static [FieldMetadataSpec];

/// The output format descriptor used by the adapter for structured output parsing.
Expand Down
Loading
Loading