diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs index 4274ceb09..3c0462cfb 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs @@ -98,6 +98,9 @@ fn coerce_int( } } crate::jsonish::Value::String(s) => { + let s = s.trim(); + // Trim trailing commas + let s = s.trim_end_matches(','); if let Ok(n) = s.parse::() { Ok(BamlValueWithFlags::Int(n.into())) } else if let Ok(n) = s.parse::() { @@ -106,6 +109,10 @@ fn coerce_int( Ok(BamlValueWithFlags::Int( ((n.round() as i64), Flag::FloatToInt(n)).into(), )) + } else if let Some(frac) = float_from_maybe_fraction(s) { + Ok(BamlValueWithFlags::Int( + ((frac.round() as i64), Flag::FloatToInt(frac)).into(), + )) } else { Err(ctx.error_unexpected_type(target, value)) } @@ -122,6 +129,20 @@ fn coerce_int( } } +fn float_from_maybe_fraction(value: &str) -> Option { + if let Some((numerator, denominator)) = value.split_once('/') { + match ( + numerator.trim().parse::(), + denominator.trim().parse::(), + ) { + (Ok(num), Ok(denom)) if denom != 0.0 => Some(num / denom), + _ => None, + } + } else { + None + } +} + fn coerce_float( ctx: &ParsingContext, target: &FieldType, @@ -141,12 +162,17 @@ fn coerce_float( } } crate::jsonish::Value::String(s) => { + let s = s.trim(); + // Trim trailing commas + let s = s.trim_end_matches(','); if let Ok(n) = s.parse::() { Ok(BamlValueWithFlags::Float(n.into())) } else if let Ok(n) = s.parse::() { Ok(BamlValueWithFlags::Float((n as f64).into())) } else if let Ok(n) = s.parse::() { Ok(BamlValueWithFlags::Float((n as f64).into())) + } else if let Some(frac) = float_from_maybe_fraction(s) { + Ok(BamlValueWithFlags::Float(frac.into())) } else { Err(ctx.error_unexpected_type(target, value)) } diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs index a63957fb6..5364375f5 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs @@ -1,9 +1,6 @@ - use anyhow::Result; use baml_types::BamlMap; -use internal_baml_core::{ - ir::{FieldType}, -}; +use internal_baml_core::ir::FieldType; use internal_baml_jinja::types::{Class, Name}; use crate::deserializer::{ @@ -51,6 +48,8 @@ impl TypeCoercer for Class { } Some(crate::jsonish::Value::Object(obj)) => { // match keys, if that fails, then do something fancy later. + let mut extra_keys = vec![]; + let mut found_keys = false; obj.iter().for_each(|(key, v)| { if let Some(field) = self .fields @@ -60,10 +59,45 @@ impl TypeCoercer for Class { let scope = ctx.enter_scope(field.0.real_name()); let parsed = field.1.coerce(&scope, &field.1, Some(v)); update_map(&mut required_values, &mut optional_values, field, parsed); + found_keys = true; } else { - flags.add_flag(Flag::ExtraKey(key.clone(), v.clone())); + extra_keys.push((key, v)); } }); + + if !found_keys && !extra_keys.is_empty() && self.fields.len() == 1 { + // Try to coerce the object into the single field + let field = &self.fields[0]; + let scope = ctx.enter_scope(&format!("", field.0.real_name())); + let parsed = field + .1 + .coerce( + &scope, + &field.1, + Some(&crate::jsonish::Value::Object(obj.clone())), + ) + .map(|mut v| { + v.add_flag(Flag::ImpliedKey(field.0.real_name().into())); + v + }); + + if let Ok(parsed_value) = parsed { + update_map( + &mut required_values, + &mut optional_values, + field, + Ok(parsed_value), + ); + } else { + extra_keys.into_iter().for_each(|(key, v)| { + flags.add_flag(Flag::ExtraKey(key.to_string(), v.clone())); + }); + } + } else { + extra_keys.into_iter().for_each(|(key, v)| { + flags.add_flag(Flag::ExtraKey(key.to_string(), v.clone())); + }); + } } Some(crate::jsonish::Value::Array(items)) => { if self.fields.len() == 1 { @@ -97,6 +131,7 @@ impl TypeCoercer for Class { let parsed = match field.1.coerce(&scope, &field.1, Some(x)) { Ok(mut v) => { v.add_flag(Flag::ImpliedKey(field.0.real_name().into())); + flags.add_flag(Flag::InferedObject(x.clone())); Ok(v) } Err(e) => Err(e), diff --git a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs index 2a8695e19..6d96912a5 100644 --- a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs +++ b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs @@ -17,6 +17,7 @@ pub enum Flag { JsonToString(crate::jsonish::Value), ImpliedKey(String), + InferedObject(crate::jsonish::Value), // Values here are all the possible matches. FirstMatch(usize, Vec>), @@ -68,6 +69,9 @@ impl std::fmt::Display for DeserializerConditions { impl std::fmt::Display for Flag { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Flag::InferedObject(value) => { + write!(f, "Infered object from: {}", value.r#type())?; + } Flag::OptionalDefaultFromNoValue => { write!(f, "Optional Default value")?; } @@ -175,6 +179,10 @@ impl DeserializerConditions { pub fn new() -> Self { Self { flags: Vec::new() } } + + pub fn flags(&self) -> &Vec { + &self.flags + } } impl Default for DeserializerConditions { diff --git a/engine/baml-lib/jsonish/src/deserializer/mod.rs b/engine/baml-lib/jsonish/src/deserializer/mod.rs index 431f9ef7f..fe848bff9 100644 --- a/engine/baml-lib/jsonish/src/deserializer/mod.rs +++ b/engine/baml-lib/jsonish/src/deserializer/mod.rs @@ -1,5 +1,5 @@ pub mod coercer; -mod deserialize_flags; +pub mod deserialize_flags; // pub mod schema; mod score; pub mod types; diff --git a/engine/baml-lib/jsonish/src/deserializer/score.rs b/engine/baml-lib/jsonish/src/deserializer/score.rs index 58ba68c2b..ce0a90c8b 100644 --- a/engine/baml-lib/jsonish/src/deserializer/score.rs +++ b/engine/baml-lib/jsonish/src/deserializer/score.rs @@ -32,6 +32,7 @@ impl WithScore for BamlValueWithFlags { impl WithScore for Flag { fn score(&self) -> i32 { match self { + Flag::InferedObject(_) => 0, // Dont penalize for this but instead handle it at the top level Flag::OptionalDefaultFromNoValue => 1, Flag::DefaultFromNoValue => 100, Flag::DefaultButHadValue(_) => 110, diff --git a/engine/baml-lib/jsonish/src/deserializer/types.rs b/engine/baml-lib/jsonish/src/deserializer/types.rs index b1a85415f..8b5309c8a 100644 --- a/engine/baml-lib/jsonish/src/deserializer/types.rs +++ b/engine/baml-lib/jsonish/src/deserializer/types.rs @@ -53,6 +53,21 @@ impl BamlValueWithFlags { BamlValueWithFlags::Image(f) => f.score(), } } + + pub fn conditions(&self) -> &DeserializerConditions { + match self { + BamlValueWithFlags::String(v) => &v.flags, + BamlValueWithFlags::Int(v) => &v.flags, + BamlValueWithFlags::Float(v) => &v.flags, + BamlValueWithFlags::Bool(v) => &v.flags, + BamlValueWithFlags::List(v, _) => &v, + BamlValueWithFlags::Map(v, _) => &v, + BamlValueWithFlags::Enum(_, v) => &v.flags, + BamlValueWithFlags::Class(_, v, _) => &v, + BamlValueWithFlags::Null(v) => &v, + BamlValueWithFlags::Image(v) => &v.flags, + } + } } #[derive(Debug, Clone)] diff --git a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs index a052bb6a6..862e09d2a 100644 --- a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs +++ b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs @@ -175,10 +175,22 @@ impl JsonParseState { log::debug!("Testing for comment after space + comma"); // If after the space we have "//" or "/*" or the beginning of a key, we'll close the string let mut buffer = ",".to_string(); + let mut anything_but_whitespace = false; while let Some((_, next_next_c)) = next.next() { + anything_but_whitespace = anything_but_whitespace + || !next_next_c.is_whitespace(); buffer.push(next_next_c); match next_next_c { - ' ' | '\n' => {} + ' ' => {} + '\n' => { + if anything_but_whitespace { + } else { + // Likely end of the key as the LLM generated a (', ' token by mistake) + // so drop the comma + log::debug!("Closing due to: newline after comma + space"); + return Some(idx); + } + } '/' => match next.peek() { Some((_, '/')) => { // This is likely a comment diff --git a/engine/baml-lib/jsonish/src/lib.rs b/engine/baml-lib/jsonish/src/lib.rs index c9069f821..309a67902 100644 --- a/engine/baml-lib/jsonish/src/lib.rs +++ b/engine/baml-lib/jsonish/src/lib.rs @@ -12,6 +12,8 @@ pub use deserializer::types::BamlValueWithFlags; use internal_baml_core::ir::TypeValue; use internal_baml_jinja::types::OutputFormatContent; +use deserializer::deserialize_flags::Flag; + pub fn from_str( of: &OutputFormatContent, target: &FieldType, @@ -42,7 +44,17 @@ pub fn from_str( // Lets try to now coerce the value into the expected schema. match target.coerce(&ctx, target, Some(&value)) { - Ok(v) => Ok(v), + Ok(v) => { + if v.conditions() + .flags() + .iter() + .any(|f| matches!(f, Flag::InferedObject(jsonish::Value::String(_)))) + { + anyhow::bail!("Failed to coerce value: {:?}", v.conditions().flags()); + } + + Ok(v) + } Err(e) => anyhow::bail!("Failed to coerce value: {}", e), } } diff --git a/engine/baml-lib/jsonish/src/tests/macros.rs b/engine/baml-lib/jsonish/src/tests/macros.rs index 3bb8965f3..9b890c007 100644 --- a/engine/baml-lib/jsonish/src/tests/macros.rs +++ b/engine/baml-lib/jsonish/src/tests/macros.rs @@ -7,7 +7,11 @@ macro_rules! test_failing_deserializer { let result = from_str(&target, &$target_type, $raw_string, false); - assert!(result.is_err(), "Failed to parse: {:?}", result); + assert!( + result.is_err(), + "Failed not to parse: {:?}", + result.unwrap() + ); } }; } diff --git a/engine/baml-lib/jsonish/src/tests/test_basics.rs b/engine/baml-lib/jsonish/src/tests/test_basics.rs index 4413da604..e97a8c9c0 100644 --- a/engine/baml-lib/jsonish/src/tests/test_basics.rs +++ b/engine/baml-lib/jsonish/src/tests/test_basics.rs @@ -1,6 +1,30 @@ use super::*; test_deserializer!(test_null, EMPTY_FILE, "null", FieldType::null(), null); +test_deserializer!( + test_null_1, + EMPTY_FILE, + "null", + FieldType::string().as_optional(), + null +); +test_deserializer!( + test_null_2, + EMPTY_FILE, + "Null", + FieldType::string().as_optional(), + // This is a string, not null + "Null" +); + +test_deserializer!( + test_null_3, + EMPTY_FILE, + "None", + FieldType::string().as_optional(), + // This is a string, not null + "None" +); test_deserializer!(test_number, EMPTY_FILE, "12111", FieldType::int(), 12111); @@ -13,6 +37,19 @@ test_deserializer!( ); test_deserializer!(test_bool, EMPTY_FILE, "true", FieldType::bool(), true); +test_deserializer!(test_bool_2, EMPTY_FILE, "True", FieldType::bool(), true); +test_deserializer!(test_bool_3, EMPTY_FILE, "false", FieldType::bool(), false); +test_deserializer!(test_bool_4, EMPTY_FILE, "False", FieldType::bool(), false); + +test_deserializer!( + test_float, + EMPTY_FILE, + "12111.123", + FieldType::float(), + 12111.123 +); + +test_deserializer!(test_float_1, EMPTY_FILE, "1/5", FieldType::float(), 0.2); test_deserializer!( test_array, diff --git a/engine/baml-lib/jsonish/src/tests/test_class.rs b/engine/baml-lib/jsonish/src/tests/test_class.rs index 4d20a93ae..a242f35d6 100644 --- a/engine/baml-lib/jsonish/src/tests/test_class.rs +++ b/engine/baml-lib/jsonish/src/tests/test_class.rs @@ -435,3 +435,174 @@ test_deserializer!( } ] ); + +const FUNCTION_FILE: &str = r#" +class Function { + selected (Function1 | Function2 | Function3) +} + +class Function1 { + function_name string + radius int +} + +class Function2 { + function_name string + diameter int +} + +class Function3 { + function_name string + length int + breadth int +} +"#; + +test_deserializer!( + test_obj_created_when_not_present, + FUNCTION_FILE, + r#"[ + { + // Calculate the area of a circle based on the radius. + function_name: 'circle.calculate_area', + // The radius of the circle. + radius: 5, + }, + { + // Calculate the circumference of a circle based on the diameter. + function_name: 'circle.calculate_circumference', + // The diameter of the circle. + diameter: 10, + } + ]"#, + FieldType::list(FieldType::class("Function")), + [ + {"selected": { + "function_name": "circle.calculate_area", + "radius": 5 + }, + }, + {"selected": + { + "function_name": "circle.calculate_circumference", + "diameter": 10 + } + } + ] +); + +test_deserializer!( + test_trailing_comma_with_space_last_field, + FUNCTION_FILE, + r#" + { + // Calculate the circumference of a circle based on the diameter. + function_name: 'circle.calculate_circumference', + // The diameter of the circle. (with a ", ") + diameter: 10, + } + "#, + FieldType::class("Function2"), + { + "function_name": "circle.calculate_circumference", + "diameter": 10 + } +); + +test_deserializer!( + test_trailing_comma_with_space_last_field_and_extra_text, + FUNCTION_FILE, + r#" + { + // Calculate the circumference of a circle based on the diameter. + function_name: 'circle.calculate_circumference', + // The diameter of the circle. (with a ", ") + diameter: 10, + Some key: "Some value" + } + and this + "#, + FieldType::class("Function2"), + { + "function_name": "circle.calculate_circumference", + "diameter": 10 + } +); + +test_failing_deserializer!( + test_nested_obj_from_string_fails_0, + r#" + class Foo { + foo Bar + } + + class Bar { + bar string + option int? + } + "#, + r#"My inner string"#, + FieldType::Class("Foo".to_string()) +); + +test_failing_deserializer!( + test_nested_obj_from_string_fails_1, + r#" + class Foo { + foo Bar + } + + class Bar { + bar string + } + "#, + r#"My inner string"#, + FieldType::Class("Foo".to_string()) +); + +test_failing_deserializer!( + test_nested_obj_from_string_fails_2, + r#" + class Foo { + foo string + } + "#, + r#"My inner string"#, + FieldType::Class("Foo".to_string()) +); + +test_deserializer!( + test_nested_obj_from_int, + r#" + class Foo { + foo int + } + "#, + r#"1214"#, + FieldType::Class("Foo".to_string()), + { "foo": 1214 } +); + +test_deserializer!( + test_nested_obj_from_float, + r#" + class Foo { + foo float + } + "#, + r#"1214.123"#, + FieldType::Class("Foo".to_string()), + { "foo": 1214.123 } +); + +test_deserializer!( + test_nested_obj_from_bool, + r#" + class Foo { + foo bool + } + "#, + r#" true "#, + FieldType::Class("Foo".to_string()), + { "foo": true } +);