diff --git a/engine/Cargo.lock b/engine/Cargo.lock index f257dff47..28c00c32d 100644 --- a/engine/Cargo.lock +++ b/engine/Cargo.lock @@ -2345,6 +2345,7 @@ dependencies = [ "env_logger", "indexmap 2.2.6", "internal-baml-core", + "itertools 0.13.0", "log", "pathdiff", "semver", @@ -2413,6 +2414,7 @@ dependencies = [ "indexmap 2.2.6", "log", "minijinja", + "regex", "serde", "serde_json", "strsim 0.11.1", @@ -2509,6 +2511,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/scope_diagnostics.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/scope_diagnostics.rs index 0f8182874..a88db813a 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/scope_diagnostics.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/scope_diagnostics.rs @@ -148,4 +148,8 @@ impl ScopeStack { pub fn push_error(&mut self, error: String) { self.scopes.last_mut().unwrap().errors.push(error); } + + pub fn push_warning(&mut self, warning: String) { + self.scopes.last_mut().unwrap().warnings.push(warning); + } } diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index 3192854ba..1505e6f2f 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -1,10 +1,13 @@ -use baml_types::{BamlMap, BamlMediaType, BamlValue, FieldType, LiteralValue, TypeValue}; +use baml_types::{ + BamlMap, BamlValue, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue +}; use core::result::Result; use std::path::PathBuf; use crate::ir::IntermediateRepr; use super::{scope_diagnostics::ScopeStack, IRHelper}; +use internal_baml_jinja::evaluate_predicate; #[derive(Default)] pub struct ParameterError { @@ -325,6 +328,36 @@ impl ArgCoercer { } } } + FieldType::Constrained { base, constraints } => { + let val = self.coerce_arg(ir, base, value, scope)?; + for c@Constraint { + level, + expression, + label, + } in constraints.iter() + { + let constraint_ok = + evaluate_predicate(&val, &expression).unwrap_or_else(|err| { + scope.push_error(format!( + "Error while evaluating check {c:?}: {:?}", + err + )); + false + }); + if !constraint_ok { + let msg = label.as_ref().unwrap_or(&expression.0); + match level { + ConstraintLevel::Check => { + scope.push_warning(format!("Failed check: {msg}")); + } + ConstraintLevel::Assert => { + scope.push_error(format!("Failed assert: {msg}")); + } + } + } + } + Ok(val) + } } } } diff --git a/engine/baml-lib/baml-core/src/ir/json_schema.rs b/engine/baml-lib/baml-core/src/ir/json_schema.rs index c2bd4d99b..6d218651b 100644 --- a/engine/baml-lib/baml-core/src/ir/json_schema.rs +++ b/engine/baml-lib/baml-core/src/ir/json_schema.rs @@ -236,6 +236,7 @@ impl<'db> WithJsonSchema for FieldType { } } } + FieldType::Constrained { base, .. } => base.json_schema(), } } } diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index d4ad06dac..8ce268942 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; -use anyhow::{anyhow, Context, Result}; -use baml_types::FieldType; +use anyhow::{anyhow, Result}; +use baml_types::{Constraint, ConstraintLevel, FieldType}; use either::Either; use indexmap::IndexMap; use internal_baml_parser_database::{ @@ -13,6 +13,7 @@ use internal_baml_parser_database::{ }; use internal_baml_schema_ast::ast::SubType; +use baml_types::JinjaExpression; use internal_baml_schema_ast::ast::{self, FieldArity, WithName, WithSpan}; use serde::Serialize; @@ -197,6 +198,8 @@ pub struct NodeAttributes { #[serde(with = "indexmap::map::serde_seq")] meta: IndexMap, + constraints: Vec, + // Spans #[serde(skip)] pub span: Option, @@ -208,39 +211,69 @@ impl NodeAttributes { } } -fn to_ir_attributes( - db: &ParserDatabase, - maybe_ast_attributes: Option<&Attributes>, -) -> IndexMap { - let mut attributes = IndexMap::new(); - - if let Some(Attributes { - description, - alias, - dynamic_type, - skip, - }) = maybe_ast_attributes - { - if let Some(true) = dynamic_type { - attributes.insert("dynamic_type".to_string(), Expression::Bool(true)); - } - if let Some(v) = alias { - attributes.insert("alias".to_string(), Expression::String(db[*v].to_string())); - } - if let Some(d) = description { - let ir_expr = match d { - ast::Expression::StringValue(s, _) => Expression::String(s.clone()), - ast::Expression::RawStringValue(s) => Expression::RawString(s.value().to_string()), - _ => panic!("Couldn't deal with description: {:?}", d), - }; - attributes.insert("description".to_string(), ir_expr); - } - if let Some(true) = skip { - attributes.insert("skip".to_string(), Expression::Bool(true)); +impl Default for NodeAttributes { + fn default() -> Self { + NodeAttributes { + meta: IndexMap::new(), + constraints: Vec::new(), + span: None, } } +} - attributes +fn to_ir_attributes( + db: &ParserDatabase, + maybe_ast_attributes: Option<&Attributes>, +) -> (IndexMap, Vec) { + let null_result = (IndexMap::new(), Vec::new()); + maybe_ast_attributes.map_or(null_result, |attributes| { + let Attributes { + description, + alias, + dynamic_type, + skip, + constraints, + } = attributes; + let description = description.as_ref().and_then(|d| { + let name = "description".to_string(); + match d { + ast::Expression::StringValue(s, _) => Some((name, Expression::String(s.clone()))), + ast::Expression::RawStringValue(s) => { + Some((name, Expression::RawString(s.value().to_string()))) + } + ast::Expression::JinjaExpressionValue(j, _) => { + Some((name, Expression::JinjaExpression(j.clone()))) + } + _ => { + eprintln!("Warning, encountered an unexpected description attribute"); + None + } + } + }); + let alias = alias + .as_ref() + .map(|v| ("alias".to_string(), Expression::String(db[*v].to_string()))); + let dynamic_type = dynamic_type.as_ref().and_then(|v| { + if *v { + Some(("dynamic_type".to_string(), Expression::Bool(true))) + } else { + None + } + }); + let skip = skip.as_ref().and_then(|v| { + if *v { + Some(("skip".to_string(), Expression::Bool(true))) + } else { + None + } + }); + + let meta = vec![description, alias, dynamic_type, skip] + .into_iter() + .filter_map(|s| s) + .collect(); + (meta, constraints.clone()) + }) } /// Nodes allow attaching metadata to a given IR entity: attributes, source location, etc @@ -256,6 +289,7 @@ pub trait WithRepr { fn attributes(&self, _: &ParserDatabase) -> NodeAttributes { NodeAttributes { meta: IndexMap::new(), + constraints: Vec::new(), span: None, } } @@ -278,8 +312,46 @@ fn type_with_arity(t: FieldType, arity: &FieldArity) -> FieldType { } impl WithRepr for ast::FieldType { + + // TODO: (Greg) This code only extracts constraints, and ignores any + // other types of attributes attached to the type directly. + fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes { + let constraints = self + .attributes() + .iter() + .filter_map(|attr| { + let level = match attr.name.to_string().as_str() { + "assert" => Some(ConstraintLevel::Assert), + "check" => Some(ConstraintLevel::Check), + _ => None + }?; + let (expression, label) = match attr.arguments.arguments.as_slice() { + [arg1, arg2] => match (arg1.clone().value, arg2.clone().value) { + (ast::Expression::JinjaExpressionValue(j,_), ast::Expression::Identifier(ast::Identifier::Local(s, _))) => Some((j,Some(s))), + _ => None + }, + [arg1] => match arg1.clone().value { + ast::Expression::JinjaExpressionValue(JinjaExpression(j),_) => Some((JinjaExpression(j.clone()),None)), + _ => None + } + _ => None, + }?; + Some(Constraint{ level, expression, label }) + }) + .collect::>(); + let attributes = NodeAttributes { + meta: IndexMap::new(), + constraints, + span: Some(self.span().clone()), + }; + + attributes + } + fn repr(&self, db: &ParserDatabase) -> Result { - Ok(match self { + let constraints = WithRepr::attributes(self, db).constraints; + let has_constraints = constraints.len() > 0; + let base = match self { ast::FieldType::Primitive(arity, typeval, ..) => { let repr = FieldType::Primitive(typeval.clone()); if arity.is_optional() { @@ -347,7 +419,14 @@ impl WithRepr for ast::FieldType { FieldType::Tuple(t.iter().map(|ft| ft.repr(db)).collect::>>()?), arity, ), - }) + }; + + let with_constraints = if has_constraints { + FieldType::Constrained { base: Box::new(base.clone()), constraints } + } else { + base + }; + Ok(with_constraints) } } @@ -384,6 +463,7 @@ pub enum Expression { RawString(String), List(Vec), Map(Vec<(Expression, Expression)>), + JinjaExpression(JinjaExpression), } impl Expression { @@ -411,6 +491,9 @@ impl WithRepr for ast::Expression { ast::Expression::NumericValue(val, _) => Expression::Numeric(val.clone()), ast::Expression::StringValue(val, _) => Expression::String(val.clone()), ast::Expression::RawStringValue(val) => Expression::RawString(val.value().to_string()), + ast::Expression::JinjaExpressionValue(val, _) => { + Expression::JinjaExpression(val.clone()) + } ast::Expression::Identifier(idn) => match idn { ast::Identifier::ENV(k, _) => { Ok(Expression::Identifier(Identifier::ENV(k.clone()))) @@ -459,7 +542,7 @@ impl WithRepr for TemplateStringWalker<'_> { fn attributes(&self, _: &ParserDatabase) -> NodeAttributes { NodeAttributes { meta: Default::default(), - + constraints: Vec::new(), span: Some(self.span().clone()), } } @@ -480,7 +563,6 @@ impl WithRepr for TemplateStringWalker<'_> { .ok() }) .collect::>(), - _ => vec![], }), content: self.template_string().to_string(), }) @@ -499,8 +581,10 @@ pub struct Enum { impl WithRepr for EnumValueWalker<'_> { fn attributes(&self, db: &ParserDatabase) -> NodeAttributes { + let (meta, constraints) = to_ir_attributes(db, self.get_default_attributes()); let attributes = NodeAttributes { - meta: to_ir_attributes(db, self.get_default_attributes()), + meta, + constraints, span: Some(self.span().clone()), }; @@ -514,8 +598,10 @@ impl WithRepr for EnumValueWalker<'_> { impl WithRepr for EnumWalker<'_> { fn attributes(&self, db: &ParserDatabase) -> NodeAttributes { + let (meta, constraints) = to_ir_attributes(db, self.get_default_attributes(SubType::Enum)); let attributes = NodeAttributes { - meta: to_ir_attributes(db, self.get_default_attributes(SubType::Enum)), + meta, + constraints, span: Some(self.span().clone()), }; @@ -541,8 +627,10 @@ pub struct Field { impl WithRepr for FieldWalker<'_> { fn attributes(&self, db: &ParserDatabase) -> NodeAttributes { + let (meta, constraints) = to_ir_attributes(db, self.get_default_attributes()); let attributes = NodeAttributes { - meta: to_ir_attributes(db, self.get_default_attributes()), + meta, + constraints, span: Some(self.span().clone()), }; @@ -570,18 +658,26 @@ impl WithRepr for FieldWalker<'_> { type ClassId = String; +/// A BAML Class. #[derive(serde::Serialize, Debug)] pub struct Class { + /// User defined class name. pub name: ClassId, + + /// Fields of the class. pub static_fields: Vec>, + + /// Parameters to the class definition. pub inputs: Vec<(String, FieldType)>, } impl WithRepr for ClassWalker<'_> { fn attributes(&self, db: &ParserDatabase) -> NodeAttributes { let default_attributes = self.get_default_attributes(SubType::Class); + let (meta, constraints) = to_ir_attributes(db, default_attributes); let attributes = NodeAttributes { - meta: to_ir_attributes(db, default_attributes), + meta, + constraints, span: Some(self.span().clone()), }; @@ -799,6 +895,7 @@ impl WithRepr for FunctionWalker<'_> { fn attributes(&self, _: &ParserDatabase) -> NodeAttributes { NodeAttributes { meta: Default::default(), + constraints: Vec::new(), span: Some(self.span().clone()), } } @@ -855,6 +952,7 @@ impl WithRepr for ClientWalker<'_> { fn attributes(&self, _: &ParserDatabase) -> NodeAttributes { NodeAttributes { meta: IndexMap::new(), + constraints: Vec::new(), span: Some(self.span().clone()), } } @@ -895,6 +993,7 @@ impl WithRepr for ConfigurationWalker<'_> { fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes { NodeAttributes { meta: IndexMap::new(), + constraints: Vec::new(), span: Some(self.span().clone()), } } @@ -936,12 +1035,12 @@ impl WithRepr for (&ConfigurationWalker<'_>, usize) { let span = self.0.test_case().functions[self.1].1.clone(); NodeAttributes { meta: IndexMap::new(), - + constraints: Vec::new(), span: Some(span), } } - fn repr(&self, db: &ParserDatabase) -> Result { + fn repr(&self, _db: &ParserDatabase) -> Result { Ok(TestCaseFunction( self.0.test_case().functions[self.1].0.clone(), )) @@ -953,6 +1052,7 @@ impl WithRepr for ConfigurationWalker<'_> { NodeAttributes { meta: IndexMap::new(), span: Some(self.span().clone()), + constraints: Vec::new(), } } @@ -1008,3 +1108,22 @@ impl WithRepr for PromptAst<'_> { }) } } + +/// Generate an IntermediateRepr from a single block of BAML source code. +/// This is useful for generating IR test fixtures. +pub fn make_test_ir(source_code: &str) -> anyhow::Result { + use std::path::PathBuf; + use internal_baml_diagnostics::SourceFile; + use crate::ValidatedSchema; + use crate::validate; + + let path: PathBuf = "fake_file.baml".into(); + let source_file: SourceFile = (path.clone(), source_code).into(); + let validated_schema: ValidatedSchema = validate(&path, vec![source_file]); + let diagnostics = &validated_schema.diagnostics; + if diagnostics.has_errors() { + return Err(anyhow::anyhow!("Source code was invalid: \n{:?}", diagnostics.errors())) + } + let ir = IntermediateRepr::from_parser_database(&validated_schema.db, validated_schema.configuration)?; + Ok(ir) +} diff --git a/engine/baml-lib/baml-core/src/ir/walker.rs b/engine/baml-lib/baml-core/src/ir/walker.rs index 9810201c5..709517e7d 100644 --- a/engine/baml-lib/baml-core/src/ir/walker.rs +++ b/engine/baml-lib/baml-core/src/ir/walker.rs @@ -2,6 +2,7 @@ use anyhow::Result; use baml_types::BamlValue; use indexmap::IndexMap; +use internal_baml_jinja::render_expression; use internal_baml_parser_database::RetryPolicyStrategy; use std::collections::HashMap; @@ -214,6 +215,15 @@ impl Expression { anyhow::bail!("Invalid numeric value: {}", n) } } + Expression::JinjaExpression(expr) => { + // TODO: do not coerce all context values to strings. + let jinja_context: HashMap = env_values + .iter() + .map(|(k, v)| (k.clone(), BamlValue::String(v.clone()))) + .collect(); + let res_string = render_expression(&expr, &jinja_context)?; + Ok(BamlValue::String(res_string)) + } } } } @@ -407,7 +417,13 @@ impl<'a> Walker<'a, &'a Field> { self.item .attributes .get("description") - .map(|v| v.as_string_value(env_values)) + .map(|v| { + let normalized = v.normalize(env_values)?; + let baml_value = normalized + .as_str() + .ok_or(anyhow::anyhow!("Unexpected: Evaluated to non-string value"))?; + Ok(String::from(baml_value)) + }) .transpose() } @@ -415,3 +431,22 @@ impl<'a> Walker<'a, &'a Field> { self.item.attributes.span.as_ref() } } + +#[cfg(test)] +mod tests { + use super::*; + use baml_types::JinjaExpression; + + #[test] + fn basic_jinja_normalization() { + let expr = Expression::JinjaExpression(JinjaExpression("this == 'hello'".to_string())); + let env = vec![("this".to_string(), "hello".to_string())] + .into_iter() + .collect(); + let normalized = expr.normalize(&env).unwrap(); + match normalized { + BamlValue::String(s) => assert_eq!(&s, "true"), + _ => panic!("Expected String Expression"), + } + } +} diff --git a/engine/baml-lib/baml-core/src/lib.rs b/engine/baml-lib/baml-core/src/lib.rs index d14c4772f..d15eee019 100644 --- a/engine/baml-lib/baml-core/src/lib.rs +++ b/engine/baml-lib/baml-core/src/lib.rs @@ -41,7 +41,7 @@ impl std::fmt::Debug for ValidatedSchema { } } -/// The most general API for dealing with Prisma schemas. It accumulates what analysis and +/// The most general API for dealing with BAML source code. It accumulates what analysis and /// validation information it can, and returns it along with any error and warning diagnostics. pub fn validate(root_path: &PathBuf, files: Vec) -> ValidatedSchema { let mut diagnostics = Diagnostics::new(root_path.clone()); diff --git a/engine/baml-lib/baml-core/src/validate/generator_loader/v1.rs b/engine/baml-lib/baml-core/src/validate/generator_loader/v1.rs index 37e8b2f65..a75162253 100644 --- a/engine/baml-lib/baml-core/src/validate/generator_loader/v1.rs +++ b/engine/baml-lib/baml-core/src/validate/generator_loader/v1.rs @@ -138,21 +138,21 @@ pub(crate) fn parse_generator( }; match parse_required_key(&args, "test_command", ast_generator.span()) { - Ok(name) => (), + Ok(_name) => (), Err(err) => { errors.push(err); } }; match parse_required_key(&args, "install_command", ast_generator.span()) { - Ok(name) => (), + Ok(_name) => (), Err(err) => { errors.push(err); } }; match parse_required_key(&args, "package_version_command", ast_generator.span()) { - Ok(name) => (), + Ok(_name) => (), Err(err) => { errors.push(err); } diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs index 8d95c67ee..c3d262f13 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs @@ -1,8 +1,9 @@ -use crate::validate::validation_pipeline::context::Context; +use crate::{validate::validation_pipeline::context::Context}; +use either::Either; use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span}; -use internal_baml_schema_ast::ast::{WithIdentifier, WithName, WithSpan}; +use internal_baml_schema_ast::ast::{FieldType, WithIdentifier, WithName, WithSpan}; use super::types::validate_type; @@ -73,7 +74,18 @@ pub(super) fn validate(ctx: &mut Context<'_>) { for func in ctx.db.walk_functions() { for args in func.walk_input_args().chain(func.walk_output_args()) { let arg = args.ast_arg(); - validate_type(ctx, &arg.1.field_type) + validate_type(ctx, &arg.1.field_type); + } + + for args in func.walk_input_args() { + let arg = args.ast_arg(); + let field_type = &arg.1.field_type; + + let span = field_type.span().clone(); + if has_checks_nested(ctx, field_type) { + ctx.push_error(DatamodelError::new_validation_error("Types with checks are not allowed as function parameters.", span)); + } + } // Ensure the client is correct. @@ -158,3 +170,32 @@ pub(super) fn validate(ctx: &mut Context<'_>) { defined_types.errors_mut().clear(); } } + +/// Recusively search for `check` attributes in a field type and all of its +/// composed children. +fn has_checks_nested(ctx: &Context<'_>, field_type: &FieldType) -> bool { + if field_type.has_checks() { + return true; + } + + match field_type { + FieldType::Symbol(_, id, ..) => { + match ctx.db.find_type(id) { + Some(Either::Left(class_walker)) => { + let mut fields = class_walker.static_fields(); + fields.any(|field| field.ast_field().expr.as_ref().map_or(false, |ft| has_checks_nested(ctx, &ft))) + } + , + _ => false, + } + }, + + FieldType::Primitive(..) => false, + FieldType::Union(_, children, ..) => children.iter().any(|ft| has_checks_nested(ctx, ft)), + FieldType::Literal(..) => false, + FieldType::Tuple(_, children, ..) => children.iter().any(|ft| has_checks_nested(ctx, ft)), + FieldType::List(_, child, ..) => has_checks_nested(ctx, child), + FieldType::Map(_, kv, ..) => + has_checks_nested(ctx, &kv.as_ref().0) || has_checks_nested(ctx, &kv.as_ref().1), + } +} diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs index cc26b9ed0..17c892664 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs @@ -1,6 +1,6 @@ use baml_types::TypeValue; use internal_baml_diagnostics::DatamodelError; -use internal_baml_schema_ast::ast::{FieldArity, FieldType, Identifier, WithName, WithSpan}; +use internal_baml_schema_ast::ast::{Argument, Attribute, Expression, FieldArity, FieldType, Identifier, WithName, WithSpan}; use crate::validate::validation_pipeline::context::Context; @@ -15,13 +15,15 @@ fn errors_with_names<'a>(ctx: &'a mut Context<'_>, idn: &Identifier) { /// Called for each type in the baml_src tree, validates that it is well-formed. /// -/// Operates in two passes: +/// Operates in three passes: /// -/// 1. Verify that the type is resolveable (for REF types) -/// 2. Verify that the type is well-formed/allowed in the language +/// 1. Verify that the type is resolveable (for REF types). +/// 2. Verify that the type is well-formed/allowed in the language. +/// 3. Verify that constraints on the type are well-formed. pub(crate) fn validate_type(ctx: &mut Context<'_>, field_type: &FieldType) { validate_type_exists(ctx, field_type); validate_type_allowed(ctx, field_type); + validate_type_constraints(ctx, field_type); } fn validate_type_exists(ctx: &mut Context<'_>, field_type: &FieldType) -> bool { @@ -46,7 +48,7 @@ fn validate_type_exists(ctx: &mut Context<'_>, field_type: &FieldType) -> bool { fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) { match field_type { FieldType::Map(arity, kv_types, ..) => { - if (arity.is_optional()) { + if arity.is_optional() { ctx.push_error(DatamodelError::new_validation_error( format!("Maps are not allowed to be optional").as_str(), field_type.span().clone(), @@ -70,7 +72,7 @@ fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) { FieldType::Symbol(..) => {} FieldType::List(arity, field_type, ..) => { - if (arity.is_optional()) { + if arity.is_optional() { ctx.push_error(DatamodelError::new_validation_error( format!("Lists are not allowed to be optional").as_str(), field_type.span().clone(), @@ -85,3 +87,30 @@ fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) { } } } + +fn validate_type_constraints(ctx: &mut Context<'_>, field_type: &FieldType) { + let constraint_attrs = field_type.attributes().iter().filter(|attr| ["assert", "check"].contains(&attr.name.name())).collect::>(); + for Attribute { arguments, span, name, .. } in constraint_attrs.iter() { + let arg_expressions = arguments.arguments.iter().map(|Argument{value,..}| value).collect::>(); + + match arg_expressions.as_slice() { + [Expression::JinjaExpressionValue(_, _), Expression::Identifier(Identifier::Local(s,_))] => { + // Ok. + }, + [Expression::JinjaExpressionValue(_, _)] => { + if name.to_string() == "check" { + ctx.push_error(DatamodelError::new_validation_error( + "Check constraints must have a name.", + span.clone() + )) + } + }, + _ => { + ctx.push_error(DatamodelError::new_validation_error( + "A constraint must have one Jinja argument such as {{ expr }}, and optionally one String label", + span.clone() + )); + } + } + } +} diff --git a/engine/baml-lib/baml-types/Cargo.toml b/engine/baml-lib/baml-types/Cargo.toml index 0cd8f0285..6a390a29e 100644 --- a/engine/baml-lib/baml-types/Cargo.toml +++ b/engine/baml-lib/baml-types/Cargo.toml @@ -19,7 +19,6 @@ workspace = true optional = true [dependencies.minijinja] -optional = true version = "1.0.16" default-features = false features = [ @@ -43,4 +42,3 @@ features = [ [features] default = ["stable_sort"] stable_sort = ["indexmap"] -mini-jinja = ["minijinja"] diff --git a/engine/baml-lib/baml-types/src/baml_value.rs b/engine/baml-lib/baml-types/src/baml_value.rs index 07a82ba44..a4f4cae4b 100644 --- a/engine/baml-lib/baml-types/src/baml_value.rs +++ b/engine/baml-lib/baml-types/src/baml_value.rs @@ -1,9 +1,11 @@ -use std::{collections::HashSet, fmt}; +use std::collections::HashMap; +use std::{collections::{HashSet, VecDeque}, fmt}; -use serde::{de::Visitor, Deserialize, Deserializer}; +use serde::ser::{SerializeMap, SerializeSeq}; +use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; use crate::media::BamlMediaType; -use crate::{BamlMap, BamlMedia}; +use crate::{BamlMap, BamlMedia, ResponseCheck}; #[derive(Clone, Debug, PartialEq)] pub enum BamlValue { @@ -141,6 +143,13 @@ impl BamlValue { _ => None, } } + + pub fn as_list_owned(self) -> Option> { + match self { + BamlValue::List(vals) => Some(vals), + _ => None, + } + } } impl std::fmt::Display for BamlValue { @@ -336,3 +345,327 @@ impl<'de> Visitor<'de> for BamlValueVisitor { Ok(BamlValue::Map(values)) } } + +/// A BamlValue with associated metadata. +/// This type is used to flexibly carry additional information. +/// It is used as a base type for situations where we want to represent +/// a BamlValue with additional information per node, such as a score, +/// or a constraint result. +#[derive(Clone, Debug, PartialEq)] +pub enum BamlValueWithMeta { + String(String, T), + Int(i64, T), + Float(f64, T), + Bool(bool, T), + Map(BamlMap>, T), + List(Vec>, T), + Media(BamlMedia, T), + Enum(String, String, T), + Class(String, BamlMap>, T), + Null(T), +} + +impl BamlValueWithMeta { + + pub fn r#type(&self) -> String { + let plain_value: BamlValue = self.into(); + plain_value.r#type() + } + + /// Iterating over a `BamlValueWithMeta` produces a depth-first traversal + /// of the value and all its children. + pub fn iter<'a>(&'a self) -> BamlValueWithMetaIterator<'a, T> { + BamlValueWithMetaIterator::new(self) + } + + pub fn value(self) -> BamlValue { + match self { + BamlValueWithMeta::String(v, _) => BamlValue::String(v), + BamlValueWithMeta::Int(v, _) => BamlValue::Int(v), + BamlValueWithMeta::Float(v, _) => BamlValue::Float(v), + BamlValueWithMeta::Bool(v, _) => BamlValue::Bool(v), + BamlValueWithMeta::Map(v, _) => { + BamlValue::Map(v.into_iter().map(|(k, v)| (k, v.value())).collect()) + } + BamlValueWithMeta::List(v, _) => { + BamlValue::List(v.into_iter().map(|v| v.value()).collect()) + } + BamlValueWithMeta::Media(v, _) => BamlValue::Media(v), + BamlValueWithMeta::Enum(v, w, _) => BamlValue::Enum(v, w), + BamlValueWithMeta::Class(n, fs, _) => { + BamlValue::Class(n, fs.into_iter().map(|(k, v)| (k, v.value())).collect()) + } + BamlValueWithMeta::Null(_) => BamlValue::Null, + } + } + + pub fn meta(&self) -> &T { + match self { + BamlValueWithMeta::String(_, m) => m, + BamlValueWithMeta::Int(_, m) => m, + BamlValueWithMeta::Float(_, m) => m, + BamlValueWithMeta::Bool(_, m) => m, + BamlValueWithMeta::Map(_, m) => m, + BamlValueWithMeta::List(_, m) => m, + BamlValueWithMeta::Media(_, m) => m, + BamlValueWithMeta::Enum(_, _, m) => m, + BamlValueWithMeta::Class(_, _, m) => m, + BamlValueWithMeta::Null(m) => m, + } + } + + pub fn meta_mut(&mut self) -> &mut T { + match self { + BamlValueWithMeta::String(_, m) => m, + BamlValueWithMeta::Int(_, m) => m, + BamlValueWithMeta::Float(_, m) => m, + BamlValueWithMeta::Bool(_, m) => m, + BamlValueWithMeta::Map(_, m) => m, + BamlValueWithMeta::List(_, m) => m, + BamlValueWithMeta::Media(_, m) => m, + BamlValueWithMeta::Enum(_, _, m) => m, + BamlValueWithMeta::Class(_, _, m) => m, + BamlValueWithMeta::Null(m) => m, + } + } + + pub fn with_default_meta(value: &BamlValue) -> BamlValueWithMeta where T: Default { + use BamlValueWithMeta::*; + match value { + BamlValue::String(s) => String(s.clone(), T::default()), + BamlValue::Int(i) => Int(*i, T::default()), + BamlValue::Float(f) => Float(*f, T::default()), + BamlValue::Bool(b) => Bool(*b, T::default()), + BamlValue::Map(entries) => BamlValueWithMeta::Map(entries.iter().map(|(k,v)| (k.clone(), Self::with_default_meta(v))).collect(), T::default()), + BamlValue::List(items) => List(items.iter().map(|i| Self::with_default_meta(i)).collect(), T::default()), + BamlValue::Media(m) => Media(m.clone(), T::default()), + BamlValue::Enum(n,v) => Enum(n.clone(), v.clone(), T::default()), + BamlValue::Class(n, items) => Map(items.iter().map(|(k,v)| (k.clone(), Self::with_default_meta(v))).collect(), T::default()), + BamlValue::Null => Null(T::default()), + _ => unimplemented!() + } + } + + pub fn map_meta(self, f: F) -> BamlValueWithMeta + where + F: Fn(T) -> U + Copy, + { + match self { + BamlValueWithMeta::String(v, m) => BamlValueWithMeta::String(v, f(m)), + BamlValueWithMeta::Int(v, m) => BamlValueWithMeta::Int(v, f(m)), + BamlValueWithMeta::Float(v, m) => BamlValueWithMeta::Float(v, f(m)), + BamlValueWithMeta::Bool(v, m) => BamlValueWithMeta::Bool(v, f(m)), + BamlValueWithMeta::Map(v, m) => BamlValueWithMeta::Map( + v.into_iter().map(|(k, v)| (k, v.map_meta(f))).collect(), + f(m), + ), + BamlValueWithMeta::List(v, m) => { + BamlValueWithMeta::List(v.into_iter().map(|v| v.map_meta(f)).collect(), f(m)) + } + BamlValueWithMeta::Media(v, m) => BamlValueWithMeta::Media(v, f(m)), + BamlValueWithMeta::Enum(v, e, m) => BamlValueWithMeta::Enum(v, e, f(m)), + BamlValueWithMeta::Class(n, fs, m) => BamlValueWithMeta::Class( + n, + fs.into_iter().map(|(k, v)| (k, v.map_meta(f))).collect(), + f(m), + ), + BamlValueWithMeta::Null(m) => BamlValueWithMeta::Null(f(m)), + } + } +} + +/// An iterator over a BamlValue and all of its sub-values. +/// It yields entries in depth-first order. +pub struct BamlValueWithMetaIterator<'a,T> { + stack: VecDeque<&'a BamlValueWithMeta>, +} + +impl <'a, T> BamlValueWithMetaIterator<'a, T> { + /// Construct a new iterator. Users should do this via + /// `.iter()` on a `BamlValueWithMeta` value. + fn new(root: &'a BamlValueWithMeta) -> Self { + let mut stack = VecDeque::new(); + stack.push_back(root); + BamlValueWithMetaIterator { stack } + } +} + +impl <'a,T:'a> Iterator for BamlValueWithMetaIterator<'a,T> { + type Item = &'a BamlValueWithMeta; + + fn next(&mut self) -> Option { + if let Some(value) = self.stack.pop_back() { + // Get all the children and push them onto the stack. + match value { + BamlValueWithMeta::List(items,_) => { + self.stack.extend(items); + } + BamlValueWithMeta::Map(fields,_) => { + for (_,v) in fields.iter() { + self.stack.push_back(v); + } + } + BamlValueWithMeta::Class(_, fields, _) => { + for (_,v) in fields.iter() { + self.stack.push_back(v); + } + } + // These items have to children. + BamlValueWithMeta::String(..) | BamlValueWithMeta::Int(..) | + BamlValueWithMeta::Float(..) | BamlValueWithMeta::Bool(..) | + BamlValueWithMeta::Media(..) | BamlValueWithMeta::Enum(..) | + BamlValueWithMeta::Null(..) => {} + } + Some(&value) + } else { + None + } + } +} + +// Boilerplate. +impl <'a, T:'a> IntoIterator for &'a BamlValueWithMeta { + type Item = &'a BamlValueWithMeta; + type IntoIter = BamlValueWithMetaIterator<'a,T>; + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl From<&BamlValueWithMeta> for BamlValue { + fn from(baml_value: &BamlValueWithMeta) -> BamlValue { + use BamlValueWithMeta::*; + match baml_value { + String(v, _) => BamlValue::String(v.clone()), + Int(v, _) => BamlValue::Int(v.clone()), + Float(v, _) => BamlValue::Float(v.clone()), + Bool(v, _) => BamlValue::Bool(v.clone()), + Map(v, _) => BamlValue::Map(v.into_iter().map(|(k,v)| (k.clone(), v.into())).collect()), + List(v, _) => BamlValue::List(v.into_iter().map(|v| v.into()).collect()), + Media(v, _) => BamlValue::Media(v.clone()), + Enum(enum_name, v, _) => BamlValue::Enum(enum_name.clone(), v.clone()), + Class(class_name, v, _) => BamlValue::Class(class_name.clone(), v.into_iter().map(|(k,v)| (k.clone(), v.into())).collect()), + Null(_) => BamlValue::Null, + } + } +} + +impl From> for BamlValue { + fn from(baml_value: BamlValueWithMeta) -> BamlValue { + use BamlValueWithMeta::*; + match baml_value { + String(v, _) => BamlValue::String(v), + Int(v, _) => BamlValue::Int(v), + Float(v, _) => BamlValue::Float(v), + Bool(v, _) => BamlValue::Bool(v), + Map(v, _) => BamlValue::Map(v.into_iter().map(|(k,v)| (k, v.into())).collect()), + List(v, _) => BamlValue::List(v.into_iter().map(|v| v.into()).collect()), + Media(v, _) => BamlValue::Media(v), + Enum(enum_name, v, _) => BamlValue::Enum(enum_name, v), + Class(class_name, v, _) => BamlValue::Class(class_name, v.into_iter().map(|(k,v)| (k, v.into())).collect()), + Null(_) => BamlValue::Null, + } + } +} + +/// This special-purpose serializer is used for the public-facing API. +/// When we want to extend the orchestrator with BamlValues packing more +/// metadata than just a `Vec`, ` +impl Serialize for BamlValueWithMeta> { + fn serialize(&self, serializer: S) -> Result + where S: Serializer, + { + match self { + BamlValueWithMeta::String(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Int(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Float(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Bool(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Map(v, cr) => { + let mut map = serializer.serialize_map(None)?; + for (key, value) in v { + map.serialize_entry(key, value)?; + } + add_checks(&mut map, cr)?; + map.end() + }, + BamlValueWithMeta::List(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Media(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Enum(_enum_name, v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Class(_class_name, v, cr) => { + let mut map = serializer.serialize_map(None)?; + for (key, value) in v { + map.serialize_entry(key, value)?; + } + add_checks(&mut map, cr)?; + map.end() + }, + BamlValueWithMeta::Null(cr) => serialize_with_checks(&(), cr, serializer), + } + } +} + +fn serialize_with_checks( + value: &T, + checks: &Vec, + serializer:S, + +) -> Result + where S: Serializer, +{ + if !checks.is_empty() { + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("value", value)?; + add_checks(&mut map, checks)?; + map.end() + } else { + value.serialize(serializer) + } +} + +fn add_checks<'a, S: SerializeMap>( + map: &'a mut S, + checks: &'a Vec, +) -> Result<(), S::Error> { + if !checks.is_empty() { + let checks_map: HashMap<_,_> = checks.iter().map(|check| (check.name.clone(), check)).collect(); + map.serialize_entry("checks", &checks_map)?; + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json; + use crate::JinjaExpression; + + #[test] + fn test_baml_value_with_meta_serialization() { + let baml_value: BamlValueWithMeta> = + BamlValueWithMeta::String("hi".to_string(), vec![]); + let baml_value_2: BamlValueWithMeta> = + BamlValueWithMeta::Class( + "ContactInfo".to_string(), + vec![ + ("primary".to_string(), BamlValueWithMeta::Class( + "PhoneNumber".to_string(), + vec![ + ("value".to_string(), BamlValueWithMeta::String( + "123-456-7890".to_string(), + vec![ + ResponseCheck { + name: "foo".to_string(), + expression: "foo".to_string(), + status: "succeeded".to_string(), + } + ] + )) + ].into_iter().collect(), + vec![] + )) + ].into_iter().collect(), + vec![]); + assert!(serde_json::to_value(baml_value).is_ok()); + assert!(serde_json::to_value(baml_value_2).is_ok()); + } +} diff --git a/engine/baml-lib/baml-types/src/constraint.rs b/engine/baml-lib/baml-types/src/constraint.rs new file mode 100644 index 000000000..063c13a0a --- /dev/null +++ b/engine/baml-lib/baml-types/src/constraint.rs @@ -0,0 +1,55 @@ +use crate::JinjaExpression; + +#[derive(Clone, Debug, serde::Serialize, PartialEq)] +pub struct Constraint { + pub level: ConstraintLevel, + pub expression: JinjaExpression, + pub label: Option, +} + +#[derive(Clone, Debug, PartialEq, serde::Serialize)] +pub enum ConstraintLevel { + Check, + Assert, +} + +/// The user-visible schema for a failed check. +#[derive(Clone, Debug, serde::Serialize)] +pub struct ResponseCheck { + pub name: String, + pub expression: String, + pub status: String, +} + +impl ResponseCheck { + /// Convert a Constraint and its status to a ResponseCheck. + /// Returns `None` if the Constraint is not a check (i.e., + /// if it doesn't meet the invariants that level==Check and + /// label==Some). + pub fn from_check_result( + ( + Constraint { + level, + expression, + label, + }, + succeeded, + ): (Constraint, bool), + ) -> Option { + match (level, label) { + (ConstraintLevel::Check, Some(label)) => { + let status = if succeeded { + "succeeded".to_string() + } else { + "failed".to_string() + }; + Some(ResponseCheck { + name: label, + expression: expression.0, + status, + }) + } + _ => None, + } + } +} diff --git a/engine/baml-lib/baml-types/src/field_type/mod.rs b/engine/baml-lib/baml-types/src/field_type/mod.rs index bde7184ab..becc743aa 100644 --- a/engine/baml-lib/baml-types/src/field_type/mod.rs +++ b/engine/baml-lib/baml-types/src/field_type/mod.rs @@ -1,4 +1,6 @@ use crate::BamlMediaType; +use crate::Constraint; +use crate::ConstraintLevel; mod builder; @@ -69,7 +71,7 @@ impl std::fmt::Display for LiteralValue { } /// FieldType represents the type of either a class field or a function arg. -#[derive(serde::Serialize, Debug, Clone)] +#[derive(serde::Serialize, Debug, Clone, PartialEq)] pub enum FieldType { Primitive(TypeValue), Enum(String), @@ -80,6 +82,7 @@ pub enum FieldType { Union(Vec), Tuple(Vec), Optional(Box), + Constrained{ base: Box, constraints: Vec }, } // Impl display for FieldType @@ -116,6 +119,7 @@ impl std::fmt::Display for FieldType { FieldType::Map(k, v) => write!(f, "map<{}, {}>", k.to_string(), v.to_string()), FieldType::List(t) => write!(f, "{}[]", t.to_string()), FieldType::Optional(t) => write!(f, "{}?", t.to_string()), + FieldType::Constrained{base,..} => base.fmt(f), } } } @@ -126,6 +130,7 @@ impl FieldType { FieldType::Primitive(_) => true, FieldType::Optional(t) => t.is_primitive(), FieldType::List(t) => t.is_primitive(), + FieldType::Constrained{base,..} => base.is_primitive(), _ => false, } } @@ -134,8 +139,8 @@ impl FieldType { match self { FieldType::Optional(_) => true, FieldType::Primitive(TypeValue::Null) => true, - FieldType::Union(types) => types.iter().any(FieldType::is_optional), + FieldType::Constrained{base,..} => base.is_optional(), _ => false, } } @@ -144,7 +149,83 @@ impl FieldType { match self { FieldType::Primitive(TypeValue::Null) => true, FieldType::Optional(t) => t.is_null(), + FieldType::Constrained{base,..} => base.is_null(), _ => false, } } + + /// Eliminate the `FieldType::Constrained` variant by searching for it, and stripping + /// it off of its base type, returning a tulpe of the base type and any constraints found + /// (if called on an argument that is not Constrained, the returned constraints Vec is + /// empty). + /// + /// If the function encounters directly nested Constrained types, + /// (i.e. `FieldType::Constrained { base: FieldType::Constrained { .. }, .. } `) + /// then the constraints of the two levels will be combined into a single vector. + /// So, we always return a base type that is not FieldType::Constrained. + pub fn distribute_constraints(self: &FieldType) -> (&FieldType, Vec) { + + match self { + // Check the first level to see if it's constrained. + FieldType::Constrained { base, constraints } => { + match base.as_ref() { + // If so, we must check the second level to see if we need to combine + // constraints across levels. + // The recursion here means that arbitrarily nested `FieldType::Constrained`s + // will be collapsed before the function returns. + FieldType::Constrained{..} => { + let (sub_base, sub_constraints) = base.as_ref().distribute_constraints(); + let combined_constraints = vec![constraints.clone(), sub_constraints].into_iter().flatten().collect(); + (sub_base, combined_constraints) + }, + _ => (base, constraints.clone()), + } + }, + _ => (self, Vec::new()), + } + } + + pub fn has_constraints(&self) -> bool { + let (_, constraints) = self.distribute_constraints(); + !constraints.is_empty() + } + + pub fn has_checks(&self) -> bool { + let (_, constraints) = self.distribute_constraints(); + constraints.iter().any(|Constraint{level,..}| level == &ConstraintLevel::Check) + } + +} + +#[cfg(test)] +mod tests { + use crate::{Constraint, ConstraintLevel, JinjaExpression}; + use super::*; + + + #[test] + fn test_nested_constraint_distribution() { + fn mk_constraint(s: &str) -> Constraint { + Constraint { level: ConstraintLevel::Assert, expression: JinjaExpression(s.to_string()), label: Some(s.to_string()) } + } + + let input = FieldType::Constrained { + constraints: vec![mk_constraint("a")], + base: Box::new(FieldType::Constrained { + constraints: vec![mk_constraint("b")], + base: Box::new(FieldType::Constrained { + constraints: vec![mk_constraint("c")], + base: Box::new(FieldType::Primitive(TypeValue::Int)), + }) + }) + }; + + let expected_base = FieldType::Primitive(TypeValue::Int); + let expected_constraints = vec![mk_constraint("a"),mk_constraint("b"), mk_constraint("c")]; + + let (base, constraints) = input.distribute_constraints(); + + assert_eq!(base, &expected_base); + assert_eq!(constraints, expected_constraints); + } } diff --git a/engine/baml-lib/baml-types/src/lib.rs b/engine/baml-lib/baml-types/src/lib.rs index 21e6cf0e8..fb721ff8d 100644 --- a/engine/baml-lib/baml-types/src/lib.rs +++ b/engine/baml-lib/baml-types/src/lib.rs @@ -1,14 +1,16 @@ +mod constraint; mod map; mod media; -#[cfg(feature = "mini-jinja")] mod minijinja; mod baml_value; mod field_type; mod generator; -pub use baml_value::BamlValue; +pub use baml_value::{BamlValue, BamlValueWithMeta}; +pub use constraint::*; pub use field_type::{FieldType, LiteralValue, TypeValue}; pub use generator::{GeneratorDefaultClientMode, GeneratorOutputType}; pub use map::Map as BamlMap; pub use media::{BamlMedia, BamlMediaContent, BamlMediaType, MediaBase64, MediaUrl}; +pub use minijinja::JinjaExpression; diff --git a/engine/baml-lib/baml-types/src/minijinja.rs b/engine/baml-lib/baml-types/src/minijinja.rs index e1c3f168d..36aa7a5a4 100644 --- a/engine/baml-lib/baml-types/src/minijinja.rs +++ b/engine/baml-lib/baml-types/src/minijinja.rs @@ -1,5 +1,19 @@ +use std::fmt; use crate::{BamlMedia, BamlValue}; +/// A wrapper around a jinja expression. The inner `String` should not contain +/// the interpolation brackets `{{ }}`; it should be a bare expression like +/// `"this|length < something"`. +#[derive(Clone, Debug, PartialEq, serde::Serialize)] +pub struct JinjaExpression(pub String); + + +impl fmt::Display for JinjaExpression { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + impl From for minijinja::Value { fn from(arg: BamlValue) -> minijinja::Value { match arg { diff --git a/engine/baml-lib/baml/tests/validation_files/constraints/constraints_everywhere.baml b/engine/baml-lib/baml/tests/validation_files/constraints/constraints_everywhere.baml new file mode 100644 index 000000000..5786ef86c --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/constraints/constraints_everywhere.baml @@ -0,0 +1,12 @@ +client Bar { + provider baml-openai-chat +} + +class Foo { + age int @assert({{this > 10}}, old_enough) +} + +function FooToInt(foo: Foo, a: Foo @assert({{this.age > 20}}, really_old)) -> int @check({{ this < 10 }}, small_int) { + client Bar + prompt #"fa"# +} diff --git a/engine/baml-lib/baml/tests/validation_files/constraints/misspelled.baml b/engine/baml-lib/baml/tests/validation_files/constraints/misspelled.baml new file mode 100644 index 000000000..674fc3cd6 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/constraints/misspelled.baml @@ -0,0 +1,11 @@ +class Foo { + // A constraint that didn't use Jinja Expression syntax. + age int @check("this < 10", still_baby) +} + +// error: Error validating: A constraint must have one Jinja argument such as {{ expr }}, and optionally one String label +// --> constraints/misspelled.baml:3 +// | +// 2 | // A constraint that didn't use Jinja Expression syntax. +// 3 | age int @check("this < 10", still_baby) +// | diff --git a/engine/baml-lib/baml/tests/validation_files/functions_v2/check_in_parameter.baml b/engine/baml-lib/baml/tests/validation_files/functions_v2/check_in_parameter.baml new file mode 100644 index 000000000..bece0ae27 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/functions_v2/check_in_parameter.baml @@ -0,0 +1,24 @@ +client MyClient { + provider baml-openai-chat +} + +class Foo { + bar Bar? + baz int +} + +class Bar { + name string @check({{ this|length > 0 }}, nonempty_name) +} + +function Go(a: Foo) -> int { + client MyClient + prompt #"Hi"# +} + +// error: Error validating: Types with checks are not allowed as function parameters. +// --> functions_v2/check_in_parameter.baml:14 +// | +// 13 | +// 14 | function Go(a: Foo) -> int { +// | diff --git a/engine/baml-lib/jinja/Cargo.toml b/engine/baml-lib/jinja/Cargo.toml index 93e6d37fb..0f0bb78fb 100644 --- a/engine/baml-lib/jinja/Cargo.toml +++ b/engine/baml-lib/jinja/Cargo.toml @@ -9,7 +9,7 @@ license-file.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -baml-types = { path = "../baml-types", features = ["mini-jinja"] } +baml-types = { path = "../baml-types" } # TODO: disable imports, etc minijinja = { version = "1.0.16", default-features = false, features = [ "macros", @@ -38,6 +38,7 @@ serde_json.workspace = true strum.workspace = true strsim = "0.11.1" colored = "2.1.0" +regex.workspace = true [dev-dependencies] env_logger = "0.11.3" diff --git a/engine/baml-lib/jinja/src/lib.rs b/engine/baml-lib/jinja/src/lib.rs index 4d49dcc9d..8d914a289 100644 --- a/engine/baml-lib/jinja/src/lib.rs +++ b/engine/baml-lib/jinja/src/lib.rs @@ -1,4 +1,4 @@ -use baml_types::{BamlMedia, BamlValue}; +use baml_types::{BamlMedia, BamlValue, JinjaExpression}; use colored::*; mod chat_message_part; mod evaluate_type; @@ -12,6 +12,7 @@ pub use evaluate_type::{PredefinedTypes, Type, TypeError}; use minijinja::{self, value::Kwargs}; use minijinja::{context, ErrorKind, Value}; use output_format::types::OutputFormatContent; +use regex::Regex; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; @@ -24,9 +25,17 @@ fn get_env<'a>() -> minijinja::Environment<'a> { env.set_debug(true); env.set_trim_blocks(true); env.set_lstrip_blocks(true); + env.add_filter("regex_match", regex_match); env } +fn regex_match(value: String, regex: String) -> bool { + match Regex::new(®ex) { + Err(_) => false, + Ok(re) => re.is_match(&value) + } +} + #[derive(Debug)] pub struct ValidationError { pub errors: Vec, @@ -80,6 +89,10 @@ pub struct RenderContext_Client { pub default_role: String, } +/// A collection of values about the rendering context that will be made +/// available to a prompt via `{{ ctx }}`. For example `{{ ctx.client.name }}` +/// used in a prompt string will resolve to the name of the client, e.g. +/// "openai". #[derive(Debug)] pub struct RenderContext { pub client: RenderContext_Client, @@ -487,12 +500,42 @@ pub fn render_prompt( } } +/// Render a bare minijinaja expression with the given context. +/// E.g. `"a|length > 2"` with context `{"a": [1, 2, 3]}` will return `"true"`. +pub fn render_expression( + expression: &JinjaExpression, + ctx: &HashMap, +) -> anyhow::Result { + let env = get_env(); + // In rust string literals, `{` is escaped as `{{`. + // So producing the string `{{}}` requires writing the literal `"{{{{}}}}"` + let template = format!(r#"{{{{ {} }}}}"#, expression.0); + let args_dict = minijinja::Value::from_serialize(ctx); + eprintln!("{}", &template); + Ok(env.render_str(&template, &args_dict)?) +} + +// TODO: (Greg) better error handling. +// TODO: (Greg) Upstream, typecheck the expression. +pub fn evaluate_predicate( + this: &BamlValue, + predicate_expression: &JinjaExpression, +) -> Result { + let ctx: HashMap = + [("this".to_string(), this.clone())].into_iter().collect(); + match render_expression(&predicate_expression, &ctx)?.as_ref() { + "true" => Ok(true), + "false" => Ok(false), + _ => Err(anyhow::anyhow!("TODO")), + } +} + #[cfg(test)] mod render_tests { use super::*; - use baml_types::{BamlMap, BamlMediaType}; + use baml_types::{BamlMap, BamlMediaType, JinjaExpression}; use env_logger; use std::sync::Once; @@ -1107,4 +1150,45 @@ mod render_tests { Ok(()) } + + #[test] + fn test_render_expressions() { + let ctx = vec![( + "a".to_string(), + BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into()) + ), ("b".to_string(), BamlValue::String("(123)456-7890".to_string()))] + .into_iter() + .collect(); + + assert_eq!( + render_expression(&JinjaExpression("1".to_string()), &ctx).unwrap(), + "1" + ); + assert_eq!( + render_expression(&JinjaExpression("1 + 1".to_string()), &ctx).unwrap(), + "2" + ); + assert_eq!( + render_expression(&JinjaExpression("a|length > 2".to_string()), &ctx).unwrap(), + "true" + ); + } + + #[test] + fn test_render_regex_match() { + let ctx = vec![( + "a".to_string(), + BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into()) + ), ("b".to_string(), BamlValue::String("(123)456-7890".to_string()))] + .into_iter() + .collect(); + assert_eq!( + render_expression(&JinjaExpression(r##"b|regex_match("123")"##.to_string()), &ctx).unwrap(), + "true" + ); + assert_eq!( + render_expression(&JinjaExpression(r##"b|regex_match("\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}")"##.to_string()), &ctx).unwrap(), + "true" + ) + } } diff --git a/engine/baml-lib/jinja/src/output_format/mod.rs b/engine/baml-lib/jinja/src/output_format/mod.rs index f2abfc716..59efe6f6e 100644 --- a/engine/baml-lib/jinja/src/output_format/mod.rs +++ b/engine/baml-lib/jinja/src/output_format/mod.rs @@ -9,6 +9,7 @@ use crate::{types::RenderOptions, RenderContext}; use self::types::OutputFormatContent; +// TODO: Rename the field to `content`. #[derive(Debug)] pub struct OutputFormat { text: OutputFormatContent, diff --git a/engine/baml-lib/jinja/src/output_format/types.rs b/engine/baml-lib/jinja/src/output_format/types.rs index 1125ef58b..fc795a69e 100644 --- a/engine/baml-lib/jinja/src/output_format/types.rs +++ b/engine/baml-lib/jinja/src/output_format/types.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use anyhow::Result; -use baml_types::{FieldType, LiteralValue, TypeValue}; +use baml_types::{FieldType, LiteralValue, TypeValue, Constraint}; use indexmap::{IndexMap, IndexSet}; #[derive(Debug)] @@ -34,18 +34,23 @@ impl Name { } } +// TODO: (Greg) Enum needs to carry its constraints. #[derive(Debug)] pub struct Enum { pub name: Name, // name and description pub values: Vec<(Name, Option)>, + pub constraints: Vec, } +/// The components of a Class needed to render `OutputFormatContent`. +/// This type is also used by `jsonish` to drive flexible parsing. #[derive(Debug)] pub struct Class { pub name: Name, - // type and description + // fields have name, type and description. pub fields: Vec<(Name, FieldType, Option)>, + pub constraints: Vec, } #[derive(Debug, Clone)] @@ -227,10 +232,8 @@ impl OutputFormatContent { } fn prefix<'a>(&self, options: &'a RenderOptions) -> Option<&'a str> { - match &options.prefix { - RenderSetting::Always(prefix) => Some(prefix.as_str()), - RenderSetting::Never => None, - RenderSetting::Auto => match &self.target { + fn auto_prefix(ft: &FieldType) -> Option<&'static str> { + match ft { FieldType::Primitive(TypeValue::String) => None, FieldType::Primitive(_) => Some("Answer as a: "), FieldType::Literal(_) => Some("Answer using this specific value:\n"), @@ -241,7 +244,13 @@ impl OutputFormatContent { FieldType::Optional(_) => Some("Answer in JSON using this schema:\n"), FieldType::Map(_, _) => Some("Answer in JSON using this schema:\n"), FieldType::Tuple(_) => None, - }, + FieldType::Constrained { base, .. } => auto_prefix(base), + } + } + match &options.prefix { + RenderSetting::Always(prefix) => Some(prefix.as_str()), + RenderSetting::Never => None, + RenderSetting::Auto => auto_prefix(&self.target), } } @@ -287,6 +296,9 @@ impl OutputFormatContent { LiteralValue::Int(i) => i.to_string(), LiteralValue::Bool(b) => b.to_string(), }, + FieldType::Constrained { base, .. } => { + self.inner_type_render(options, base, render_state, group_hoisted_literals)? + } FieldType::Enum(e) => { let Some(enm) = self.enums.get(e) else { return Err(minijinja::Error::new( @@ -523,6 +535,7 @@ mod tests { (Name::new("Green".to_string()), None), (Name::new("Blue".to_string()), None), ], + constraints: Vec::new(), }); let content = OutputFormatContent::new(enums, vec![], FieldType::Enum("Color".to_string())); @@ -553,6 +566,7 @@ mod tests { Some("The person's age".to_string()), ), ], + constraints: Vec::new(), }); let content = @@ -589,6 +603,7 @@ mod tests { None, ), ], + constraints: Vec::new(), }); let content = diff --git a/engine/baml-lib/jinja/src/render_context.rs b/engine/baml-lib/jinja/src/render_context.rs index a1f8861c4..ba33725b2 100644 --- a/engine/baml-lib/jinja/src/render_context.rs +++ b/engine/baml-lib/jinja/src/render_context.rs @@ -22,6 +22,8 @@ impl std::fmt::Display for RenderContext_Client { } } +// TODO: (Greg) This type is duplicated in `src/lib.rs`. Are they both +// needed? If not, delete one. #[derive(Debug)] pub struct RenderContext { client: RenderContext_Client, diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index 8dde401fb..bab1d87a3 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -3,7 +3,7 @@ use baml_types::BamlMap; use internal_baml_core::{ir::FieldType, ir::TypeValue}; use crate::deserializer::{ - coercer::{DefaultValue, TypeCoercer}, + coercer::{run_user_checks, DefaultValue, TypeCoercer}, deserialize_flags::{DeserializerConditions, Flag}, types::BamlValueWithFlags, }; @@ -84,6 +84,19 @@ impl TypeCoercer for FieldType { FieldType::Optional(_) => coerce_optional(ctx, self, value), FieldType::Map(_, _) => coerce_map(ctx, self, value), FieldType::Tuple(_) => Err(ctx.error_internal("Tuple not supported")), + FieldType::Constrained { base, .. } => { + let mut coerced_value = base.coerce(ctx, base, value)?; + let constraint_results = + run_user_checks(&coerced_value.clone().into(), &self).map_err( + |e| ParsingError { + reason: format!("Failed to evaluate constraints: {:?}", e), + scope: ctx.scope.clone(), + causes: Vec::new(), + }, + )?; + coerced_value.add_flag(Flag::ConstraintResults(constraint_results)); + Ok(coerced_value) + } }, } } @@ -100,7 +113,7 @@ impl DefaultValue for FieldType { match self { FieldType::Enum(e) => None, FieldType::Literal(_) => None, - FieldType::Class(c) => None, + FieldType::Class(_) => None, FieldType::List(_) => Some(BamlValueWithFlags::List(get_flags(), Vec::new())), FieldType::Union(items) => items.iter().find_map(|i| i.default_value(error)), FieldType::Primitive(TypeValue::Null) | FieldType::Optional(_) => { @@ -119,6 +132,8 @@ impl DefaultValue for FieldType { } } FieldType::Primitive(_) => None, + // If it has constraints, we can't assume our defaults meet them. + FieldType::Constrained { .. } => None, } } } diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs index 4ebf500f5..5b273f72d 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use baml_types::BamlMap; +use baml_types::{BamlMap, Constraint}; use internal_baml_core::ir::FieldType; use internal_baml_jinja::types::{Class, Name}; @@ -11,7 +11,7 @@ use crate::deserializer::{ use super::ParsingContext; -// Name, type, description +// Name, type, description, constraints. type FieldValue = (Name, FieldType, Option); impl TypeCoercer for Class { diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs index d07bd2257..9c63089d1 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs @@ -10,7 +10,9 @@ mod ir_ref; mod match_string; use anyhow::Result; -use internal_baml_jinja::types::OutputFormatContent; + +use baml_types::{BamlValue, Constraint}; +use internal_baml_jinja::{evaluate_predicate, types::OutputFormatContent}; use internal_baml_core::ir::FieldType; @@ -125,7 +127,7 @@ impl ParsingContext<'_> { &self, unparsed: Vec<(String, &ParsingError)>, missing: Vec, - item: Option<&crate::jsonish::Value>, + _item: Option<&crate::jsonish::Value>, ) -> ParsingError { ParsingError { reason: format!( @@ -136,7 +138,7 @@ impl ParsingContext<'_> { scope: self.scope.clone(), causes: missing .into_iter() - .map(|(k)| ParsingError { + .map(|k| ParsingError { scope: self.scope.clone(), reason: format!("Missing required field: {}", k), causes: vec![], @@ -219,3 +221,22 @@ pub trait TypeCoercer { pub trait DefaultValue { fn default_value(&self, error: Option<&ParsingError>) -> Option; } + +/// Run all checks and asserts for a value at a given type. +pub fn run_user_checks( + baml_value: &BamlValue, + type_: &FieldType, +) -> Result> { + match type_ { + FieldType::Constrained { constraints, .. } => { + constraints.iter().map(|constraint| { + let result = + evaluate_predicate(baml_value, &constraint.expression).map_err(|e| { + anyhow::anyhow!(format!("Error evaluating constraint: {:?}", e)) + })?; + Ok((constraint.clone(), result)) + }).collect::>>() + } + _ => Ok(vec![]), + } +} diff --git a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs index a05b85399..04bdebfe4 100644 --- a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs +++ b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs @@ -1,4 +1,5 @@ use super::{coercer::ParsingError, types::BamlValueWithFlags}; +use baml_types::Constraint; #[derive(Debug, Clone)] pub enum Flag { @@ -42,6 +43,9 @@ pub enum Flag { // X -> Object convertions. NoFields(Option), + + // Constraint results. + ConstraintResults(Vec<(Constraint, bool)>), } #[derive(Clone)] @@ -90,9 +94,18 @@ impl DeserializerConditions { Flag::NoFields(_) => None, Flag::UnionMatch(_idx, _) => None, Flag::DefaultButHadUnparseableValue(e) => Some(e.clone()), + Flag::ConstraintResults(_) => None, }) .collect::>() } + + pub fn constraint_results(&self) -> Vec<(Constraint, bool)> { + self.flags.iter().filter_map(|flag| match flag { + Flag::ConstraintResults(cs) => Some(cs.clone()), + _ => None, + }).flatten().collect() + } + } impl std::fmt::Debug for DeserializerConditions { @@ -229,6 +242,13 @@ impl std::fmt::Display for Flag { writeln!(f, "")?; } } + Flag::ConstraintResults(cs) => { + for (Constraint{ label, level, expression }, succeeded) in cs.iter() { + let msg = label.as_ref().unwrap_or(&expression.0); + let f_result = if *succeeded { "Succeeded" } else { "Failed" }; + writeln!(f, "{level:?} {msg} {f_result}")?; + } + } } Ok(()) } diff --git a/engine/baml-lib/jsonish/src/deserializer/score.rs b/engine/baml-lib/jsonish/src/deserializer/score.rs index cba62ce0c..bf25dc39b 100644 --- a/engine/baml-lib/jsonish/src/deserializer/score.rs +++ b/engine/baml-lib/jsonish/src/deserializer/score.rs @@ -1,3 +1,5 @@ +use baml_types::{Constraint, ConstraintLevel}; + use super::{ deserialize_flags::{DeserializerConditions, Flag}, types::{BamlValueWithFlags, ValueWithFlags}, @@ -62,6 +64,18 @@ impl WithScore for Flag { Flag::StringToChar(_) => 1, Flag::FloatToInt(_) => 1, Flag::NoFields(_) => 1, + Flag::ConstraintResults(cs) => { + cs + .iter() + .map(|(Constraint{ level,.. }, succeeded)| + if *succeeded { 0 } else { + match level { + ConstraintLevel::Check => 5, + ConstraintLevel::Assert => 50, + } + }) + .sum() + } } } } diff --git a/engine/baml-lib/jsonish/src/deserializer/types.rs b/engine/baml-lib/jsonish/src/deserializer/types.rs index e1dc4681f..b82657e0a 100644 --- a/engine/baml-lib/jsonish/src/deserializer/types.rs +++ b/engine/baml-lib/jsonish/src/deserializer/types.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use baml_types::{BamlMap, BamlMedia, BamlValue}; +use baml_types::{BamlMap, BamlMedia, BamlValue, BamlValueWithMeta, Constraint}; use serde_json::json; use strsim::jaro; @@ -229,7 +229,7 @@ impl BamlValueWithFlags { #[derive(Debug, Clone)] pub struct ValueWithFlags { - value: T, + pub value: T, pub(super) flags: DeserializerConditions, } @@ -440,3 +440,33 @@ impl std::fmt::Display for BamlValueWithFlags { Ok(()) } } + +impl From for BamlValueWithMeta> { + fn from(baml_value: BamlValueWithFlags) -> Self { + use BamlValueWithFlags::*; + let c = baml_value.conditions().constraint_results(); + match baml_value { + String(ValueWithFlags { value, .. }) => BamlValueWithMeta::String(value, c), + Int(ValueWithFlags { value, .. }) => BamlValueWithMeta::Int(value, c), + Float(ValueWithFlags { value, .. }) => BamlValueWithMeta::Float(value, c), + Bool(ValueWithFlags { value, .. }) => BamlValueWithMeta::Bool(value, c), + Map(_, values) => BamlValueWithMeta::Map( + values.into_iter().map(|(k, v)| (k, v.1.into())).collect(), + c, + ), // TODO: (Greg) I discard the DeserializerConditions tupled up with the value of the BamlMap. I'm not sure why BamlMap value is (DeserializerContitions, BamlValueWithFlags) in the first place. + List(_, values) => { + BamlValueWithMeta::List(values.into_iter().map(|v| v.into()).collect(), c) + } + Media(ValueWithFlags { value, .. }) => BamlValueWithMeta::Media(value, c), + Enum(enum_name, ValueWithFlags { value, .. }) => { + BamlValueWithMeta::Enum(enum_name, value, c) + } + Class(class_name, _, fields) => BamlValueWithMeta::Class( + class_name, + fields.into_iter().map(|(k, v)| (k, v.into())).collect(), + c, + ), + Null(_) => BamlValueWithMeta::Null(c), + } + } +} diff --git a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs index 19a7ed99d..ab7a5058c 100644 --- a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs +++ b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs @@ -215,7 +215,7 @@ impl JsonParseState { log::debug!("Closing due to: new key after space + comma"); return Some(idx); } - x => { + _x => { break; } } diff --git a/engine/baml-lib/jsonish/src/lib.rs b/engine/baml-lib/jsonish/src/lib.rs index e89408c68..8325a01a5 100644 --- a/engine/baml-lib/jsonish/src/lib.rs +++ b/engine/baml-lib/jsonish/src/lib.rs @@ -2,7 +2,7 @@ mod tests; use anyhow::Result; -mod deserializer; +pub mod deserializer; mod jsonish; use baml_types::FieldType; diff --git a/engine/baml-lib/jsonish/src/tests/macros.rs b/engine/baml-lib/jsonish/src/tests/macros.rs index 2a8703437..7c0dd8281 100644 --- a/engine/baml-lib/jsonish/src/tests/macros.rs +++ b/engine/baml-lib/jsonish/src/tests/macros.rs @@ -16,6 +16,12 @@ macro_rules! test_failing_deserializer { }; } +/// Arguments: +/// name: name of test function to generate. +/// file_content: a BAML schema. +/// raw_string: an example payload coming from an LLM to parse. +/// target_type: The type to try to parse raw_string into. +/// json: The expected JSON encoding that the parser should return. macro_rules! test_deserializer { ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr, $($json:tt)+) => { #[test_log::test] @@ -45,6 +51,25 @@ macro_rules! test_deserializer { }; } +macro_rules! test_deserializer_with_expected_score { + ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr, $target_score:expr) => { + #[test_log::test] + fn $name() { + let ir = load_test_ir($file_content); + let target = render_output_format(&ir, &$target_type, &Default::default()).unwrap(); + + let result = from_str(&target, &$target_type, $raw_string, false); + + assert!(result.is_ok(), "Failed to parse: {:?}", result); + + let value = result.unwrap(); + dbg!(&value); + log::trace!("Score: {}", value.score()); + assert_eq!(value.score(), $target_score); + } + }; +} + macro_rules! test_partial_deserializer { ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr, $($json:tt)+) => { #[test_log::test] diff --git a/engine/baml-lib/jsonish/src/tests/mod.rs b/engine/baml-lib/jsonish/src/tests/mod.rs index 8b39002c6..e990a9cd4 100644 --- a/engine/baml-lib/jsonish/src/tests/mod.rs +++ b/engine/baml-lib/jsonish/src/tests/mod.rs @@ -6,6 +6,7 @@ pub mod macros; mod test_basics; mod test_class; +mod test_constraints; mod test_enum; mod test_lists; mod test_literals; @@ -18,7 +19,7 @@ use std::{ path::PathBuf, }; -use baml_types::BamlValue; +use baml_types::{BamlValue, Constraint, ConstraintLevel, JinjaExpression}; use internal_baml_core::{ internal_baml_diagnostics::SourceFile, ir::{repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper, TypeValue}, @@ -105,20 +106,53 @@ fn find_enum_value( Ok(Some((name, desc))) } +/// Eliminate the `FieldType::Constrained` variant by searching for it, and stripping +/// it off of its base type, returning a tulpe of the base type and any constraints found +/// (if called on an argument that is not Constrained, the returned constraints Vec is +/// empty). +/// +/// If the function encounters directly nested Constrained types, +/// (i.e. `FieldType::Constrained { base: FieldType::Constrained { .. }, .. } `) +/// then the constraints of the two levels will be combined into a single vector. +/// So, we always return a base type that is not FieldType::Constrained. +fn distribute_constraints(field_type: &FieldType) -> (&FieldType, Vec) { + + match field_type { + // Check the first level to see if it's constrained. + FieldType::Constrained { base, constraints } => { + match base.as_ref() { + // If so, we must check the second level to see if we need to combine + // constraints across levels. + // The recursion here means that arbitrarily nested `FieldType::Constrained`s + // will be collapsed before the function returns. + FieldType::Constrained{..} => { + let (sub_base, sub_constraints) = distribute_constraints(base); + let combined_constraints = vec![constraints.clone(), sub_constraints].into_iter().flatten().collect(); + (sub_base, combined_constraints) + }, + _ => (base, constraints.clone()), + } + }, + _ => (field_type, Vec::new()), + } +} + +// TODO: (Greg) Is the use of `String` as a hash key safe? Is there some way to +// get a collision that results in some type not getting put onto the stack? fn relevant_data_models<'a>( ir: &'a IntermediateRepr, output: &'a FieldType, env_values: &HashMap, ) -> Result<(Vec, Vec)> { - let mut checked_types = HashSet::new(); + let mut checked_types: HashSet = HashSet::new(); let mut enums = Vec::new(); - let mut classes = Vec::new(); + let mut classes: Vec = Vec::new(); let mut start: Vec = vec![output.clone()]; while !start.is_empty() { let output = start.pop().unwrap(); - match &output { - FieldType::Enum(enm) => { + match distribute_constraints(&output) { + (FieldType::Enum(enm), constraints) => { if checked_types.insert(output.to_string()) { let walker = ir.find_enum(enm); @@ -140,15 +174,16 @@ fn relevant_data_models<'a>( enums.push(Enum { name: Name::new_with_alias(enm.to_string(), walker?.alias(env_values)?), values, + constraints, }); } } - FieldType::List(inner) | FieldType::Optional(inner) => { + (FieldType::List(inner), _constraints) | (FieldType::Optional(inner), _constraints) => { if !checked_types.contains(&inner.to_string()) { start.push(inner.as_ref().clone()); } } - FieldType::Map(k, v) => { + (FieldType::Map(k, v), _constraints) => { if checked_types.insert(output.to_string()) { if !checked_types.contains(&k.to_string()) { start.push(k.as_ref().clone()); @@ -158,7 +193,7 @@ fn relevant_data_models<'a>( } } } - FieldType::Tuple(options) | FieldType::Union(options) => { + (FieldType::Tuple(options), _constraints) | (FieldType::Union(options), _constraints) => { if checked_types.insert((&output).to_string()) { for inner in options { if !checked_types.contains(&inner.to_string()) { @@ -167,7 +202,7 @@ fn relevant_data_models<'a>( } } } - FieldType::Class(cls) => { + (FieldType::Class(cls), constraints) => { if checked_types.insert(output.to_string()) { let walker = ir.find_class(&cls); @@ -192,11 +227,15 @@ fn relevant_data_models<'a>( classes.push(Class { name: Name::new_with_alias(cls.to_string(), walker?.alias(env_values)?), fields, + constraints, }); } } - FieldType::Primitive(_) => {} - FieldType::Literal(_) => {} + (FieldType::Literal(_), _) => {} + (FieldType::Primitive(_), _constraints) => {} + (FieldType::Constrained{..}, _) => { + unreachable!("It is guaranteed that a call to distribute_constraints will not return FieldType::Constrained") + } } } diff --git a/engine/baml-lib/jsonish/src/tests/test_constraints.rs b/engine/baml-lib/jsonish/src/tests/test_constraints.rs new file mode 100644 index 000000000..8c9e52e5b --- /dev/null +++ b/engine/baml-lib/jsonish/src/tests/test_constraints.rs @@ -0,0 +1,129 @@ +use super::*; + +const CLASS_FOO_INT_STRING: &str = r#" +class Foo { + age int + @check({{this < 10}}, "age less than 10") + @check({{this < 20}}, "age less than 20") + @assert({{this >= 0}}, "nonnegative") + name string + @assert({{this|length > 0}}, "Nonempty name") +} +"#; + +test_deserializer_with_expected_score!( + test_class_failing_one_check, + CLASS_FOO_INT_STRING, + r#"{"age": 11, "name": "Greg"}"#, + FieldType::Class("Foo".to_string()), + 5 +); + +test_deserializer_with_expected_score!( + test_class_failing_two_checks, + CLASS_FOO_INT_STRING, + r#"{"age": 21, "name": "Grog"}"#, + FieldType::Class("Foo".to_string()), + 10 +); + +test_deserializer_with_expected_score!( + test_class_failing_assert, + CLASS_FOO_INT_STRING, + r#"{"age": -1, "name": "Sam"}"#, + FieldType::Class("Foo".to_string()), + 50 +); + +test_deserializer_with_expected_score!( + test_class_multiple_failing_asserts, + CLASS_FOO_INT_STRING, + r#"{"age": -1, "name": ""}"#, + FieldType::Class("Foo".to_string()), + 100 +); + +const UNION_WITH_CHECKS: &str = r#" +class Thing1 { + bar int @check({{ this < 10 }}, "bar small") +} + +class Thing2 { + bar int @check({{ this > 20 }}, "bar big") +} + +class Either { + bar Thing1 | Thing2 + things (Thing1 | Thing2)[] @assert({{this|length < 4}}, "list not too long") +} +"#; + +test_deserializer_with_expected_score!( + test_union_decision_from_check, + UNION_WITH_CHECKS, + r#"{"bar": 5, "things":[]}"#, + FieldType::Class("Either".to_string()), + 2 +); + +test_deserializer_with_expected_score!( + test_union_decision_from_check_no_good_answer, + UNION_WITH_CHECKS, + r#"{"bar": 15, "things":[]}"#, + FieldType::Class("Either".to_string()), + 7 +); + +test_deserializer_with_expected_score!( + test_union_decision_in_list, + UNION_WITH_CHECKS, + r#"{"bar": 1, "things":[{"bar": 25}, {"bar": 35}, {"bar": 15}, {"bar": 15}]}"#, + FieldType::Class("Either".to_string()), + 62 +); + +const MAP_WITH_CHECKS: &str = r#" +class Foo { + foo map @check({{ this["hello"] == 10 }}, "hello is 10") +} +"#; + +test_deserializer_with_expected_score!( + test_map_with_check, + MAP_WITH_CHECKS, + r#"{"foo": {"hello": 10, "there":13}}"#, + FieldType::Class("Foo".to_string()), + 1 +); + +test_deserializer_with_expected_score!( + test_map_with_check_fails, + MAP_WITH_CHECKS, + r#"{"foo": {"hello": 11, "there":13}}"#, + FieldType::Class("Foo".to_string()), + 6 +); + +const NESTED_CLASS_CONSTRAINTS: &str = r#" +class Outer { + inner Inner +} + +class Inner { + value int @check({{ this < 10 }}) +} +"#; + +test_deserializer_with_expected_score!( + test_nested_class_constraints, + NESTED_CLASS_CONSTRAINTS, + r#"{"inner": {"value": 15}}"#, + FieldType::Class("Outer".to_string()), + 5 +); + +const MISSPELLED_CONSTRAINT: &str = r#" +class Foo { + foo int @description("hi") @check({{this == 1}},"hi") +} +"#; diff --git a/engine/baml-lib/jsonish/src/tests/test_unions.rs b/engine/baml-lib/jsonish/src/tests/test_unions.rs index 9d2784b3d..40316411d 100644 --- a/engine/baml-lib/jsonish/src/tests/test_unions.rs +++ b/engine/baml-lib/jsonish/src/tests/test_unions.rs @@ -247,3 +247,35 @@ test_deserializer!( ] } ); + +const CONTACT_INFO: &str = r#" +class PhoneNumber { + value string @check({{this|regex_match("\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}")}}, "valid_phone_number") + foo int? // A nullable marker indicating PhoneNumber was chosen. +} + +class EmailAddress { + value string @check({{this|regex_match("^[_]*([a-z0-9]+(\.|_*)?)+@([a-z][a-z0-9-]+(\.|-*\.))+[a-z]{2,6}$")}}, "valid_email") + bar int? // A nullable marker indicating EmailAddress was chosen. +} + +class ContactInfo { + primary PhoneNumber | EmailAddress +} +"#; + +test_deserializer!( + test_check1, + CONTACT_INFO, + r#"{"primary": {"value": "908-797-8281"}}"#, + FieldType::Class("ContactInfo".to_string()), + {"primary": {"value": "908-797-8281", "foo": null}} +); + +test_deserializer!( + test_check2, + CONTACT_INFO, + r#"{"primary": {"value": "help@boundaryml.com"}}"#, + FieldType::Class("ContactInfo".to_string()), + {"primary": {"value": "help@boundaryml.com", "bar": null}} +); diff --git a/engine/baml-lib/parser-database/src/attributes/constraint.rs b/engine/baml-lib/parser-database/src/attributes/constraint.rs new file mode 100644 index 000000000..993944cf7 --- /dev/null +++ b/engine/baml-lib/parser-database/src/attributes/constraint.rs @@ -0,0 +1,38 @@ +use baml_types::{Constraint, ConstraintLevel}; +use internal_baml_schema_ast::ast::Expression; + +use crate::{context::Context, types::Attributes}; + +pub(super) fn visit_constraint_attributes( + attribute_name: String, + attributes: &mut Attributes, + ctx: &mut Context<'_>, +) { + let expression_arg = ctx.visit_default_arg_with_idx("expression").map_err(|err| { + ctx.push_error(err); + }); + let label = ctx.visit_default_arg_with_idx("name"); + let label = match label { + Ok((_, Expression::StringValue(descr, _))) => Some(descr.clone()), + _ => None, + }; + match expression_arg { + Ok((_, Expression::JinjaExpressionValue(expression, _))) => { + let level = match attribute_name.as_str() { + "assert" => ConstraintLevel::Assert, + "check" => ConstraintLevel::Check, + _ => { + panic!("Internal error: Only \"assert\" and \"check\" are valid attribute names in this context."); + } + }; + attributes.constraints.push(Constraint { + level, + expression: expression.clone(), + label, + }); + } + _ => panic!( + "The impossible happened: Reached arguments that are ruled out by the tokenizer." + ), + } +} diff --git a/engine/baml-lib/parser-database/src/attributes/mod.rs b/engine/baml-lib/parser-database/src/attributes/mod.rs index 0b0efc708..ec26fec8b 100644 --- a/engine/baml-lib/parser-database/src/attributes/mod.rs +++ b/engine/baml-lib/parser-database/src/attributes/mod.rs @@ -1,10 +1,12 @@ use internal_baml_schema_ast::ast::{Top, TopId, TypeExpId, TypeExpressionBlock}; mod alias; +mod constraint; mod description; mod to_string_attribute; use crate::interner::StringId; use crate::{context::Context, types::ClassAttributes, types::EnumAttributes}; +use baml_types::Constraint; use internal_baml_schema_ast::ast::{Expression, SubType}; /// @@ -21,6 +23,9 @@ pub struct Attributes { /// Whether the node should be skipped during prompt rendering and parsing. pub skip: Option, + + /// @check and @assert attributes attached to the node. + pub constraints: Vec, } impl Attributes { @@ -63,7 +68,6 @@ impl Attributes { pub fn set_skip(&mut self) { self.skip.replace(true); } - } pub(super) fn resolve_attributes(ctx: &mut Context<'_>) { for top in ctx.ast.iter_tops() { @@ -90,7 +94,7 @@ fn resolve_type_exp_block_attributes<'db>( let mut enum_attributes = EnumAttributes::default(); for (value_idx, _value) in ast_typexpr.iter_fields() { - ctx.visit_attributes((type_id, value_idx).into()); + ctx.assert_all_attributes_processed((type_id, value_idx).into()); if let Some(attrs) = to_string_attribute::visit(ctx, false) { enum_attributes.value_serilizers.insert(value_idx, attrs); } @@ -98,7 +102,7 @@ fn resolve_type_exp_block_attributes<'db>( } // Now validate the enum attributes. - ctx.visit_attributes(type_id.into()); + ctx.assert_all_attributes_processed(type_id.into()); enum_attributes.serilizer = to_string_attribute::visit(ctx, true); ctx.validate_visited_attributes(); @@ -108,7 +112,7 @@ fn resolve_type_exp_block_attributes<'db>( let mut class_attributes = ClassAttributes::default(); for (field_idx, _field) in ast_typexpr.iter_fields() { - ctx.visit_attributes((type_id, field_idx).into()); + ctx.assert_all_attributes_processed((type_id, field_idx).into()); if let Some(attrs) = to_string_attribute::visit(ctx, false) { class_attributes.field_serilizers.insert(field_idx, attrs); } @@ -116,7 +120,7 @@ fn resolve_type_exp_block_attributes<'db>( } // Now validate the class attributes. - ctx.visit_attributes(type_id.into()); + ctx.assert_all_attributes_processed(type_id.into()); class_attributes.serilizer = to_string_attribute::visit(ctx, true); ctx.validate_visited_attributes(); diff --git a/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs b/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs index c9fa3d4b7..70567efa2 100644 --- a/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs +++ b/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs @@ -1,6 +1,7 @@ use crate::{context::Context, types::Attributes}; use super::alias::visit_alias_attribute; +use super::constraint::visit_constraint_attributes; use super::description::visit_description_attribute; pub(super) fn visit(ctx: &mut Context<'_>, as_block: bool) -> Option { @@ -26,6 +27,13 @@ pub(super) fn visit(ctx: &mut Context<'_>, as_block: bool) -> Option ctx.validate_visited_arguments(); } + if let Some(attribute_name) = ctx.visit_repeated_attr_from_names(&["assert", "check"]) { + panic!("HERE"); + visit_constraint_attributes(attribute_name, &mut attributes, ctx); + modified = true; + ctx.validate_visited_arguments(); + } + if as_block { if ctx.visit_optional_single_attr("dynamic") { attributes.set_dynamic_type(); diff --git a/engine/baml-lib/parser-database/src/context/attributes.rs b/engine/baml-lib/parser-database/src/context/attributes.rs index 4ddecc8d4..3d725f440 100644 --- a/engine/baml-lib/parser-database/src/context/attributes.rs +++ b/engine/baml-lib/parser-database/src/context/attributes.rs @@ -10,7 +10,7 @@ pub(super) struct AttributesValidationState { /// The attribute being validated. pub(super) attribute: Option, - pub(super) args: VecDeque, // the _remaining_ arguments of `attribute` + pub(super) args: VecDeque, // the _remaining_ arguments of `attribute` } impl AttributesValidationState { diff --git a/engine/baml-lib/parser-database/src/context/mod.rs b/engine/baml-lib/parser-database/src/context/mod.rs index 04faf7832..5b7a95ae1 100644 --- a/engine/baml-lib/parser-database/src/context/mod.rs +++ b/engine/baml-lib/parser-database/src/context/mod.rs @@ -1,5 +1,5 @@ use internal_baml_diagnostics::DatamodelWarning; -use internal_baml_schema_ast::ast::ArguementId; +use internal_baml_schema_ast::ast::ArgumentId; use crate::{ ast, ast::WithName, interner::StringInterner, names::Names, types::Types, DatamodelError, @@ -83,10 +83,13 @@ impl<'db> Context<'db> { /// /// - When you are done validating an attribute, you must call `discard_arguments()` or /// `validate_visited_arguments()`. Otherwise, Context will helpfully panic. - pub(super) fn visit_attributes(&mut self, ast_attributes: ast::AttributeContainer) { + pub(super) fn assert_all_attributes_processed( + &mut self, + ast_attributes: ast::AttributeContainer, + ) { if self.attributes.attributes.is_some() || !self.attributes.unused_attributes.is_empty() { panic!( - "`ctx.visit_attributes() called with {:?} while the Context is still validating previous attribute set on {:?}`", + "`ctx.assert_all_attributes_processed() called with {:?} while the Context is still validating previous attribute set on {:?}`", ast_attributes, self.attributes.attributes ); @@ -98,7 +101,7 @@ impl<'db> Context<'db> { /// Extract an attribute that can occur zero or more times. Example: @@index on models. /// /// Returns `true` as long as a next attribute is found. - pub(crate) fn visit_repeated_attr(&mut self, name: &'static str) -> bool { + pub(crate) fn _visit_repeated_attr(&mut self, name: &'static str) -> bool { let mut has_valid_attribute = false; while !has_valid_attribute { @@ -117,6 +120,37 @@ impl<'db> Context<'db> { has_valid_attribute } + /// Extract an attribute that can occur zero or more times. Example: @assert on types. + /// Argument is a list of names that are all valid for this attribute. + /// + /// Returns Some(name_match) if name_match is the attribute name and is in the + /// `names` argument. + pub(crate) fn visit_repeated_attr_from_names( + &mut self, + names: &'static [&'static str], + ) -> Option { + let mut has_valid_attribute = false; + let mut matching_name: Option = None; + + let all_attributes = + iter_attributes(self.attributes.attributes.as_ref(), self.ast).collect::>(); + while !has_valid_attribute { + let first_attr = iter_attributes(self.attributes.attributes.as_ref(), self.ast) + .filter(|(_, attr)| names.contains(&attr.name.name())) + .find(|(attr_id, _)| self.attributes.unused_attributes.contains(attr_id)); + let (attr_id, attr) = if let Some(first_attr) = first_attr { + first_attr + } else { + break; + }; + self.attributes.unused_attributes.remove(&attr_id); + has_valid_attribute = self.set_attribute(attr_id, attr); + matching_name = Some(attr.name.name().to_string()); + } + + matching_name + } + /// Validate an _optional_ attribute that should occur only once. Returns whether the attribute /// is defined. #[must_use] @@ -155,7 +189,7 @@ impl<'db> Context<'db> { pub(crate) fn visit_default_arg_with_idx( &mut self, name: &str, - ) -> Result<(ArguementId, &'db ast::Expression), DatamodelError> { + ) -> Result<(ArgumentId, &'db ast::Expression), DatamodelError> { match self.attributes.args.pop_front() { Some(arg_idx) => { let arg = self.arg_at(arg_idx); @@ -186,7 +220,7 @@ impl<'db> Context<'db> { self.discard_arguments(); } - /// Counterpart to visit_attributes(). This must be called at the end of the validation of the + /// Counterpart to assert_all_attributes_processed(). This must be called at the end of the validation of the /// attribute set. The Drop impl will helpfully panic otherwise. pub(crate) fn validate_visited_attributes(&mut self) { if !self.attributes.args.is_empty() || self.attributes.attribute.is_some() { @@ -216,7 +250,7 @@ impl<'db> Context<'db> { &self.ast[id] } - fn arg_at(&self, idx: ArguementId) -> &'db ast::Argument { + fn arg_at(&self, idx: ArgumentId) -> &'db ast::Argument { &self.current_attribute().arguments[idx] } diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index 5fe20ef7a..38cdeb663 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -47,7 +47,7 @@ pub use types::{ }; use self::{context::Context, interner::StringId, types::Types}; -use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Diagnostics}; +use internal_baml_diagnostics::{DatamodelError, Diagnostics}; use names::Names; /// ParserDatabase is a container for a Schema AST, together with information diff --git a/engine/baml-lib/schema-ast/src/ast.rs b/engine/baml-lib/schema-ast/src/ast.rs index df322df8f..1c31fe8b3 100644 --- a/engine/baml-lib/schema-ast/src/ast.rs +++ b/engine/baml-lib/schema-ast/src/ast.rs @@ -18,7 +18,7 @@ mod type_expression_block; mod value_expression_block; pub(crate) use self::comment::Comment; -pub use argument::{ArguementId, Argument, ArgumentsList}; +pub use argument::{ArgumentId, Argument, ArgumentsList}; pub use attribute::{Attribute, AttributeContainer, AttributeId}; pub use config::ConfigBlockProperty; pub use expression::{Expression, RawString}; @@ -32,7 +32,7 @@ pub use top::Top; pub use traits::{WithAttributes, WithDocumentation, WithIdentifier, WithName, WithSpan}; pub use type_expression_block::{FieldId, SubType, TypeExpressionBlock}; pub use value_expression_block::{ - ArgumentId, BlockArg, BlockArgs, ValueExprBlock, ValueExprBlockType, + BlockArg, BlockArgs, ValueExprBlock, ValueExprBlockType, }; /// AST representation of a prisma schema. diff --git a/engine/baml-lib/schema-ast/src/ast/argument.rs b/engine/baml-lib/schema-ast/src/ast/argument.rs index 265ba4985..5fc04d019 100644 --- a/engine/baml-lib/schema-ast/src/ast/argument.rs +++ b/engine/baml-lib/schema-ast/src/ast/argument.rs @@ -4,19 +4,19 @@ use std::fmt::{Display, Formatter}; /// An opaque identifier for a value in an AST enum. Use the /// `r#enum[enum_value_id]` syntax to resolve the id to an `ast::EnumValue`. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ArguementId(pub u32); +pub struct ArgumentId(pub u32); -impl ArguementId { +impl ArgumentId { /// Used for range bounds when iterating over BTreeMaps. - pub const MIN: ArguementId = ArguementId(0); + pub const MIN: ArgumentId = ArgumentId(0); /// Used for range bounds when iterating over BTreeMaps. - pub const MAX: ArguementId = ArguementId(u32::MAX); + pub const MAX: ArgumentId = ArgumentId(u32::MAX); } -impl std::ops::Index for ArgumentsList { +impl std::ops::Index for ArgumentsList { type Output = Argument; - fn index(&self, index: ArguementId) -> &Self::Output { + fn index(&self, index: ArgumentId) -> &Self::Output { &self.arguments[index.0 as usize] } } @@ -34,11 +34,11 @@ pub struct ArgumentsList { } impl ArgumentsList { - pub fn iter(&self) -> impl ExactSizeIterator { + pub fn iter(&self) -> impl ExactSizeIterator { self.arguments .iter() .enumerate() - .map(|(idx, field)| (ArguementId(idx as u32), field)) + .map(|(idx, field)| (ArgumentId(idx as u32), field)) } } diff --git a/engine/baml-lib/schema-ast/src/ast/attribute.rs b/engine/baml-lib/schema-ast/src/ast/attribute.rs index 896d92ee2..0ccd549d1 100644 --- a/engine/baml-lib/schema-ast/src/ast/attribute.rs +++ b/engine/baml-lib/schema-ast/src/ast/attribute.rs @@ -1,4 +1,4 @@ -use super::{ArguementId, ArgumentsList, Identifier, Span, WithIdentifier, WithSpan}; +use super::{ArgumentId, ArgumentsList, Identifier, Span, WithIdentifier, WithSpan}; use std::ops::Index; /// An attribute (following `@` or `@@``) on a model, model field, enum, enum value or composite @@ -29,7 +29,7 @@ pub struct Attribute { impl Attribute { /// Try to find the argument and return its span. - pub fn span_for_argument(&self, argument: ArguementId) -> Span { + pub fn span_for_argument(&self, argument: ArgumentId) -> Span { self.arguments[argument].span.clone() } diff --git a/engine/baml-lib/schema-ast/src/ast/expression.rs b/engine/baml-lib/schema-ast/src/ast/expression.rs index 0c4892e92..b7677beb2 100644 --- a/engine/baml-lib/schema-ast/src/ast/expression.rs +++ b/engine/baml-lib/schema-ast/src/ast/expression.rs @@ -4,6 +4,7 @@ use crate::ast::Span; use std::fmt; use super::{Identifier, WithName, WithSpan}; +use baml_types::JinjaExpression; #[derive(Debug, Clone)] pub struct RawString { @@ -159,6 +160,8 @@ pub enum Expression { Array(Vec, Span), /// A mapping function. Map(Vec<(Expression, Expression)>, Span), + /// A JinjaExpression. e.g. "this|length > 5". + JinjaExpressionValue(JinjaExpression, Span), } impl fmt::Display for Expression { @@ -171,6 +174,7 @@ impl fmt::Display for Expression { Expression::RawStringValue(val, ..) => { write!(f, "{}", crate::string_literal(val.value())) } + Expression::JinjaExpressionValue(val,..) => fmt::Display::fmt(val, f), Expression::Array(vals, _) => { let vals = vals .iter() @@ -293,6 +297,7 @@ impl Expression { Self::NumericValue(_, span) => span, Self::StringValue(_, span) => span, Self::RawStringValue(r) => r.span(), + Self::JinjaExpressionValue(_,span) => span, Self::Identifier(id) => id.span(), Self::Map(_, span) => span, Self::Array(_, span) => span, @@ -310,6 +315,7 @@ impl Expression { Expression::NumericValue(_, _) => "numeric", Expression::StringValue(_, _) => "string", Expression::RawStringValue(_) => "raw_string", + Expression::JinjaExpressionValue(_, _) => "jinja_expression", Expression::Identifier(id) => match id { Identifier::String(_, _) => "string", Identifier::Local(_, _) => "local_type", @@ -354,6 +360,8 @@ impl Expression { (StringValue(_,_), _) => panic!("Types do not match: {:?} and {:?}", self, other), (RawStringValue(s1), RawStringValue(s2)) => s1.assert_eq_up_to_span(s2), (RawStringValue(_), _) => panic!("Types do not match: {:?} and {:?}", self, other), + (JinjaExpressionValue(j1, _), JinjaExpressionValue(j2, _)) => assert_eq!(j1, j2), + (JinjaExpressionValue(_,_), _) => panic!("Types do not match: {:?} and {:?}", self, other), (Array(xs,_), Array(ys,_)) => { assert_eq!(xs.len(), ys.len()); xs.iter().zip(ys).for_each(|(x,y)| { x.assert_eq_up_to_span(y); }) diff --git a/engine/baml-lib/schema-ast/src/ast/field.rs b/engine/baml-lib/schema-ast/src/ast/field.rs index 6ff31fd98..c824d27ef 100644 --- a/engine/baml-lib/schema-ast/src/ast/field.rs +++ b/engine/baml-lib/schema-ast/src/ast/field.rs @@ -2,8 +2,7 @@ use baml_types::{LiteralValue, TypeValue}; use internal_baml_diagnostics::DatamodelError; use super::{ - traits::WithAttributes, Attribute, Comment, Identifier, Span, WithDocumentation, - WithIdentifier, WithName, WithSpan, + traits::WithAttributes, Attribute, Comment, Identifier, SchemaAst, Span, WithDocumentation, WithIdentifier, WithName, WithSpan }; /// A field definition in a model or a composite type. @@ -257,6 +256,10 @@ impl FieldType { } } + pub fn has_checks(&self) -> bool { + self.attributes().iter().any(|Attribute{name,..}| name.to_string().as_str() == "check") + } + pub fn assert_eq_up_to_span(&self, other: &Self) { use FieldType::*; diff --git a/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs b/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs index b6314d1b8..87d5147c9 100644 --- a/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs @@ -30,21 +30,19 @@ pub enum SubType { Other(String), } -/// An enum declaration. Enumeration can either be in the database schema, or completely a Prisma level concept. -/// -/// PostgreSQL stores enums in a schema, while in MySQL the information is in -/// the table definition. On MongoDB the enumerations are handled in the Query -/// Engine. +/// A class or enum declaration. #[derive(Debug, Clone)] pub struct TypeExpressionBlock { - /// The name of the enum. + /// The name of the class or enum. /// /// ```ignore /// enum Foo { ... } /// ^^^ + /// class Bar { ... } + /// ^^^ /// ``` pub name: Identifier, - /// The values of the enum. + /// The values of the enum, or fields of the class. /// /// ```ignore /// enum Foo { diff --git a/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs b/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs index 62ca4b6ae..ccebc25fa 100644 --- a/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs @@ -2,12 +2,9 @@ use super::{ traits::WithAttributes, Attribute, Comment, Expression, Field, FieldType, Identifier, Span, WithDocumentation, WithIdentifier, WithSpan, }; +use super::argument::ArgumentId; use std::fmt::Display; use std::fmt::Formatter; -/// An opaque identifier for a value in an AST enum. Use the -/// `r#enum[enum_value_id]` syntax to resolve the id to an `ast::EnumValue`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ArgumentId(pub u32); /// An opaque identifier for a field in an AST model. Use the /// `model[field_id]` syntax to resolve the id to an `ast::Field`. @@ -29,13 +26,6 @@ impl std::ops::Index for ValueExprBlock { } } -impl ArgumentId { - /// Used for range bounds when iterating over BTreeMaps. - pub const MIN: ArgumentId = ArgumentId(0); - /// Used for range bounds when iterating over BTreeMaps. - pub const MAX: ArgumentId = ArgumentId(u32::MAX); -} - impl std::ops::Index for BlockArgs { type Output = (Identifier, BlockArg); diff --git a/engine/baml-lib/schema-ast/src/parser/datamodel.pest b/engine/baml-lib/schema-ast/src/parser/datamodel.pest index f615316bc..3a906850b 100644 --- a/engine/baml-lib/schema-ast/src/parser/datamodel.pest +++ b/engine/baml-lib/schema-ast/src/parser/datamodel.pest @@ -94,7 +94,8 @@ map_entry = { (comment_block | empty_lines)* ~ map_key ~ (expression | ENTRY_CAT splitter = _{ ("," ~ NEWLINE?) | NEWLINE } map_expression = { "{" ~ empty_lines? ~ (map_entry ~ (splitter ~ map_entry)*)? ~ (comment_block | empty_lines)* ~ "}" } array_expression = { "[" ~ empty_lines? ~ ((expression | ARRAY_CATCH_ALL) ~ trailing_comment? ~ (splitter ~ (comment_block | empty_lines)* ~ (expression | ARRAY_CATCH_ALL) ~ trailing_comment?)*)? ~ (comment_block | empty_lines)* ~ splitter? ~ "]" } -expression = { map_expression | array_expression | numeric_literal | string_literal | identifier } +jinja_expression = { "{{" ~ (!("}}" | "{{") ~ ANY)* ~ "}}" } +expression = { jinja_expression | map_expression | array_expression | numeric_literal | string_literal | identifier } ARRAY_CATCH_ALL = { !"]" ~ CATCH_ALL } ENTRY_CATCH_ALL = { field_attribute | BLOCK_LEVEL_CATCH_ALL } // ###################################### diff --git a/engine/baml-lib/schema-ast/src/parser/parse_attribute.rs b/engine/baml-lib/schema-ast/src/parser/parse_attribute.rs index b41793bfa..0f18f0e02 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_attribute.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_attribute.rs @@ -1,3 +1,6 @@ +use baml_types::ConstraintLevel; +use internal_baml_diagnostics::DatamodelError; + use super::{ helpers::{parsing_catch_all, Pair}, parse_identifier::parse_identifier, diff --git a/engine/baml-lib/schema-ast/src/parser/parse_expression.rs b/engine/baml-lib/schema-ast/src/parser/parse_expression.rs index 2a0edb3f0..caac48b55 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_expression.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_expression.rs @@ -3,6 +3,7 @@ use super::{ parse_identifier::parse_identifier, Rule, }; +use baml_types::JinjaExpression; use crate::{assert_correct_parser, ast::*, unreachable_rule}; use internal_baml_diagnostics::Diagnostics; @@ -17,6 +18,7 @@ pub(crate) fn parse_expression( Rule::string_literal => Some(parse_string_literal(first_child, diagnostics)), Rule::map_expression => Some(parse_map(first_child, diagnostics)), Rule::array_expression => Some(parse_array(first_child, diagnostics)), + Rule::jinja_expression => Some(parse_jinja_expression(first_child, diagnostics)), Rule::identifier => Some(Expression::Identifier(parse_identifier( first_child, @@ -245,10 +247,32 @@ fn unescape_string(val: &str) -> String { result } +/// Parse a `JinjaExpression` from raw source. Escape backslashes, +/// because we want the user's backslash intent to be preserved in +/// the string backing the `JinjaExpression`. In other words, control +/// sequences like `\n` are intended to be forwarded to the Jinja +/// processing engine, not to break a Jinja Expression into two lines, +/// therefor the backing string should be contain "\\n". +pub fn parse_jinja_expression(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Expression { + assert_correct_parser!(token, Rule::jinja_expression); + let mut inner_text = String::new(); + for c in token.as_str()[2..token.as_str().len() - 2].chars() { + match c { + // When encountering a single backslash, produce two backslashes. + '\\' => inner_text.push_str("\\\\"), + // Otherwise, just copy the character. + _ => inner_text.push(c), + } + } + Expression::JinjaExpressionValue(JinjaExpression(inner_text), diagnostics.span(token.as_span())) +} + #[cfg(test)] mod tests { + use super::*; use super::super::{BAMLParser, Rule}; - use pest::{consumes_to, parses_to}; + use pest::{Parser, parses_to, consumes_to}; + use internal_baml_diagnostics::{Diagnostics, SourceFile}; #[test] fn array_trailing_comma() { @@ -287,4 +311,24 @@ mod tests { ])] }; } + + #[test] + fn test_parse_jinja_expression() { + let input = "{{ 1 + 1 }}"; + let root_path = "test_file.baml"; + let source = SourceFile::new_static(root_path.into(), input); + let mut diagnostics = Diagnostics::new(root_path.into()); + diagnostics.set_source(&source); + + let pair = BAMLParser::parse(Rule::jinja_expression, input) + .unwrap() + .next() + .unwrap(); + let expr = parse_jinja_expression(pair, &mut diagnostics); + match expr { + Expression::JinjaExpressionValue(JinjaExpression(s), _) => assert_eq!(s, " 1 + 1 "), + _ => panic!("Expected JinjaExpression, got {:?}", expr), + } + } + } diff --git a/engine/baml-lib/schema-ast/src/parser/parse_field.rs b/engine/baml-lib/schema-ast/src/parser/parse_field.rs index 2644eb562..eabff4c3e 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_field.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_field.rs @@ -60,6 +60,18 @@ pub(crate) fn parse_value_expr( } } +fn reassociate_type_attributes( + field_attributes: &mut Vec, + field_type: &mut FieldType, +) { + let mut all_attrs = field_type.attributes().to_owned(); + all_attrs.append(field_attributes); + let (attrs_for_type, attrs_for_field): (Vec, Vec) = + all_attrs.into_iter().partition(|attr| ["assert", "check"].contains(&attr.name())); + field_type.set_attributes(attrs_for_type); + *field_attributes = attrs_for_field; +} + pub(crate) fn parse_type_expr( model_name: &Option, container_type: &'static str, @@ -90,11 +102,18 @@ pub(crate) fn parse_type_expr( } } + // Strip certain attributes from the field and attach them to the type. + match field_type.as_mut() { + None => {}, + Some(ft) => reassociate_type_attributes(&mut field_attributes, ft), + } + match (name, &field_type) { + // Class field. (Some(name), Some(field_type)) => Ok(Field { expr: Some(field_type.clone()), name, - attributes: field_type.clone().attributes().to_vec(), + attributes: field_attributes, documentation: comment, span: diagnostics.span(pair_span), }), diff --git a/engine/baml-lib/schema-ast/src/parser/parse_schema.rs b/engine/baml-lib/schema-ast/src/parser/parse_schema.rs index 79fcd479c..c8e00dd6c 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_schema.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_schema.rs @@ -178,6 +178,7 @@ mod tests { let input = r#" class MyClass { myProperty string[] @description("This is a description") @alias("MP") + prop2 string @description({{ "a " + "b" }}) } "#; @@ -192,11 +193,13 @@ mod tests { assert_eq!(schema_ast.tops.len(), 1); match &schema_ast.tops[0] { - Top::Class(model) => { - assert_eq!(model.name.name(), "MyClass"); - assert_eq!(model.fields.len(), 1); - assert_eq!(model.fields[0].name.name(), "myProperty"); - assert_eq!(model.fields[0].attributes.len(), 2) + Top::Class(TypeExpressionBlock { name, fields, .. }) => { + assert_eq!(name.name(), "MyClass"); + assert_eq!(fields.len(), 2); + assert_eq!(fields[0].name.name(), "myProperty"); + assert_eq!(fields[1].name.name(), "prop2"); + assert_eq!(fields[0].attributes.len(), 2); + assert_eq!(fields[1].attributes.len(), 1); } _ => panic!("Expected a model declaration"), } diff --git a/engine/baml-runtime/src/cli/mod.rs b/engine/baml-runtime/src/cli/mod.rs index 6cabc917b..458d569ea 100644 --- a/engine/baml-runtime/src/cli/mod.rs +++ b/engine/baml-runtime/src/cli/mod.rs @@ -1,5 +1,5 @@ mod dev; -mod generate; +pub mod generate; mod init; mod serve; diff --git a/engine/baml-runtime/src/cli/serve/mod.rs b/engine/baml-runtime/src/cli/serve/mod.rs index 6cc67af64..90ef713b8 100644 --- a/engine/baml-runtime/src/cli/serve/mod.rs +++ b/engine/baml-runtime/src/cli/serve/mod.rs @@ -33,7 +33,9 @@ use tokio::{net::TcpListener, sync::RwLock}; use tokio_stream::StreamExt; use crate::{ - client_registry::ClientRegistry, errors::ExposedError, internal::llm_client::LLMResponse, + client_registry::ClientRegistry, + errors::ExposedError, + internal::llm_client::{LLMResponse, ResponseBamlValue}, BamlRuntime, FunctionResult, RuntimeContextManager, }; use internal_baml_codegen::openapi::OpenApiSchema; @@ -367,7 +369,7 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping` LLMResponse::Success(_) => match function_result.parsed_content() { // Just because the LLM returned 2xx doesn't mean that it returned parse-able content! Ok(parsed) => { - (StatusCode::OK, Json::(parsed.into())).into_response() + (StatusCode::OK, Json::(parsed.clone())).into_response() } Err(e) => { if let Some(ExposedError::ValidationError { @@ -478,8 +480,10 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping` Ok(function_result) => match function_result.llm_response() { LLMResponse::Success(_) => match function_result.parsed_content() { // Just because the LLM returned 2xx doesn't mean that it returned parse-able content! - Ok(parsed) => (StatusCode::OK, Json::(parsed.into())) - .into_response(), + Ok(parsed) => { + (StatusCode::OK, Json::(parsed.clone())) + .into_response() + } Err(e) => { log::debug!("Error parsing content: {:?}", e); diff --git a/engine/baml-runtime/src/internal/llm_client/mod.rs b/engine/baml-runtime/src/internal/llm_client/mod.rs index 934526e17..f5110d772 100644 --- a/engine/baml-runtime/src/internal/llm_client/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/mod.rs @@ -1,6 +1,5 @@ use std::collections::{HashMap, HashSet}; -use base64::write; use colored::*; pub mod llm_provider; pub mod orchestrator; @@ -12,10 +11,11 @@ pub mod traits; use anyhow::Result; +use baml_types::{BamlValueWithMeta, Constraint, ConstraintLevel, ResponseCheck}; use internal_baml_core::ir::ClientWalker; -use internal_baml_jinja::{ChatMessagePart, RenderedChatMessage, RenderedPrompt}; +use internal_baml_jinja::RenderedPrompt; +use jsonish::BamlValueWithFlags; use serde::{Deserialize, Serialize}; -use serde_json::Map; use std::error::Error; use reqwest::StatusCode; @@ -23,6 +23,33 @@ use reqwest::StatusCode; #[cfg(target_arch = "wasm32")] use wasm_bindgen::JsValue; +pub type ResponseBamlValue = BamlValueWithMeta>; + +/// Validate a parsed value, checking asserts and checks. +pub fn parsed_value_to_response(baml_value: BamlValueWithFlags) -> Result { + let baml_value_with_meta: BamlValueWithMeta> = baml_value.into(); + let first_failing_assert: Option = baml_value_with_meta + .iter() + .map(|v| v.meta()) + .flatten() + .filter_map(|(c @ Constraint { level, .. }, succeeded)| { + if !succeeded && level == &ConstraintLevel::Assert { + Some(c.clone()) + } else { + None + } + }) + .next(); + match first_failing_assert { + Some(err) => Err(anyhow::anyhow!("Failed assertion: {:?}", err)), + None => Ok(baml_value_with_meta.map_meta(|cs| { + cs.into_iter() + .filter_map(|res| ResponseCheck::from_check_result(res)) + .collect() + })), + } +} + #[derive(Clone, Copy, PartialEq)] pub enum ResolveMediaUrls { // there are 5 input formats: diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs index f16408694..04817bf92 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs @@ -7,8 +7,7 @@ use web_time::Duration; use crate::{ internal::{ llm_client::{ - traits::{WithPrompt, WithSingleCallable}, - LLMResponse, + parsed_value_to_response, traits::{WithPrompt, WithSingleCallable}, LLMResponse, ResponseBamlValue }, prompt_renderer::PromptRenderer, }, @@ -28,7 +27,7 @@ pub async fn orchestrate( Vec<( OrchestrationScope, LLMResponse, - Option>, + Option>, )>, Duration, ) { @@ -50,7 +49,13 @@ pub async fn orchestrate( }; let sleep_duration = node.error_sleep_duration().cloned(); - results.push((node.scope, response, parsed_response)); + let response_with_constraints: Option> = + parsed_response.map( + |r| r.and_then( + |v| parsed_value_to_response(v) + ) + ); + results.push((node.scope, response, response_with_constraints)); // Currently, we break out of the loop if an LLM responded, even if we couldn't parse the result. if results diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs index c8069a961..81fa7542a 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs @@ -83,7 +83,7 @@ impl OrchestratorNode { } } -#[derive(Default, Clone, Serialize)] +#[derive(Debug, Default, Clone, Serialize)] pub struct OrchestrationScope { pub scope: Vec, } @@ -138,7 +138,7 @@ impl OrchestrationScope { } } -#[derive(Clone, Serialize)] +#[derive(Clone, Debug, Serialize)] pub enum ExecutionScope { Direct(String), // PolicyName, RetryCount, RetryDelayMs diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs index ecf0ac5fb..ccda0b29d 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs @@ -8,8 +8,7 @@ use web_time::Duration; use crate::{ internal::{ llm_client::{ - traits::{WithPrompt, WithStreamable}, - LLMErrorResponse, LLMResponse, + parsed_value_to_response, traits::{WithPrompt, WithStreamable}, LLMErrorResponse, LLMResponse, ResponseBamlValue }, prompt_renderer::PromptRenderer, }, @@ -31,7 +30,7 @@ pub async fn orchestrate_stream( Vec<( OrchestrationScope, LLMResponse, - Option>, + Option>, )>, Duration, ) @@ -60,10 +59,12 @@ where match &stream_part { LLMResponse::Success(s) => { let parsed = partial_parse_fn(&s.content); + let response_value: Result = + parsed.and_then(|v| parsed_value_to_response(v)); on_event(FunctionResult::new( node.scope.clone(), LLMResponse::Success(s.clone()), - Some(parsed), + Some(response_value), )); } _ => {} @@ -92,8 +93,9 @@ where LLMResponse::Success(s) => Some(parse_fn(&s.content)), _ => None, }; + let response_value: Option> = parsed_response.map(|r| r.and_then(|v| parsed_value_to_response(v))); let sleep_duration = node.error_sleep_duration().cloned(); - results.push((node.scope, final_response, parsed_response)); + results.push((node.scope, final_response, response_value)); // Currently, we break out of the loop if an LLM responded, even if we couldn't parse the result. if results diff --git a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs index 81b7b21e6..f425e5448 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use anyhow::Result; -use baml_types::BamlValue; +use baml_types::{BamlValue, Constraint}; use indexmap::IndexSet; use internal_baml_core::ir::{ repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper, @@ -66,7 +66,7 @@ fn find_new_class_field<'a>( field_name: &str, class_walker: &Result>, overrides: &'a RuntimeClassOverride, - ctx: &RuntimeContext, + _ctx: &RuntimeContext, ) -> Result<(Name, FieldType, Option)> { let Some(field_overrides) = overrides.new_fields.get(field_name) else { anyhow::bail!("Class {} does not have a field: {}", class_name, field_name); @@ -205,8 +205,8 @@ fn relevant_data_models<'a>( let mut start: Vec = vec![output.clone()]; while let Some(output) = start.pop() { - match &output { - FieldType::Enum(enm) => { + match output.distribute_constraints() { + (FieldType::Enum(enm), constraints) => { if checked_types.insert(output.to_string()) { let overrides = ctx.enum_overrides.get(enm); let walker = ir.find_enum(enm); @@ -246,15 +246,16 @@ fn relevant_data_models<'a>( enums.push(Enum { name: Name::new_with_alias(enm.to_string(), alias.value()), values, + constraints, }); } } - FieldType::List(inner) | FieldType::Optional(inner) => { + (FieldType::List(inner), _) | (FieldType::Optional(inner), _) => { if !checked_types.contains(&inner.to_string()) { start.push(inner.as_ref().clone()); } } - FieldType::Map(k, v) => { + (FieldType::Map(k, v), _) => { if checked_types.insert(output.to_string()) { if !checked_types.contains(&k.to_string()) { start.push(k.as_ref().clone()); @@ -264,7 +265,7 @@ fn relevant_data_models<'a>( } } } - FieldType::Tuple(options) | FieldType::Union(options) => { + (FieldType::Tuple(options), _) | (FieldType::Union(options), _) => { if checked_types.insert((&output).to_string()) { for inner in options { if !checked_types.contains(&inner.to_string()) { @@ -273,7 +274,7 @@ fn relevant_data_models<'a>( } } } - FieldType::Class(cls) => { + (FieldType::Class(cls), constraints) => { if checked_types.insert(output.to_string()) { let overrides = ctx.class_override.get(cls); let walker = ir.find_class(&cls); @@ -330,11 +331,15 @@ fn relevant_data_models<'a>( classes.push(Class { name: Name::new_with_alias(cls.to_string(), alias.value()), fields, + constraints, }); } } - FieldType::Primitive(_) => {} - FieldType::Literal(_) => {} + (FieldType::Literal(_), _) => {} + (FieldType::Primitive(_), _) => {} + (FieldType::Constrained{..}, _)=> { + unreachable!("It is guaranteed that a call to distribute_constraints will not return FieldType::Constrained") + }, } } @@ -343,9 +348,10 @@ fn relevant_data_models<'a>( #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; use crate::BamlRuntime; - use std::collections::HashMap; #[test] fn skipped_variants_are_not_rendered() { @@ -372,4 +378,5 @@ mod tests { assert_eq!(foo_enum.values[0].0.real_name(), "Bar".to_string()); assert_eq!(foo_enum.values.len(), 1); } + } diff --git a/engine/baml-runtime/src/types/expression_helper.rs b/engine/baml-runtime/src/types/expression_helper.rs index dd05b724b..16df949e5 100644 --- a/engine/baml-runtime/src/types/expression_helper.rs +++ b/engine/baml-runtime/src/types/expression_helper.rs @@ -51,6 +51,7 @@ pub fn to_value(ctx: &RuntimeContext, expr: &Expression) -> Result>>()?; json!(res) - } + }, + Expression::JinjaExpression(_) => anyhow::bail!("Unable to normalize jinja expression to a value without a context."), }) } diff --git a/engine/baml-runtime/src/types/response.rs b/engine/baml-runtime/src/types/response.rs index 9f6ac4017..9a4e03fc7 100644 --- a/engine/baml-runtime/src/types/response.rs +++ b/engine/baml-runtime/src/types/response.rs @@ -1,16 +1,16 @@ pub use crate::internal::llm_client::LLMResponse; -use crate::{errors::ExposedError, internal::llm_client::orchestrator::OrchestrationScope}; +use crate::{errors::ExposedError, internal::llm_client::{orchestrator::OrchestrationScope, ResponseBamlValue}}; use anyhow::Result; use colored::*; use baml_types::BamlValue; -use jsonish::BamlValueWithFlags; +#[derive(Debug)] pub struct FunctionResult { event_chain: Vec<( OrchestrationScope, LLMResponse, - Option>, + Option>, )>, } @@ -27,7 +27,6 @@ impl std::fmt::Display for FunctionResult { writeln!(f, "{}", self.llm_response())?; match &self.parsed() { Some(Ok(val)) => { - let val: BamlValue = val.into(); writeln!( f, "{}", @@ -48,10 +47,10 @@ impl FunctionResult { pub fn new( scope: OrchestrationScope, response: LLMResponse, - parsed: Option>, + baml_value: Option>, ) -> Self { Self { - event_chain: vec![(scope, response, parsed)], + event_chain: vec![(scope, response, baml_value)], } } @@ -60,7 +59,7 @@ impl FunctionResult { ) -> &Vec<( OrchestrationScope, LLMResponse, - Option>, + Option>, )> { &self.event_chain } @@ -69,7 +68,7 @@ impl FunctionResult { chain: Vec<( OrchestrationScope, LLMResponse, - Option>, + Option>, )>, ) -> Result { if chain.is_empty() { @@ -91,11 +90,11 @@ impl FunctionResult { &self.event_chain.last().unwrap().0 } - pub fn parsed(&self) -> &Option> { + pub fn parsed(&self) -> &Option> { &self.event_chain.last().unwrap().2 } - pub fn parsed_content(&self) -> Result<&BamlValueWithFlags> { + pub fn parsed_content(&self) -> Result<&ResponseBamlValue> { self.parsed() .as_ref() .map(|res| { diff --git a/engine/language_client_codegen/Cargo.toml b/engine/language_client_codegen/Cargo.toml index d78fd5853..7a90dc4b3 100644 --- a/engine/language_client_codegen/Cargo.toml +++ b/engine/language_client_codegen/Cargo.toml @@ -25,3 +25,4 @@ sugar_path = "1.2.0" walkdir.workspace = true semver = "1.0.23" colored = "2.1.0" +itertools = "0.13.0" diff --git a/engine/language_client_codegen/src/lib.rs b/engine/language_client_codegen/src/lib.rs index 4f6d152e4..c25216d02 100644 --- a/engine/language_client_codegen/src/lib.rs +++ b/engine/language_client_codegen/src/lib.rs @@ -1,10 +1,11 @@ use anyhow::{Context, Result}; +use baml_types::{Constraint, ConstraintLevel, FieldType}; use indexmap::IndexMap; use internal_baml_core::{ configuration::{GeneratorDefaultClientMode, GeneratorOutputType}, ir::repr::IntermediateRepr, }; -use std::{collections::BTreeMap, path::PathBuf}; +use std::{collections::{BTreeMap, HashSet}, path::PathBuf}; use version_check::{check_version, GeneratorType, VersionCheckMode}; mod dir_writer; @@ -219,3 +220,172 @@ impl GenerateClient for GeneratorOutputType { }) } } + +/// A set of names of @check attributes. This set determines the +/// way name of a Python Class or TypeScript Interface that holds +/// the results of running these checks. See TODO (Docs) for details on +/// the support types generated from checks. +#[derive(Clone, Debug, Eq)] +pub struct TypeCheckAttributes(pub HashSet); + +impl PartialEq for TypeCheckAttributes { + fn eq(&self, other: &Self) -> bool { + self.0.len() == other.0.len() && self.0.iter().all(|x| other.0.contains(x)) + } +} + +impl <'a> std::hash::Hash for TypeCheckAttributes { + fn hash(&self, state: &mut H) + where H: std::hash::Hasher + { + let mut strings: Vec<_> = self.0.iter().collect(); + strings.sort(); + strings.into_iter().for_each(|s| s.hash(state)) + } + +} + +impl TypeCheckAttributes { + /// Extend one set of attributes with the contents of another. + pub fn extend(&mut self, other: &TypeCheckAttributes) { + self.0.extend(other.0.clone()) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +/// Search the IR for all types with checks, combining the checks on each type +/// into a `TypeCheckAttributes` (a HashSet of the check names). Return a HashSet +/// of these HashSets. +/// +/// For example, consider this IR defining two classes: +/// +/// ``` baml +/// class Foo { +/// int @check("a") @check("b") +/// string @check("a") +/// } +/// +/// class Bar { +/// bool @check("a") +/// } +/// ```` +/// +/// It contains two distinct `TypeCheckAttributes`: +/// - ["a"] +/// - ["a", "b"] +/// +/// We will need to construct two district support types: +/// `Classes_a` and `Classes_a_b`. +pub fn type_check_attributes( + ir: &IntermediateRepr +) -> HashSet { + + + let mut all_types_in_ir: Vec<&FieldType> = Vec::new(); + for class in ir.walk_classes() { + for field in class.item.elem.static_fields.iter() { + let field_type = &field.elem.r#type.elem; + all_types_in_ir.push(field_type); + } + } + for function in ir.walk_functions() { + for (_param_name, parameter) in function.item.elem.inputs.iter() { + all_types_in_ir.push(parameter); + } + let return_type = &function.item.elem.output; + all_types_in_ir.push(return_type); + } + + all_types_in_ir.into_iter().filter_map(field_type_attributes).collect() + +} + +/// The set of Check names associated with a type. +fn field_type_attributes<'a>(field_type: &FieldType) -> Option { + match field_type { + FieldType::Constrained {base, constraints} => { + let direct_sub_attributes = field_type_attributes(base); + let mut check_names = + TypeCheckAttributes( + constraints + .iter() + .filter_map(|Constraint {label, level, ..}| + if matches!(level, ConstraintLevel::Check) { + Some(label.clone().expect("TODO")) + } else { None } + ).collect::>()); + if let Some(ref sub_attrs) = direct_sub_attributes { + check_names.extend(&sub_attrs); + } + if !check_names.is_empty() { + Some(check_names) + } else { + None + } + }, + _ => None + } +} + +#[cfg(test)] +mod tests { + use internal_baml_core::ir::repr::make_test_ir; + use super::*; + + + /// Utility function for creating test fixtures. + fn mk_tc_attrs(names: &[&str]) -> TypeCheckAttributes { + TypeCheckAttributes(names.into_iter().map(|s| s.to_string()).collect()) + } + + #[test] + fn type_check_attributes_eq() { + assert_eq!(mk_tc_attrs(&["a", "b"]), mk_tc_attrs(&["b", "a"])); + + let attrs: HashSet = vec![mk_tc_attrs(&["a", "b"])].into_iter().collect(); + assert!(attrs.contains( &mk_tc_attrs(&["a", "b"]) )); + assert!(attrs.contains( &mk_tc_attrs(&["b", "a"]) )); + + } + + #[test] + fn find_type_check_attributes() { + let ir = make_test_ir( + r##" +client GPT4 { + provider openai + options { + model gpt-4o + api_key env.OPENAI_API_KEY + } +} + +function Go(a: int @assert({{ this < 0 }}, c)) -> Foo { + client GPT4 + prompt #""# +} + +class Foo { + ab int @check({{this}}, a) @check({{this}}, b) + a int @check({{this}}, a) +} + +class Bar { + cb int @check({{this}}, c) @check({{this}}, b) + nil int @description("no checks") @assert({{this}}, a) @assert({{this}}, d) +} + + "##).expect("Valid source"); + + let attrs = type_check_attributes(&ir); + dbg!(&attrs); + assert_eq!(attrs.len(), 3); + assert!(attrs.contains( &mk_tc_attrs(&["a","b"]) )); + assert!(attrs.contains( &mk_tc_attrs(&["a"]) )); + assert!(attrs.contains( &mk_tc_attrs(&["b", "c"]) )); + assert!(!attrs.contains( &mk_tc_attrs(&["a", "d"]) )); + } +} diff --git a/engine/language_client_codegen/src/openapi.rs b/engine/language_client_codegen/src/openapi.rs index 90d4f1bbf..1d7ec2901 100644 --- a/engine/language_client_codegen/src/openapi.rs +++ b/engine/language_client_codegen/src/openapi.rs @@ -1,5 +1,4 @@ -use std::collections::HashMap; -use std::{path::PathBuf, process::Command}; +use std::path::PathBuf; use anyhow::{Context, Result}; use baml_types::{BamlMediaType, FieldType, LiteralValue, TypeValue}; @@ -8,10 +7,10 @@ use internal_baml_core::ir::{ repr::{Function, IntermediateRepr, Node, Walker}, ClassWalker, EnumWalker, }; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use serde_json::json; -use crate::dir_writer::{FileCollector, LanguageFeatures, RemoveDirBehavior}; +use crate::{dir_writer::{FileCollector, LanguageFeatures, RemoveDirBehavior}, field_type_attributes, TypeCheckAttributes}; #[derive(Default)] pub(super) struct OpenApiLanguageFeatures {} @@ -71,46 +70,6 @@ impl Serialize for OpenApiSchema<'_> { &self, serializer: S, ) -> core::result::Result { - let baml_image_schema = TypeSpecWithMeta { - meta: TypeMetadata { - title: Some("BamlImage".to_string()), - r#enum: None, - r#const: None, - nullable: false, - }, - type_spec: TypeSpec::Inline(TypeDef::Class { - properties: vec![ - ( - "base64".to_string(), - TypeSpecWithMeta { - meta: TypeMetadata { - title: None, - r#enum: None, - r#const: None, - nullable: false, - }, - type_spec: TypeSpec::Inline(TypeDef::String), - }, - ), - ( - "media_type".to_string(), - TypeSpecWithMeta { - meta: TypeMetadata { - title: None, - r#enum: None, - r#const: None, - nullable: true, - }, - type_spec: TypeSpec::Inline(TypeDef::String), - }, - ), - ] - .into_iter() - .collect(), - required: vec!["base64".to_string()], - additional_properties: false, - }), - }; let schemas = match self .schemas .iter() @@ -272,6 +231,17 @@ impl Serialize for OpenApiSchema<'_> { }, "required": ["name", "provider", "options"] }) + ), + ( "Check", + json!({ + "type": "object", + "properties": { + "name": { "type": "string" }, + "expr": { "type": "string" }, + "status": { "type": "string" } + } + + }) ) ] .into_iter() @@ -374,6 +344,40 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for OpenApi } } +pub fn type_name_for_checks(checks: &TypeCheckAttributes) -> String { + let mut name = "Checks".to_string(); + let mut names: Vec<&String> = checks.0.iter().collect(); + names.sort(); + for check_name in names.iter() { + name.push_str("__"); + name.push_str(check_name); + } + name +} + +fn check() -> TypeSpecWithMeta { + TypeSpecWithMeta { + meta: TypeMetadata::default(), + type_spec: TypeSpec::Ref{ r#ref: "#components/schemas/Check".to_string() }, + } +} + +/// The type definition for a single "Checked_*" type. Note that we don't +/// produce a named type for each of these the way we do for SDK +/// codegeneration. +fn type_def_for_checks(checks: TypeCheckAttributes) -> TypeSpecWithMeta { + TypeSpecWithMeta { + meta: TypeMetadata::default(), + type_spec: TypeSpec::Inline( + TypeDef::Class { + properties: checks.0.iter().map(|check_name| (check_name.clone(), check())).collect(), + required: checks.0.into_iter().collect(), + additional_properties: false, + } + ) + } +} + impl<'ir> TryFrom>> for OpenApiMethodDef<'ir> { type Error = anyhow::Error; @@ -638,6 +642,27 @@ impl<'ir> ToTypeReferenceInTypeDefinition<'ir> for FieldType { // something i saw suggested doing this type_spec } + FieldType::Constrained{base,..} => { + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_type_spec(ir)?; + let checks_type_spec = type_def_for_checks(checks); + TypeSpecWithMeta { + meta: TypeMetadata::default(), + type_spec: TypeSpec::Inline( + TypeDef::Class { + properties: vec![("value".to_string(), base_type_ref),("checks".to_string(), checks_type_spec)].into_iter().collect(), + required: vec!["value".to_string(), "checks".to_string()], + additional_properties: false, + } + ) + } + } + None => { + base.to_type_spec(ir)? + } + } + }, }) } } @@ -672,6 +697,17 @@ struct TypeMetadata { nullable: bool, } +impl Default for TypeMetadata { + fn default() -> Self { + TypeMetadata { + title: None, + r#enum: None, + r#const: None, + nullable: false, + } + } +} + #[derive(Clone, Debug, Serialize)] #[serde(untagged)] enum TypeSpec { diff --git a/engine/language_client_codegen/src/python/generate_types.rs b/engine/language_client_codegen/src/python/generate_types.rs index 829bbdbb6..7800b80d5 100644 --- a/engine/language_client_codegen/src/python/generate_types.rs +++ b/engine/language_client_codegen/src/python/generate_types.rs @@ -1,4 +1,8 @@ use anyhow::Result; +use itertools::join; +use std::borrow::Cow; + +use crate::{field_type_attributes, type_check_attributes, TypeCheckAttributes}; use super::python_language_features::ToPython; use internal_baml_core::ir::{ @@ -10,6 +14,7 @@ use internal_baml_core::ir::{ pub(crate) struct PythonTypes<'ir> { enums: Vec>, classes: Vec>, + checks_classes: Vec> } #[derive(askama::Template)] @@ -17,6 +22,7 @@ pub(crate) struct PythonTypes<'ir> { pub(crate) struct TypeBuilder<'ir> { enums: Vec>, classes: Vec>, + checks_classes: Vec>, } struct PythonEnum<'ir> { @@ -26,15 +32,16 @@ struct PythonEnum<'ir> { } struct PythonClass<'ir> { - name: &'ir str, + name: Cow<'ir, str>, // the name, and the type of the field - fields: Vec<(&'ir str, String)>, + fields: Vec<(Cow<'ir, str>, String)>, dynamic: bool, } #[derive(askama::Template)] #[template(path = "partial_types.py.j2", escape = "none")] pub(crate) struct PythonStreamTypes<'ir> { + check_type_names: String, partial_classes: Vec>, } @@ -52,9 +59,15 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonT fn try_from( (ir, _): (&'ir IntermediateRepr, &'_ crate::GeneratorArgs), ) -> Result> { + let checks_classes = + type_check_attributes(ir) + .into_iter() + .map(|checks| type_def_for_checks(checks)) + .collect::>(); Ok(PythonTypes { enums: ir.walk_enums().map(PythonEnum::from).collect::>(), classes: ir.walk_classes().map(PythonClass::from).collect::>(), + checks_classes, }) } } @@ -65,9 +78,15 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for TypeBui fn try_from( (ir, _): (&'ir IntermediateRepr, &'_ crate::GeneratorArgs), ) -> Result> { + let checks_classes = + type_check_attributes(ir) + .into_iter() + .map(|checks| type_def_for_checks(checks)) + .collect::>(); Ok(TypeBuilder { enums: ir.walk_enums().map(PythonEnum::from).collect::>(), classes: ir.walk_classes().map(PythonClass::from).collect::>(), + checks_classes, }) } } @@ -91,7 +110,7 @@ impl<'ir> From> for PythonEnum<'ir> { impl<'ir> From> for PythonClass<'ir> { fn from(c: ClassWalker<'ir>) -> Self { PythonClass { - name: c.name(), + name: Cow::Borrowed(c.name()), dynamic: c.item.attributes.get("dynamic_type").is_some(), fields: c .item @@ -100,7 +119,7 @@ impl<'ir> From> for PythonClass<'ir> { .iter() .map(|f| { ( - f.elem.name.as_str(), + Cow::Borrowed(f.elem.name.as_str()), add_default_value( &f.elem.r#type.elem, &f.elem.r#type.elem.to_type_ref(&c.db), @@ -116,7 +135,13 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonS type Error = anyhow::Error; fn try_from((ir, _): (&'ir IntermediateRepr, &'_ crate::GeneratorArgs)) -> Result { + let check_type_names = + join(type_check_attributes(ir) + .into_iter() + .map(|checks| type_name_for_checks(&checks)), + ", "); Ok(Self { + check_type_names, partial_classes: ir .walk_classes() .map(PartialPythonClass::from) @@ -157,6 +182,25 @@ pub fn add_default_value(node: &FieldType, type_str: &String) -> String { } } +pub fn type_name_for_checks(checks: &TypeCheckAttributes) -> String { + let mut name = "Checks".to_string(); + let mut names: Vec<&String> = checks.0.iter().collect(); + names.sort(); + for check_name in names.iter() { + name.push_str("__"); + name.push_str(check_name); + } + name +} + +fn type_def_for_checks(checks: TypeCheckAttributes) -> PythonClass<'static> { + PythonClass { + name: Cow::Owned(type_name_for_checks(&checks)), + fields: checks.0.into_iter().map(|check_name| (Cow::Owned(check_name), "baml_py.Check".to_string())).collect(), + dynamic: false + } +} + trait ToTypeReferenceInTypeDefinition { fn to_type_ref(&self, ir: &IntermediateRepr) -> String; fn to_partial_type_ref(&self, ir: &IntermediateRepr, wrapped: bool) -> String; @@ -200,6 +244,18 @@ impl ToTypeReferenceInTypeDefinition for FieldType { .join(", ") ), FieldType::Optional(inner) => format!("Optional[{}]", inner.to_type_ref(ir)), + FieldType::Constrained{base, ..} => { + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_type_ref(ir); + let checks_type_ref = type_name_for_checks(&checks); + format!("baml_py.Checked[{base_type_ref},{checks_type_ref}]") + } + None => { + base.to_type_ref(ir) + } + } + }, } } @@ -250,6 +306,17 @@ impl ToTypeReferenceInTypeDefinition for FieldType { .join(", ") ), FieldType::Optional(inner) => inner.to_partial_type_ref(ir, false), + FieldType::Constrained{base,..} => { + let base_type_ref = base.to_partial_type_ref(ir, false); + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_partial_type_ref(ir, false); + let checks_type_ref = type_name_for_checks(&checks); + format!("baml_py.Checked[{base_type_ref},{checks_type_ref}]") + } + None => base_type_ref + } + }, } } } diff --git a/engine/language_client_codegen/src/python/mod.rs b/engine/language_client_codegen/src/python/mod.rs index a89e03866..688358cdc 100644 --- a/engine/language_client_codegen/src/python/mod.rs +++ b/engine/language_client_codegen/src/python/mod.rs @@ -4,6 +4,7 @@ mod python_language_features; use std::path::PathBuf; use anyhow::Result; +use generate_types::type_name_for_checks; use indexmap::IndexMap; use internal_baml_core::{ configuration::GeneratorDefaultClientMode, @@ -11,7 +12,7 @@ use internal_baml_core::{ }; use self::python_language_features::{PythonLanguageFeatures, ToPython}; -use crate::dir_writer::FileCollector; +use crate::{dir_writer::FileCollector, field_type_attributes}; #[derive(askama::Template)] #[template(path = "async_client.py.j2", escape = "none")] @@ -109,7 +110,7 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonInit { impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonGlobals { type Error = anyhow::Error; - fn try_from((_, args): (&'_ IntermediateRepr, &'_ crate::GeneratorArgs)) -> Result { + fn try_from((_, _args): (&'_ IntermediateRepr, &'_ crate::GeneratorArgs)) -> Result { Ok(PythonGlobals {}) } } @@ -157,12 +158,12 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonClient let (_function, _impl_) = c.item; Ok(PythonFunction { name: f.name().to_string(), - partial_return_type: f.elem().output().to_partial_type_ref(ir), - return_type: f.elem().output().to_type_ref(ir), + partial_return_type: f.elem().output().to_partial_type_ref(ir, true), + return_type: f.elem().output().to_type_ref(ir, true), args: f .inputs() .iter() - .map(|(name, r#type)| (name.to_string(), r#type.to_type_ref(ir))) + .map(|(name, r#type)| (name.to_string(), r#type.to_type_ref(ir, false))) .collect(), }) }) @@ -178,13 +179,13 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonClient } trait ToTypeReferenceInClientDefinition { - fn to_type_ref(&self, ir: &IntermediateRepr) -> String; + fn to_type_ref(&self, ir: &IntermediateRepr, with_checked: bool) -> String; - fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String; + fn to_partial_type_ref(&self, ir: &IntermediateRepr, with_checked: bool) -> String; } impl ToTypeReferenceInClientDefinition for FieldType { - fn to_type_ref(&self, ir: &IntermediateRepr) -> String { + fn to_type_ref(&self, ir: &IntermediateRepr, with_checked: bool) -> String { match self { FieldType::Enum(name) => { if ir @@ -199,16 +200,16 @@ impl ToTypeReferenceInClientDefinition for FieldType { } FieldType::Literal(value) => format!("Literal[{}]", value), FieldType::Class(name) => format!("types.{name}"), - FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir)), + FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir, with_checked)), FieldType::Map(key, value) => { - format!("Dict[{}, {}]", key.to_type_ref(ir), value.to_type_ref(ir)) + format!("Dict[{}, {}]", key.to_type_ref(ir, with_checked), value.to_type_ref(ir, with_checked)) } FieldType::Primitive(r#type) => r#type.to_python(), FieldType::Union(inner) => format!( "Union[{}]", inner .iter() - .map(|t| t.to_type_ref(ir)) + .map(|t| t.to_type_ref(ir, with_checked)) .collect::>() .join(", ") ), @@ -216,15 +217,27 @@ impl ToTypeReferenceInClientDefinition for FieldType { "Tuple[{}]", inner .iter() - .map(|t| t.to_type_ref(ir)) + .map(|t| t.to_type_ref(ir, with_checked)) .collect::>() .join(", ") ), - FieldType::Optional(inner) => format!("Optional[{}]", inner.to_type_ref(ir)), + FieldType::Optional(inner) => format!("Optional[{}]", inner.to_type_ref(ir, with_checked)), + FieldType::Constrained{base, ..} => { + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_type_ref(ir, with_checked); + let checks_type_ref = type_name_for_checks(&checks); + format!("baml_py.Checked[{base_type_ref},types.{checks_type_ref}]") + } + None => { + base.to_type_ref(ir, with_checked) + } + } + }, } } - fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String { + fn to_partial_type_ref(&self, ir: &IntermediateRepr, with_checked: bool) -> String { match self { FieldType::Enum(name) => { if ir @@ -239,12 +252,12 @@ impl ToTypeReferenceInClientDefinition for FieldType { } FieldType::Class(name) => format!("partial_types.{name}"), FieldType::Literal(value) => format!("Literal[{}]", value), - FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir)), + FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir, with_checked)), FieldType::Map(key, value) => { format!( "Dict[{}, {}]", - key.to_type_ref(ir), - value.to_partial_type_ref(ir) + key.to_type_ref(ir, with_checked), + value.to_partial_type_ref(ir, with_checked) ) } FieldType::Primitive(r#type) => format!("Optional[{}]", r#type.to_python()), @@ -252,7 +265,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { "Optional[Union[{}]]", inner .iter() - .map(|t| t.to_partial_type_ref(ir)) + .map(|t| t.to_partial_type_ref(ir, with_checked)) .collect::>() .join(", ") ), @@ -260,11 +273,23 @@ impl ToTypeReferenceInClientDefinition for FieldType { "Optional[Tuple[{}]]", inner .iter() - .map(|t| t.to_partial_type_ref(ir)) + .map(|t| t.to_partial_type_ref(ir, with_checked)) .collect::>() .join(", ") ), - FieldType::Optional(inner) => inner.to_partial_type_ref(ir), + FieldType::Optional(inner) => inner.to_partial_type_ref(ir, with_checked), + FieldType::Constrained{base, ..} => { + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_partial_type_ref(ir, with_checked); + let checks_type_ref = type_name_for_checks(&checks); + format!("baml_py.Checked[{base_type_ref},types.{checks_type_ref}]") + } + None => { + base.to_partial_type_ref(ir, with_checked) + } + } + }, } } } diff --git a/engine/language_client_codegen/src/python/templates/partial_types.py.j2 b/engine/language_client_codegen/src/python/templates/partial_types.py.j2 index 3638f4553..4e6b6a9e3 100644 --- a/engine/language_client_codegen/src/python/templates/partial_types.py.j2 +++ b/engine/language_client_codegen/src/python/templates/partial_types.py.j2 @@ -6,6 +6,10 @@ from typing import Dict, List, Optional, Union, Literal from . import types +{% if !check_type_names.is_empty() %} +from .types import {{check_type_names}} +{% endif %} + ############################################################################### # # These types are used for streaming, for when an instance of a type diff --git a/engine/language_client_codegen/src/python/templates/types.py.j2 b/engine/language_client_codegen/src/python/templates/types.py.j2 index cc3973cf3..9d1797eff 100644 --- a/engine/language_client_codegen/src/python/templates/types.py.j2 +++ b/engine/language_client_codegen/src/python/templates/types.py.j2 @@ -13,6 +13,15 @@ class {{enum.name}}(str, Enum): {%- endfor %} {% endfor %} +{#- Checks Classes -#} +{% for cls in checks_classes %} +class {{cls.name}}(BaseModel): + + {%- for (name, type) in cls.fields %} + {{name}}: {{type}} + {%- endfor %} +{% endfor %} + {#- Classes -#} {% for cls in classes %} class {{cls.name}}(BaseModel): diff --git a/engine/language_client_codegen/src/ruby/expression.rs b/engine/language_client_codegen/src/ruby/expression.rs index c23976112..567e58381 100644 --- a/engine/language_client_codegen/src/ruby/expression.rs +++ b/engine/language_client_codegen/src/ruby/expression.rs @@ -36,6 +36,7 @@ impl ToRuby for Expression { Expression::RawString(val) => format!("`{}`", val.replace('`', "\\`")), Expression::Numeric(val) => val.clone(), Expression::Bool(val) => val.to_string(), + Expression::JinjaExpression(val) => val.to_string(), } } } diff --git a/engine/language_client_codegen/src/ruby/field_type.rs b/engine/language_client_codegen/src/ruby/field_type.rs index 43622c6be..908b2ee3b 100644 --- a/engine/language_client_codegen/src/ruby/field_type.rs +++ b/engine/language_client_codegen/src/ruby/field_type.rs @@ -2,6 +2,8 @@ use std::collections::HashSet; use baml_types::{BamlMediaType, FieldType, LiteralValue, TypeValue}; +use crate::{field_type_attributes, ruby::generate_types::type_name_for_checks}; + use super::ruby_language_features::ToRuby; impl ToRuby for FieldType { @@ -47,6 +49,18 @@ impl ToRuby for FieldType { .join(", ") ), FieldType::Optional(inner) => format!("T.nilable({})", inner.to_ruby()), + FieldType::Constrained{base,..} => { + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_ruby(); + let checks_type_ref = type_name_for_checks(&checks); + format!("Baml::Checked[{base_type_ref}, {checks_type_ref}]") + } + None => { + base.to_ruby() + } + } + } } } } diff --git a/engine/language_client_codegen/src/ruby/generate_types.rs b/engine/language_client_codegen/src/ruby/generate_types.rs index 329a6a58e..b9f134c4e 100644 --- a/engine/language_client_codegen/src/ruby/generate_types.rs +++ b/engine/language_client_codegen/src/ruby/generate_types.rs @@ -1,7 +1,10 @@ +use std::borrow::Cow; use std::collections::HashSet; use anyhow::Result; +use crate::{field_type_attributes, type_check_attributes, TypeCheckAttributes}; + use super::ruby_language_features::ToRuby; use internal_baml_core::ir::{repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType}; @@ -10,6 +13,7 @@ use internal_baml_core::ir::{repr::IntermediateRepr, ClassWalker, EnumWalker, Fi pub(crate) struct RubyTypes<'ir> { enums: Vec>, classes: Vec>, + checks_classes: Vec>, } struct RubyEnum<'ir> { @@ -19,8 +23,8 @@ struct RubyEnum<'ir> { } struct RubyStruct<'ir> { - name: &'ir str, - fields: Vec<(&'ir str, String)>, + name: Cow<'ir, str>, + fields: Vec<(Cow<'ir, str>, String)>, dynamic: bool, } @@ -51,6 +55,7 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir crate::GeneratorArgs)> for RubyTy Ok(RubyTypes { enums: ir.walk_enums().map(|e| e.into()).collect(), classes: ir.walk_classes().map(|c| c.into()).collect(), + checks_classes: type_check_attributes(ir).into_iter().map(|checks| type_def_for_checks(checks)).collect::>() }) } } @@ -74,14 +79,14 @@ impl<'ir> From> for RubyEnum<'ir> { impl<'ir> From> for RubyStruct<'ir> { fn from(c: ClassWalker<'ir>) -> RubyStruct<'ir> { RubyStruct { - name: c.name(), + name: Cow::Borrowed(c.name()), dynamic: c.item.attributes.get("dynamic_type").is_some(), fields: c .item .elem .static_fields .iter() - .map(|f| (f.elem.name.as_str(), f.elem.r#type.elem.to_type_ref())) + .map(|f| (Cow::Borrowed(f.elem.name.as_str()), f.elem.r#type.elem.to_type_ref())) .collect(), } } @@ -163,6 +168,18 @@ impl ToTypeReferenceInTypeDefinition for FieldType { .join(", ") ), FieldType::Optional(inner) => inner.to_partial_type_ref(), + FieldType::Constrained{base,..} => { + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_partial_type_ref(); + let checks_type_ref = type_name_for_checks(&checks); + format!("Baml::Checked[{base_type_ref}, {checks_type_ref}]") + } + None => { + base.to_partial_type_ref() + } + } + }, } } } @@ -179,3 +196,22 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for TypeReg }) } } + +pub fn type_name_for_checks(checks: &TypeCheckAttributes) -> String { + let mut name = "Checks".to_string(); + let mut names: Vec<&String> = checks.0.iter().collect(); + names.sort(); + for check_name in names.iter() { + name.push_str("__"); + name.push_str(check_name); + } + name +} + +fn type_def_for_checks(checks: TypeCheckAttributes) -> RubyStruct<'static> { + RubyStruct { + name: Cow::Owned(type_name_for_checks(&checks)), + fields: checks.0.into_iter().map(|check_name| (Cow::Owned(check_name), "Baml::Check".to_string())).collect(), + dynamic: false + } +} diff --git a/engine/language_client_codegen/src/ruby/templates/types.rb.j2 b/engine/language_client_codegen/src/ruby/templates/types.rb.j2 index 16da94873..e57ad69b6 100644 --- a/engine/language_client_codegen/src/ruby/templates/types.rb.j2 +++ b/engine/language_client_codegen/src/ruby/templates/types.rb.j2 @@ -22,6 +22,11 @@ module Baml class {{cls.name}} < T::Struct; end {%- endfor %} + {#- Forward declarations for checks classes #} + {%- for cls in checks_classes %} + class {{cls.name}} < T::Struct; end + {%- endfor %} + {#- https://sorbet.org/docs/tstruct #} {%- for cls in classes %} class {{cls.name}} < T::Struct @@ -42,5 +47,27 @@ module Baml end end {%- endfor %} + + {#- https://sorbet.org/docs/tstruct #} + {%- for cls in checks_classes %} + class {{cls.name}} < T::Struct + include Baml::Sorbet::Struct + + {%- for (name, type) in cls.fields %} + const :{{name}}, {{type}} + {%- endfor %} + + def initialize(props) + super( + {%- for (name, _) in cls.fields %} + {{name}}: props[:{{name}}], + {%- endfor %} + ) + + @props = props + end + end + {%- endfor %} + end -end \ No newline at end of file +end diff --git a/engine/language_client_codegen/src/typescript/generate_types.rs b/engine/language_client_codegen/src/typescript/generate_types.rs index 8f67d1693..9fd8a09c9 100644 --- a/engine/language_client_codegen/src/typescript/generate_types.rs +++ b/engine/language_client_codegen/src/typescript/generate_types.rs @@ -1,8 +1,10 @@ +use std::borrow::Cow; + use anyhow::Result; use internal_baml_core::ir::{repr::IntermediateRepr, ClassWalker, EnumWalker}; -use crate::GeneratorArgs; +use crate::{type_check_attributes, GeneratorArgs, TypeCheckAttributes}; use super::ToTypeReferenceInClientDefinition; @@ -17,6 +19,7 @@ pub(crate) struct TypeBuilder<'ir> { #[template(path = "types.ts.j2", escape = "none")] pub(crate) struct TypescriptTypes<'ir> { enums: Vec>, + check_classes: Vec>, classes: Vec>, } @@ -26,12 +29,19 @@ struct TypescriptEnum<'ir> { pub dynamic: bool, } -struct TypescriptClass<'ir> { - name: &'ir str, - fields: Vec<(&'ir str, bool, String)>, - dynamic: bool, +pub struct TypescriptClass<'ir> { + pub name: Cow<'ir, str>, + pub fields: Vec<(Cow<'ir, str>, bool, String)>, + pub dynamic: bool, } +// TODO: Use this. +// pub struct TypescriptChecksClass { +// pub name: String, +// pub fields: Vec<(String, bool, String)>, +// pub dynamic: bool, +// } + impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir GeneratorArgs)> for TypescriptTypes<'ir> { type Error = anyhow::Error; @@ -43,6 +53,10 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir GeneratorArgs)> for TypescriptTyp .walk_enums() .map(|e| Into::::into(&e)) .collect::>(), + check_classes: type_check_attributes(ir) + .iter() + .map(|checks| type_def_for_checks(checks)) + .collect::>(), classes: ir .walk_classes() .map(|e| Into::::into(&e)) @@ -87,7 +101,7 @@ impl<'ir> From<&EnumWalker<'ir>> for TypescriptEnum<'ir> { impl<'ir> From<&ClassWalker<'ir>> for TypescriptClass<'ir> { fn from(c: &ClassWalker<'ir>) -> TypescriptClass<'ir> { TypescriptClass { - name: c.name(), + name: Cow::Borrowed(c.name()), dynamic: c.item.attributes.get("dynamic_type").is_some(), fields: c .item @@ -96,7 +110,7 @@ impl<'ir> From<&ClassWalker<'ir>> for TypescriptClass<'ir> { .iter() .map(|f| { ( - f.elem.name.as_str(), + Cow::Borrowed(f.elem.name.as_str()), f.elem.r#type.elem.is_optional(), f.elem.r#type.elem.to_type_ref(&c.db), ) @@ -105,3 +119,22 @@ impl<'ir> From<&ClassWalker<'ir>> for TypescriptClass<'ir> { } } } + +pub fn type_def_for_checks(checks: &TypeCheckAttributes) -> TypescriptClass<'static> { + TypescriptClass { + name: Cow::Owned(type_name_for_checks(checks)), + dynamic: false, + fields: checks.0.iter().map(|check_name| (Cow::Owned(check_name.clone()), false, "Check".to_string())).collect(), + } +} + +pub fn type_name_for_checks(checks: &TypeCheckAttributes) -> String { + let mut name = "Checks".to_string(); + let mut names: Vec<&String> = checks.0.iter().collect(); + names.sort(); + for check_name in names.iter() { + name.push_str("__"); + name.push_str(check_name); + } + name +} diff --git a/engine/language_client_codegen/src/typescript/mod.rs b/engine/language_client_codegen/src/typescript/mod.rs index 75e35a08b..78e2c3c8e 100644 --- a/engine/language_client_codegen/src/typescript/mod.rs +++ b/engine/language_client_codegen/src/typescript/mod.rs @@ -4,7 +4,7 @@ mod typescript_language_features; use std::path::PathBuf; use anyhow::Result; -use either::Either; +use generate_types::type_name_for_checks; use indexmap::IndexMap; use internal_baml_core::{ configuration::GeneratorDefaultClientMode, @@ -12,12 +12,14 @@ use internal_baml_core::{ }; use self::typescript_language_features::{ToTypescript, TypescriptLanguageFeatures}; -use crate::dir_writer::FileCollector; +use crate::{dir_writer::FileCollector, field_type_attributes, type_check_attributes, TypeCheckAttributes}; +use self::generate_types::{TypescriptClass, type_def_for_checks}; #[derive(askama::Template)] #[template(path = "async_client.ts.j2", escape = "none")] struct AsyncTypescriptClient { funcs: Vec, + check_types: Vec>, types: Vec, } @@ -25,11 +27,13 @@ struct AsyncTypescriptClient { #[template(path = "sync_client.ts.j2", escape = "none")] struct SyncTypescriptClient { funcs: Vec, + check_types: Vec>, types: Vec, } struct TypescriptClient { funcs: Vec, + check_types: Vec>, types: Vec, } @@ -37,6 +41,7 @@ impl From for AsyncTypescriptClient { fn from(value: TypescriptClient) -> Self { Self { funcs: value.funcs, + check_types: value.check_types, types: value.types, } } @@ -46,6 +51,7 @@ impl From for SyncTypescriptClient { fn from(value: TypescriptClient) -> Self { Self { funcs: value.funcs, + check_types: value.check_types, types: value.types, } } @@ -153,6 +159,8 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for TypescriptCli .flatten() .collect(); + let check_types = type_check_attributes(ir).iter().map(|checks| type_def_for_checks(checks)).collect(); + let types = ir .walk_classes() .map(|c| c.name().to_string()) @@ -160,6 +168,7 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for TypescriptCli .collect(); Ok(TypescriptClient { funcs: functions, + check_types, types, }) } @@ -295,6 +304,18 @@ impl ToTypeReferenceInClientDefinition for FieldType { .join(", ") ), FieldType::Optional(inner) => format!("{} | null", inner.to_type_ref(ir)), + FieldType::Constrained{base,..} => { + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_type_ref(ir); + let checks_type_ref = type_name_for_checks(&checks); + format!("Checked<{base_type_ref},{checks_type_ref}>") + } + None => { + base.to_type_ref(ir) + } + } + }, } } } diff --git a/engine/language_client_codegen/src/typescript/templates/index.ts.j2 b/engine/language_client_codegen/src/typescript/templates/index.ts.j2 index d07181474..677c84088 100644 --- a/engine/language_client_codegen/src/typescript/templates/index.ts.j2 +++ b/engine/language_client_codegen/src/typescript/templates/index.ts.j2 @@ -6,4 +6,4 @@ export { b } from "./sync_client" export * from "./types" export * from "./tracing" export { resetBamlEnvVars } from "./globals" -export { BamlValidationError } from "@boundaryml/baml" +export { BamlValidationError, Checked } from "@boundaryml/baml" diff --git a/engine/language_client_codegen/src/typescript/templates/types.ts.j2 b/engine/language_client_codegen/src/typescript/templates/types.ts.j2 index f560952ce..00fe870f3 100644 --- a/engine/language_client_codegen/src/typescript/templates/types.ts.j2 +++ b/engine/language_client_codegen/src/typescript/templates/types.ts.j2 @@ -8,6 +8,14 @@ export enum {{enum.name}} { } {% endfor %} +{%- for cls in check_classes %} +export interface {{cls.name}} { + {%- for (name, optional, type) in cls.fields %} + {{name}}: {{type}} + {%- endfor %} +} +{% endfor %} + {%- for cls in classes %} export interface {{cls.name}} { {%- for (name, optional, type) in cls.fields %} @@ -17,4 +25,4 @@ export interface {{cls.name}} { [key: string]: any; {%- endif %} } -{% endfor %} \ No newline at end of file +{% endfor %} diff --git a/engine/language_client_python/python_src/baml_py/__init__.py b/engine/language_client_python/python_src/baml_py/__init__.py index f5b32b514..49735260c 100644 --- a/engine/language_client_python/python_src/baml_py/__init__.py +++ b/engine/language_client_python/python_src/baml_py/__init__.py @@ -18,6 +18,7 @@ ) from .stream import BamlStream, BamlSyncStream from .ctx_manager import CtxManager as BamlCtxManager +from .constraints import Check, Checked __all__ = [ "BamlRuntime", diff --git a/engine/language_client_python/python_src/baml_py/constraints.py b/engine/language_client_python/python_src/baml_py/constraints.py new file mode 100644 index 000000000..45ecf8aeb --- /dev/null +++ b/engine/language_client_python/python_src/baml_py/constraints.py @@ -0,0 +1,14 @@ +from typing import Generic, Optional, TypeVar +from pydantic import BaseModel + +T = TypeVar('T') +K = TypeVar('K') + +class Check(BaseModel): + name: Optional[str] + expression: str + status: str + +class Checked(BaseModel, Generic[T,K]): + value: T + checks: K \ No newline at end of file diff --git a/engine/language_client_python/src/types/function_results.rs b/engine/language_client_python/src/types/function_results.rs index 427538360..f890dd559 100644 --- a/engine/language_client_python/src/types/function_results.rs +++ b/engine/language_client_python/src/types/function_results.rs @@ -23,6 +23,6 @@ impl FunctionResult { .parsed_content() .map_err(BamlError::from_anyhow)?; - Ok(pythonize(py, &BamlValue::from(parsed))?) + Ok(pythonize(py, &parsed)?) } } diff --git a/engine/language_client_ruby/Gemfile.lock b/engine/language_client_ruby/Gemfile.lock index 89d3522bf..62f0a54ee 100644 --- a/engine/language_client_ruby/Gemfile.lock +++ b/engine/language_client_ruby/Gemfile.lock @@ -1,7 +1,7 @@ PATH remote: . specs: - baml (0.52.1) + baml (0.60.0) GEM remote: https://rubygems.org/ diff --git a/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs b/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs index a3f23c7b1..e11547edd 100644 --- a/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs +++ b/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs @@ -1,4 +1,3 @@ -use baml_types::BamlValue; use magnus::{ class, exception::runtime_error, method, prelude::*, value::Value, Error, RModule, Ruby, }; @@ -39,7 +38,7 @@ impl FunctionResult { ) -> Result { match rb_self.inner.parsed_content() { Ok(parsed) => { - ruby_to_json::RubyToJson::serialize_baml(ruby, types, &BamlValue::from(parsed)) + ruby_to_json::RubyToJson::serialize_baml(ruby, types, parsed.clone()) .map_err(|e| { magnus::Error::new( ruby.exception_type_error(), diff --git a/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs b/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs index 310c1f5a6..0f8df7b70 100644 --- a/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs +++ b/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs @@ -1,7 +1,7 @@ -use baml_types::{BamlMap, BamlValue}; +use baml_types::{BamlValue, BamlMap, BamlValueWithMeta, ResponseCheck}; use indexmap::IndexMap; use magnus::{ - prelude::*, typed_data::Obj, value::Value, Error, Float, Integer, IntoValue, RArray, RClass, + prelude::*, typed_data::Obj, value::Value, class, Error, Float, Integer, IntoValue, RArray, RClass, RHash, RModule, RString, Ruby, Symbol, TypedData, }; use std::result::Result; @@ -26,57 +26,112 @@ impl<'rb> RubyToJson<'rb> { serde_magnus::serialize(&json) } - pub fn serialize_baml(ruby: &Ruby, types: RModule, from: &BamlValue) -> crate::Result { - match from { - BamlValue::Class(class_name, class_fields) => { - let hash = ruby.hash_new(); - for (k, v) in class_fields.iter() { - let k = ruby.sym_new(k.as_str()); - let v = RubyToJson::serialize_baml(ruby, types, v)?; - hash.aset(k, v)?; - } - match types.const_get::<_, RClass>(class_name.as_str()) { - Ok(class_type) => class_type.funcall("new", (hash,)), - Err(_) => { - let dynamic_class_type = ruby.eval::("Baml::DynamicStruct")?; - dynamic_class_type.funcall("new", (hash,)) + pub fn type_name_for_checks(checks: &Vec) -> String { + let mut name = "Checks".to_string(); + let mut names: Vec<&String> = checks.iter().map(|ResponseCheck{name, ..}| name).collect(); + names.sort(); + for check_name in names.iter() { + name.push_str("__"); + name.push_str(check_name); + } + name + } + + /// Serialize a list of check results into some `Checked__*` instance. + pub fn serialize_response_checks(ruby: &Ruby, checks: &Vec) -> crate::Result { + + let class_name = format!("Types::{}", Self::type_name_for_checks(checks)); + let checks_class = ruby.eval::(&class_name)?; + + // Create a `Check` for each check in the `Checked__*`. + let hash = ruby.hash_new(); + checks.iter().try_for_each(|ResponseCheck{name, expression, status}| { + let check_class = ruby.eval::("Baml::Checks::Check")?; + let check_hash = ruby.hash_new(); + check_hash.aset(ruby.sym_new("name"), name.as_str())?; + check_hash.aset(ruby.sym_new("expr"), expression.as_str())?; + check_hash.aset(ruby.sym_new("status"), status.as_str())?; + + let check: Value = check_class.funcall("new", (check_hash,))?; + hash.aset(ruby.sym_new(name.as_str()), check)?; + crate::Result::Ok(()) + })?; + + checks_class.funcall("new", (hash,)) + } + + pub fn serialize_baml(ruby: &Ruby, types: RModule, mut from: BamlValueWithMeta>) -> crate::Result { + + // If we encounter a BamlValue node with check results, serialize it as + // { value: T, checks: K }. To compute `value`, we strip the metadata + // off the node and pass it back to `serialize_baml`. + if !from.meta().is_empty() { + let meta = from.meta().clone(); + let checks = Self::serialize_response_checks(ruby, &meta)?; + + *from.meta_mut() = vec![]; + let serialized_subvalue = Self::serialize_baml(ruby, types, from)?; + + let checked_class = ruby.eval::("Baml::Checked").expect("SHOWME"); + let hash = ruby.hash_new(); + hash.aset(ruby.sym_new("value"), serialized_subvalue)?; + hash.aset(ruby.sym_new("checks"), checks)?; + Ok(checked_class.funcall("new", (hash,)).expect("problem here")) + } + // Otherwise encode it directly. + else { + match from { + BamlValueWithMeta::Class(class_name, class_fields, _) => { + let hash = ruby.hash_new(); + for (k, v) in class_fields.into_iter() { + let k = ruby.sym_new(k.as_str()); + let v = RubyToJson::serialize_baml(ruby, types, v)?; + hash.aset(k, v)?; } - } - } - BamlValue::Enum(enum_name, enum_value) => { - if let Ok(enum_type) = types.const_get::<_, RClass>(enum_name.as_str()) { - let enum_value = ruby.str_new(enum_value); - if let Ok(enum_instance) = enum_type.funcall("deserialize", (enum_value,)) { - return Ok(enum_instance); + match types.const_get::<_, RClass>(class_name.as_str()) { + Ok(class_type) => class_type.funcall("new", (hash,)), + Err(_) => { + let dynamic_class_type = ruby.eval::("Baml::DynamicStruct")?; + dynamic_class_type.funcall("new", (hash,)) + } } } + BamlValueWithMeta::Enum(enum_name, enum_value, _) => { + if let Ok(enum_type) = types.const_get::<_, RClass>(enum_name.as_str()) { + let enum_value = ruby.str_new(&enum_value); + if let Ok(enum_instance) = enum_type.funcall("deserialize", (enum_value,)) { + return Ok(enum_instance); + } + } - Ok(ruby.str_new(enum_value).into_value_with(ruby)) - } - BamlValue::Map(m) => { - let hash = ruby.hash_new(); - for (k, v) in m.iter() { - let k = ruby.str_new(k); - let v = RubyToJson::serialize_baml(ruby, types, v)?; - hash.aset(k, v)?; + Ok(ruby.str_new(&enum_value).into_value_with(ruby)) } - Ok(hash.into_value_with(ruby)) - } - BamlValue::List(l) => { - let arr = ruby.ary_new(); - for v in l.iter() { - let v = RubyToJson::serialize_baml(ruby, types, v)?; - arr.push(v)?; + BamlValueWithMeta::Map(m,_) => { + let hash = ruby.hash_new(); + for (k, v) in m.into_iter() { + let k = ruby.str_new(&k); + let v = RubyToJson::serialize_baml(ruby, types, v)?; + hash.aset(k, v)?; + } + Ok(hash.into_value_with(ruby)) } - Ok(arr.into_value_with(ruby)) + BamlValueWithMeta::List(l, _) => { + let arr = ruby.ary_new(); + for v in l.into_iter() { + let v = RubyToJson::serialize_baml(ruby, types, v)?; + arr.push(v)?; + } + Ok(arr.into_value_with(ruby)) + } + _ => serde_magnus::serialize(&from), } - _ => serde_magnus::serialize(from), + } } pub fn serialize(ruby: &Ruby, types: RModule, from: Value) -> crate::Result { let json = RubyToJson::convert(from)?; - RubyToJson::serialize_baml(ruby, types, &json) + RubyToJson::serialize_baml(ruby, types, BamlValueWithMeta::with_default_meta(&json)) } /// Convert a Ruby object to a JSON object. diff --git a/engine/language_client_ruby/lib/baml.rb b/engine/language_client_ruby/lib/baml.rb index 7c0aa3887..430a5c1ab 100644 --- a/engine/language_client_ruby/lib/baml.rb +++ b/engine/language_client_ruby/lib/baml.rb @@ -7,12 +7,17 @@ # require_relative "baml/ruby_ffi" require_relative "stream" require_relative "struct" +require_relative "checked" module Baml ClientRegistry = Baml::Ffi::ClientRegistry Image = Baml::Ffi::Image Audio = Baml::Ffi::Audio + # Reexport Checked types. + Checked = Baml::Checks::Checked + Check = Baml::Checks::Check + # Dynamically + idempotently define Baml::TypeConverter # NB: this does not respect raise_coercion_error = false def self.convert_to(type) @@ -53,4 +58,4 @@ def _convert(value, type, raise_coercion_error, coerce_empty_to_nil) Baml.const_get(:TypeConverter).new(type) end -end \ No newline at end of file +end diff --git a/engine/language_client_ruby/lib/checked.rb b/engine/language_client_ruby/lib/checked.rb new file mode 100644 index 000000000..5a6d2a0cd --- /dev/null +++ b/engine/language_client_ruby/lib/checked.rb @@ -0,0 +1,36 @@ +require "sorbet-runtime" + +module Baml + module Checks + + class Check < T::Struct + extend T::Sig + + const :name, String + const :expr, String + const :status, String + + + def initialize(props) + super(name: props[:name], expr: props[:expr], status: props[:status]) + end + end + + class Checked < T::Struct + extend T::Sig + extend T::Generic + + Value = type_member + Checks = type_member + + const :value, Value + const :checks, Checks + + def initialize(props) + super(value: props[:value], checks: props[:checks]) + end + + end + + end +end diff --git a/engine/language_client_typescript/checked.d.ts b/engine/language_client_typescript/checked.d.ts new file mode 100644 index 000000000..c5157cb19 --- /dev/null +++ b/engine/language_client_typescript/checked.d.ts @@ -0,0 +1,15 @@ +export interface Checked { + value: T; + checks: K; +} +export interface Check { + name: string; + expr: string; + status: "succeeded" | "failed"; +} +export interface BaseChecks { + [key: string]: Check; +} +export declare function all_succeeded(checks: K): boolean; +export declare function get_checks(checks: K): Check[]; +//# sourceMappingURL=checked.d.ts.map \ No newline at end of file diff --git a/engine/language_client_typescript/checked.d.ts.map b/engine/language_client_typescript/checked.d.ts.map new file mode 100644 index 000000000..5bbc89cf6 --- /dev/null +++ b/engine/language_client_typescript/checked.d.ts.map @@ -0,0 +1 @@ +{"version":3,"file":"checked.d.ts","sourceRoot":"","sources":["typescript_src/checked.ts"],"names":[],"mappings":"AAAA,MAAM,WAAW,OAAO,CAAC,CAAC,EAAC,CAAC,SAAS,UAAU;IAC3C,KAAK,EAAE,CAAC,CAAC;IACT,MAAM,EAAE,CAAC,CAAC;CACb;AAED,MAAM,WAAW,KAAK;IAClB,IAAI,EAAE,MAAM,CAAC;IACb,IAAI,EAAE,MAAM,CAAA;IACZ,MAAM,EAAE,WAAW,GAAG,QAAQ,CAAA;CACjC;AAED,MAAM,WAAW,UAAU;IACvB,CAAC,GAAG,EAAE,MAAM,GAAG,KAAK,CAAA;CACvB;AAED,wBAAgB,aAAa,CAAC,CAAC,SAAS,UAAU,EAAE,MAAM,EAAE,CAAC,GAAG,OAAO,CAEtE;AAED,wBAAgB,UAAU,CAAC,CAAC,SAAS,UAAU,EAAE,MAAM,EAAE,CAAC,GAAG,KAAK,EAAE,CAEnE"} \ No newline at end of file diff --git a/engine/language_client_typescript/checked.js b/engine/language_client_typescript/checked.js new file mode 100644 index 000000000..161044e9a --- /dev/null +++ b/engine/language_client_typescript/checked.js @@ -0,0 +1,11 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.get_checks = exports.all_succeeded = void 0; +function all_succeeded(checks) { + return Object.values(checks).every(value => value.status == "succeeded"); +} +exports.all_succeeded = all_succeeded; +function get_checks(checks) { + return Object.values(checks); +} +exports.get_checks = get_checks; diff --git a/engine/language_client_typescript/index.d.ts b/engine/language_client_typescript/index.d.ts index dafbd2ff7..2a991d309 100644 --- a/engine/language_client_typescript/index.d.ts +++ b/engine/language_client_typescript/index.d.ts @@ -1,6 +1,7 @@ export { BamlRuntime, FunctionResult, FunctionResultStream, BamlImage as Image, ClientBuilder, BamlAudio as Audio, invoke_runtime_cli, ClientRegistry, BamlLogEvent, } from './native'; export { BamlStream } from './stream'; export { BamlCtxManager } from './async_context_vars'; +export { Checked } from './checked'; export declare class BamlValidationError extends Error { prompt: string; raw_output: string; diff --git a/engine/language_client_typescript/index.d.ts.map b/engine/language_client_typescript/index.d.ts.map index 678922c63..f8fcf9ef8 100644 --- a/engine/language_client_typescript/index.d.ts.map +++ b/engine/language_client_typescript/index.d.ts.map @@ -1 +1 @@ -{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["typescript_src/index.ts"],"names":[],"mappings":"AAAA,OAAO,EACL,WAAW,EACX,cAAc,EACd,oBAAoB,EACpB,SAAS,IAAI,KAAK,EAClB,aAAa,EACb,SAAS,IAAI,KAAK,EAClB,kBAAkB,EAClB,cAAc,EACd,YAAY,GACb,MAAM,UAAU,CAAA;AACjB,OAAO,EAAE,UAAU,EAAE,MAAM,UAAU,CAAA;AACrC,OAAO,EAAE,cAAc,EAAE,MAAM,sBAAsB,CAAA;AAErD,qBAAa,mBAAoB,SAAQ,KAAK;IAC5C,MAAM,EAAE,MAAM,CAAA;IACd,UAAU,EAAE,MAAM,CAAA;gBAEN,MAAM,EAAE,MAAM,EAAE,UAAU,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM;IAS/D,MAAM,CAAC,IAAI,CAAC,KAAK,EAAE,KAAK,GAAG,mBAAmB,GAAG,KAAK;IAuBtD,MAAM,IAAI,MAAM;CAWjB;AAGD,wBAAgB,yBAAyB,CAAC,KAAK,EAAE,KAAK,GAAG,mBAAmB,GAAG,KAAK,CAEnF"} \ No newline at end of file +{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["typescript_src/index.ts"],"names":[],"mappings":"AAAA,OAAO,EACL,WAAW,EACX,cAAc,EACd,oBAAoB,EACpB,SAAS,IAAI,KAAK,EAClB,aAAa,EACb,SAAS,IAAI,KAAK,EAClB,kBAAkB,EAClB,cAAc,EACd,YAAY,GACb,MAAM,UAAU,CAAA;AACjB,OAAO,EAAE,UAAU,EAAE,MAAM,UAAU,CAAA;AACrC,OAAO,EAAE,cAAc,EAAE,MAAM,sBAAsB,CAAA;AACrD,OAAO,EAAE,OAAO,EAAE,MAAM,WAAW,CAAA;AAEnC,qBAAa,mBAAoB,SAAQ,KAAK;IAC5C,MAAM,EAAE,MAAM,CAAA;IACd,UAAU,EAAE,MAAM,CAAA;gBAEN,MAAM,EAAE,MAAM,EAAE,UAAU,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM;IAS/D,MAAM,CAAC,IAAI,CAAC,KAAK,EAAE,KAAK,GAAG,mBAAmB,GAAG,KAAK;IAuBtD,MAAM,IAAI,MAAM;CAWjB;AAGD,wBAAgB,yBAAyB,CAAC,KAAK,EAAE,KAAK,GAAG,mBAAmB,GAAG,KAAK,CAEnF"} \ No newline at end of file diff --git a/engine/language_client_typescript/package.json b/engine/language_client_typescript/package.json index ef44139d3..492887816 100644 --- a/engine/language_client_typescript/package.json +++ b/engine/language_client_typescript/package.json @@ -29,7 +29,10 @@ "./stream.d.ts", "./stream.js", "./type_builder.d.ts", - "./type_builder.js" + "./type_builder.js", + "./checked.js", + "./checked.d.ts", + "./checked.d.ts.map" ], "main": "./index.js", "types": "./index.d.ts", diff --git a/engine/language_client_typescript/src/runtime.rs b/engine/language_client_typescript/src/runtime.rs index 570872a4f..76fdc561f 100644 --- a/engine/language_client_typescript/src/runtime.rs +++ b/engine/language_client_typescript/src/runtime.rs @@ -346,7 +346,7 @@ impl BamlRuntime { } #[napi] - pub fn flush(&mut self, env: Env) -> napi::Result<()> { + pub fn flush(&mut self, _env: Env) -> napi::Result<()> { self.inner.flush().map_err(|e| from_anyhow_error(e)) } diff --git a/engine/language_client_typescript/src/types/function_results.rs b/engine/language_client_typescript/src/types/function_results.rs index 866819af8..93b25eead 100644 --- a/engine/language_client_typescript/src/types/function_results.rs +++ b/engine/language_client_typescript/src/types/function_results.rs @@ -1,4 +1,3 @@ -use baml_types::BamlValue; use napi_derive::napi; use crate::errors::from_anyhow_error; @@ -23,6 +22,6 @@ impl FunctionResult { .parsed_content() .map_err(|e| from_anyhow_error(e))?; - Ok(serde_json::json!(BamlValue::from(parsed))) + Ok(serde_json::to_value(parsed)?) } } diff --git a/engine/language_client_typescript/typescript_src/checked.ts b/engine/language_client_typescript/typescript_src/checked.ts new file mode 100644 index 000000000..9f5e8032c --- /dev/null +++ b/engine/language_client_typescript/typescript_src/checked.ts @@ -0,0 +1,22 @@ +export interface Checked { + value: T, + checks: K, +} + +export interface Check { + name: string, + expr: string + status: "succeeded" | "failed" +} + +export interface BaseChecks { + [key: string]: Check +} + +export function all_succeeded(checks: K): boolean { + return Object.values(checks).every(value => value.status == "succeeded"); +} + +export function get_checks(checks: K): Check[] { + return Object.values(checks) +} diff --git a/engine/language_client_typescript/typescript_src/index.ts b/engine/language_client_typescript/typescript_src/index.ts index 27bd94611..c4f5a70ff 100644 --- a/engine/language_client_typescript/typescript_src/index.ts +++ b/engine/language_client_typescript/typescript_src/index.ts @@ -11,6 +11,7 @@ export { } from './native' export { BamlStream } from './stream' export { BamlCtxManager } from './async_context_vars' +export { Checked } from './checked' export class BamlValidationError extends Error { prompt: string diff --git a/integ-tests/baml_src/test-files/constraints/constraints.baml b/integ-tests/baml_src/test-files/constraints/constraints.baml new file mode 100644 index 000000000..3d0266ad4 --- /dev/null +++ b/integ-tests/baml_src/test-files/constraints/constraints.baml @@ -0,0 +1,75 @@ +// These classes and functions test several properties of +// constrains: +// +// - The ability for constrains on fields to pass or fail. +// - The ability for constraints on bare args and return types to pass or fail. +// - The ability of constraints to influence which variant of a union is chosen +// by the parser, when the structure is not sufficient to decide. + + +class Martian { + age int @check({{ this < 30 }}, young_enough) +} + +class Earthling { + age int @check({{this < 200 and this > 0}}, earth_aged) @check({{this >1}}, no_infants) +} + + +class FooAny { + planetary_age Martian | Earthling + certainty int @check({{this == 102931}}, unreasonably_certain) + species string @check({{this == "Homo sapiens"}}, trivial) @check({{this|regex_match("Homo")}}, regex_good) @check({{this|regex_match("neanderthalensis")}}, regex_bad) +} + + +function PredictAge(name: string) -> FooAny { + client GPT35 + prompt #" + Using your understanding of the historical popularity + of names, predict the age of a person with the name + {{ name }} in years. Also predict their genus and + species. It's Homo sapiens (with exactly that spelling + and capitalization). I'll give you a hint: If the name + is "Greg", his age is 41. + + {{ctx.output_format}} + "# +} + + +function PredictAgeBare(inp: string @assert({{this|length > 1}}, big_enough)) -> int @check({{this == 10102}}, too_big) { + client GPT35 + prompt #" + Using your understanding of the historical popularity + of names, predict the age of a person with the name + {{ inp.name }} in years. Also predict their genus and + species. It's Homo sapiens (with exactly that spelling). + + {{ctx.output_format}} + "# +} + +function ReturnFailingAssert(inp: int @assert({{this < 10}}, small_int)) -> int @assert({{this > 100}}, big_int) { + client GPT35 + prompt #" + Return the next integer after {{ inp }}. + + {{ctx.output_format}} + "# +} + +class TwoStoriesOneTitle { + title string + story_a string @assert( {{this|length > 1000000}}, too_long_story ) + story_b string @assert( {{this|length > 1000000}}, too_long_story ) +} + +function StreamFailingAssertion(theme: string, length: int) -> TwoStoriesOneTitle { + client GPT35 + prompt #" + Tell me two different stories along the theme of {{ theme }} with the same title. + Please make each about {{ length }} words long. + {{ctx.output_format}} + "# +} diff --git a/integ-tests/baml_src/test-files/constraints/contact-info.baml b/integ-tests/baml_src/test-files/constraints/contact-info.baml new file mode 100644 index 000000000..d08903744 --- /dev/null +++ b/integ-tests/baml_src/test-files/constraints/contact-info.baml @@ -0,0 +1,24 @@ +class PhoneNumber { + value string @check({{this|regex_match("\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}")}}, valid_phone_number) +} + +class EmailAddress { + value string @check({{this|regex_match("^[_]*([a-z0-9]+(\.|_*)?)+@([a-z][a-z0-9-]+(\.|-*\.))+[a-z]{2,6}$")}}, valid_email) +} + +class ContactInfo { + primary PhoneNumber | EmailAddress + secondary (PhoneNumber | EmailAddress)? +} + +function ExtractContactInfo(document: string) -> ContactInfo { + client GPT35 + prompt #" + Extract a primary contact info, and if possible a secondary contact + info, from this document: + + {{ document }} + + {{ ctx.output_format }} + "# +} diff --git a/integ-tests/openapi/baml_client/openapi.yaml b/integ-tests/openapi/baml_client/openapi.yaml index d9f81ccb0..114ed8a70 100644 --- a/integ-tests/openapi/baml_client/openapi.yaml +++ b/integ-tests/openapi/baml_client/openapi.yaml @@ -232,6 +232,19 @@ paths: title: ExpectFailureResponse type: string operationId: ExpectFailure + /call/ExtractContactInfo: + post: + requestBody: + $ref: '#/components/requestBodies/ExtractContactInfo' + responses: + '200': + description: Successful operation + content: + application/json: + schema: + title: ExtractContactInfoResponse + $ref: '#/components/schemas/ContactInfo' + operationId: ExtractContactInfo /call/ExtractNames: post: requestBody: @@ -556,6 +569,47 @@ paths: items: $ref: '#/components/schemas/OptionalTest_ReturnType' operationId: OptionalTest_Function + /call/PredictAge: + post: + requestBody: + $ref: '#/components/requestBodies/PredictAge' + responses: + '200': + description: Successful operation + content: + application/json: + schema: + title: PredictAgeResponse + $ref: '#/components/schemas/FooAny' + operationId: PredictAge + /call/PredictAgeBare: + post: + requestBody: + $ref: '#/components/requestBodies/PredictAgeBare' + responses: + '200': + description: Successful operation + content: + application/json: + schema: + title: PredictAgeBareResponse + type: object + properties: + value: + type: integer + checks: + type: object + properties: + too_big: + $ref: '#components/schemas/Check' + required: + - too_big + additionalProperties: false + required: + - value + - checks + additionalProperties: false + operationId: PredictAgeBare /call/PromptTestClaude: post: requestBody: @@ -647,6 +701,19 @@ paths: title: PromptTestStreamingResponse type: string operationId: PromptTestStreaming + /call/ReturnFailingAssert: + post: + requestBody: + $ref: '#/components/requestBodies/ReturnFailingAssert' + responses: + '200': + description: Successful operation + content: + application/json: + schema: + title: ReturnFailingAssertResponse + type: integer + operationId: ReturnFailingAssert /call/SchemaDescriptions: post: requestBody: @@ -673,6 +740,19 @@ paths: title: StreamBigNumbersResponse $ref: '#/components/schemas/BigNumbers' operationId: StreamBigNumbers + /call/StreamFailingAssertion: + post: + requestBody: + $ref: '#/components/requestBodies/StreamFailingAssertion' + responses: + '200': + description: Successful operation + content: + application/json: + schema: + title: StreamFailingAssertionResponse + $ref: '#/components/schemas/TwoStoriesOneTitle' + operationId: StreamFailingAssertion /call/StreamOneBigNumber: post: requestBody: @@ -1380,6 +1460,22 @@ components: $ref: '#/components/schemas/BamlOptions' required: [] additionalProperties: false + ExtractContactInfo: + required: true + content: + application/json: + schema: + title: ExtractContactInfoRequest + type: object + properties: + document: + type: string + __baml_options__: + nullable: true + $ref: '#/components/schemas/BamlOptions' + required: + - document + additionalProperties: false ExtractNames: required: true content: @@ -1770,6 +1866,38 @@ components: required: - input additionalProperties: false + PredictAge: + required: true + content: + application/json: + schema: + title: PredictAgeRequest + type: object + properties: + name: + type: string + __baml_options__: + nullable: true + $ref: '#/components/schemas/BamlOptions' + required: + - name + additionalProperties: false + PredictAgeBare: + required: true + content: + application/json: + schema: + title: PredictAgeBareRequest + type: object + properties: + inp: + type: string + __baml_options__: + nullable: true + $ref: '#/components/schemas/BamlOptions' + required: + - inp + additionalProperties: false PromptTestClaude: required: true content: @@ -1882,6 +2010,22 @@ components: required: - input additionalProperties: false + ReturnFailingAssert: + required: true + content: + application/json: + schema: + title: ReturnFailingAssertRequest + type: object + properties: + inp: + type: integer + __baml_options__: + nullable: true + $ref: '#/components/schemas/BamlOptions' + required: + - inp + additionalProperties: false SchemaDescriptions: required: true content: @@ -1914,6 +2058,25 @@ components: required: - digits additionalProperties: false + StreamFailingAssertion: + required: true + content: + application/json: + schema: + title: StreamFailingAssertionRequest + type: object + properties: + theme: + type: string + length: + type: integer + __baml_options__: + nullable: true + $ref: '#/components/schemas/BamlOptions' + required: + - theme + - length + additionalProperties: false StreamOneBigNumber: required: true content: @@ -2532,6 +2695,15 @@ components: - name - provider - options + Check: + type: object + properties: + name: + type: string + expr: + type: string + status: + type: string Category: enum: - Refund @@ -2719,6 +2891,20 @@ components: - big_nums - another additionalProperties: false + ContactInfo: + type: object + properties: + primary: + oneOf: + - $ref: '#/components/schemas/PhoneNumber' + - $ref: '#/components/schemas/EmailAddress' + secondary: + oneOf: + - $ref: '#/components/schemas/PhoneNumber' + - $ref: '#/components/schemas/EmailAddress' + required: + - primary + additionalProperties: false CustomTaskResult: type: object properties: @@ -2776,6 +2962,32 @@ components: properties: {} required: [] additionalProperties: false + Earthling: + type: object + properties: + age: + type: object + properties: + value: + type: integer + checks: + type: object + properties: + earth_aged: + $ref: '#components/schemas/Check' + no_infants: + $ref: '#components/schemas/Check' + required: + - earth_aged + - no_infants + additionalProperties: false + required: + - value + - checks + additionalProperties: false + required: + - age + additionalProperties: false Education: type: object properties: @@ -2811,6 +3023,29 @@ components: - body - from_address additionalProperties: false + EmailAddress: + type: object + properties: + value: + type: object + properties: + value: + type: string + checks: + type: object + properties: + valid_email: + $ref: '#components/schemas/Check' + required: + - valid_email + additionalProperties: false + required: + - value + - checks + additionalProperties: false + required: + - value + additionalProperties: false Event: type: object properties: @@ -2856,6 +3091,58 @@ components: - arrivalTime - seatNumber additionalProperties: false + FooAny: + type: object + properties: + planetary_age: + oneOf: + - $ref: '#/components/schemas/Martian' + - $ref: '#/components/schemas/Earthling' + certainty: + type: object + properties: + value: + type: integer + checks: + type: object + properties: + unreasonably_certain: + $ref: '#components/schemas/Check' + required: + - unreasonably_certain + additionalProperties: false + required: + - value + - checks + additionalProperties: false + species: + type: object + properties: + value: + type: string + checks: + type: object + properties: + regex_bad: + $ref: '#components/schemas/Check' + trivial: + $ref: '#components/schemas/Check' + regex_good: + $ref: '#components/schemas/Check' + required: + - regex_bad + - trivial + - regex_good + additionalProperties: false + required: + - value + - checks + additionalProperties: false + required: + - planetary_age + - certainty + - species + additionalProperties: false GroceryReceipt: type: object properties: @@ -2903,6 +3190,29 @@ components: - prop2 - prop3 additionalProperties: false + Martian: + type: object + properties: + age: + type: object + properties: + value: + type: integer + checks: + type: object + properties: + young_enough: + $ref: '#components/schemas/Check' + required: + - young_enough + additionalProperties: false + required: + - value + - checks + additionalProperties: false + required: + - age + additionalProperties: false NamedArgsSingleClass: type: object properties: @@ -2988,6 +3298,29 @@ components: $ref: '#/components/schemas/Color' required: [] additionalProperties: false + PhoneNumber: + type: object + properties: + value: + type: object + properties: + value: + type: string + checks: + type: object + properties: + valid_phone_number: + $ref: '#components/schemas/Check' + required: + - valid_phone_number + additionalProperties: false + required: + - value + - checks + additionalProperties: false + required: + - value + additionalProperties: false Quantity: type: object properties: @@ -3230,6 +3563,20 @@ components: - prop1 - prop2 additionalProperties: false + TwoStoriesOneTitle: + type: object + properties: + title: + type: string + story_a: + type: string + story_b: + type: string + required: + - title + - story_a + - story_b + additionalProperties: false UnionTest_ReturnType: type: object properties: diff --git a/integ-tests/python/baml_client/async_client.py b/integ-tests/python/baml_client/async_client.py index f3a0d202a..a59347c31 100644 --- a/integ-tests/python/baml_client/async_client.py +++ b/integ-tests/python/baml_client/async_client.py @@ -443,6 +443,30 @@ async def ExpectFailure( mdl = create_model("ExpectFailureReturnType", inner=(str, ...)) return coerce(mdl, raw.parsed()) + async def ExtractContactInfo( + self, + document: str, + baml_options: BamlCallOptions = {}, + ) -> types.ContactInfo: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "ExtractContactInfo", + { + "document": document, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("ExtractContactInfoReturnType", inner=(types.ContactInfo, ...)) + return coerce(mdl, raw.parsed()) + async def ExtractNames( self, input: str, @@ -1019,6 +1043,54 @@ async def OptionalTest_Function( mdl = create_model("OptionalTest_FunctionReturnType", inner=(List[Optional[types.OptionalTest_ReturnType]], ...)) return coerce(mdl, raw.parsed()) + async def PredictAge( + self, + name: str, + baml_options: BamlCallOptions = {}, + ) -> types.FooAny: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "PredictAge", + { + "name": name, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("PredictAgeReturnType", inner=(types.FooAny, ...)) + return coerce(mdl, raw.parsed()) + + async def PredictAgeBare( + self, + inp: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.Checked[int,types.Checks__too_big]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "PredictAgeBare", + { + "inp": inp, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("PredictAgeBareReturnType", inner=(baml_py.Checked[int,types.Checks__too_big], ...)) + return coerce(mdl, raw.parsed()) + async def PromptTestClaude( self, input: str, @@ -1187,6 +1259,30 @@ async def PromptTestStreaming( mdl = create_model("PromptTestStreamingReturnType", inner=(str, ...)) return coerce(mdl, raw.parsed()) + async def ReturnFailingAssert( + self, + inp: int, + baml_options: BamlCallOptions = {}, + ) -> int: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "ReturnFailingAssert", + { + "inp": inp, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("ReturnFailingAssertReturnType", inner=(int, ...)) + return coerce(mdl, raw.parsed()) + async def SchemaDescriptions( self, input: str, @@ -1235,6 +1331,30 @@ async def StreamBigNumbers( mdl = create_model("StreamBigNumbersReturnType", inner=(types.BigNumbers, ...)) return coerce(mdl, raw.parsed()) + async def StreamFailingAssertion( + self, + theme: str,length: int, + baml_options: BamlCallOptions = {}, + ) -> types.TwoStoriesOneTitle: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "StreamFailingAssertion", + { + "theme": theme,"length": length, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("StreamFailingAssertionReturnType", inner=(types.TwoStoriesOneTitle, ...)) + return coerce(mdl, raw.parsed()) + async def StreamOneBigNumber( self, digits: int, @@ -2568,6 +2688,39 @@ def ExpectFailure( self.__ctx_manager.get(), ) + def ExtractContactInfo( + self, + document: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[partial_types.ContactInfo, types.ContactInfo]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "ExtractContactInfo", + { + "document": document, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("ExtractContactInfoReturnType", inner=(types.ContactInfo, ...)) + partial_mdl = create_model("ExtractContactInfoPartialReturnType", inner=(partial_types.ContactInfo, ...)) + + return baml_py.BamlStream[partial_types.ContactInfo, types.ContactInfo]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def ExtractNames( self, input: str, @@ -3362,6 +3515,72 @@ def OptionalTest_Function( self.__ctx_manager.get(), ) + def PredictAge( + self, + name: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[partial_types.FooAny, types.FooAny]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "PredictAge", + { + "name": name, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("PredictAgeReturnType", inner=(types.FooAny, ...)) + partial_mdl = create_model("PredictAgePartialReturnType", inner=(partial_types.FooAny, ...)) + + return baml_py.BamlStream[partial_types.FooAny, types.FooAny]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + + def PredictAgeBare( + self, + inp: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[baml_py.Checked[Optional[int],types.Checks__too_big], baml_py.Checked[int,types.Checks__too_big]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "PredictAgeBare", + { + "inp": inp, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("PredictAgeBareReturnType", inner=(baml_py.Checked[int,types.Checks__too_big], ...)) + partial_mdl = create_model("PredictAgeBarePartialReturnType", inner=(baml_py.Checked[Optional[int],types.Checks__too_big], ...)) + + return baml_py.BamlStream[baml_py.Checked[Optional[int],types.Checks__too_big], baml_py.Checked[int,types.Checks__too_big]]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def PromptTestClaude( self, input: str, @@ -3593,6 +3812,39 @@ def PromptTestStreaming( self.__ctx_manager.get(), ) + def ReturnFailingAssert( + self, + inp: int, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[Optional[int], int]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "ReturnFailingAssert", + { + "inp": inp, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("ReturnFailingAssertReturnType", inner=(int, ...)) + partial_mdl = create_model("ReturnFailingAssertPartialReturnType", inner=(Optional[int], ...)) + + return baml_py.BamlStream[Optional[int], int]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def SchemaDescriptions( self, input: str, @@ -3659,6 +3911,40 @@ def StreamBigNumbers( self.__ctx_manager.get(), ) + def StreamFailingAssertion( + self, + theme: str,length: int, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[partial_types.TwoStoriesOneTitle, types.TwoStoriesOneTitle]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "StreamFailingAssertion", + { + "theme": theme, + "length": length, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("StreamFailingAssertionReturnType", inner=(types.TwoStoriesOneTitle, ...)) + partial_mdl = create_model("StreamFailingAssertionPartialReturnType", inner=(partial_types.TwoStoriesOneTitle, ...)) + + return baml_py.BamlStream[partial_types.TwoStoriesOneTitle, types.TwoStoriesOneTitle]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def StreamOneBigNumber( self, digits: int, diff --git a/integ-tests/python/baml_client/inlinedbaml.py b/integ-tests/python/baml_client/inlinedbaml.py index 0c06b842b..6883c9f4e 100644 --- a/integ-tests/python/baml_client/inlinedbaml.py +++ b/integ-tests/python/baml_client/inlinedbaml.py @@ -29,6 +29,8 @@ "test-files/aliases/classes.baml": "class TestClassAlias {\n key string @alias(\"key-dash\") @description(#\"\n This is a description for key\n af asdf\n \"#)\n key2 string @alias(\"key21\")\n key3 string @alias(\"key with space\")\n key4 string //unaliased\n key5 string @alias(\"key.with.punctuation/123\")\n}\n\nfunction FnTestClassAlias(input: string) -> TestClassAlias {\n client GPT35\n prompt #\"\n {{ctx.output_format}}\n \"#\n}\n\ntest FnTestClassAlias {\n functions [FnTestClassAlias]\n args {\n input \"example input\"\n }\n}\n", "test-files/aliases/enums.baml": "enum TestEnum {\n A @alias(\"k1\") @description(#\"\n User is angry\n \"#)\n B @alias(\"k22\") @description(#\"\n User is happy\n \"#)\n // tests whether k1 doesnt incorrectly get matched with k11\n C @alias(\"k11\") @description(#\"\n User is sad\n \"#)\n D @alias(\"k44\") @description(\n User is confused\n )\n E @description(\n User is excited\n )\n F @alias(\"k5\") // only alias\n \n G @alias(\"k6\") @description(#\"\n User is bored\n With a long description\n \"#)\n \n @@alias(\"Category\")\n}\n\nfunction FnTestAliasedEnumOutput(input: string) -> TestEnum {\n client GPT35\n prompt #\"\n Classify the user input into the following category\n \n {{ ctx.output_format }}\n\n {{ _.role('user') }}\n {{input}}\n\n {{ _.role('assistant') }}\n Category ID:\n \"#\n}\n\ntest FnTestAliasedEnumOutput {\n functions [FnTestAliasedEnumOutput]\n args {\n input \"mehhhhh\"\n }\n}", "test-files/comments/comments.baml": "// add some functions, classes, enums etc with comments all over.", + "test-files/constraints/constraints.baml": "// These classes and functions test several properties of\n// constrains:\n//\n// - The ability for constrains on fields to pass or fail.\n// - The ability for constraints on bare args and return types to pass or fail.\n// - The ability of constraints to influence which variant of a union is chosen\n// by the parser, when the structure is not sufficient to decide.\n\n\nclass Martian {\n age int @check({{ this < 30 }}, young_enough)\n}\n\nclass Earthling {\n age int @check({{this < 200 and this > 0}}, earth_aged) @check({{this >1}}, no_infants)\n}\n\n\nclass FooAny {\n planetary_age Martian | Earthling\n certainty int @check({{this == 102931}}, unreasonably_certain)\n species string @check({{this == \"Homo sapiens\"}}, trivial) @check({{this|regex_match(\"Homo\")}}, regex_good) @check({{this|regex_match(\"neanderthalensis\")}}, regex_bad)\n}\n\n\nfunction PredictAge(name: string) -> FooAny {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling\n and capitalization). I'll give you a hint: If the name\n is \"Greg\", his age is 41.\n\n {{ctx.output_format}}\n \"#\n}\n\n\nfunction PredictAgeBare(inp: string @assert({{this|length > 1}}, big_enough)) -> int @check({{this == 10102}}, too_big) {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ inp.name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling).\n\n {{ctx.output_format}}\n \"#\n}\n\nfunction ReturnFailingAssert(inp: int @assert({{this < 10}}, small_int)) -> int @assert({{this > 100}}, big_int) {\n client GPT35\n prompt #\"\n Return the next integer after {{ inp }}.\n\n {{ctx.output_format}}\n \"#\n}\n\nclass TwoStoriesOneTitle {\n title string\n story_a string @assert( {{this|length > 1000000}}, too_long_story )\n story_b string @assert( {{this|length > 1000000}}, too_long_story )\n}\n\nfunction StreamFailingAssertion(theme: string, length: int) -> TwoStoriesOneTitle {\n client GPT35\n prompt #\"\n Tell me two different stories along the theme of {{ theme }} with the same title.\n Please make each about {{ length }} words long.\n {{ctx.output_format}}\n \"#\n}\n", + "test-files/constraints/contact-info.baml": "class PhoneNumber {\n value string @check({{this|regex_match(\"\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}\")}}, valid_phone_number)\n}\n\nclass EmailAddress {\n value string @check({{this|regex_match(\"^[_]*([a-z0-9]+(\\.|_*)?)+@([a-z][a-z0-9-]+(\\.|-*\\.))+[a-z]{2,6}$\")}}, valid_email)\n}\n\nclass ContactInfo {\n primary PhoneNumber | EmailAddress\n secondary (PhoneNumber | EmailAddress)?\n}\n\nfunction ExtractContactInfo(document: string) -> ContactInfo {\n client GPT35\n prompt #\"\n Extract a primary contact info, and if possible a secondary contact\n info, from this document:\n\n {{ document }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/descriptions/descriptions.baml": "\nclass Nested {\n prop3 string | null @description(#\"\n write \"three\"\n \"#)\n prop4 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n prop20 Nested2\n}\n\nclass Nested2 {\n prop11 string | null @description(#\"\n write \"three\"\n \"#)\n prop12 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n}\n\nclass Schema {\n prop1 string | null @description(#\"\n write \"one\"\n \"#)\n prop2 Nested | string @description(#\"\n write \"two\"\n \"#)\n prop5 (string | null)[] @description(#\"\n write \"hi\"\n \"#)\n prop6 string | Nested[] @alias(\"blah\") @description(#\"\n write the string \"blah\" regardless of the other types here\n \"#)\n nested_attrs (string | null | Nested)[] @description(#\"\n write the string \"nested\" regardless of other types\n \"#)\n parens (string | null) @description(#\"\n write \"parens1\"\n \"#)\n other_group (string | (int | string)) @description(#\"\n write \"other\"\n \"#) @alias(other)\n}\n\n\nfunction SchemaDescriptions(input: string) -> Schema {\n client GPT4o\n prompt #\"\n Return a schema with this format:\n\n {{ctx.output_format}}\n \"#\n}", "test-files/dynamic/client-registry.baml": "// Intentionally use a bad key\nclient BadClient {\n provider openai\n options {\n model \"gpt-3.5-turbo\"\n api_key \"sk-invalid\"\n }\n}\n\nfunction ExpectFailure() -> string {\n client BadClient\n\n prompt #\"\n What is the capital of England?\n \"#\n}\n", "test-files/dynamic/dynamic.baml": "class DynamicClassOne {\n @@dynamic\n}\n\nenum DynEnumOne {\n @@dynamic\n}\n\nenum DynEnumTwo {\n @@dynamic\n}\n\nclass SomeClassNestedDynamic {\n hi string\n @@dynamic\n\n}\n\nclass DynamicClassTwo {\n hi string\n some_class SomeClassNestedDynamic\n status DynEnumOne\n @@dynamic\n}\n\nfunction DynamicFunc(input: DynamicClassOne) -> DynamicClassTwo {\n client GPT35\n prompt #\"\n Please extract the schema from \n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nclass DynInputOutput {\n testKey string\n @@dynamic\n}\n\nfunction DynamicInputOutput(input: DynInputOutput) -> DynInputOutput {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\nfunction DynamicListInputOutput(input: DynInputOutput[]) -> DynInputOutput[] {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\n\n\nclass DynamicOutput {\n @@dynamic\n}\n \nfunction MyFunc(input: string) -> DynamicOutput {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ClassifyDynEnumTwo(input: string) -> DynEnumTwo {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}", diff --git a/integ-tests/python/baml_client/partial_types.py b/integ-tests/python/baml_client/partial_types.py index 26b0b2f19..73714ecd0 100644 --- a/integ-tests/python/baml_client/partial_types.py +++ b/integ-tests/python/baml_client/partial_types.py @@ -20,6 +20,10 @@ from . import types + +from .types import Checks__young_enough, Checks__earth_aged__no_infants, Checks__too_big, Checks__valid_email, Checks__unreasonably_certain, Checks__regex_bad__regex_good__trivial, Checks__valid_phone_number + + ############################################################################### # # These types are used for streaming, for when an instance of a type @@ -74,6 +78,12 @@ class CompoundBigNumbers(BaseModel): big_nums: List["BigNumbers"] another: Optional["BigNumbers"] = None +class ContactInfo(BaseModel): + + + primary: Optional[Union["PhoneNumber", "EmailAddress"]] = None + secondary: Optional[Union["PhoneNumber", "EmailAddress", Optional[None]]] = None + class CustomTaskResult(BaseModel): @@ -112,6 +122,11 @@ class DynamicOutput(BaseModel): model_config = ConfigDict(extra='allow') +class Earthling(BaseModel): + + + age: baml_py.Checked[Optional[int],Checks__earth_aged__no_infants] + class Education(BaseModel): @@ -128,6 +143,11 @@ class Email(BaseModel): body: Optional[str] = None from_address: Optional[str] = None +class EmailAddress(BaseModel): + + + value: baml_py.Checked[Optional[str],Checks__valid_email] + class Event(BaseModel): @@ -150,6 +170,13 @@ class FlightConfirmation(BaseModel): arrivalTime: Optional[str] = None seatNumber: Optional[str] = None +class FooAny(BaseModel): + + + planetary_age: Optional[Union["Martian", "Earthling"]] = None + certainty: baml_py.Checked[Optional[int],Checks__unreasonably_certain] + species: baml_py.Checked[Optional[str],Checks__regex_bad__regex_good__trivial] + class GroceryReceipt(BaseModel): @@ -171,6 +198,11 @@ class InnerClass2(BaseModel): prop2: Optional[int] = None prop3: Optional[float] = None +class Martian(BaseModel): + + + age: baml_py.Checked[Optional[int],Checks__young_enough] + class NamedArgsSingleClass(BaseModel): @@ -218,6 +250,11 @@ class Person(BaseModel): name: Optional[str] = None hair_color: Optional[Union[types.Color, str]] = None +class PhoneNumber(BaseModel): + + + value: baml_py.Checked[Optional[str],Checks__valid_phone_number] + class Quantity(BaseModel): @@ -320,6 +357,13 @@ class TestOutputClass(BaseModel): prop1: Optional[str] = None prop2: Optional[int] = None +class TwoStoriesOneTitle(BaseModel): + + + title: Optional[str] = None + story_a: Optional[str] = None + story_b: Optional[str] = None + class UnionTest_ReturnType(BaseModel): diff --git a/integ-tests/python/baml_client/sync_client.py b/integ-tests/python/baml_client/sync_client.py index 2d127ac97..1b122859e 100644 --- a/integ-tests/python/baml_client/sync_client.py +++ b/integ-tests/python/baml_client/sync_client.py @@ -441,6 +441,30 @@ def ExpectFailure( mdl = create_model("ExpectFailureReturnType", inner=(str, ...)) return coerce(mdl, raw.parsed()) + def ExtractContactInfo( + self, + document: str, + baml_options: BamlCallOptions = {}, + ) -> types.ContactInfo: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "ExtractContactInfo", + { + "document": document, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("ExtractContactInfoReturnType", inner=(types.ContactInfo, ...)) + return coerce(mdl, raw.parsed()) + def ExtractNames( self, input: str, @@ -1017,6 +1041,54 @@ def OptionalTest_Function( mdl = create_model("OptionalTest_FunctionReturnType", inner=(List[Optional[types.OptionalTest_ReturnType]], ...)) return coerce(mdl, raw.parsed()) + def PredictAge( + self, + name: str, + baml_options: BamlCallOptions = {}, + ) -> types.FooAny: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "PredictAge", + { + "name": name, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("PredictAgeReturnType", inner=(types.FooAny, ...)) + return coerce(mdl, raw.parsed()) + + def PredictAgeBare( + self, + inp: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.Checked[int,types.Checks__too_big]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "PredictAgeBare", + { + "inp": inp, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("PredictAgeBareReturnType", inner=(baml_py.Checked[int,types.Checks__too_big], ...)) + return coerce(mdl, raw.parsed()) + def PromptTestClaude( self, input: str, @@ -1185,6 +1257,30 @@ def PromptTestStreaming( mdl = create_model("PromptTestStreamingReturnType", inner=(str, ...)) return coerce(mdl, raw.parsed()) + def ReturnFailingAssert( + self, + inp: int, + baml_options: BamlCallOptions = {}, + ) -> int: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "ReturnFailingAssert", + { + "inp": inp, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("ReturnFailingAssertReturnType", inner=(int, ...)) + return coerce(mdl, raw.parsed()) + def SchemaDescriptions( self, input: str, @@ -1233,6 +1329,30 @@ def StreamBigNumbers( mdl = create_model("StreamBigNumbersReturnType", inner=(types.BigNumbers, ...)) return coerce(mdl, raw.parsed()) + def StreamFailingAssertion( + self, + theme: str,length: int, + baml_options: BamlCallOptions = {}, + ) -> types.TwoStoriesOneTitle: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "StreamFailingAssertion", + { + "theme": theme,"length": length, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("StreamFailingAssertionReturnType", inner=(types.TwoStoriesOneTitle, ...)) + return coerce(mdl, raw.parsed()) + def StreamOneBigNumber( self, digits: int, @@ -2567,6 +2687,39 @@ def ExpectFailure( self.__ctx_manager.get(), ) + def ExtractContactInfo( + self, + document: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[partial_types.ContactInfo, types.ContactInfo]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "ExtractContactInfo", + { + "document": document, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("ExtractContactInfoReturnType", inner=(types.ContactInfo, ...)) + partial_mdl = create_model("ExtractContactInfoPartialReturnType", inner=(partial_types.ContactInfo, ...)) + + return baml_py.BamlSyncStream[partial_types.ContactInfo, types.ContactInfo]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def ExtractNames( self, input: str, @@ -3361,6 +3514,72 @@ def OptionalTest_Function( self.__ctx_manager.get(), ) + def PredictAge( + self, + name: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[partial_types.FooAny, types.FooAny]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "PredictAge", + { + "name": name, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("PredictAgeReturnType", inner=(types.FooAny, ...)) + partial_mdl = create_model("PredictAgePartialReturnType", inner=(partial_types.FooAny, ...)) + + return baml_py.BamlSyncStream[partial_types.FooAny, types.FooAny]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + + def PredictAgeBare( + self, + inp: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[baml_py.Checked[Optional[int],types.Checks__too_big], baml_py.Checked[int,types.Checks__too_big]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "PredictAgeBare", + { + "inp": inp, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("PredictAgeBareReturnType", inner=(baml_py.Checked[int,types.Checks__too_big], ...)) + partial_mdl = create_model("PredictAgeBarePartialReturnType", inner=(baml_py.Checked[Optional[int],types.Checks__too_big], ...)) + + return baml_py.BamlSyncStream[baml_py.Checked[Optional[int],types.Checks__too_big], baml_py.Checked[int,types.Checks__too_big]]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def PromptTestClaude( self, input: str, @@ -3592,6 +3811,39 @@ def PromptTestStreaming( self.__ctx_manager.get(), ) + def ReturnFailingAssert( + self, + inp: int, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[Optional[int], int]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "ReturnFailingAssert", + { + "inp": inp, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("ReturnFailingAssertReturnType", inner=(int, ...)) + partial_mdl = create_model("ReturnFailingAssertPartialReturnType", inner=(Optional[int], ...)) + + return baml_py.BamlSyncStream[Optional[int], int]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def SchemaDescriptions( self, input: str, @@ -3658,6 +3910,40 @@ def StreamBigNumbers( self.__ctx_manager.get(), ) + def StreamFailingAssertion( + self, + theme: str,length: int, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[partial_types.TwoStoriesOneTitle, types.TwoStoriesOneTitle]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "StreamFailingAssertion", + { + "theme": theme, + "length": length, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("StreamFailingAssertionReturnType", inner=(types.TwoStoriesOneTitle, ...)) + partial_mdl = create_model("StreamFailingAssertionPartialReturnType", inner=(partial_types.TwoStoriesOneTitle, ...)) + + return baml_py.BamlSyncStream[partial_types.TwoStoriesOneTitle, types.TwoStoriesOneTitle]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def StreamOneBigNumber( self, digits: int, diff --git a/integ-tests/python/baml_client/type_builder.py b/integ-tests/python/baml_client/type_builder.py index e54c2b50a..b09b168ff 100644 --- a/integ-tests/python/baml_client/type_builder.py +++ b/integ-tests/python/baml_client/type_builder.py @@ -19,7 +19,7 @@ class TypeBuilder(_TypeBuilder): def __init__(self): super().__init__(classes=set( - ["BigNumbers","Blah","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassWithImage","CompoundBigNumbers","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Education","Email","Event","FakeImage","FlightConfirmation","GroceryReceipt","InnerClass","InnerClass2","NamedArgsSingleClass","Nested","Nested2","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","Person","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","UnionTest_ReturnType","WithReasoning",] + ["BigNumbers","Blah","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","GroceryReceipt","InnerClass","InnerClass2","Martian","NamedArgsSingleClass","Nested","Nested2","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning",] ), enums=set( ["Category","Category2","Category3","Color","DataType","DynEnumOne","DynEnumTwo","EnumInClass","EnumOutput","Hobby","NamedArgsSingleEnum","NamedArgsSingleEnumList","OptionalTest_CategoryType","OrderStatus","Tag","TestEnum",] )) diff --git a/integ-tests/python/baml_client/types.py b/integ-tests/python/baml_client/types.py index e470b5127..6e97d37bc 100644 --- a/integ-tests/python/baml_client/types.py +++ b/integ-tests/python/baml_client/types.py @@ -119,6 +119,30 @@ class TestEnum(str, Enum): F = "F" G = "G" +class Checks__valid_phone_number(BaseModel): + valid_phone_number: baml_py.Check + +class Checks__too_big(BaseModel): + too_big: baml_py.Check + +class Checks__valid_email(BaseModel): + valid_email: baml_py.Check + +class Checks__unreasonably_certain(BaseModel): + unreasonably_certain: baml_py.Check + +class Checks__earth_aged__no_infants(BaseModel): + no_infants: baml_py.Check + earth_aged: baml_py.Check + +class Checks__regex_bad__regex_good__trivial(BaseModel): + trivial: baml_py.Check + regex_good: baml_py.Check + regex_bad: baml_py.Check + +class Checks__young_enough(BaseModel): + young_enough: baml_py.Check + class BigNumbers(BaseModel): @@ -165,6 +189,12 @@ class CompoundBigNumbers(BaseModel): big_nums: List["BigNumbers"] another: "BigNumbers" +class ContactInfo(BaseModel): + + + primary: Union["PhoneNumber", "EmailAddress"] + secondary: Union["PhoneNumber", "EmailAddress", None] + class CustomTaskResult(BaseModel): @@ -203,6 +233,11 @@ class DynamicOutput(BaseModel): model_config = ConfigDict(extra='allow') +class Earthling(BaseModel): + + + age: baml_py.Checked[int,Checks__earth_aged__no_infants] + class Education(BaseModel): @@ -219,6 +254,11 @@ class Email(BaseModel): body: str from_address: str +class EmailAddress(BaseModel): + + + value: baml_py.Checked[str,Checks__valid_email] + class Event(BaseModel): @@ -241,6 +281,13 @@ class FlightConfirmation(BaseModel): arrivalTime: str seatNumber: str +class FooAny(BaseModel): + + + planetary_age: Union["Martian", "Earthling"] + certainty: baml_py.Checked[int,Checks__unreasonably_certain] + species: baml_py.Checked[str,Checks__regex_bad__regex_good__trivial] + class GroceryReceipt(BaseModel): @@ -262,6 +309,11 @@ class InnerClass2(BaseModel): prop2: int prop3: float +class Martian(BaseModel): + + + age: baml_py.Checked[int,Checks__young_enough] + class NamedArgsSingleClass(BaseModel): @@ -309,6 +361,11 @@ class Person(BaseModel): name: Optional[str] = None hair_color: Optional[Union["Color", str]] = None +class PhoneNumber(BaseModel): + + + value: baml_py.Checked[str,Checks__valid_phone_number] + class Quantity(BaseModel): @@ -411,6 +468,13 @@ class TestOutputClass(BaseModel): prop1: str prop2: int +class TwoStoriesOneTitle(BaseModel): + + + title: str + story_a: str + story_b: str + class UnionTest_ReturnType(BaseModel): diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index d11607f5d..be94b27b2 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -21,11 +21,13 @@ from ..baml_client import partial_types from ..baml_client.types import ( DynInputOutput, + FooAny, NamedArgsSingleEnumList, NamedArgsSingleClass, StringToClassEntry, CompoundBigNumbers, ) +import baml_client.types as types from ..baml_client.tracing import trace, set_tags, flush, on_log_event from ..baml_client.type_builder import TypeBuilder from ..baml_client import reset_baml_env_vars @@ -59,6 +61,27 @@ async def test_single_string_list(self): res = await b.TestFnNamedArgsSingleStringList(["a", "b", "c"]) assert "a" in res and "b" in res and "c" in res + @pytest.mark.asyncio + async def test_constraints(self): + res = await b.PredictAge("Greg") + assert res.certainty.checks.unreasonably_certain.status == "failed" + + @pytest.mark.asyncio + async def test_constraint_union_variant_checking(self): + res = await b.ExtractContactInfo("Reach me at 123-456-7890") + assert res.primary.value is not None + assert res.primary.value.checks.valid_phone_number.status == "succeeded" + + res = await b.ExtractContactInfo("Reach me at help@boundaryml.com") + assert res.primary.value is not None + assert res.primary.value.checks.valid_email.status == "succeeded" + assert res.secondary is None + + res = await b.ExtractContactInfo("Reach me at help@boundaryml.com, or 111-222-3333 if needed.") + assert res.primary.value is not None + assert res.primary.value.checks.valid_email.status == "succeeded" + assert res.secondary.value.checks.valid_phone_number.status == "succeeded" + @pytest.mark.asyncio async def test_single_class(self): res = await b.TestFnNamedArgsSingleClass( @@ -1211,3 +1234,25 @@ async def test_no_stream_compound_object_with_yapping(): if msg.another is not None: assert True if msg.another.a is None else msg.another.a == res.another.a assert True if msg.another.b is None else msg.another.b == res.another.b + +@pytest.mark.asyncio +async def test_return_failing_assert(): + with pytest.raises(errors.BamlValidationError): + msg = await b.ReturnFailingAssert(1) + + +@pytest.mark.asyncio +async def test_parameter_failing_assert(): + with pytest.raises(errors.BamlInvalidArgumentError): + msg = await b.ReturnFailingAssert(100) + assert msg == 103 + +@pytest.mark.asyncio +async def test_failing_assert_can_stream(): + stream = b.stream.StreamFailingAssertion("Yoshimi battles the pink robots", 300) + async for msg in stream: + print(msg.story_a) + print(msg.story_b) + with pytest.raises(errors.BamlValidationError): + final = await stream.get_final_response() + assert "Yoshimi" in final.story_a diff --git a/integ-tests/ruby/baml_client/client.rb b/integ-tests/ruby/baml_client/client.rb index 63cee85fe..e64f3e566 100644 --- a/integ-tests/ruby/baml_client/client.rb +++ b/integ-tests/ruby/baml_client/client.rb @@ -562,6 +562,38 @@ def ExpectFailure( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + document: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::ContactInfo) + } + def ExtractContactInfo( + *varargs, + document:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("ExtractContactInfo may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "ExtractContactInfo", + { + document: document, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -1330,6 +1362,70 @@ def OptionalTest_Function( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + name: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::FooAny) + } + def PredictAge( + *varargs, + name:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PredictAge may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "PredictAge", + { + name: name, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + + sig { + params( + varargs: T.untyped, + inp: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Checked[Integer, Checks__too_big]) + } + def PredictAgeBare( + *varargs, + inp:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PredictAgeBare may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "PredictAgeBare", + { + inp: inp, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -1554,6 +1650,38 @@ def PromptTestStreaming( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + inp: Integer, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Integer) + } + def ReturnFailingAssert( + *varargs, + inp:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("ReturnFailingAssert may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "ReturnFailingAssert", + { + inp: inp, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -1618,6 +1746,38 @@ def StreamBigNumbers( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + theme: String,length: Integer, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::TwoStoriesOneTitle) + } + def StreamFailingAssertion( + *varargs, + theme:,length:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("StreamFailingAssertion may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "StreamFailingAssertion", + { + theme: theme,length: length, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -3247,6 +3407,41 @@ def ExpectFailure( ) end + sig { + params( + varargs: T.untyped, + document: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Types::ContactInfo]) + } + def ExtractContactInfo( + *varargs, + document:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("ExtractContactInfo may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "ExtractContactInfo", + { + document: document, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::PartialTypes::ContactInfo, Baml::Types::ContactInfo].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, @@ -4087,6 +4282,76 @@ def OptionalTest_Function( ) end + sig { + params( + varargs: T.untyped, + name: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Types::FooAny]) + } + def PredictAge( + *varargs, + name:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PredictAge may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "PredictAge", + { + name: name, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::PartialTypes::FooAny, Baml::Types::FooAny].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + + sig { + params( + varargs: T.untyped, + inp: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Checked[Integer, Checks__too_big]]) + } + def PredictAgeBare( + *varargs, + inp:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PredictAgeBare may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "PredictAgeBare", + { + inp: inp, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::Checked[T.nilable(Integer), Checks__too_big], Baml::Checked[Integer, Checks__too_big]].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, @@ -4332,6 +4597,41 @@ def PromptTestStreaming( ) end + sig { + params( + varargs: T.untyped, + inp: Integer, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Integer]) + } + def ReturnFailingAssert( + *varargs, + inp:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("ReturnFailingAssert may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "ReturnFailingAssert", + { + inp: inp, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.nilable(Integer), Integer].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, @@ -4402,6 +4702,41 @@ def StreamBigNumbers( ) end + sig { + params( + varargs: T.untyped, + theme: String,length: Integer, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Types::TwoStoriesOneTitle]) + } + def StreamFailingAssertion( + *varargs, + theme:,length:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("StreamFailingAssertion may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "StreamFailingAssertion", + { + theme: theme,length: length, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::PartialTypes::TwoStoriesOneTitle, Baml::Types::TwoStoriesOneTitle].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, diff --git a/integ-tests/ruby/baml_client/inlined.rb b/integ-tests/ruby/baml_client/inlined.rb index 5db069735..883858a8d 100644 --- a/integ-tests/ruby/baml_client/inlined.rb +++ b/integ-tests/ruby/baml_client/inlined.rb @@ -29,6 +29,8 @@ module Inlined "test-files/aliases/classes.baml" => "class TestClassAlias {\n key string @alias(\"key-dash\") @description(#\"\n This is a description for key\n af asdf\n \"#)\n key2 string @alias(\"key21\")\n key3 string @alias(\"key with space\")\n key4 string //unaliased\n key5 string @alias(\"key.with.punctuation/123\")\n}\n\nfunction FnTestClassAlias(input: string) -> TestClassAlias {\n client GPT35\n prompt #\"\n {{ctx.output_format}}\n \"#\n}\n\ntest FnTestClassAlias {\n functions [FnTestClassAlias]\n args {\n input \"example input\"\n }\n}\n", "test-files/aliases/enums.baml" => "enum TestEnum {\n A @alias(\"k1\") @description(#\"\n User is angry\n \"#)\n B @alias(\"k22\") @description(#\"\n User is happy\n \"#)\n // tests whether k1 doesnt incorrectly get matched with k11\n C @alias(\"k11\") @description(#\"\n User is sad\n \"#)\n D @alias(\"k44\") @description(\n User is confused\n )\n E @description(\n User is excited\n )\n F @alias(\"k5\") // only alias\n \n G @alias(\"k6\") @description(#\"\n User is bored\n With a long description\n \"#)\n \n @@alias(\"Category\")\n}\n\nfunction FnTestAliasedEnumOutput(input: string) -> TestEnum {\n client GPT35\n prompt #\"\n Classify the user input into the following category\n \n {{ ctx.output_format }}\n\n {{ _.role('user') }}\n {{input}}\n\n {{ _.role('assistant') }}\n Category ID:\n \"#\n}\n\ntest FnTestAliasedEnumOutput {\n functions [FnTestAliasedEnumOutput]\n args {\n input \"mehhhhh\"\n }\n}", "test-files/comments/comments.baml" => "// add some functions, classes, enums etc with comments all over.", + "test-files/constraints/constraints.baml" => "// These classes and functions test several properties of\n// constrains:\n//\n// - The ability for constrains on fields to pass or fail.\n// - The ability for constraints on bare args and return types to pass or fail.\n// - The ability of constraints to influence which variant of a union is chosen\n// by the parser, when the structure is not sufficient to decide.\n\n\nclass Martian {\n age int @check({{ this < 30 }}, young_enough)\n}\n\nclass Earthling {\n age int @check({{this < 200 and this > 0}}, earth_aged) @check({{this >1}}, no_infants)\n}\n\n\nclass FooAny {\n planetary_age Martian | Earthling\n certainty int @check({{this == 102931}}, unreasonably_certain)\n species string @check({{this == \"Homo sapiens\"}}, trivial) @check({{this|regex_match(\"Homo\")}}, regex_good) @check({{this|regex_match(\"neanderthalensis\")}}, regex_bad)\n}\n\n\nfunction PredictAge(name: string) -> FooAny {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling\n and capitalization). I'll give you a hint: If the name\n is \"Greg\", his age is 41.\n\n {{ctx.output_format}}\n \"#\n}\n\n\nfunction PredictAgeBare(inp: string @assert({{this|length > 1}}, big_enough)) -> int @check({{this == 10102}}, too_big) {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ inp.name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling).\n\n {{ctx.output_format}}\n \"#\n}\n\nfunction ReturnFailingAssert(inp: int @assert({{this < 10}}, small_int)) -> int @assert({{this > 100}}, big_int) {\n client GPT35\n prompt #\"\n Return the next integer after {{ inp }}.\n\n {{ctx.output_format}}\n \"#\n}\n\nclass TwoStoriesOneTitle {\n title string\n story_a string @assert( {{this|length > 1000000}}, too_long_story )\n story_b string @assert( {{this|length > 1000000}}, too_long_story )\n}\n\nfunction StreamFailingAssertion(theme: string, length: int) -> TwoStoriesOneTitle {\n client GPT35\n prompt #\"\n Tell me two different stories along the theme of {{ theme }} with the same title.\n Please make each about {{ length }} words long.\n {{ctx.output_format}}\n \"#\n}\n", + "test-files/constraints/contact-info.baml" => "class PhoneNumber {\n value string @check({{this|regex_match(\"\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}\")}}, valid_phone_number)\n}\n\nclass EmailAddress {\n value string @check({{this|regex_match(\"^[_]*([a-z0-9]+(\\.|_*)?)+@([a-z][a-z0-9-]+(\\.|-*\\.))+[a-z]{2,6}$\")}}, valid_email)\n}\n\nclass ContactInfo {\n primary PhoneNumber | EmailAddress\n secondary (PhoneNumber | EmailAddress)?\n}\n\nfunction ExtractContactInfo(document: string) -> ContactInfo {\n client GPT35\n prompt #\"\n Extract a primary contact info, and if possible a secondary contact\n info, from this document:\n\n {{ document }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/descriptions/descriptions.baml" => "\nclass Nested {\n prop3 string | null @description(#\"\n write \"three\"\n \"#)\n prop4 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n prop20 Nested2\n}\n\nclass Nested2 {\n prop11 string | null @description(#\"\n write \"three\"\n \"#)\n prop12 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n}\n\nclass Schema {\n prop1 string | null @description(#\"\n write \"one\"\n \"#)\n prop2 Nested | string @description(#\"\n write \"two\"\n \"#)\n prop5 (string | null)[] @description(#\"\n write \"hi\"\n \"#)\n prop6 string | Nested[] @alias(\"blah\") @description(#\"\n write the string \"blah\" regardless of the other types here\n \"#)\n nested_attrs (string | null | Nested)[] @description(#\"\n write the string \"nested\" regardless of other types\n \"#)\n parens (string | null) @description(#\"\n write \"parens1\"\n \"#)\n other_group (string | (int | string)) @description(#\"\n write \"other\"\n \"#) @alias(other)\n}\n\n\nfunction SchemaDescriptions(input: string) -> Schema {\n client GPT4o\n prompt #\"\n Return a schema with this format:\n\n {{ctx.output_format}}\n \"#\n}", "test-files/dynamic/client-registry.baml" => "// Intentionally use a bad key\nclient BadClient {\n provider openai\n options {\n model \"gpt-3.5-turbo\"\n api_key \"sk-invalid\"\n }\n}\n\nfunction ExpectFailure() -> string {\n client BadClient\n\n prompt #\"\n What is the capital of England?\n \"#\n}\n", "test-files/dynamic/dynamic.baml" => "class DynamicClassOne {\n @@dynamic\n}\n\nenum DynEnumOne {\n @@dynamic\n}\n\nenum DynEnumTwo {\n @@dynamic\n}\n\nclass SomeClassNestedDynamic {\n hi string\n @@dynamic\n\n}\n\nclass DynamicClassTwo {\n hi string\n some_class SomeClassNestedDynamic\n status DynEnumOne\n @@dynamic\n}\n\nfunction DynamicFunc(input: DynamicClassOne) -> DynamicClassTwo {\n client GPT35\n prompt #\"\n Please extract the schema from \n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nclass DynInputOutput {\n testKey string\n @@dynamic\n}\n\nfunction DynamicInputOutput(input: DynInputOutput) -> DynInputOutput {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\nfunction DynamicListInputOutput(input: DynInputOutput[]) -> DynInputOutput[] {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\n\n\nclass DynamicOutput {\n @@dynamic\n}\n \nfunction MyFunc(input: string) -> DynamicOutput {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ClassifyDynEnumTwo(input: string) -> DynEnumTwo {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}", diff --git a/integ-tests/ruby/baml_client/partial-types.rb b/integ-tests/ruby/baml_client/partial-types.rb index 4bb8df926..2c5bae3db 100644 --- a/integ-tests/ruby/baml_client/partial-types.rb +++ b/integ-tests/ruby/baml_client/partial-types.rb @@ -27,20 +27,25 @@ class ClassOptionalOutput < T::Struct; end class ClassOptionalOutput2 < T::Struct; end class ClassWithImage < T::Struct; end class CompoundBigNumbers < T::Struct; end + class ContactInfo < T::Struct; end class CustomTaskResult < T::Struct; end class DummyOutput < T::Struct; end class DynInputOutput < T::Struct; end class DynamicClassOne < T::Struct; end class DynamicClassTwo < T::Struct; end class DynamicOutput < T::Struct; end + class Earthling < T::Struct; end class Education < T::Struct; end class Email < T::Struct; end + class EmailAddress < T::Struct; end class Event < T::Struct; end class FakeImage < T::Struct; end class FlightConfirmation < T::Struct; end + class FooAny < T::Struct; end class GroceryReceipt < T::Struct; end class InnerClass < T::Struct; end class InnerClass2 < T::Struct; end + class Martian < T::Struct; end class NamedArgsSingleClass < T::Struct; end class Nested < T::Struct; end class Nested2 < T::Struct; end @@ -48,6 +53,7 @@ class OptionalTest_Prop1 < T::Struct; end class OptionalTest_ReturnType < T::Struct; end class OrderInfo < T::Struct; end class Person < T::Struct; end + class PhoneNumber < T::Struct; end class Quantity < T::Struct; end class RaysData < T::Struct; end class ReceiptInfo < T::Struct; end @@ -62,6 +68,7 @@ class TestClassAlias < T::Struct; end class TestClassNested < T::Struct; end class TestClassWithEnum < T::Struct; end class TestOutputClass < T::Struct; end + class TwoStoriesOneTitle < T::Struct; end class UnionTest_ReturnType < T::Struct; end class WithReasoning < T::Struct; end class BigNumbers < T::Struct @@ -170,6 +177,20 @@ def initialize(props) @props = props end end + class ContactInfo < T::Struct + include Baml::Sorbet::Struct + const :primary, T.nilable(T.any(Baml::PartialTypes::PhoneNumber, Baml::PartialTypes::EmailAddress)) + const :secondary, T.nilable(T.any(Baml::PartialTypes::PhoneNumber, Baml::PartialTypes::EmailAddress, T.nilable(NilClass))) + + def initialize(props) + super( + primary: props[:primary], + secondary: props[:secondary], + ) + + @props = props + end + end class CustomTaskResult < T::Struct include Baml::Sorbet::Struct const :bookOrder, T.nilable(T.any(Baml::PartialTypes::BookOrder, T.nilable(NilClass))) @@ -248,6 +269,18 @@ def initialize(props) @props = props end end + class Earthling < T::Struct + include Baml::Sorbet::Struct + const :age, Baml::Checked[T.nilable(Integer), Checks__earth_aged__no_infants] + + def initialize(props) + super( + age: props[:age], + ) + + @props = props + end + end class Education < T::Struct include Baml::Sorbet::Struct const :institution, T.nilable(String) @@ -284,6 +317,18 @@ def initialize(props) @props = props end end + class EmailAddress < T::Struct + include Baml::Sorbet::Struct + const :value, Baml::Checked[T.nilable(String), Checks__valid_email] + + def initialize(props) + super( + value: props[:value], + ) + + @props = props + end + end class Event < T::Struct include Baml::Sorbet::Struct const :title, T.nilable(String) @@ -334,6 +379,22 @@ def initialize(props) @props = props end end + class FooAny < T::Struct + include Baml::Sorbet::Struct + const :planetary_age, T.nilable(T.any(Baml::PartialTypes::Martian, Baml::PartialTypes::Earthling)) + const :certainty, Baml::Checked[T.nilable(Integer), Checks__unreasonably_certain] + const :species, Baml::Checked[T.nilable(String), Checks__regex_bad__regex_good__trivial] + + def initialize(props) + super( + planetary_age: props[:planetary_age], + certainty: props[:certainty], + species: props[:species], + ) + + @props = props + end + end class GroceryReceipt < T::Struct include Baml::Sorbet::Struct const :receiptId, T.nilable(String) @@ -382,6 +443,18 @@ def initialize(props) @props = props end end + class Martian < T::Struct + include Baml::Sorbet::Struct + const :age, Baml::Checked[T.nilable(Integer), Checks__young_enough] + + def initialize(props) + super( + age: props[:age], + ) + + @props = props + end + end class NamedArgsSingleClass < T::Struct include Baml::Sorbet::Struct const :key, T.nilable(String) @@ -488,6 +561,18 @@ def initialize(props) @props = props end end + class PhoneNumber < T::Struct + include Baml::Sorbet::Struct + const :value, Baml::Checked[T.nilable(String), Checks__valid_phone_number] + + def initialize(props) + super( + value: props[:value], + ) + + @props = props + end + end class Quantity < T::Struct include Baml::Sorbet::Struct const :amount, T.nilable(T.any(T.nilable(Integer), T.nilable(Float))) @@ -718,6 +803,22 @@ def initialize(props) @props = props end end + class TwoStoriesOneTitle < T::Struct + include Baml::Sorbet::Struct + const :title, T.nilable(String) + const :story_a, T.nilable(String) + const :story_b, T.nilable(String) + + def initialize(props) + super( + title: props[:title], + story_a: props[:story_a], + story_b: props[:story_b], + ) + + @props = props + end + end class UnionTest_ReturnType < T::Struct include Baml::Sorbet::Struct const :prop1, T.nilable(T.any(T.nilable(String), T.nilable(T::Boolean))) diff --git a/integ-tests/ruby/baml_client/type-registry.rb b/integ-tests/ruby/baml_client/type-registry.rb index 24d9a20ce..1293cef45 100644 --- a/integ-tests/ruby/baml_client/type-registry.rb +++ b/integ-tests/ruby/baml_client/type-registry.rb @@ -18,7 +18,7 @@ module Baml class TypeBuilder def initialize @registry = Baml::Ffi::TypeBuilder.new - @classes = Set[ "BigNumbers", "Blah", "BookOrder", "ClassOptionalOutput", "ClassOptionalOutput2", "ClassWithImage", "CompoundBigNumbers", "CustomTaskResult", "DummyOutput", "DynInputOutput", "DynamicClassOne", "DynamicClassTwo", "DynamicOutput", "Education", "Email", "Event", "FakeImage", "FlightConfirmation", "GroceryReceipt", "InnerClass", "InnerClass2", "NamedArgsSingleClass", "Nested", "Nested2", "OptionalTest_Prop1", "OptionalTest_ReturnType", "OrderInfo", "Person", "Quantity", "RaysData", "ReceiptInfo", "ReceiptItem", "Recipe", "Resume", "Schema", "SearchParams", "SomeClassNestedDynamic", "StringToClassEntry", "TestClassAlias", "TestClassNested", "TestClassWithEnum", "TestOutputClass", "UnionTest_ReturnType", "WithReasoning", ] + @classes = Set[ "BigNumbers", "Blah", "BookOrder", "ClassOptionalOutput", "ClassOptionalOutput2", "ClassWithImage", "CompoundBigNumbers", "ContactInfo", "CustomTaskResult", "DummyOutput", "DynInputOutput", "DynamicClassOne", "DynamicClassTwo", "DynamicOutput", "Earthling", "Education", "Email", "EmailAddress", "Event", "FakeImage", "FlightConfirmation", "FooAny", "GroceryReceipt", "InnerClass", "InnerClass2", "Martian", "NamedArgsSingleClass", "Nested", "Nested2", "OptionalTest_Prop1", "OptionalTest_ReturnType", "OrderInfo", "Person", "PhoneNumber", "Quantity", "RaysData", "ReceiptInfo", "ReceiptItem", "Recipe", "Resume", "Schema", "SearchParams", "SomeClassNestedDynamic", "StringToClassEntry", "TestClassAlias", "TestClassNested", "TestClassWithEnum", "TestOutputClass", "TwoStoriesOneTitle", "UnionTest_ReturnType", "WithReasoning", ] @enums = Set[ "Category", "Category2", "Category3", "Color", "DataType", "DynEnumOne", "DynEnumTwo", "EnumInClass", "EnumOutput", "Hobby", "NamedArgsSingleEnum", "NamedArgsSingleEnumList", "OptionalTest_CategoryType", "OrderStatus", "Tag", "TestEnum", ] end diff --git a/integ-tests/ruby/baml_client/types.rb b/integ-tests/ruby/baml_client/types.rb index 989ec5aae..1e90b9f35 100644 --- a/integ-tests/ruby/baml_client/types.rb +++ b/integ-tests/ruby/baml_client/types.rb @@ -137,20 +137,25 @@ class ClassOptionalOutput < T::Struct; end class ClassOptionalOutput2 < T::Struct; end class ClassWithImage < T::Struct; end class CompoundBigNumbers < T::Struct; end + class ContactInfo < T::Struct; end class CustomTaskResult < T::Struct; end class DummyOutput < T::Struct; end class DynInputOutput < T::Struct; end class DynamicClassOne < T::Struct; end class DynamicClassTwo < T::Struct; end class DynamicOutput < T::Struct; end + class Earthling < T::Struct; end class Education < T::Struct; end class Email < T::Struct; end + class EmailAddress < T::Struct; end class Event < T::Struct; end class FakeImage < T::Struct; end class FlightConfirmation < T::Struct; end + class FooAny < T::Struct; end class GroceryReceipt < T::Struct; end class InnerClass < T::Struct; end class InnerClass2 < T::Struct; end + class Martian < T::Struct; end class NamedArgsSingleClass < T::Struct; end class Nested < T::Struct; end class Nested2 < T::Struct; end @@ -158,6 +163,7 @@ class OptionalTest_Prop1 < T::Struct; end class OptionalTest_ReturnType < T::Struct; end class OrderInfo < T::Struct; end class Person < T::Struct; end + class PhoneNumber < T::Struct; end class Quantity < T::Struct; end class RaysData < T::Struct; end class ReceiptInfo < T::Struct; end @@ -172,8 +178,16 @@ class TestClassAlias < T::Struct; end class TestClassNested < T::Struct; end class TestClassWithEnum < T::Struct; end class TestOutputClass < T::Struct; end + class TwoStoriesOneTitle < T::Struct; end class UnionTest_ReturnType < T::Struct; end class WithReasoning < T::Struct; end + class Checks__young_enough < T::Struct; end + class Checks__regex_bad__regex_good__trivial < T::Struct; end + class Checks__unreasonably_certain < T::Struct; end + class Checks__earth_aged__no_infants < T::Struct; end + class Checks__valid_email < T::Struct; end + class Checks__valid_phone_number < T::Struct; end + class Checks__too_big < T::Struct; end class BigNumbers < T::Struct include Baml::Sorbet::Struct const :a, Integer @@ -280,6 +294,20 @@ def initialize(props) @props = props end end + class ContactInfo < T::Struct + include Baml::Sorbet::Struct + const :primary, T.any(Baml::Types::PhoneNumber, Baml::Types::EmailAddress) + const :secondary, T.any(Baml::Types::PhoneNumber, Baml::Types::EmailAddress, NilClass) + + def initialize(props) + super( + primary: props[:primary], + secondary: props[:secondary], + ) + + @props = props + end + end class CustomTaskResult < T::Struct include Baml::Sorbet::Struct const :bookOrder, T.any(Baml::Types::BookOrder, T.nilable(NilClass)) @@ -358,6 +386,18 @@ def initialize(props) @props = props end end + class Earthling < T::Struct + include Baml::Sorbet::Struct + const :age, Baml::Checked[Integer, Checks__earth_aged__no_infants] + + def initialize(props) + super( + age: props[:age], + ) + + @props = props + end + end class Education < T::Struct include Baml::Sorbet::Struct const :institution, String @@ -394,6 +434,18 @@ def initialize(props) @props = props end end + class EmailAddress < T::Struct + include Baml::Sorbet::Struct + const :value, Baml::Checked[String, Checks__valid_email] + + def initialize(props) + super( + value: props[:value], + ) + + @props = props + end + end class Event < T::Struct include Baml::Sorbet::Struct const :title, String @@ -444,6 +496,22 @@ def initialize(props) @props = props end end + class FooAny < T::Struct + include Baml::Sorbet::Struct + const :planetary_age, T.any(Baml::Types::Martian, Baml::Types::Earthling) + const :certainty, Baml::Checked[Integer, Checks__unreasonably_certain] + const :species, Baml::Checked[String, Checks__regex_bad__regex_good__trivial] + + def initialize(props) + super( + planetary_age: props[:planetary_age], + certainty: props[:certainty], + species: props[:species], + ) + + @props = props + end + end class GroceryReceipt < T::Struct include Baml::Sorbet::Struct const :receiptId, String @@ -492,6 +560,18 @@ def initialize(props) @props = props end end + class Martian < T::Struct + include Baml::Sorbet::Struct + const :age, Baml::Checked[Integer, Checks__young_enough] + + def initialize(props) + super( + age: props[:age], + ) + + @props = props + end + end class NamedArgsSingleClass < T::Struct include Baml::Sorbet::Struct const :key, String @@ -598,6 +678,18 @@ def initialize(props) @props = props end end + class PhoneNumber < T::Struct + include Baml::Sorbet::Struct + const :value, Baml::Checked[String, Checks__valid_phone_number] + + def initialize(props) + super( + value: props[:value], + ) + + @props = props + end + end class Quantity < T::Struct include Baml::Sorbet::Struct const :amount, T.any(Integer, Float) @@ -828,6 +920,22 @@ def initialize(props) @props = props end end + class TwoStoriesOneTitle < T::Struct + include Baml::Sorbet::Struct + const :title, String + const :story_a, String + const :story_b, String + + def initialize(props) + super( + title: props[:title], + story_a: props[:story_a], + story_b: props[:story_b], + ) + + @props = props + end + end class UnionTest_ReturnType < T::Struct include Baml::Sorbet::Struct const :prop1, T.any(String, T::Boolean) @@ -858,5 +966,96 @@ def initialize(props) @props = props end end + class Checks__young_enough < T::Struct + include Baml::Sorbet::Struct + const :young_enough, Baml::Check + + def initialize(props) + super( + young_enough: props[:young_enough], + ) + + @props = props + end + end + class Checks__regex_bad__regex_good__trivial < T::Struct + include Baml::Sorbet::Struct + const :regex_bad, Baml::Check + const :regex_good, Baml::Check + const :trivial, Baml::Check + + def initialize(props) + super( + regex_bad: props[:regex_bad], + regex_good: props[:regex_good], + trivial: props[:trivial], + ) + + @props = props + end + end + class Checks__unreasonably_certain < T::Struct + include Baml::Sorbet::Struct + const :unreasonably_certain, Baml::Check + + def initialize(props) + super( + unreasonably_certain: props[:unreasonably_certain], + ) + + @props = props + end + end + class Checks__earth_aged__no_infants < T::Struct + include Baml::Sorbet::Struct + const :earth_aged, Baml::Check + const :no_infants, Baml::Check + + def initialize(props) + super( + earth_aged: props[:earth_aged], + no_infants: props[:no_infants], + ) + + @props = props + end + end + class Checks__valid_email < T::Struct + include Baml::Sorbet::Struct + const :valid_email, Baml::Check + + def initialize(props) + super( + valid_email: props[:valid_email], + ) + + @props = props + end + end + class Checks__valid_phone_number < T::Struct + include Baml::Sorbet::Struct + const :valid_phone_number, Baml::Check + + def initialize(props) + super( + valid_phone_number: props[:valid_phone_number], + ) + + @props = props + end + end + class Checks__too_big < T::Struct + include Baml::Sorbet::Struct + const :too_big, Baml::Check + + def initialize(props) + super( + too_big: props[:too_big], + ) + + @props = props + end + end + end end \ No newline at end of file diff --git a/integ-tests/ruby/test_functions.rb b/integ-tests/ruby/test_functions.rb index 784c77953..907ec452a 100644 --- a/integ-tests/ruby/test_functions.rb +++ b/integ-tests/ruby/test_functions.rb @@ -287,4 +287,10 @@ ) assert_match(/london/, capitol.downcase) end + + it "uses constraints for unions" do + res = b.ExtractContactInfo(document: "reach me at 888-888-8888, or try to email hello@boundaryml.com") + assert_equal res['primary']['value'].value, "888-888-8888" + assert_equal res['primary']['value'].checks.valid_phone_number.status, "succeeded" + end end diff --git a/integ-tests/typescript/README.md b/integ-tests/typescript/README.md index f83dcaa51..1b23ead8f 100644 --- a/integ-tests/typescript/README.md +++ b/integ-tests/typescript/README.md @@ -1,2 +1,13 @@ To run the test with a filter: -pnpm integ-tests -t "works with fallbacks" \ No newline at end of file + +``` bash +pnpm integ-tests -t "works with fallbacks" +``` + + +Note: Before running, you need to build the typescript runtime: + +``` bash +cd engine/language_client_typescript +pnpm build:debug +``` diff --git a/integ-tests/typescript/baml_client/async_client.ts b/integ-tests/typescript/baml_client/async_client.ts index 41ea1ac5e..a73ef000f 100644 --- a/integ-tests/typescript/baml_client/async_client.ts +++ b/integ-tests/typescript/baml_client/async_client.ts @@ -16,7 +16,7 @@ $ pnpm add @boundaryml/baml // @ts-nocheck // biome-ignore format: autogenerated code import { BamlRuntime, FunctionResult, BamlCtxManager, BamlStream, Image, ClientRegistry, BamlValidationError, createBamlValidationError } from "@boundaryml/baml" -import {BigNumbers, Blah, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassWithImage, CompoundBigNumbers, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Education, Email, Event, FakeImage, FlightConfirmation, GroceryReceipt, InnerClass, InnerClass2, NamedArgsSingleClass, Nested, Nested2, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, Person, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, UnionTest_ReturnType, WithReasoning, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" +import {BigNumbers, Blah, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassWithImage, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, GroceryReceipt, InnerClass, InnerClass2, Martian, NamedArgsSingleClass, Nested, Nested2, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, TwoStoriesOneTitle, UnionTest_ReturnType, WithReasoning, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" import TypeBuilder from "./type_builder" import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME } from "./globals" @@ -442,6 +442,31 @@ export class BamlAsyncClient { } } + async ExtractContactInfo( + document: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "ExtractContactInfo", + { + "document": document + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as ContactInfo + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async ExtractNames( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1042,6 +1067,56 @@ export class BamlAsyncClient { } } + async PredictAge( + name: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "PredictAge", + { + "name": name + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as FooAny + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + async PredictAgeBare( + inp: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise> { + try { + const raw = await this.runtime.callFunction( + "PredictAgeBare", + { + "inp": inp + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as Checked + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async PromptTestClaude( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1217,6 +1292,31 @@ export class BamlAsyncClient { } } + async ReturnFailingAssert( + inp: number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "ReturnFailingAssert", + { + "inp": inp + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as number + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async SchemaDescriptions( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1267,6 +1367,31 @@ export class BamlAsyncClient { } } + async StreamFailingAssertion( + theme: string,length: number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "StreamFailingAssertion", + { + "theme": theme,"length": length + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as TwoStoriesOneTitle + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async StreamOneBigNumber( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -2626,6 +2751,39 @@ class BamlStreamClient { } } + ExtractContactInfo( + document: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, ContactInfo> { + try { + const raw = this.runtime.streamFunction( + "ExtractContactInfo", + { + "document": document + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, ContactInfo>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is ContactInfo => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + ExtractNames( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -3418,6 +3576,72 @@ class BamlStreamClient { } } + PredictAge( + name: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, FooAny> { + try { + const raw = this.runtime.streamFunction( + "PredictAge", + { + "name": name + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, FooAny>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is FooAny => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + + PredictAgeBare( + inp: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream>, Checked> { + try { + const raw = this.runtime.streamFunction( + "PredictAgeBare", + { + "inp": inp + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream>, Checked>( + raw, + (a): a is RecursivePartialNull> => a, + (a): a is Checked => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + PromptTestClaude( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -3649,6 +3873,39 @@ class BamlStreamClient { } } + ReturnFailingAssert( + inp: number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, number> { + try { + const raw = this.runtime.streamFunction( + "ReturnFailingAssert", + { + "inp": inp + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, number>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is number => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + SchemaDescriptions( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -3715,6 +3972,39 @@ class BamlStreamClient { } } + StreamFailingAssertion( + theme: string,length: number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, TwoStoriesOneTitle> { + try { + const raw = this.runtime.streamFunction( + "StreamFailingAssertion", + { + "theme": theme,"length": length + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, TwoStoriesOneTitle>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is TwoStoriesOneTitle => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + StreamOneBigNumber( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } diff --git a/integ-tests/typescript/baml_client/index.ts b/integ-tests/typescript/baml_client/index.ts index 6fc0cb7ab..bdc7d19a1 100644 --- a/integ-tests/typescript/baml_client/index.ts +++ b/integ-tests/typescript/baml_client/index.ts @@ -21,4 +21,4 @@ export { b } from "./async_client" export * from "./types" export * from "./tracing" export { resetBamlEnvVars } from "./globals" -export { BamlValidationError } from "@boundaryml/baml" \ No newline at end of file +export { BamlValidationError, Checked } from "@boundaryml/baml" \ No newline at end of file diff --git a/integ-tests/typescript/baml_client/inlinedbaml.ts b/integ-tests/typescript/baml_client/inlinedbaml.ts index 185c840b8..61df7c6a5 100644 --- a/integ-tests/typescript/baml_client/inlinedbaml.ts +++ b/integ-tests/typescript/baml_client/inlinedbaml.ts @@ -30,6 +30,8 @@ const fileMap = { "test-files/aliases/classes.baml": "class TestClassAlias {\n key string @alias(\"key-dash\") @description(#\"\n This is a description for key\n af asdf\n \"#)\n key2 string @alias(\"key21\")\n key3 string @alias(\"key with space\")\n key4 string //unaliased\n key5 string @alias(\"key.with.punctuation/123\")\n}\n\nfunction FnTestClassAlias(input: string) -> TestClassAlias {\n client GPT35\n prompt #\"\n {{ctx.output_format}}\n \"#\n}\n\ntest FnTestClassAlias {\n functions [FnTestClassAlias]\n args {\n input \"example input\"\n }\n}\n", "test-files/aliases/enums.baml": "enum TestEnum {\n A @alias(\"k1\") @description(#\"\n User is angry\n \"#)\n B @alias(\"k22\") @description(#\"\n User is happy\n \"#)\n // tests whether k1 doesnt incorrectly get matched with k11\n C @alias(\"k11\") @description(#\"\n User is sad\n \"#)\n D @alias(\"k44\") @description(\n User is confused\n )\n E @description(\n User is excited\n )\n F @alias(\"k5\") // only alias\n \n G @alias(\"k6\") @description(#\"\n User is bored\n With a long description\n \"#)\n \n @@alias(\"Category\")\n}\n\nfunction FnTestAliasedEnumOutput(input: string) -> TestEnum {\n client GPT35\n prompt #\"\n Classify the user input into the following category\n \n {{ ctx.output_format }}\n\n {{ _.role('user') }}\n {{input}}\n\n {{ _.role('assistant') }}\n Category ID:\n \"#\n}\n\ntest FnTestAliasedEnumOutput {\n functions [FnTestAliasedEnumOutput]\n args {\n input \"mehhhhh\"\n }\n}", "test-files/comments/comments.baml": "// add some functions, classes, enums etc with comments all over.", + "test-files/constraints/constraints.baml": "// These classes and functions test several properties of\n// constrains:\n//\n// - The ability for constrains on fields to pass or fail.\n// - The ability for constraints on bare args and return types to pass or fail.\n// - The ability of constraints to influence which variant of a union is chosen\n// by the parser, when the structure is not sufficient to decide.\n\n\nclass Martian {\n age int @check({{ this < 30 }}, young_enough)\n}\n\nclass Earthling {\n age int @check({{this < 200 and this > 0}}, earth_aged) @check({{this >1}}, no_infants)\n}\n\n\nclass FooAny {\n planetary_age Martian | Earthling\n certainty int @check({{this == 102931}}, unreasonably_certain)\n species string @check({{this == \"Homo sapiens\"}}, trivial) @check({{this|regex_match(\"Homo\")}}, regex_good) @check({{this|regex_match(\"neanderthalensis\")}}, regex_bad)\n}\n\n\nfunction PredictAge(name: string) -> FooAny {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling\n and capitalization). I'll give you a hint: If the name\n is \"Greg\", his age is 41.\n\n {{ctx.output_format}}\n \"#\n}\n\n\nfunction PredictAgeBare(inp: string @assert({{this|length > 1}}, big_enough)) -> int @check({{this == 10102}}, too_big) {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ inp.name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling).\n\n {{ctx.output_format}}\n \"#\n}\n\nfunction ReturnFailingAssert(inp: int @assert({{this < 10}}, small_int)) -> int @assert({{this > 100}}, big_int) {\n client GPT35\n prompt #\"\n Return the next integer after {{ inp }}.\n\n {{ctx.output_format}}\n \"#\n}\n\nclass TwoStoriesOneTitle {\n title string\n story_a string @assert( {{this|length > 1000000}}, too_long_story )\n story_b string @assert( {{this|length > 1000000}}, too_long_story )\n}\n\nfunction StreamFailingAssertion(theme: string, length: int) -> TwoStoriesOneTitle {\n client GPT35\n prompt #\"\n Tell me two different stories along the theme of {{ theme }} with the same title.\n Please make each about {{ length }} words long.\n {{ctx.output_format}}\n \"#\n}\n", + "test-files/constraints/contact-info.baml": "class PhoneNumber {\n value string @check({{this|regex_match(\"\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}\")}}, valid_phone_number)\n}\n\nclass EmailAddress {\n value string @check({{this|regex_match(\"^[_]*([a-z0-9]+(\\.|_*)?)+@([a-z][a-z0-9-]+(\\.|-*\\.))+[a-z]{2,6}$\")}}, valid_email)\n}\n\nclass ContactInfo {\n primary PhoneNumber | EmailAddress\n secondary (PhoneNumber | EmailAddress)?\n}\n\nfunction ExtractContactInfo(document: string) -> ContactInfo {\n client GPT35\n prompt #\"\n Extract a primary contact info, and if possible a secondary contact\n info, from this document:\n\n {{ document }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/descriptions/descriptions.baml": "\nclass Nested {\n prop3 string | null @description(#\"\n write \"three\"\n \"#)\n prop4 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n prop20 Nested2\n}\n\nclass Nested2 {\n prop11 string | null @description(#\"\n write \"three\"\n \"#)\n prop12 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n}\n\nclass Schema {\n prop1 string | null @description(#\"\n write \"one\"\n \"#)\n prop2 Nested | string @description(#\"\n write \"two\"\n \"#)\n prop5 (string | null)[] @description(#\"\n write \"hi\"\n \"#)\n prop6 string | Nested[] @alias(\"blah\") @description(#\"\n write the string \"blah\" regardless of the other types here\n \"#)\n nested_attrs (string | null | Nested)[] @description(#\"\n write the string \"nested\" regardless of other types\n \"#)\n parens (string | null) @description(#\"\n write \"parens1\"\n \"#)\n other_group (string | (int | string)) @description(#\"\n write \"other\"\n \"#) @alias(other)\n}\n\n\nfunction SchemaDescriptions(input: string) -> Schema {\n client GPT4o\n prompt #\"\n Return a schema with this format:\n\n {{ctx.output_format}}\n \"#\n}", "test-files/dynamic/client-registry.baml": "// Intentionally use a bad key\nclient BadClient {\n provider openai\n options {\n model \"gpt-3.5-turbo\"\n api_key \"sk-invalid\"\n }\n}\n\nfunction ExpectFailure() -> string {\n client BadClient\n\n prompt #\"\n What is the capital of England?\n \"#\n}\n", "test-files/dynamic/dynamic.baml": "class DynamicClassOne {\n @@dynamic\n}\n\nenum DynEnumOne {\n @@dynamic\n}\n\nenum DynEnumTwo {\n @@dynamic\n}\n\nclass SomeClassNestedDynamic {\n hi string\n @@dynamic\n\n}\n\nclass DynamicClassTwo {\n hi string\n some_class SomeClassNestedDynamic\n status DynEnumOne\n @@dynamic\n}\n\nfunction DynamicFunc(input: DynamicClassOne) -> DynamicClassTwo {\n client GPT35\n prompt #\"\n Please extract the schema from \n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nclass DynInputOutput {\n testKey string\n @@dynamic\n}\n\nfunction DynamicInputOutput(input: DynInputOutput) -> DynInputOutput {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\nfunction DynamicListInputOutput(input: DynInputOutput[]) -> DynInputOutput[] {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\n\n\nclass DynamicOutput {\n @@dynamic\n}\n \nfunction MyFunc(input: string) -> DynamicOutput {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ClassifyDynEnumTwo(input: string) -> DynEnumTwo {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}", diff --git a/integ-tests/typescript/baml_client/sync_client.ts b/integ-tests/typescript/baml_client/sync_client.ts index 25e577efa..80bdb9642 100644 --- a/integ-tests/typescript/baml_client/sync_client.ts +++ b/integ-tests/typescript/baml_client/sync_client.ts @@ -16,7 +16,7 @@ $ pnpm add @boundaryml/baml // @ts-nocheck // biome-ignore format: autogenerated code import { BamlRuntime, FunctionResult, BamlCtxManager, BamlSyncStream, Image, ClientRegistry } from "@boundaryml/baml" -import {BigNumbers, Blah, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassWithImage, CompoundBigNumbers, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Education, Email, Event, FakeImage, FlightConfirmation, GroceryReceipt, InnerClass, InnerClass2, NamedArgsSingleClass, Nested, Nested2, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, Person, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, UnionTest_ReturnType, WithReasoning, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" +import {BigNumbers, Blah, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassWithImage, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, GroceryReceipt, InnerClass, InnerClass2, Martian, NamedArgsSingleClass, Nested, Nested2, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, TwoStoriesOneTitle, UnionTest_ReturnType, WithReasoning, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" import TypeBuilder from "./type_builder" import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME } from "./globals" @@ -442,6 +442,31 @@ export class BamlSyncClient { } } + ExtractContactInfo( + document: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): ContactInfo { + try { + const raw = this.runtime.callFunctionSync( + "ExtractContactInfo", + { + "document": document + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as ContactInfo + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + ExtractNames( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1042,6 +1067,56 @@ export class BamlSyncClient { } } + PredictAge( + name: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): FooAny { + try { + const raw = this.runtime.callFunctionSync( + "PredictAge", + { + "name": name + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as FooAny + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + PredictAgeBare( + inp: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Checked { + try { + const raw = this.runtime.callFunctionSync( + "PredictAgeBare", + { + "inp": inp + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as Checked + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + PromptTestClaude( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1217,6 +1292,31 @@ export class BamlSyncClient { } } + ReturnFailingAssert( + inp: number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): number { + try { + const raw = this.runtime.callFunctionSync( + "ReturnFailingAssert", + { + "inp": inp + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as number + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + SchemaDescriptions( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1267,6 +1367,31 @@ export class BamlSyncClient { } } + StreamFailingAssertion( + theme: string,length: number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): TwoStoriesOneTitle { + try { + const raw = this.runtime.callFunctionSync( + "StreamFailingAssertion", + { + "theme": theme,"length": length + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as TwoStoriesOneTitle + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + StreamOneBigNumber( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } diff --git a/integ-tests/typescript/baml_client/type_builder.ts b/integ-tests/typescript/baml_client/type_builder.ts index 1c95e7578..095f3af98 100644 --- a/integ-tests/typescript/baml_client/type_builder.ts +++ b/integ-tests/typescript/baml_client/type_builder.ts @@ -48,7 +48,7 @@ export default class TypeBuilder { constructor() { this.tb = new _TypeBuilder({ classes: new Set([ - "BigNumbers","Blah","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassWithImage","CompoundBigNumbers","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Education","Email","Event","FakeImage","FlightConfirmation","GroceryReceipt","InnerClass","InnerClass2","NamedArgsSingleClass","Nested","Nested2","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","Person","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","UnionTest_ReturnType","WithReasoning", + "BigNumbers","Blah","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","GroceryReceipt","InnerClass","InnerClass2","Martian","NamedArgsSingleClass","Nested","Nested2","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning", ]), enums: new Set([ "Category","Category2","Category3","Color","DataType","DynEnumOne","DynEnumTwo","EnumInClass","EnumOutput","Hobby","NamedArgsSingleEnum","NamedArgsSingleEnumList","OptionalTest_CategoryType","OrderStatus","Tag","TestEnum", diff --git a/integ-tests/typescript/baml_client/types.ts b/integ-tests/typescript/baml_client/types.ts index cc6c2b3d7..0e56d98b6 100644 --- a/integ-tests/typescript/baml_client/types.ts +++ b/integ-tests/typescript/baml_client/types.ts @@ -116,6 +116,37 @@ export enum TestEnum { G = "G", } +export interface Checks__too_big { + too_big: Check +} + +export interface Checks__valid_phone_number { + valid_phone_number: Check +} + +export interface Checks__valid_email { + valid_email: Check +} + +export interface Checks__earth_aged__no_infants { + earth_aged: Check + no_infants: Check +} + +export interface Checks__unreasonably_certain { + unreasonably_certain: Check +} + +export interface Checks__regex_bad__regex_good__trivial { + regex_good: Check + trivial: Check + regex_bad: Check +} + +export interface Checks__young_enough { + young_enough: Check +} + export interface BigNumbers { a: number b: number @@ -162,6 +193,12 @@ export interface CompoundBigNumbers { } +export interface ContactInfo { + primary: PhoneNumber | EmailAddress + secondary?: PhoneNumber | EmailAddress | null + +} + export interface CustomTaskResult { bookOrder?: BookOrder | null | null flightConfirmation?: FlightConfirmation | null | null @@ -200,6 +237,11 @@ export interface DynamicOutput { [key: string]: any; } +export interface Earthling { + age: Checked + +} + export interface Education { institution: string location: string @@ -216,6 +258,11 @@ export interface Email { } +export interface EmailAddress { + value: Checked + +} + export interface Event { title: string date: string @@ -238,6 +285,13 @@ export interface FlightConfirmation { } +export interface FooAny { + planetary_age: Martian | Earthling + certainty: Checked + species: Checked + +} + export interface GroceryReceipt { receiptId: string storeName: string @@ -259,6 +313,11 @@ export interface InnerClass2 { } +export interface Martian { + age: Checked + +} + export interface NamedArgsSingleClass { key: string key_two: boolean @@ -306,6 +365,11 @@ export interface Person { [key: string]: any; } +export interface PhoneNumber { + value: Checked + +} + export interface Quantity { amount: number | number unit?: string | null @@ -408,6 +472,13 @@ export interface TestOutputClass { } +export interface TwoStoriesOneTitle { + title: string + story_a: string + story_b: string + +} + export interface UnionTest_ReturnType { prop1: string | boolean prop2: (number | boolean)[] diff --git a/integ-tests/typescript/test-report.html b/integ-tests/typescript/test-report.html index 24bca5c57..cc579d16d 100644 --- a/integ-tests/typescript/test-report.html +++ b/integ-tests/typescript/test-report.html @@ -257,4 +257,8 @@ font-size: 1rem; padding: 0 0.5rem; } -

Test Report

Started: 2024-10-07 14:47:45
Suites (1)
1 passed
0 failed
0 pending
Tests (46)
1 passed
0 failed
45 pending
Integ tests > should work for all inputs
single bool
pending
0s
Integ tests > should work for all inputs
single string list
pending
0s
Integ tests > should work for all inputs
single class
pending
0s
Integ tests > should work for all inputs
multiple classes
pending
0s
Integ tests > should work for all inputs
single enum list
pending
0s
Integ tests > should work for all inputs
single float
pending
0s
Integ tests > should work for all inputs
single int
pending
0s
Integ tests > should work for all inputs
single optional string
pending
0s
Integ tests > should work for all inputs
single map string to string
pending
0s
Integ tests > should work for all inputs
single map string to class
pending
0s
Integ tests > should work for all inputs
single map string to map
pending
0s
Integ tests
should work for all outputs
pending
0s
Integ tests
works with retries1
pending
0s
Integ tests
works with retries2
pending
0s
Integ tests
works with fallbacks
pending
0s
Integ tests
should work with image from url
pending
0s
Integ tests
should work with image from base 64
pending
0s
Integ tests
should work with audio base 64
pending
0s
Integ tests
should work with audio from url
pending
0s
Integ tests
should support streaming in OpenAI
pending
0s
Integ tests
should support streaming in Gemini
pending
0s
Integ tests
should support AWS
pending
0s
Integ tests
should support streaming in AWS
pending
0s
Integ tests
should support OpenAI shorthand
pending
0s
Integ tests
should support OpenAI shorthand streaming
pending
0s
Integ tests
should support anthropic shorthand
pending
0s
Integ tests
should support anthropic shorthand streaming
pending
0s
Integ tests
should support streaming without iterating
pending
0s
Integ tests
should support streaming in Claude
pending
0s
Integ tests
should support vertex
pending
0s
Integ tests
supports tracing sync
pending
0s
Integ tests
supports tracing async
pending
0s
Integ tests
should work with dynamic types single
pending
0s
Integ tests
should work with dynamic types enum
pending
0s
Integ tests
should work with dynamic types class
pending
0s
Integ tests
should work with dynamic inputs class
pending
0s
Integ tests
should work with dynamic inputs list
pending
0s
Integ tests
should work with dynamic output map
pending
0s
Integ tests
should work with dynamic output union
pending
0s
Integ tests
should work with nested classes
pending
0s
Integ tests
should work with dynamic client
pending
0s
Integ tests
should work with 'onLogEvent'
pending
0s
Integ tests
should work with a sync client
pending
0s
Integ tests
should raise an error when appropriate
pending
0s
Integ tests
should raise a BAMLValidationError
passed
0.58s
Integ tests
should reset environment variables correctly
pending
0s
\ No newline at end of file +

Test Report

Started: 2024-10-17 22:02:55
Suites (1)
0 passed
1 failed
0 pending
Tests (47)
0 passed
1 failed
46 pending
Integ tests > should work for all inputs
single bool
pending
0s
Integ tests > should work for all inputs
single string list
pending
0s
Integ tests > should work for all inputs
single class
pending
0s
Integ tests > should work for all inputs
multiple classes
pending
0s
Integ tests > should work for all inputs
single enum list
pending
0s
Integ tests > should work for all inputs
single float
pending
0s
Integ tests > should work for all inputs
single int
pending
0s
Integ tests > should work for all inputs
single optional string
pending
0s
Integ tests > should work for all inputs
single map string to string
pending
0s
Integ tests > should work for all inputs
single map string to class
pending
0s
Integ tests > should work for all inputs
single map string to map
pending
0s
Integ tests
should work for all outputs
pending
0s
Integ tests
works with retries1
pending
0s
Integ tests
works with retries2
pending
0s
Integ tests
works with fallbacks
pending
0s
Integ tests
should work with image from url
pending
0s
Integ tests
should work with image from base 64
pending
0s
Integ tests
should work with audio base 64
pending
0s
Integ tests
should work with audio from url
pending
0s
Integ tests
should support streaming in OpenAI
pending
0s
Integ tests
should support streaming in Gemini
pending
0s
Integ tests
should support AWS
pending
0s
Integ tests
should support streaming in AWS
pending
0s
Integ tests
should support OpenAI shorthand
pending
0s
Integ tests
should support OpenAI shorthand streaming
pending
0s
Integ tests
should support anthropic shorthand
pending
0s
Integ tests
should support anthropic shorthand streaming
pending
0s
Integ tests
should support streaming without iterating
pending
0s
Integ tests
should support streaming in Claude
pending
0s
Integ tests
should support vertex
pending
0s
Integ tests
supports tracing sync
pending
0s
Integ tests
supports tracing async
pending
0s
Integ tests
should work with dynamic types single
pending
0s
Integ tests
should work with dynamic types enum
pending
0s
Integ tests
should work with dynamic types class
pending
0s
Integ tests
should work with dynamic inputs class
pending
0s
Integ tests
should work with dynamic inputs list
pending
0s
Integ tests
should work with dynamic output map
pending
0s
Integ tests
should work with dynamic output union
pending
0s
Integ tests
should work with nested classes
pending
0s
Integ tests
should work with dynamic client
pending
0s
Integ tests
should work with 'onLogEvent'
pending
0s
Integ tests
should work with a sync client
pending
0s
Integ tests
should raise an error when appropriate
pending
0s
Integ tests
should raise a BAMLValidationError
pending
0s
Integ tests
should reset environment variables correctly
pending
0s
Integ tests
should include checks
failed
0.851s
Error: expect(received).toEqual(expected) // deep equality
+
+Expected: 2
+Received: 1
+    at Object.toEqual (/Users/greghale/code/baml/integ-tests/typescript/tests/integ-tests.test.ts:654:57)
\ No newline at end of file diff --git a/integ-tests/typescript/tests/integ-tests.test.ts b/integ-tests/typescript/tests/integ-tests.test.ts index 5c9aba440..3128a8c90 100644 --- a/integ-tests/typescript/tests/integ-tests.test.ts +++ b/integ-tests/typescript/tests/integ-tests.test.ts @@ -18,6 +18,7 @@ import { RecursivePartialNull } from '../baml_client/async_client' import { b as b_sync } from '../baml_client/sync_client' import { config } from 'dotenv' import { BamlLogEvent, BamlRuntime } from '@boundaryml/baml/native' +import { all_succeeded, get_checks } from '@boundaryml/baml/checked' import { AsyncLocalStorage } from 'async_hooks' import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME, resetBamlEnvVars } from '../baml_client/globals' config() @@ -641,6 +642,16 @@ describe('Integ tests', () => { ) expect(people.length).toBeGreaterThan(0) }) + + it('should include checks', async() => { + const res = await b.ExtractContactInfo("Reach me at 333-333-4444. If that doesn't work, me@hellovai.com!") + expect(res.primary.value.value).toEqual("333-333-4444"); + expect(res.primary.value.checks.valid_phone_number.status).toEqual("succeeded"); + expect(res.secondary?.value.value).toEqual("me@hellovai.com"); + expect(res.secondary?.value.checks.valid_email.status).toEqual("succeeded"); + expect(all_succeeded(res.primary.value.checks)); + expect(get_checks(res.primary.value.checks).length).toEqual(1); + }) }) interface MyInterface { diff --git a/shell.nix b/shell.nix index f5f42bf49..7af7b0af3 100644 --- a/shell.nix +++ b/shell.nix @@ -1,3 +1,5 @@ +# TODO: Package jest + let pkgs = import { }; @@ -30,12 +32,14 @@ in pkgs.mkShell { rustfmt maturin nodePackages.pnpm + nodePackages.nodejs python3 poetry rust-analyzer fern ruby nixfmt-classic + swc ] ++ (if pkgs.stdenv.isDarwin then appleDeps else [ ]); LIBCLANG_PATH = pkgs.libclang.lib + "/lib/"; @@ -46,6 +50,6 @@ in pkgs.mkShell { shellHook = '' export PROJECT_ROOT=/$(pwd) - export PATH=/$PROJECT_ROOT/tools:$PATH + export PATH=/$PROJECT_ROOT/tools:$PROJECT_ROOT/integ-tests/typescript/node_modules/.bin:$PATH ''; }