diff --git a/Cargo.lock b/Cargo.lock index 5b4d727..5d3b49a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1190,6 +1190,7 @@ dependencies = [ "hf-hub", "indexmap", "kdam", + "minijinja", "parquet", "rand 0.8.5", "rayon", @@ -1212,6 +1213,7 @@ name = "dsrs_macros" version = "0.7.2" dependencies = [ "dspy-rs", + "minijinja", "proc-macro-crate", "proc-macro2", "quote", diff --git a/crates/bamltype/tests/ui/non_string_literal_attr.stderr b/crates/bamltype/tests/ui/non_string_literal_attr.stderr index 40feadf..c764186 100644 --- a/crates/bamltype/tests/ui/non_string_literal_attr.stderr +++ b/crates/bamltype/tests/ui/non_string_literal_attr.stderr @@ -2,4 +2,4 @@ error: expected string literal; hint: wrap the value in quotes --> tests/ui/non_string_literal_attr.rs:4:8 | 4 | #[baml(name = 123)] - | ^^^^ + | ^^^^^^^^^^ diff --git a/crates/dspy-rs/Cargo.toml b/crates/dspy-rs/Cargo.toml index 16f99ec..2d7f585 100644 --- a/crates/dspy-rs/Cargo.toml +++ b/crates/dspy-rs/Cargo.toml @@ -43,6 +43,7 @@ rig-core = { git = "https://github.com/0xPlaygrounds/rig", rev="e7849df" } enum_dispatch = "0.3.13" tracing = "0.1.44" tracing-subscriber = { version = "0.3.22", features = ["env-filter", "fmt"] } +minijinja = { git = "https://github.com/boundaryml/minijinja.git", branch = "main", default-features = false, features = ["builtins", "serde"] } [package.metadata.cargo-machete] ignored = ["rig-core"] diff --git a/crates/dspy-rs/src/adapter/chat.rs b/crates/dspy-rs/src/adapter/chat.rs index a05a928..16d6716 100644 --- a/crates/dspy-rs/src/adapter/chat.rs +++ b/crates/dspy-rs/src/adapter/chat.rs @@ -1,5 +1,6 @@ use anyhow::Result; use indexmap::IndexMap; +use minijinja::UndefinedBehavior; use regex::Regex; use rig::tool::ToolDyn; use serde_json::{Value, json}; @@ -17,8 +18,8 @@ use crate::serde_utils::get_iter_from_value; use crate::utils::cache::CacheEntry; use crate::{ BamlValue, Cache, Chat, ConstraintLevel, ConstraintResult, Example, FieldMeta, Flag, - JsonishError, LM, Message, MetaSignature, OutputFormatContent, ParseError, Prediction, - RenderOptions, Signature, TypeIR, + InputRenderSpec, JsonishError, LM, Message, MetaSignature, OutputFormatContent, ParseError, + Prediction, RenderOptions, Signature, TypeIR, }; #[derive(Default, Clone)] @@ -526,15 +527,19 @@ impl ChatAdapter { return String::new(); }; let input_output_format = ::baml_output_format(); + let input_json = build_input_context_value(fields, S::input_fields(), input_output_format); + let vars = Value::Object(serde_json::Map::new()); let mut result = String::new(); for field_spec in S::input_fields() { if let Some(value) = fields.get(field_spec.rust_name) { result.push_str(&format!("[[ ## {} ## ]]\n", field_spec.name)); - result.push_str(&format_baml_value_for_prompt_typed( + result.push_str(&render_input_field( + field_spec, value, + &input_json, input_output_format, - field_spec.format, + &vars, )); result.push_str("\n\n"); } @@ -880,23 +885,92 @@ fn format_baml_value_for_prompt(value: &BamlValue) -> String { } } -fn format_baml_value_for_prompt_typed( +fn render_input_field( + field_spec: &crate::FieldSpec, value: &BamlValue, + input: &Value, output_format: &OutputFormatContent, - format: Option<&str>, + vars: &Value, ) -> String { - let format = match format { - Some(format) => format, - None => { - if let BamlValue::String(s) = value { - return s.clone(); - } - "json" + match field_spec.input_render { + InputRenderSpec::Default => match value { + BamlValue::String(s) => s.clone(), + _ => crate::bamltype::internal_baml_jinja::format_baml_value( + value, + output_format, + "json", + ) + .unwrap_or_else(|_| "".to_string()), + }, + InputRenderSpec::Format(format) => { + crate::bamltype::internal_baml_jinja::format_baml_value(value, output_format, format) + .unwrap_or_else(|_| "".to_string()) + } + InputRenderSpec::Jinja(template) => { + render_input_field_jinja(template, field_spec, value, input, output_format, vars) + .unwrap_or_else(|_| "".to_string()) } + } +} + +fn render_input_field_jinja( + template: &str, + field_spec: &crate::FieldSpec, + value: &BamlValue, + input: &Value, + output_format: &OutputFormatContent, + vars: &Value, +) -> Result { + let mut env = minijinja::Environment::new(); + env.set_undefined_behavior(UndefinedBehavior::Strict); + env.add_template("__input_field__", template)?; + let template = env.get_template("__input_field__")?; + + let this = baml_value_to_render_json(value, output_format); + let field = json!({ + "name": field_spec.name, + "rust_name": field_spec.rust_name, + "type": (field_spec.type_ir)().diagnostic_repr().to_string(), + }); + let context = json!({ + "this": this, + "input": input, + "field": field, + "vars": vars, + }); + + template.render(minijinja::Value::from_serialize(context)) +} + +fn build_input_context_value( + fields: &crate::bamltype::baml_types::BamlMap, + field_specs: &[crate::FieldSpec], + output_format: &OutputFormatContent, +) -> Value { + let mut map = serde_json::Map::new(); + + for field_spec in field_specs { + let Some(value) = fields.get(field_spec.rust_name) else { + continue; + }; + let value_json = baml_value_to_render_json(value, output_format); + map.insert(field_spec.rust_name.to_string(), value_json.clone()); + if field_spec.name != field_spec.rust_name { + map.entry(field_spec.name.to_string()).or_insert(value_json); + } + } + + Value::Object(map) +} + +fn baml_value_to_render_json(value: &BamlValue, output_format: &OutputFormatContent) -> Value { + let Ok(rendered_json) = + crate::bamltype::internal_baml_jinja::format_baml_value(value, output_format, "json") + else { + return serde_json::to_value(value).unwrap_or(Value::Null); }; - crate::bamltype::internal_baml_jinja::format_baml_value(value, output_format, format) - .unwrap_or_else(|_| "".to_string()) + serde_json::from_str(&rendered_json).unwrap_or(Value::Null) } fn collect_flags_recursive(value: &BamlValueWithFlags, flags: &mut Vec) { diff --git a/crates/dspy-rs/src/core/signature.rs b/crates/dspy-rs/src/core/signature.rs index b91d4a1..b4981cd 100644 --- a/crates/dspy-rs/src/core/signature.rs +++ b/crates/dspy-rs/src/core/signature.rs @@ -2,6 +2,13 @@ use crate::{Example, OutputFormatContent, TypeIR}; use anyhow::Result; use serde_json::Value; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InputRenderSpec { + Default, + Format(&'static str), + Jinja(&'static str), +} + #[derive(Debug, Clone, Copy)] pub struct FieldSpec { pub name: &'static str, @@ -9,7 +16,7 @@ pub struct FieldSpec { pub description: &'static str, pub type_ir: fn() -> TypeIR, pub constraints: &'static [ConstraintSpec], - pub format: Option<&'static str>, + pub input_render: InputRenderSpec, } #[derive(Debug, Clone, Copy)] diff --git a/crates/dspy-rs/src/predictors/predict.rs b/crates/dspy-rs/src/predictors/predict.rs index cb280ca..169f97b 100644 --- a/crates/dspy-rs/src/predictors/predict.rs +++ b/crates/dspy-rs/src/predictors/predict.rs @@ -10,7 +10,7 @@ use tracing::{debug, trace}; use crate::adapter::Adapter; use crate::bamltype::baml_types::BamlMap; use crate::bamltype::compat::{BamlValueConvert, ToBamlValue}; -use crate::core::{FieldSpec, MetaSignature, Module, Optimizable, Signature}; +use crate::core::{FieldSpec, InputRenderSpec, MetaSignature, Module, Optimizable, Signature}; use crate::{ BamlValue, CallResult, Chat, ChatAdapter, Example, GLOBAL_SETTINGS, LM, LmError, LmUsage, PredictError, Prediction, @@ -258,8 +258,14 @@ fn field_specs_to_value(fields: &[FieldSpec], field_type: &'static str) -> Value meta.insert("desc".to_string(), json!(field.description)); meta.insert("schema".to_string(), json!("")); meta.insert("__dsrs_field_type".to_string(), json!(field_type)); - if let Some(format) = field.format { - meta.insert("format".to_string(), json!(format)); + match field.input_render { + InputRenderSpec::Default => {} + InputRenderSpec::Format(format) => { + meta.insert("format".to_string(), json!(format)); + } + InputRenderSpec::Jinja(template) => { + meta.insert("render".to_string(), json!({ "jinja": template })); + } } result.insert(field.rust_name.to_string(), Value::Object(meta)); } diff --git a/crates/dspy-rs/tests/test_input_format.rs b/crates/dspy-rs/tests/test_input_format.rs index 129ee47..e306bfa 100644 --- a/crates/dspy-rs/tests/test_input_format.rs +++ b/crates/dspy-rs/tests/test_input_format.rs @@ -62,6 +62,48 @@ struct DefaultFormatSig { answer: String, } +#[derive(Signature, Clone, Debug)] +/// Render a context field using Jinja. +struct RenderJinjaSig { + #[input] + question: String, + + #[input] + #[alias("ctx")] + #[render( + jinja = "{{ this.text }} | {{ input.question }} | {{ input.ctx.text }} | {{ input.context.text }} | {{ field.name }} | {{ field.rust_name }}" + )] + context: Document, + + #[output] + answer: String, +} + +#[derive(Signature, Clone, Debug)] +/// Render with strict undefined vars. +struct RenderJinjaStrictSig { + #[input] + #[render(jinja = "{{ missing_var }}")] + question: String, + + #[output] + answer: String, +} + +#[derive(Signature, Clone, Debug)] +/// Render using field metadata and vars context. +struct RenderJinjaFieldMetaSig { + #[input] + #[alias("ctx")] + #[render( + jinja = "{{ field.name }}|{{ field.rust_name }}|{{ field.type }}|{{ vars is defined }}" + )] + context: Document, + + #[output] + answer: String, +} + fn extract_field(message: &str, field_name: &str) -> String { let start_marker = format!("[[ ## {field_name} ## ]]"); let start_pos = message @@ -184,3 +226,55 @@ fn typed_input_default_non_string_is_json() { .expect("expected array with object"); assert_eq!(first.get("text").and_then(|v| v.as_str()), Some("Hello")); } + +#[test] +fn typed_input_render_jinja_uses_context_values() { + let adapter = ChatAdapter; + let input = RenderJinjaSigInput { + question: "Question".to_string(), + context: Document { + text: "Hello".to_string(), + }, + }; + + let message = adapter.format_user_message_typed::(&input); + let context_value = extract_field(&message, "ctx"); + + assert_eq!( + context_value, + "Hello | Question | Hello | Hello | ctx | context" + ); +} + +#[test] +fn typed_input_render_jinja_strict_undefined_returns_error_sentinel() { + let adapter = ChatAdapter; + let input = RenderJinjaStrictSigInput { + question: "Question".to_string(), + }; + + let message = adapter.format_user_message_typed::(&input); + let question_value = extract_field(&message, "question"); + + assert_eq!(question_value, ""); +} + +#[test] +fn typed_input_render_jinja_exposes_field_metadata_and_vars() { + let adapter = ChatAdapter; + let input = RenderJinjaFieldMetaSigInput { + context: Document { + text: "Hello".to_string(), + }, + }; + + let message = adapter.format_user_message_typed::(&input); + let context_value = extract_field(&message, "ctx"); + let parts: Vec<&str> = context_value.split('|').collect(); + + assert_eq!(parts.len(), 4); + assert_eq!(parts[0], "ctx"); + assert_eq!(parts[1], "context"); + assert!(parts[2].contains("Document")); + assert_eq!(parts[3].to_ascii_lowercase(), "true"); +} diff --git a/crates/dsrs-macros/Cargo.toml b/crates/dsrs-macros/Cargo.toml index 1c5de3e..4b66663 100644 --- a/crates/dsrs-macros/Cargo.toml +++ b/crates/dsrs-macros/Cargo.toml @@ -19,6 +19,7 @@ quote = "1" proc-macro2 = "1" proc-macro-crate = "3.2" serde_json = { version = "1.0.143", features = ["preserve_order"] } +minijinja = { git = "https://github.com/boundaryml/minijinja.git", branch = "main", default-features = false, features = ["serde"] } [dev-dependencies] dspy-rs = { path = "../dspy-rs" } diff --git a/crates/dsrs-macros/src/lib.rs b/crates/dsrs-macros/src/lib.rs index ae9ff9d..650d7ae 100644 --- a/crates/dsrs-macros/src/lib.rs +++ b/crates/dsrs-macros/src/lib.rs @@ -1,6 +1,7 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; use serde_json::{Value, json}; +use std::collections::HashMap; use syn::{ Attribute, Data, DeriveInput, Expr, ExprLit, Fields, Ident, Lit, LitStr, Meta, MetaNameValue, Token, Visibility, @@ -19,7 +20,10 @@ pub fn derive_optimizable(input: TokenStream) -> TokenStream { optim::optimizable_impl(input) } -#[proc_macro_derive(Signature, attributes(input, output, check, assert, alias, format))] +#[proc_macro_derive( + Signature, + attributes(input, output, check, assert, alias, format, render) +)] pub fn derive_signature(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let runtime = match resolve_dspy_rs_path() { @@ -69,10 +73,17 @@ struct ParsedField { is_output: bool, description: String, alias: Option, - format: Option, + input_render: ParsedInputRender, constraints: Vec, } +#[derive(Clone)] +enum ParsedInputRender { + Default, + Format(String), + Jinja(String), +} + #[derive(Clone, Copy, PartialEq, Eq)] enum ParsedConstraintKind { Check, @@ -174,6 +185,8 @@ fn parse_signature_fields( "#[derive(Signature)] requires at least one #[output] field", )); } + validate_unique_llm_names(&input_fields, "input")?; + validate_unique_llm_names(&output_fields, "output")?; Ok(ParsedSignature { input_fields, @@ -183,6 +196,25 @@ fn parse_signature_fields( }) } +fn validate_unique_llm_names(fields: &[ParsedField], kind: &str) -> syn::Result<()> { + let mut seen = HashMap::::new(); + + for field in fields { + let rust_name = field.ident.to_string(); + let llm_name = field.alias.as_deref().unwrap_or(&rust_name).to_string(); + if let Some(previous_rust_name) = seen.insert(llm_name.clone(), rust_name.clone()) { + return Err(syn::Error::new( + proc_macro2::Span::call_site(), + format!( + "duplicate {kind} field name `{llm_name}` after aliasing; conflicts between `{previous_rust_name}` and `{rust_name}`" + ), + )); + } + } + + Ok(()) +} + fn parse_single_field(field: &syn::Field) -> syn::Result { let ident = field.ident.clone().ok_or_else(|| { syn::Error::new_spanned(field, "#[derive(Signature)] requires named fields") @@ -192,6 +224,7 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { let mut is_output = false; let mut alias = None; let mut format = None; + let mut render_jinja = None; let mut constraints = Vec::new(); let mut desc_override = None; @@ -216,6 +249,16 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { )); } format = Some(parse_string_attr(attr, "format")?); + } else if attr.path().is_ident("render") { + if render_jinja.is_some() { + return Err(syn::Error::new_spanned( + attr, + "#[render] can only be specified once per field", + )); + } + let template = parse_render_jinja_attr(attr)?; + validate_jinja_template(&template, attr.span())?; + render_jinja = Some(template); } else if attr.path().is_ident("check") { constraints.push(parse_constraint_attr(attr, ParsedConstraintKind::Check)?); } else if attr.path().is_ident("assert") { @@ -223,16 +266,31 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { } } + if format.is_some() && render_jinja.is_some() { + return Err(syn::Error::new_spanned( + field, + "#[format] and #[render] cannot be combined on the same field", + )); + } + if format.is_some() && !is_input { return Err(syn::Error::new_spanned( field, "#[format] is only supported on #[input] fields", )); } + if render_jinja.is_some() && !is_input { + return Err(syn::Error::new_spanned( + field, + "#[render] is only supported on #[input] fields", + )); + } - if let Some(format_value) = format.as_deref() { + let input_render = if let Some(template) = render_jinja { + ParsedInputRender::Jinja(template) + } else if let Some(format_value) = format { match format_value.to_ascii_lowercase().as_str() { - "json" | "yaml" | "toon" => {} + "json" | "yaml" | "toon" => ParsedInputRender::Format(format_value), _ => { return Err(syn::Error::new_spanned( field, @@ -240,7 +298,9 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { )); } } - } + } else { + ParsedInputRender::Default + }; let doc_comment = collect_doc_comment(&field.attrs); let description = desc_override.unwrap_or(doc_comment); @@ -252,7 +312,7 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { is_output, description, alias, - format, + input_render, constraints, }) } @@ -304,6 +364,44 @@ fn parse_string_attr(attr: &Attribute, attr_name: &str) -> syn::Result { } } +fn parse_render_jinja_attr(attr: &Attribute) -> syn::Result { + match &attr.meta { + Meta::List(list) => { + let metas = list.parse_args_with( + syn::punctuated::Punctuated::::parse_terminated, + )?; + + if metas.len() != 1 { + return Err(syn::Error::new_spanned( + attr, + "expected #[render(jinja = \"...\")]", + )); + } + + match metas.first() { + Some(Meta::NameValue(meta)) if meta.path.is_ident("jinja") => { + parse_string_expr(&meta.value, meta.span()) + } + _ => Err(syn::Error::new_spanned( + attr, + "expected #[render(jinja = \"...\")]", + )), + } + } + _ => Err(syn::Error::new_spanned( + attr, + "expected #[render(jinja = \"...\")]", + )), + } +} + +fn validate_jinja_template(template: &str, span: proc_macro2::Span) -> syn::Result<()> { + let mut env = minijinja::Environment::new(); + env.add_template("__input_field__", template) + .map_err(|_| syn::Error::new(span, "invalid Jinja syntax in #[render(jinja = \"...\")]"))?; + Ok(()) +} + fn parse_constraint_attr( attr: &Attribute, kind: ParsedConstraintKind, @@ -463,12 +561,16 @@ fn generate_field_specs( let llm_name = LitStr::new(llm_name, proc_macro2::Span::call_site()); let rust_name = LitStr::new(&field_name, proc_macro2::Span::call_site()); let description = LitStr::new(&field.description, proc_macro2::Span::call_site()); - let format = match &field.format { - Some(value) => { + let input_render = match &field.input_render { + ParsedInputRender::Default => quote! { #runtime::InputRenderSpec::Default }, + ParsedInputRender::Format(value) => { + let lit = LitStr::new(value, proc_macro2::Span::call_site()); + quote! { #runtime::InputRenderSpec::Format(#lit) } + } + ParsedInputRender::Jinja(value) => { let lit = LitStr::new(value, proc_macro2::Span::call_site()); - quote! { Some(#lit) } + quote! { #runtime::InputRenderSpec::Jinja(#lit) } } - None => quote! { None }, }; let type_ir_fn_name = format_ident!("__{}_{}_type_ir", prefix, field_name_ident); @@ -554,7 +656,7 @@ fn generate_field_specs( description: #description, type_ir: #type_ir_fn_name, constraints: #constraints_name, - format: #format, + input_render: #input_render, } }); } diff --git a/crates/dsrs-macros/tests/signature_derive.rs b/crates/dsrs-macros/tests/signature_derive.rs index 27b8155..8f9dc37 100644 --- a/crates/dsrs-macros/tests/signature_derive.rs +++ b/crates/dsrs-macros/tests/signature_derive.rs @@ -1,4 +1,4 @@ -use dspy_rs::Signature as SignatureTrait; +use dspy_rs::{InputRenderSpec, Signature as SignatureTrait}; /// Test instruction #[derive(dsrs_macros::Signature)] @@ -23,6 +23,20 @@ struct NormalizedConstraintSig { score: f64, } +#[derive(dsrs_macros::Signature)] +struct RenderSpecSig { + #[input] + #[render(jinja = "{{ this }}")] + template_input: String, + + #[input] + #[format("yaml")] + yaml_input: String, + + #[output] + answer: String, +} + #[test] fn test_generates_input_struct() { let input = TestSigInput { @@ -81,3 +95,25 @@ fn test_constraint_operator_normalization() { "this >= 0.0 and this <= 1.0" ); } + +#[test] +fn test_input_render_spec_generation() { + let input_fields = ::input_fields(); + assert_eq!(input_fields.len(), 2); + + assert_eq!( + input_fields + .iter() + .find(|field| field.rust_name == "template_input") + .map(|field| field.input_render), + Some(InputRenderSpec::Jinja("{{ this }}")) + ); + + assert_eq!( + input_fields + .iter() + .find(|field| field.rust_name == "yaml_input") + .map(|field| field.input_render), + Some(InputRenderSpec::Format("yaml")) + ); +} diff --git a/crates/dsrs-macros/tests/ui/alias_conflict.rs b/crates/dsrs-macros/tests/ui/alias_conflict.rs new file mode 100644 index 0000000..f0edb34 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/alias_conflict.rs @@ -0,0 +1,16 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct AliasConflict { + #[input] + first: String, + + #[input] + #[alias("first")] + second: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/alias_conflict.stderr b/crates/dsrs-macros/tests/ui/alias_conflict.stderr new file mode 100644 index 0000000..e7ecf82 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/alias_conflict.stderr @@ -0,0 +1,7 @@ +error: duplicate input field name `first` after aliasing; conflicts between `first` and `second` + --> tests/ui/alias_conflict.rs:3:10 + | +3 | #[derive(Signature)] + | ^^^^^^^^^ + | + = note: this error originates in the derive macro `Signature` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/crates/dsrs-macros/tests/ui/format_render_conflict.rs b/crates/dsrs-macros/tests/ui/format_render_conflict.rs new file mode 100644 index 0000000..118e550 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/format_render_conflict.rs @@ -0,0 +1,14 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct FormatRenderConflict { + #[input] + #[format("json")] + #[render(jinja = "{{ this }}")] + context: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/format_render_conflict.stderr b/crates/dsrs-macros/tests/ui/format_render_conflict.stderr new file mode 100644 index 0000000..4930485 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/format_render_conflict.stderr @@ -0,0 +1,8 @@ +error: #[format] and #[render] cannot be combined on the same field + --> tests/ui/format_render_conflict.rs:5:5 + | +5 | / #[input] +6 | | #[format("json")] +7 | | #[render(jinja = "{{ this }}")] +8 | | context: String, + | |___________________^ diff --git a/crates/dsrs-macros/tests/ui/render_duplicate.rs b/crates/dsrs-macros/tests/ui/render_duplicate.rs new file mode 100644 index 0000000..372f7a5 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_duplicate.rs @@ -0,0 +1,14 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct RenderDuplicate { + #[input] + #[render(jinja = "{{ this }}")] + #[render(jinja = "{{ this }}")] + context: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/render_duplicate.stderr b/crates/dsrs-macros/tests/ui/render_duplicate.stderr new file mode 100644 index 0000000..0d31e5f --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_duplicate.stderr @@ -0,0 +1,5 @@ +error: #[render] can only be specified once per field + --> tests/ui/render_duplicate.rs:7:5 + | +7 | #[render(jinja = "{{ this }}")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/crates/dsrs-macros/tests/ui/render_invalid_jinja.rs b/crates/dsrs-macros/tests/ui/render_invalid_jinja.rs new file mode 100644 index 0000000..3bf3548 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_invalid_jinja.rs @@ -0,0 +1,13 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct RenderInvalidJinja { + #[input] + #[render(jinja = "{{ this ")] + context: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/render_invalid_jinja.stderr b/crates/dsrs-macros/tests/ui/render_invalid_jinja.stderr new file mode 100644 index 0000000..bf02274 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_invalid_jinja.stderr @@ -0,0 +1,5 @@ +error: invalid Jinja syntax in #[render(jinja = "...")] + --> tests/ui/render_invalid_jinja.rs:6:5 + | +6 | #[render(jinja = "{{ this ")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/crates/dsrs-macros/tests/ui/render_invalid_key.rs b/crates/dsrs-macros/tests/ui/render_invalid_key.rs new file mode 100644 index 0000000..2c3faf1 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_invalid_key.rs @@ -0,0 +1,13 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct RenderInvalidKey { + #[input] + #[render(template = "{{ this }}")] + context: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/render_invalid_key.stderr b/crates/dsrs-macros/tests/ui/render_invalid_key.stderr new file mode 100644 index 0000000..c599c0a --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_invalid_key.stderr @@ -0,0 +1,5 @@ +error: expected #[render(jinja = "...")] + --> tests/ui/render_invalid_key.rs:6:5 + | +6 | #[render(template = "{{ this }}")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/crates/dsrs-macros/tests/ui/render_non_literal.rs b/crates/dsrs-macros/tests/ui/render_non_literal.rs new file mode 100644 index 0000000..159bf71 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_non_literal.rs @@ -0,0 +1,15 @@ +use dsrs_macros::Signature; + +const TEMPLATE: &str = "{{ this }}"; + +#[derive(Signature)] +struct RenderNonLiteral { + #[input] + #[render(jinja = TEMPLATE)] + context: String, + + #[output] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/render_non_literal.stderr b/crates/dsrs-macros/tests/ui/render_non_literal.stderr new file mode 100644 index 0000000..b47f5d1 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_non_literal.stderr @@ -0,0 +1,5 @@ +error: expected string literal; hint: wrap the value in quotes + --> tests/ui/render_non_literal.rs:8:14 + | +8 | #[render(jinja = TEMPLATE)] + | ^^^^^^^^^^^^^^^^ diff --git a/crates/dsrs-macros/tests/ui/render_on_output.rs b/crates/dsrs-macros/tests/ui/render_on_output.rs new file mode 100644 index 0000000..117cb24 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_on_output.rs @@ -0,0 +1,13 @@ +use dsrs_macros::Signature; + +#[derive(Signature)] +struct RenderOnOutput { + #[input] + question: String, + + #[output] + #[render(jinja = "{{ this }}")] + answer: String, +} + +fn main() {} diff --git a/crates/dsrs-macros/tests/ui/render_on_output.stderr b/crates/dsrs-macros/tests/ui/render_on_output.stderr new file mode 100644 index 0000000..02d30b4 --- /dev/null +++ b/crates/dsrs-macros/tests/ui/render_on_output.stderr @@ -0,0 +1,7 @@ +error: #[render] is only supported on #[input] fields + --> tests/ui/render_on_output.rs:8:5 + | + 8 | / #[output] + 9 | | #[render(jinja = "{{ this }}")] +10 | | answer: String, + | |__________________^ diff --git a/docs/docs/building-blocks/adapter.mdx b/docs/docs/building-blocks/adapter.mdx index 7fb62e0..9f75708 100644 --- a/docs/docs/building-blocks/adapter.mdx +++ b/docs/docs/building-blocks/adapter.mdx @@ -148,7 +148,7 @@ This is useful for understanding what the LM sees. ## Input formatting options -The `#[format]` attribute on input fields controls serialization: +Input fields support two rendering paths: ```rust #[derive(Signature, Clone, Debug)] @@ -160,12 +160,27 @@ struct Search { #[format("yaml")] filters: Vec, // serialized as YAML + #[input] + #[render(jinja = "{{ this.text }}\nQuestion: {{ input.query }}")] + context: Context, // custom Jinja rendering + #[output] results: Vec, } ``` -Options: `"json"` (default for complex types), `"yaml"`, `"toon"` +- `#[format("json" | "yaml" | "toon")]`: serialize a field using that format. +- `#[render(jinja = "...")]`: render a field with a MiniJinja template. +- `#[format]` and `#[render]` are mutually exclusive on the same field. + +`#[render]` context in `ChatAdapter`: +- `this`: current field value +- `input`: full input object (rust names and aliases) +- `field`: `{ name, rust_name, type }` +- `vars`: adapter/surface vars (currently `{}` in `ChatAdapter`) + +`ChatAdapter` configures MiniJinja with strict undefined behavior. +If template rendering fails at runtime (for example, missing variables), the field renders as `` so prompt construction still completes. ## Real example: Insurance claim extraction diff --git a/docs/docs/building-blocks/signature.mdx b/docs/docs/building-blocks/signature.mdx index f06d61e..274add1 100644 --- a/docs/docs/building-blocks/signature.mdx +++ b/docs/docs/building-blocks/signature.mdx @@ -172,6 +172,28 @@ context: Vec, Options: `"json"`, `"yaml"`, `"toon"` +### `#[render]` - Custom input rendering (Jinja) + +```rust +#[input] +#[render(jinja = "{{ this.title }}\n{{ this.body | truncate(300) }}")] +ticket: Ticket, +``` + +Use this when you need custom text rendering for an input field. + +- Only valid on `#[input]` fields +- Template must be a string literal +- Jinja syntax is validated at compile time +- Cannot be combined with `#[format]` on the same field +- Runtime uses strict undefined behavior; missing variables render as `` + +Template context: +- `this` - Current field value as JSON-like data +- `input` - Full input object (available via rust field names and aliases) +- `field` - Metadata object with `name`, `rust_name`, and `type` +- `vars` - Adapter/surface variables (currently empty in `ChatAdapter`) + ### `#[check]` - Soft constraint ```rust diff --git a/docs/docs/getting-started/quickstart.mdx b/docs/docs/getting-started/quickstart.mdx index 19e2110..e976268 100644 --- a/docs/docs/getting-started/quickstart.mdx +++ b/docs/docs/getting-started/quickstart.mdx @@ -222,6 +222,11 @@ struct Rating { } ``` +### Input formatting and rendering + +Use `#[format("json" | "yaml" | "toon")]` for serialization, or `#[render(jinja = "...")]` for custom field text. +See the full attribute reference in [Signatures](/docs/building-blocks/signature) and runtime behavior in [Adapter](/docs/building-blocks/adapter). + ### Multi-step pipelines Compose [modules](/docs/building-blocks/module) for complex workflows: