diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs index 811f95bf6..67628e803 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs @@ -2,7 +2,7 @@ use std::any::Any; use crate::deserializer::{deserialize_flags::Flag, types::BamlValueWithFlags}; use anyhow::Result; -use internal_baml_core::ir::FieldType; +use internal_baml_core::{ast::Field, ir::FieldType}; use super::{ParsingContext, ParsingError}; @@ -59,25 +59,94 @@ pub(super) fn pick_best( // Sort by (false, score, index) all_valid_scores.sort_by( |&(a, a_score, a_default, a_val), &(b, b_score, b_default, b_val)| { - if a_val.r#type() == b_val.r#type() && matches!(a_val, BamlValueWithFlags::List(_, _)) { - let a_is_single = a_val - .conditions() - .flags - .iter() - .any(|f| matches!(f, Flag::SingleToArray)); - let b_is_single = b_val - .conditions() - .flags - .iter() - .any(|f| matches!(f, Flag::SingleToArray)); - - match (a_is_single, b_is_single) { - // Return B - (true, false) => return std::cmp::Ordering::Greater, - // Return A - (false, true) => return std::cmp::Ordering::Less, - _ => {} + if a_val.r#type() == b_val.r#type() { + if matches!(a_val, BamlValueWithFlags::List(_, _)) { + let a_is_single = a_val + .conditions() + .flags + .iter() + .any(|f| matches!(f, Flag::SingleToArray)); + let b_is_single = b_val + .conditions() + .flags + .iter() + .any(|f| matches!(f, Flag::SingleToArray)); + + match (a_is_single, b_is_single) { + // Return B + (true, false) => return std::cmp::Ordering::Greater, + // Return A + (false, true) => return std::cmp::Ordering::Less, + _ => {} + } + } + } + + // De-value default values when comparing + match (a_val, b_val) { + ( + BamlValueWithFlags::Class(_, a_conds, a_props), + BamlValueWithFlags::Class(_, b_conds, b_props), + ) => { + // If matching on a union, and one of the choices is picking an object that only + // had a single string coerced from JSON, prefer the other one + // (since string cost is low, its better to pick the other one if possible) + if matches!(target, FieldType::Union(_)) { + let a_is_coerced_string = a_props.len() == 1 + && a_props.iter().all(|(_, cond)| { + matches!(cond, BamlValueWithFlags::String(..)) + && cond + .conditions() + .flags + .iter() + .any(|f| matches!(f, Flag::ImpliedKey(..))) + }); + + let b_is_coerced_string = b_props.len() == 1 + && b_props.iter().all(|(_, cond)| { + matches!(cond, BamlValueWithFlags::String(..)) + && cond + .conditions() + .flags + .iter() + .any(|f| matches!(f, Flag::ImpliedKey(..))) + }); + + match (a_is_coerced_string, b_is_coerced_string) { + // Return B + (true, false) => return std::cmp::Ordering::Greater, + // Return A + (false, true) => return std::cmp::Ordering::Less, + _ => {} + } + } + + let a_is_default = a_props.iter().all(|(k, cond)| { + cond.conditions().flags.iter().any(|f| { + matches!( + f, + Flag::OptionalDefaultFromNoValue | Flag::DefaultFromNoValue + ) + }) + }); + let b_is_default = b_props.iter().all(|(k, cond)| { + cond.conditions().flags.iter().any(|f| { + matches!( + f, + Flag::OptionalDefaultFromNoValue | Flag::DefaultFromNoValue + ) + }) + }); + + match (a_is_default, b_is_default) { + // Return B + (true, false) => return std::cmp::Ordering::Greater, + // Return A + (false, true) => return std::cmp::Ordering::Less, + _ => {} + } } + _ => {} } match a_default.cmp(&b_default) { diff --git a/engine/baml-lib/jsonish/src/tests/macros.rs b/engine/baml-lib/jsonish/src/tests/macros.rs index 71ebceae8..2a8703437 100644 --- a/engine/baml-lib/jsonish/src/tests/macros.rs +++ b/engine/baml-lib/jsonish/src/tests/macros.rs @@ -62,8 +62,9 @@ macro_rules! test_partial_deserializer { assert!(result.is_ok(), "Failed to parse: {:?}", result); let value = result.unwrap(); + log::trace!("Score: {}", value.score()); let value: BamlValue = value.into(); - println!("{:#?}", value); + log::info!("{}", value); let json_value = json!(value); let expected = serde_json::json!($($json)+); diff --git a/engine/baml-lib/jsonish/src/tests/test_partials.rs b/engine/baml-lib/jsonish/src/tests/test_partials.rs index 63082e05f..897086f19 100644 --- a/engine/baml-lib/jsonish/src/tests/test_partials.rs +++ b/engine/baml-lib/jsonish/src/tests/test_partials.rs @@ -232,3 +232,128 @@ test_partial_deserializer!( "wordCounts": [] } ); + +const CHOPPY_BAML_FILE: &str = r##" +class Error { + code int + message string +} + +// Technically, everything can cast to this object. +class ErrorBasic { + message string +} + +class GraphJson { + vertices Vertex[] + edges Edge[] +} + +class Vertex { + id string @description(#" + A unique human-readable identifier for the vertex, like 'peter' + "#) + metadata map @description(#" + Arbitrary metadata for the vertex, like 'name' or 'age' + "#) +} + +class Edge { + source_id string + target_id string + // note, you could use an enum here if you know what rthe relationships are + relationship string @description(#" + A human-readable label for the edge, like 'knows' or "works_with", etc.. + "#) +} + "##; + +const TRIMMED_CHOPPY_RESULT: &str = r#" +```json +{ + "vertices": [ + { + "id": "stephanie_morales", + "metadata": { + "name": "Stephanie Morales", + "affiliation": "Made Space" + } + }, + { + "id": + "#; + +test_partial_deserializer!( + test_partial_choppy, + CHOPPY_BAML_FILE, + TRIMMED_CHOPPY_RESULT, + FieldType::Class("GraphJson".to_string()), + { + "vertices": [ + { + "id": "stephanie_morales", + "metadata": { + "name": "Stephanie Morales", + "affiliation": "Made Space" + } + }, + { + "id": null, + "metadata": { + } + } + ], + "edges": [ + ] + } +); + +test_partial_deserializer!( + test_partial_choppy_union, + CHOPPY_BAML_FILE, + TRIMMED_CHOPPY_RESULT, + FieldType::union(vec![FieldType::Class("GraphJson".to_string()), FieldType::Class("GraphJson".to_string()).as_list(), FieldType::Class("Error".to_string())]), + { + "vertices": [ + { + "id": "stephanie_morales", + "metadata": { + "name": "Stephanie Morales", + "affiliation": "Made Space" + } + }, + { + "id": null, + "metadata": { + } + } + ], + "edges": [ + ] + } +); + +test_partial_deserializer!( + test_partial_choppy_union_2, + CHOPPY_BAML_FILE, + TRIMMED_CHOPPY_RESULT, + FieldType::union(vec![FieldType::Class("GraphJson".to_string()), FieldType::Class("ErrorBasic".to_string())]), + { + "vertices": [ + { + "id": "stephanie_morales", + "metadata": { + "name": "Stephanie Morales", + "affiliation": "Made Space" + } + }, + { + "id": null, + "metadata": { + } + } + ], + "edges": [ + ] + } +);