diff --git a/Cargo.lock b/Cargo.lock index 4e7849b..872e88e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1162,6 +1162,7 @@ dependencies = [ "hf-hub", "indexmap", "kdam", + "minijinja", "parquet", "rand 0.8.5", "rayon", @@ -1184,6 +1185,7 @@ name = "dsrs_macros" version = "0.7.2" dependencies = [ "dspy-rs", + "minijinja", "proc-macro-crate", "proc-macro2", "quote", diff --git a/const_block_generic b/const_block_generic deleted file mode 100755 index 4a286b0..0000000 Binary files a/const_block_generic and /dev/null differ diff --git a/crates/bamltype/tests/ui.rs b/crates/bamltype/tests/ui.rs index 365834b..01abce0 100644 --- a/crates/bamltype/tests/ui.rs +++ b/crates/bamltype/tests/ui.rs @@ -1,4 +1,8 @@ #[test] +#[cfg_attr( + miri, + ignore = "trybuild launches subprocesses and is unsupported under miri" +)] fn ui_compile_failures() { let t = trybuild::TestCases::new(); t.compile_fail("tests/ui/*.rs"); diff --git a/crates/dspy-rs/Cargo.toml b/crates/dspy-rs/Cargo.toml index 0ea4584..c128afe 100644 --- a/crates/dspy-rs/Cargo.toml +++ b/crates/dspy-rs/Cargo.toml @@ -44,6 +44,7 @@ rig-core = { git = "https://github.com/0xPlaygrounds/rig", rev="e7849df" } enum_dispatch = "0.3.13" tracing = "0.1.44" tracing-subscriber = { version = "0.3.22", features = ["env-filter", "fmt"] } +minijinja = { git = "https://github.com/boundaryml/minijinja.git", branch = "main", default-features = false, features = ["builtins", "serde"] } [package.metadata.cargo-machete] ignored = ["rig-core"] diff --git a/crates/dspy-rs/src/adapter/chat.rs b/crates/dspy-rs/src/adapter/chat.rs index bdac90d..8566602 100644 --- a/crates/dspy-rs/src/adapter/chat.rs +++ b/crates/dspy-rs/src/adapter/chat.rs @@ -4,15 +4,20 @@ use bamltype::jsonish::BamlValueWithFlags; use bamltype::jsonish::deserializer::coercer::run_user_checks; use bamltype::jsonish::deserializer::deserialize_flags::DeserializerConditions; use indexmap::IndexMap; +use minijinja::UndefinedBehavior; +use minijinja::value::{Kwargs, Value as MiniJinjaValue}; use regex::Regex; -use std::sync::LazyLock; +use serde_json::{Value, json}; +use std::collections::HashMap; +use std::sync::{LazyLock, Mutex}; use tracing::{debug, trace}; use super::Adapter; use crate::CallMetadata; use crate::{ - BamlType, BamlValue, ConstraintLevel, ConstraintResult, FieldMeta, Flag, JsonishError, Message, - OutputFormatContent, ParseError, PredictError, Predicted, RenderOptions, Signature, TypeIR, + BamlType, BamlValue, ConstraintLevel, ConstraintResult, FieldMeta, Flag, InputRenderSpec, + JsonishError, Message, OutputFormatContent, ParseError, PredictError, Predicted, RenderOptions, + Signature, TypeIR, }; /// Builds prompts and parses responses using the `[[ ## field ## ]]` delimiter protocol. @@ -33,6 +38,101 @@ pub struct ChatAdapter; static FIELD_HEADER_PATTERN: LazyLock = LazyLock::new(|| Regex::new(r"^\[\[ ## ([^#]+?) ## \]\]").unwrap()); +const INPUT_RENDER_TEMPLATE_NAME: &str = "__input_field__"; + +#[derive(Clone)] +struct CachedInputRenderTemplate { + env: minijinja::Environment<'static>, +} + +static INPUT_RENDER_TEMPLATE_CACHE: LazyLock< + Mutex>, +> = LazyLock::new(|| Mutex::new(HashMap::new())); + +fn regex_match(value: String, regex: String) -> bool { + match Regex::new(®ex) { + Ok(re) => re.is_match(&value), + Err(_) => false, + } +} + +fn sum_filter(value: Vec) -> MiniJinjaValue { + let int_sum: Option = value + .iter() + .map(|value| ::try_from(value.clone()).ok()) + .collect::>>() + .map(|ints| ints.into_iter().sum()); + let float_sum: Option = value + .into_iter() + .map(|value| ::try_from(value).ok()) + .collect::>>() + .map(|floats| floats.into_iter().sum()); + int_sum.map_or( + float_sum.map_or(MiniJinjaValue::from(0), MiniJinjaValue::from), + MiniJinjaValue::from, + ) +} + +fn truncate_filter( + value: String, + positional_length: Option, + kwargs: Kwargs, +) -> Result { + let kwarg_length: Option = kwargs.get("length")?; + let length = kwarg_length.or(positional_length).unwrap_or(255); + let killwords: Option = kwargs.get("killwords")?; + let leeway: Option = kwargs.get("leeway")?; + let end: Option = kwargs.get("end")?; + kwargs.assert_all_used()?; + + let killwords = killwords.unwrap_or(false); + let leeway = leeway.unwrap_or(5); + let end = end.unwrap_or_else(|| "...".to_string()); + let value_len = value.chars().count(); + + if value_len <= length.saturating_add(leeway) { + return Ok(value); + } + + let trim_to = length.saturating_sub(end.chars().count()); + if trim_to == 0 { + return Ok(end.chars().take(length).collect()); + } + + let mut truncated: String = value.chars().take(trim_to).collect(); + if !killwords { + if let Some(index) = truncated.rfind(char::is_whitespace) { + if index > 0 { + truncated.truncate(index); + } + } + truncated = truncated.trim_end().to_string(); + } + + Ok(format!("{truncated}{end}")) +} + +fn build_input_render_environment() -> minijinja::Environment<'static> { + // Keep this setup aligned with BAML's jinja env defaults, then add contrib filters. + let mut env = minijinja::Environment::new(); + env.set_formatter(|output, state, value| { + let value = if value.is_none() { + &MiniJinjaValue::from("null") + } else { + value + }; + minijinja::escape_formatter(output, state, value) + }); + env.set_debug(true); + env.set_trim_blocks(true); + env.set_lstrip_blocks(true); + env.set_undefined_behavior(UndefinedBehavior::Strict); + env.add_filter("regex_match", regex_match); + env.add_filter("sum", sum_filter); + env.add_filter("truncate", truncate_filter); + env +} + fn render_field_type_schema( parent_format: &OutputFormatContent, type_ir: &TypeIR, @@ -396,15 +496,19 @@ impl ChatAdapter { { let baml_value = input.to_baml_value(); let input_output_format = ::baml_output_format(); + let input_json = build_input_context_value(schema, &baml_value); + let vars = Value::Object(serde_json::Map::new()); let mut result = String::new(); for field_spec in schema.input_fields() { if let Some(value) = value_for_path_relaxed(&baml_value, field_spec.path()) { result.push_str(&format!("[[ ## {} ## ]]\n", field_spec.lm_name)); - result.push_str(&format_baml_value_for_prompt_typed( + result.push_str(&render_input_field( + field_spec, value, + &input_json, input_output_format, - field_spec.format, + &vars, )); result.push_str("\n\n"); } @@ -865,23 +969,117 @@ fn format_baml_value_for_prompt(value: &BamlValue) -> String { } } -fn format_baml_value_for_prompt_typed( +fn render_input_field( + field_spec: &crate::FieldSchema, value: &BamlValue, + input: &Value, output_format: &OutputFormatContent, - format: Option<&str>, + vars: &Value, ) -> String { - let format = match format { - Some(format) => format, - None => { - if let BamlValue::String(s) = value { - return s.clone(); - } - "json" + match field_spec.input_render { + InputRenderSpec::Default => match value { + BamlValue::String(s) => s.clone(), + _ => bamltype::internal_baml_jinja::format_baml_value(value, output_format, "json") + .unwrap_or_else(|_| "".to_string()), + }, + InputRenderSpec::Format(format) => { + bamltype::internal_baml_jinja::format_baml_value(value, output_format, format) + .unwrap_or_else(|_| "".to_string()) + } + InputRenderSpec::Jinja(template) => { + render_input_field_jinja(template, field_spec, value, input, output_format, vars) + } + } +} + +fn build_input_context_value(schema: &crate::SignatureSchema, root: &BamlValue) -> Value { + let mut input_json = baml_value_to_render_json(root); + let Some(root_map) = input_json.as_object_mut() else { + return input_json; + }; + + // Provide alias lookups for top-level fields so templates can use either + // Rust field names (`input.question`) or prompt aliases (`input.query`). + for field_spec in schema.input_fields() { + if field_spec.rust_name.contains('.') || field_spec.lm_name == field_spec.rust_name { + continue; + } + if field_spec.path().iter().nth(1).is_some() { + continue; + } + if let Some(value) = root_map.get(field_spec.rust_name.as_str()).cloned() { + root_map + .entry(field_spec.lm_name.to_string()) + .or_insert(value); } + } + + input_json +} + +fn baml_value_to_render_json(value: &BamlValue) -> Value { + serde_json::to_value(value).unwrap_or(Value::Null) +} + +fn render_input_field_jinja( + template: &'static str, + field_spec: &crate::FieldSchema, + value: &BamlValue, + input: &Value, + _output_format: &OutputFormatContent, + vars: &Value, +) -> String { + let env = { + let mut cache = INPUT_RENDER_TEMPLATE_CACHE + .lock() + .expect("input render template cache lock poisoned"); + cache + .entry(template) + .or_insert_with(|| { + let mut env = build_input_render_environment(); + env.add_template(INPUT_RENDER_TEMPLATE_NAME, template) + .unwrap_or_else(|err| { + panic!( + "failed to compile cached input render template for `{}` ({}): {err}", + field_spec.lm_name, field_spec.rust_name + ) + }); + CachedInputRenderTemplate { env } + }) + .env + .clone() }; - bamltype::internal_baml_jinja::format_baml_value(value, output_format, format) - .unwrap_or_else(|_| "".to_string()) + let compiled = env + .get_template(INPUT_RENDER_TEMPLATE_NAME) + .unwrap_or_else(|err| { + panic!( + "failed to fetch cached input render template for `{}` ({}): {err}", + field_spec.lm_name, field_spec.rust_name + ) + }); + + let this = baml_value_to_render_json(value); + let field = json!({ + "name": field_spec.lm_name, + "rust_name": field_spec.rust_name, + "type": field_spec.type_ir.diagnostic_repr().to_string(), + }); + let context = json!({ + "this": this, + "input": input, + "field": field, + "vars": vars, + }); + + compiled + .render(minijinja::Value::from_serialize(context)) + .unwrap_or_else(|err| { + panic!( + "failed to render input field `{}` (rust `{}`) with #[render(jinja = ...)] template `{}`: {err}", + field_spec.lm_name, field_spec.rust_name, template + ) + }) } fn collect_flags_recursive(value: &BamlValueWithFlags, flags: &mut Vec) { diff --git a/crates/dspy-rs/src/core/mod.rs b/crates/dspy-rs/src/core/mod.rs index a16896b..64bf2cb 100644 --- a/crates/dspy-rs/src/core/mod.rs +++ b/crates/dspy-rs/src/core/mod.rs @@ -37,7 +37,7 @@ pub use lm::*; pub use module::*; pub use module_ext::*; pub use predicted::{CallMetadata, ConstraintResult, FieldMeta, Predicted}; -pub use schema::{FieldMetadataSpec, FieldPath, FieldSchema, SignatureSchema}; +pub use schema::{FieldMetadataSpec, FieldPath, FieldSchema, InputRenderSpec, SignatureSchema}; pub use settings::*; pub use signature::*; pub use specials::*; diff --git a/crates/dspy-rs/src/core/schema.rs b/crates/dspy-rs/src/core/schema.rs index 2afbd7b..0e870e4 100644 --- a/crates/dspy-rs/src/core/schema.rs +++ b/crates/dspy-rs/src/core/schema.rs @@ -48,7 +48,17 @@ impl FieldPath { /// Static metadata for a single signature field, emitted by `#[derive(Signature)]`. /// /// Carries the Rust field name, optional LM-facing alias, constraint specs, and -/// format hints. Fed into [`SignatureSchema`] construction alongside Facet shape data. +/// input render hints. Fed into [`SignatureSchema`] construction alongside Facet shape data. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InputRenderSpec { + /// Default behavior: strings are raw, non-strings are rendered as JSON. + Default, + /// Explicit format hint (`#[format("json" | "yaml" | "toon")]`). + Format(&'static str), + /// Custom Jinja template (`#[render(jinja = "...")]`). + Jinja(&'static str), +} + #[derive(Debug, Clone, Copy)] pub struct FieldMetadataSpec { /// The Rust field name as written in the signature struct. @@ -57,8 +67,8 @@ pub struct FieldMetadataSpec { pub alias: Option<&'static str>, /// Constraint specs from `#[check(...)]` and `#[assert(...)]` attributes. pub constraints: &'static [ConstraintSpec], - /// Optional format hint (e.g. `#[format = "json"]`). - pub format: Option<&'static str>, + /// Input rendering policy for this field. + pub input_render: InputRenderSpec, } /// Complete schema for a single field in a signature, combining Facet shape data with metadata. @@ -81,8 +91,8 @@ pub struct FieldSchema { pub path: FieldPath, /// Constraints declared on this field. pub constraints: &'static [ConstraintSpec], - /// Optional format hint. - pub format: Option<&'static str>, + /// Input rendering policy. + pub input_render: InputRenderSpec, } impl FieldSchema { @@ -329,7 +339,9 @@ fn emit_field( let lm_name = inherited .and_then(|meta| meta.alias) .unwrap_or_else(|| field.effective_name()); - let format = inherited.and_then(|meta| meta.format); + let input_render = inherited + .map(|meta| meta.input_render) + .unwrap_or(InputRenderSpec::Default); out.push(FieldSchema { lm_name, @@ -339,7 +351,7 @@ fn emit_field( shape: field.shape(), path, constraints, - format, + input_render, }); Ok(()) diff --git a/crates/dspy-rs/src/core/signature.rs b/crates/dspy-rs/src/core/signature.rs index 5610740..616917a 100644 --- a/crates/dspy-rs/src/core/signature.rs +++ b/crates/dspy-rs/src/core/signature.rs @@ -73,9 +73,9 @@ pub trait Signature: Send + Sync + 'static { /// The Facet shape of the output struct. fn output_shape() -> &'static Shape; - /// Per-field metadata for input fields (aliases, constraints, format hints). + /// Per-field metadata for input fields (aliases, constraints, input render hints). fn input_field_metadata() -> &'static [FieldMetadataSpec]; - /// Per-field metadata for output fields (aliases, constraints, format hints). + /// Per-field metadata for output fields (aliases, constraints, input render hints). fn output_field_metadata() -> &'static [FieldMetadataSpec]; /// The output format descriptor used by the adapter for structured output parsing. diff --git a/crates/dspy-rs/tests/test_input_format.rs b/crates/dspy-rs/tests/test_input_format.rs index c26d696..8170ae9 100644 --- a/crates/dspy-rs/tests/test_input_format.rs +++ b/crates/dspy-rs/tests/test_input_format.rs @@ -61,6 +61,74 @@ struct DefaultFormatSig { answer: String, } +#[derive(Signature, Clone, Debug)] +/// Render a context field using Jinja. +struct RenderJinjaSig { + #[input] + question: String, + + #[input] + #[alias("ctx")] + #[render( + jinja = "{{ this.text }} | {{ input.question }} | {{ input.ctx.text }} | {{ input.context.text }} | {{ field.name }} | {{ field.rust_name }}" + )] + context: Document, + + #[output] + answer: String, +} + +#[derive(Signature, Clone, Debug)] +/// Render with strict undefined vars. +struct RenderJinjaStrictSig { + #[input] + #[render(jinja = "{{ missing_var }}")] + question: String, + + #[output] + answer: String, +} + +#[derive(Signature, Clone, Debug)] +/// Render using field metadata and vars context. +struct RenderJinjaFieldMetaSig { + #[input] + #[alias("ctx")] + #[render( + jinja = "{{ field.name }}|{{ field.rust_name }}|{{ field.type }}|{{ vars is defined }}" + )] + context: Document, + + #[output] + answer: String, +} + +#[derive(Signature, Clone, Debug)] +/// Render non-string primitive fields. +struct RenderPrimitiveSig { + #[input] + #[render(jinja = "{{ this }}")] + count: i64, + + #[input] + #[render(jinja = "{{ this }}")] + is_ready: bool, + + #[output] + answer: String, +} + +#[derive(Signature, Clone, Debug)] +/// Render using contrib filters registered in the adapter Jinja environment. +struct RenderContribFilterSig { + #[input] + #[render(jinja = "{{ this.text | truncate(length=5, killwords=true, leeway=0, end='') }}")] + context: Document, + + #[output] + answer: String, +} + fn extract_field(message: &str, field_name: &str) -> String { let start_marker = format!("[[ ## {field_name} ## ]]"); let start_pos = message @@ -206,3 +274,86 @@ fn typed_input_appends_response_instruction_reminder() { assert!(message.contains("[[ ## answer ## ]]")); assert!(message.contains("[[ ## completed ## ]]")); } + +#[test] +fn typed_input_render_jinja_uses_context_values() { + let adapter = ChatAdapter; + let input = RenderJinjaSigInput { + question: "Question".to_string(), + context: Document { + text: "Hello".to_string(), + }, + }; + + let message = adapter.format_user_message_typed::(&input); + let context_value = extract_field(&message, "ctx"); + + assert_eq!( + context_value, + "Hello | Question | Hello | Hello | ctx | context" + ); +} + +#[test] +fn typed_input_render_jinja_missing_var_panics() { + let adapter = ChatAdapter; + let input = RenderJinjaStrictSigInput { + question: "Question".to_string(), + }; + + let result = std::panic::catch_unwind(|| { + adapter.format_user_message_typed::(&input) + }); + assert!(result.is_err(), "missing Jinja variables should panic"); +} + +#[test] +fn typed_input_render_jinja_exposes_field_metadata_and_vars() { + let adapter = ChatAdapter; + let input = RenderJinjaFieldMetaSigInput { + context: Document { + text: "Hello".to_string(), + }, + }; + + let message = adapter.format_user_message_typed::(&input); + let context_value = extract_field(&message, "ctx"); + let parts: Vec<&str> = context_value.split('|').collect(); + + assert_eq!(parts.len(), 4); + assert_eq!(parts[0], "ctx"); + assert_eq!(parts[1], "context"); + assert!(parts[2].contains("Document")); + assert_eq!(parts[3].to_ascii_lowercase(), "true"); +} + +#[test] +fn typed_input_render_jinja_non_string_primitives() { + let adapter = ChatAdapter; + let input = RenderPrimitiveSigInput { + count: 42, + is_ready: true, + }; + + let message = adapter.format_user_message_typed::(&input); + let count_value = extract_field(&message, "count"); + let ready_value = extract_field(&message, "is_ready"); + + assert_eq!(count_value, "42"); + assert_eq!(ready_value.to_ascii_lowercase(), "true"); +} + +#[test] +fn typed_input_render_jinja_supports_contrib_filters() { + let adapter = ChatAdapter; + let input = RenderContribFilterSigInput { + context: Document { + text: "abcdefg".to_string(), + }, + }; + + let message = adapter.format_user_message_typed::(&input); + let context_value = extract_field(&message, "context"); + + assert_eq!(context_value, "abcde"); +} diff --git a/crates/dspy-rs/tests/test_signature_macro.rs b/crates/dspy-rs/tests/test_signature_macro.rs index e54ffc9..01f527d 100644 --- a/crates/dspy-rs/tests/test_signature_macro.rs +++ b/crates/dspy-rs/tests/test_signature_macro.rs @@ -1,4 +1,4 @@ -use dspy_rs::Signature; +use dspy_rs::{InputRenderSpec, Signature}; #[derive(Signature, Clone, Debug)] struct AliasAndFormatSignature { @@ -24,19 +24,46 @@ fn signature_macro_emits_alias_and_format_metadata() { let input = &schema.input_fields()[0]; assert_eq!(input.rust_name, "request_body"); assert_eq!(input.lm_name, "payload"); - assert_eq!(input.format, Some("json")); + assert_eq!(input.input_render, InputRenderSpec::Format("json")); let output = &schema.output_fields()[0]; assert_eq!(output.rust_name, "answer"); assert_eq!(output.lm_name, "result"); - assert_eq!(output.format, None); + assert_eq!(output.input_render, InputRenderSpec::Default); let input_meta = AliasAndFormatSignature::input_field_metadata(); assert_eq!(input_meta[0].alias, Some("payload")); - assert_eq!(input_meta[0].format, Some("json")); + assert_eq!(input_meta[0].input_render, InputRenderSpec::Format("json")); let output_meta = AliasAndFormatSignature::output_field_metadata(); assert_eq!(output_meta[0].alias, Some("result")); + assert_eq!(output_meta[0].input_render, InputRenderSpec::Default); +} + +#[derive(Signature, Clone, Debug)] +struct RenderSignature { + #[input] + #[render(jinja = "{{ this }}")] + question: String, + + #[output] + answer: String, +} + +#[test] +fn signature_macro_emits_render_metadata() { + let schema = RenderSignature::schema(); + assert_eq!(schema.input_fields().len(), 1); + assert_eq!( + schema.input_fields()[0].input_render, + InputRenderSpec::Jinja("{{ this }}") + ); + + let input_meta = RenderSignature::input_field_metadata(); + assert_eq!( + input_meta[0].input_render, + InputRenderSpec::Jinja("{{ this }}") + ); } #[derive(Signature, Clone, Debug)] diff --git a/crates/dsrs-macros/Cargo.toml b/crates/dsrs-macros/Cargo.toml index 1c5de3e..4b66663 100644 --- a/crates/dsrs-macros/Cargo.toml +++ b/crates/dsrs-macros/Cargo.toml @@ -19,6 +19,7 @@ quote = "1" proc-macro2 = "1" proc-macro-crate = "3.2" serde_json = { version = "1.0.143", features = ["preserve_order"] } +minijinja = { git = "https://github.com/boundaryml/minijinja.git", branch = "main", default-features = false, features = ["serde"] } [dev-dependencies] dspy-rs = { path = "../dspy-rs" } diff --git a/crates/dsrs-macros/src/lib.rs b/crates/dsrs-macros/src/lib.rs index 2afcf05..320d521 100644 --- a/crates/dsrs-macros/src/lib.rs +++ b/crates/dsrs-macros/src/lib.rs @@ -1,6 +1,6 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use syn::{ Attribute, Data, DeriveInput, Expr, ExprLit, Fields, Ident, Lit, LitStr, Meta, MetaNameValue, Token, Visibility, @@ -16,7 +16,7 @@ use runtime_path::resolve_dspy_rs_path; #[proc_macro_derive( Signature, - attributes(input, output, check, assert, alias, format, flatten) + attributes(input, output, check, assert, alias, format, render, flatten) )] pub fn derive_signature(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -82,10 +82,17 @@ struct ParsedField { is_flatten: bool, description: String, alias: Option, - format: Option, + input_render: ParsedInputRender, constraints: Vec, } +#[derive(Clone)] +enum ParsedInputRender { + Default, + Format(String), + Jinja(String), +} + #[derive(Clone, Copy, PartialEq, Eq)] enum ParsedConstraintKind { Check, @@ -187,6 +194,8 @@ fn parse_signature_fields( "#[derive(Signature)] requires at least one #[output] field", )); } + validate_unique_lm_names(&input_fields, "input")?; + validate_unique_lm_names(&output_fields, "output")?; Ok(ParsedSignature { input_fields, @@ -196,6 +205,25 @@ fn parse_signature_fields( }) } +fn validate_unique_lm_names(fields: &[ParsedField], kind: &str) -> syn::Result<()> { + let mut seen = HashMap::::new(); + + for field in fields { + let rust_name = field.ident.to_string(); + let lm_name = field.alias.as_deref().unwrap_or(&rust_name).to_string(); + if let Some(previous_rust_name) = seen.insert(lm_name.clone(), rust_name.clone()) { + return Err(syn::Error::new( + proc_macro2::Span::call_site(), + format!( + "duplicate {kind} field name `{lm_name}` after aliasing; conflicts between `{previous_rust_name}` and `{rust_name}`" + ), + )); + } + } + + Ok(()) +} + fn parse_single_field(field: &syn::Field) -> syn::Result { let ident = field.ident.clone().ok_or_else(|| { syn::Error::new_spanned(field, "#[derive(Signature)] requires named fields") @@ -207,6 +235,7 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { let mut saw_flatten = false; let mut alias = None; let mut format = None; + let mut render_jinja = None; let mut constraints = Vec::new(); let mut desc_override = None; @@ -231,6 +260,16 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { )); } format = Some(parse_string_attr(attr, "format")?); + } else if attr.path().is_ident("render") { + if render_jinja.is_some() { + return Err(syn::Error::new_spanned( + attr, + "#[render] can only be specified once per field", + )); + } + let template = parse_render_jinja_attr(attr)?; + validate_jinja_template(&template, attr.span())?; + render_jinja = Some(template); } else if attr.path().is_ident("flatten") { if saw_flatten { return Err(syn::Error::new_spanned( @@ -247,16 +286,31 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { } } + if format.is_some() && render_jinja.is_some() { + return Err(syn::Error::new_spanned( + field, + "#[format] and #[render] cannot be combined on the same field", + )); + } + if format.is_some() && !is_input { return Err(syn::Error::new_spanned( field, "#[format] is only supported on #[input] fields", )); } + if render_jinja.is_some() && !is_input { + return Err(syn::Error::new_spanned( + field, + "#[render] is only supported on #[input] fields", + )); + } - if let Some(format_value) = format.as_deref() { + let input_render = if let Some(template) = render_jinja { + ParsedInputRender::Jinja(template) + } else if let Some(format_value) = format { match format_value.to_ascii_lowercase().as_str() { - "json" | "yaml" | "toon" => {} + "json" | "yaml" | "toon" => ParsedInputRender::Format(format_value), _ => { return Err(syn::Error::new_spanned( field, @@ -264,12 +318,18 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { )); } } - } + } else { + ParsedInputRender::Default + }; - if is_flatten && (alias.is_some() || format.is_some() || !constraints.is_empty()) { + if is_flatten + && (alias.is_some() + || !matches!(input_render, ParsedInputRender::Default) + || !constraints.is_empty()) + { return Err(syn::Error::new_spanned( field, - "#[flatten] cannot be combined with #[alias], #[format], #[check], or #[assert]", + "#[flatten] cannot be combined with #[alias], #[format], #[render], #[check], or #[assert]", )); } @@ -286,7 +346,7 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { is_flatten, description, alias, - format, + input_render, constraints, }) } @@ -338,6 +398,44 @@ fn parse_string_attr(attr: &Attribute, attr_name: &str) -> syn::Result { } } +fn parse_render_jinja_attr(attr: &Attribute) -> syn::Result { + match &attr.meta { + Meta::List(list) => { + let metas = list.parse_args_with( + syn::punctuated::Punctuated::::parse_terminated, + )?; + + if metas.len() != 1 { + return Err(syn::Error::new_spanned( + attr, + "expected #[render(jinja = \"...\")]", + )); + } + + match metas.first() { + Some(Meta::NameValue(meta)) if meta.path.is_ident("jinja") => { + parse_string_expr(&meta.value, meta.span()) + } + _ => Err(syn::Error::new_spanned( + attr, + "expected #[render(jinja = \"...\")]", + )), + } + } + _ => Err(syn::Error::new_spanned( + attr, + "expected #[render(jinja = \"...\")]", + )), + } +} + +fn validate_jinja_template(template: &str, span: proc_macro2::Span) -> syn::Result<()> { + let mut env = minijinja::Environment::new(); + env.add_template("__input_field__", template) + .map_err(|_| syn::Error::new(span, "invalid Jinja syntax in #[render(jinja = \"...\")]"))?; + Ok(()) +} + fn parse_constraint_attr( attr: &Attribute, kind: ParsedConstraintKind, @@ -877,7 +975,7 @@ fn field_tokens(field: &ParsedField) -> proc_macro2::TokenStream { attrs.push(quote! { #[facet(flatten)] }); } - // Note: aliases, formats, and constraints are emitted in + // Note: aliases, input render hints, and constraints are emitted in // generate_field_metadata(), not as struct attributes. quote! { @@ -919,12 +1017,16 @@ fn generate_field_metadata( } None => quote! { None }, }; - let format = match &field.format { - Some(value) => { + let input_render = match &field.input_render { + ParsedInputRender::Default => quote! { #runtime::InputRenderSpec::Default }, + ParsedInputRender::Format(value) => { let lit = LitStr::new(value, proc_macro2::Span::call_site()); - quote! { Some(#lit) } + quote! { #runtime::InputRenderSpec::Format(#lit) } + } + ParsedInputRender::Jinja(value) => { + let lit = LitStr::new(value, proc_macro2::Span::call_site()); + quote! { #runtime::InputRenderSpec::Jinja(#lit) } } - None => quote! { None }, }; let constraints_name = format_ident!( @@ -973,7 +1075,7 @@ fn generate_field_metadata( rust_name: #rust_name, alias: #alias, constraints: #constraints_name, - format: #format, + input_render: #input_render, } }); } diff --git a/crates/dsrs-macros/tests/signature_derive.rs b/crates/dsrs-macros/tests/signature_derive.rs index 9368db1..abe825f 100644 --- a/crates/dsrs-macros/tests/signature_derive.rs +++ b/crates/dsrs-macros/tests/signature_derive.rs @@ -1,4 +1,4 @@ -use dspy_rs::{BamlType, Facet, Signature as SignatureTrait, SignatureSchema}; +use dspy_rs::{BamlType, Facet, InputRenderSpec, Signature as SignatureTrait, SignatureSchema}; /// Test instruction #[derive(dsrs_macros::Signature, Clone, Debug)] @@ -36,6 +36,20 @@ struct LiteralConstraintSig { answer: String, } +#[derive(dsrs_macros::Signature, Clone, Debug)] +struct RenderSpecSig { + #[input] + #[render(jinja = "{{ this }}")] + template_input: String, + + #[input] + #[format("yaml")] + yaml_input: String, + + #[output] + answer: String, +} + #[derive(Clone, Debug)] #[BamlType] struct GenericCtx { @@ -121,3 +135,25 @@ fn derives_generic_helpers_and_flatten_paths() { let output_names: Vec<&str> = schema.output_fields().iter().map(|f| f.lm_name).collect(); assert_eq!(output_names, vec!["answer"]); } + +#[test] +fn emits_input_render_metadata() { + let input_meta = ::input_field_metadata(); + assert_eq!(input_meta.len(), 2); + + assert_eq!( + input_meta + .iter() + .find(|field| field.rust_name == "template_input") + .map(|field| field.input_render), + Some(InputRenderSpec::Jinja("{{ this }}")) + ); + + assert_eq!( + input_meta + .iter() + .find(|field| field.rust_name == "yaml_input") + .map(|field| field.input_render), + Some(InputRenderSpec::Format("yaml")) + ); +} diff --git a/crates/dsrs-macros/tests/ui/alias_conflict.rs b/crates/dsrs-macros/tests/ui/alias_conflict.rs new file mode 100644 index 0000000..f0edb34 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/alias_conflict.rs @@ -0,0 +1,16 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct AliasConflict { + #[input] + first: String, + + #[input] + #[alias("first")] + second: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/alias_conflict.stderr b/crates/dsrs-macros/tests/ui/alias_conflict.stderr new file mode 100644 index 0000000..e7ecf82 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/alias_conflict.stderr @@ -0,0 +1,7 @@ +error: duplicate input field name `first` after aliasing; conflicts between `first` and `second` + --> tests/ui/alias_conflict.rs:3:10 + | +3 | #[derive(Signature)] + | ^^^^^^^^^ + | + = note: this error originates in the derive macro `Signature` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/crates/dsrs-macros/tests/ui/format_render_conflict.rs b/crates/dsrs-macros/tests/ui/format_render_conflict.rs new file mode 100644 index 0000000..118e550 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/format_render_conflict.rs @@ -0,0 +1,14 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct FormatRenderConflict { + #[input] + #[format("json")] + #[render(jinja = "{{ this }}")] + context: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/format_render_conflict.stderr b/crates/dsrs-macros/tests/ui/format_render_conflict.stderr new file mode 100644 index 0000000..4930485 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/format_render_conflict.stderr @@ -0,0 +1,8 @@ +error: #[format] and #[render] cannot be combined on the same field + --> tests/ui/format_render_conflict.rs:5:5 + | +5 | / #[input] +6 | | #[format("json")] +7 | | #[render(jinja = "{{ this }}")] +8 | | context: String, + | |___________________^ diff --git a/crates/dsrs-macros/tests/ui/render_duplicate.rs b/crates/dsrs-macros/tests/ui/render_duplicate.rs new file mode 100644 index 0000000..372f7a5 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_duplicate.rs @@ -0,0 +1,14 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct RenderDuplicate { + #[input] + #[render(jinja = "{{ this }}")] + #[render(jinja = "{{ this }}")] + context: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/render_duplicate.stderr b/crates/dsrs-macros/tests/ui/render_duplicate.stderr new file mode 100644 index 0000000..0d31e5f --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_duplicate.stderr @@ -0,0 +1,5 @@ +error: #[render] can only be specified once per field + --> tests/ui/render_duplicate.rs:7:5 + | +7 | #[render(jinja = "{{ this }}")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/crates/dsrs-macros/tests/ui/render_invalid_jinja.rs b/crates/dsrs-macros/tests/ui/render_invalid_jinja.rs new file mode 100644 index 0000000..3bf3548 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_invalid_jinja.rs @@ -0,0 +1,13 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct RenderInvalidJinja { + #[input] + #[render(jinja = "{{ this ")] + context: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/render_invalid_jinja.stderr b/crates/dsrs-macros/tests/ui/render_invalid_jinja.stderr new file mode 100644 index 0000000..239c181 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_invalid_jinja.stderr @@ -0,0 +1,5 @@ +error: invalid Jinja syntax in #[render(jinja = "...")] + --> tests/ui/render_invalid_jinja.rs:6:5 + | +6 | #[render(jinja = "{{ this ")] + | ^ diff --git a/crates/dsrs-macros/tests/ui/render_invalid_key.rs b/crates/dsrs-macros/tests/ui/render_invalid_key.rs new file mode 100644 index 0000000..2c3faf1 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_invalid_key.rs @@ -0,0 +1,13 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct RenderInvalidKey { + #[input] + #[render(template = "{{ this }}")] + context: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/render_invalid_key.stderr b/crates/dsrs-macros/tests/ui/render_invalid_key.stderr new file mode 100644 index 0000000..c599c0a --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_invalid_key.stderr @@ -0,0 +1,5 @@ +error: expected #[render(jinja = "...")] + --> tests/ui/render_invalid_key.rs:6:5 + | +6 | #[render(template = "{{ this }}")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/crates/dsrs-macros/tests/ui/render_non_literal.rs b/crates/dsrs-macros/tests/ui/render_non_literal.rs new file mode 100644 index 0000000..159bf71 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_non_literal.rs @@ -0,0 +1,15 @@ +use dsrs_macros::Signature; + +const TEMPLATE: &str = "{{ this }}"; + +#[derive(Signature)] +struct RenderNonLiteral { + #[input] + #[render(jinja = TEMPLATE)] + context: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/render_non_literal.stderr b/crates/dsrs-macros/tests/ui/render_non_literal.stderr new file mode 100644 index 0000000..be37869 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_non_literal.stderr @@ -0,0 +1,5 @@ +error: expected string literal; hint: wrap the value in quotes + --> tests/ui/render_non_literal.rs:8:14 + | +8 | #[render(jinja = TEMPLATE)] + | ^^^^^ diff --git a/crates/dsrs-macros/tests/ui/render_on_output.rs b/crates/dsrs-macros/tests/ui/render_on_output.rs new file mode 100644 index 0000000..117cb24 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_on_output.rs @@ -0,0 +1,13 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct RenderOnOutput { + #[input] + question: String, + + #[output] + #[render(jinja = "{{ this }}")] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/render_on_output.stderr b/crates/dsrs-macros/tests/ui/render_on_output.stderr new file mode 100644 index 0000000..02d30b4 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_on_output.stderr @@ -0,0 +1,7 @@ +error: #[render] is only supported on #[input] fields + --> tests/ui/render_on_output.rs:8:5 + | + 8 | / #[output] + 9 | | #[render(jinja = "{{ this }}")] +10 | | answer: String, + | |__________________^ diff --git a/docs/docs/building-blocks/adapter.mdx b/docs/docs/building-blocks/adapter.mdx index 7fb62e0..3e2f515 100644 --- a/docs/docs/building-blocks/adapter.mdx +++ b/docs/docs/building-blocks/adapter.mdx @@ -148,7 +148,7 @@ This is useful for understanding what the LM sees. ## Input formatting options -The `#[format]` attribute on input fields controls serialization: +Input fields support two rendering paths: ```rust #[derive(Signature, Clone, Debug)] @@ -160,12 +160,33 @@ struct Search { #[format("yaml")] filters: Vec, // serialized as YAML + #[input] + #[render(jinja = "{{ this.text }}\nQuestion: {{ input.query }}")] + context: Context, // custom Jinja rendering + #[output] results: Vec, } ``` -Options: `"json"` (default for complex types), `"yaml"`, `"toon"` +- `#[format("json" | "yaml" | "toon")]`: serialize a field using that format. +- `#[render(jinja = "...")]`: render a field with a MiniJinja template. +- `#[format]` and `#[render]` are mutually exclusive on the same field. + +`#[render]` context in `ChatAdapter`: +- `this`: current field value +- `input`: full input object (plus top-level alias overlays) +- `field`: `{ name, rust_name, type }` +- `vars`: adapter/surface vars (currently `{}`) + +`ChatAdapter` configures MiniJinja with strict undefined behavior. +Available filters include: +- MiniJinja built-ins +- BAML parity helpers (`regex_match`, `sum`) +- `truncate` for length-limited string rendering + +If template rendering fails at runtime (for example, missing variables), `ChatAdapter` panics. +Compiled templates are cached process-wide by template string. ## Real example: Insurance claim extraction diff --git a/docs/docs/building-blocks/signature.mdx b/docs/docs/building-blocks/signature.mdx index f06d61e..21a8a60 100644 --- a/docs/docs/building-blocks/signature.mdx +++ b/docs/docs/building-blocks/signature.mdx @@ -172,6 +172,28 @@ context: Vec, Options: `"json"`, `"yaml"`, `"toon"` +### `#[render]` - Custom input rendering (Jinja) + +```rust +#[input] +#[render(jinja = "{{ this.title }}\n{{ this.body | truncate(300) }}")] +ticket: Ticket, +``` + +Use this when you need custom text rendering for an input field. + +- Only valid on `#[input]` fields +- Template must be a string literal +- Jinja syntax is validated at compile time +- Cannot be combined with `#[format]` on the same field +- In `ChatAdapter`, built-ins + BAML parity helpers (`regex_match`, `sum`) + `truncate` are available + +Template context: +- `this` - Current field value as JSON-like data +- `input` - Full input object (with top-level alias overlays) +- `field` - Metadata object with `name`, `rust_name`, and `type` +- `vars` - Adapter/surface variables (currently empty in `ChatAdapter`) + ### `#[check]` - Soft constraint ```rust diff --git a/docs/docs/getting-started/quickstart.mdx b/docs/docs/getting-started/quickstart.mdx index dbb2317..85b18dc 100644 --- a/docs/docs/getting-started/quickstart.mdx +++ b/docs/docs/getting-started/quickstart.mdx @@ -164,6 +164,11 @@ Compose multi-step pipelines ## Adding complexity +### Input formatting and rendering + +Use `#[format("json" | "yaml" | "toon")]` for serialization, or `#[render(jinja = "...")]` for custom field text. +See the full attribute reference in [Signatures](/docs/building-blocks/signature) and runtime behavior in [Adapter](/docs/building-blocks/adapter). + ### Custom types When you need more than primitives, add [`#[BamlType]`](/docs/building-blocks/types): diff --git a/promote_attr_like b/promote_attr_like deleted file mode 100755 index 3426b4f..0000000 Binary files a/promote_attr_like and /dev/null differ diff --git a/promote_generic b/promote_generic deleted file mode 100755 index 1f079c6..0000000 Binary files a/promote_generic and /dev/null differ diff --git a/promote_generic_fnptr b/promote_generic_fnptr deleted file mode 100755 index cd3d727..0000000 Binary files a/promote_generic_fnptr and /dev/null differ