Skip to content

Commit

Permalink
Fix parsing for streaming of objects more stable (#1031)
Browse files Browse the repository at this point in the history
We penalize objects that are throwing away data by picking default
values.

For example, during parsing of complex objects, we sometimes prefer an
empty object of default values, but instead we can prefer something
else. This also works for handling bad union disambiguation.

<!-- ELLIPSIS_HIDDEN -->



> [!IMPORTANT]
> Improve object streaming stability by refining default value handling
in deserialization and enhancing logging and tests.
> 
>   - **Behavior**:
> - Modify `pick_best()` in `array_helper.rs` to penalize default values
like null or empty lists when selecting best values.
> - Add handling for `BamlValueWithFlags::Class` to check for default
flags in properties.
>   - **Logging**:
> - Replace `println!` with `log::info!` in `macros.rs` for better
logging control.
>     - Add `log::trace!` for score logging in `macros.rs`.
>   - **Tests**:
> - Add `test_partial_choppy` in `test_partials.rs` to test partial
deserialization with incomplete data.
> - Update `test_partial_deserializer!` macro in `macros.rs` to include
score logging.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 916851c. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
hellovai authored Oct 12, 2024
1 parent 0c73cab commit 8aa9c00
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 20 deletions.
107 changes: 88 additions & 19 deletions engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::any::Any;

use crate::deserializer::{deserialize_flags::Flag, types::BamlValueWithFlags};
use anyhow::Result;
use internal_baml_core::ir::FieldType;
use internal_baml_core::{ast::Field, ir::FieldType};

use super::{ParsingContext, ParsingError};

Expand Down Expand Up @@ -59,25 +59,94 @@ pub(super) fn pick_best(
// Sort by (false, score, index)
all_valid_scores.sort_by(
|&(a, a_score, a_default, a_val), &(b, b_score, b_default, b_val)| {
if a_val.r#type() == b_val.r#type() && matches!(a_val, BamlValueWithFlags::List(_, _)) {
let a_is_single = a_val
.conditions()
.flags
.iter()
.any(|f| matches!(f, Flag::SingleToArray));
let b_is_single = b_val
.conditions()
.flags
.iter()
.any(|f| matches!(f, Flag::SingleToArray));

match (a_is_single, b_is_single) {
// Return B
(true, false) => return std::cmp::Ordering::Greater,
// Return A
(false, true) => return std::cmp::Ordering::Less,
_ => {}
if a_val.r#type() == b_val.r#type() {
if matches!(a_val, BamlValueWithFlags::List(_, _)) {
let a_is_single = a_val
.conditions()
.flags
.iter()
.any(|f| matches!(f, Flag::SingleToArray));
let b_is_single = b_val
.conditions()
.flags
.iter()
.any(|f| matches!(f, Flag::SingleToArray));

match (a_is_single, b_is_single) {
// Return B
(true, false) => return std::cmp::Ordering::Greater,
// Return A
(false, true) => return std::cmp::Ordering::Less,
_ => {}
}
}
}

// De-value default values when comparing
match (a_val, b_val) {
(
BamlValueWithFlags::Class(_, a_conds, a_props),
BamlValueWithFlags::Class(_, b_conds, b_props),
) => {
// If matching on a union, and one of the choices is picking an object that only
// had a single string coerced from JSON, prefer the other one
// (since string cost is low, its better to pick the other one if possible)
if matches!(target, FieldType::Union(_)) {
let a_is_coerced_string = a_props.len() == 1
&& a_props.iter().all(|(_, cond)| {
matches!(cond, BamlValueWithFlags::String(..))
&& cond
.conditions()
.flags
.iter()
.any(|f| matches!(f, Flag::ImpliedKey(..)))
});

let b_is_coerced_string = b_props.len() == 1
&& b_props.iter().all(|(_, cond)| {
matches!(cond, BamlValueWithFlags::String(..))
&& cond
.conditions()
.flags
.iter()
.any(|f| matches!(f, Flag::ImpliedKey(..)))
});

match (a_is_coerced_string, b_is_coerced_string) {
// Return B
(true, false) => return std::cmp::Ordering::Greater,
// Return A
(false, true) => return std::cmp::Ordering::Less,
_ => {}
}
}

let a_is_default = a_props.iter().all(|(k, cond)| {
cond.conditions().flags.iter().any(|f| {
matches!(
f,
Flag::OptionalDefaultFromNoValue | Flag::DefaultFromNoValue
)
})
});
let b_is_default = b_props.iter().all(|(k, cond)| {
cond.conditions().flags.iter().any(|f| {
matches!(
f,
Flag::OptionalDefaultFromNoValue | Flag::DefaultFromNoValue
)
})
});

match (a_is_default, b_is_default) {
// Return B
(true, false) => return std::cmp::Ordering::Greater,
// Return A
(false, true) => return std::cmp::Ordering::Less,
_ => {}
}
}
_ => {}
}

match a_default.cmp(&b_default) {
Expand Down
3 changes: 2 additions & 1 deletion engine/baml-lib/jsonish/src/tests/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ macro_rules! test_partial_deserializer {
assert!(result.is_ok(), "Failed to parse: {:?}", result);

let value = result.unwrap();
log::trace!("Score: {}", value.score());
let value: BamlValue = value.into();
println!("{:#?}", value);
log::info!("{}", value);
let json_value = json!(value);

let expected = serde_json::json!($($json)+);
Expand Down
125 changes: 125 additions & 0 deletions engine/baml-lib/jsonish/src/tests/test_partials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,128 @@ test_partial_deserializer!(
"wordCounts": []
}
);

const CHOPPY_BAML_FILE: &str = r##"
class Error {
code int
message string
}
// Technically, everything can cast to this object.
class ErrorBasic {
message string
}
class GraphJson {
vertices Vertex[]
edges Edge[]
}
class Vertex {
id string @description(#"
A unique human-readable identifier for the vertex, like 'peter'
"#)
metadata map<string, string> @description(#"
Arbitrary metadata for the vertex, like 'name' or 'age'
"#)
}
class Edge {
source_id string
target_id string
// note, you could use an enum here if you know what rthe relationships are
relationship string @description(#"
A human-readable label for the edge, like 'knows' or "works_with", etc..
"#)
}
"##;

const TRIMMED_CHOPPY_RESULT: &str = r#"
```json
{
"vertices": [
{
"id": "stephanie_morales",
"metadata": {
"name": "Stephanie Morales",
"affiliation": "Made Space"
}
},
{
"id":
"#;

test_partial_deserializer!(
test_partial_choppy,
CHOPPY_BAML_FILE,
TRIMMED_CHOPPY_RESULT,
FieldType::Class("GraphJson".to_string()),
{
"vertices": [
{
"id": "stephanie_morales",
"metadata": {
"name": "Stephanie Morales",
"affiliation": "Made Space"
}
},
{
"id": null,
"metadata": {
}
}
],
"edges": [
]
}
);

test_partial_deserializer!(
test_partial_choppy_union,
CHOPPY_BAML_FILE,
TRIMMED_CHOPPY_RESULT,
FieldType::union(vec![FieldType::Class("GraphJson".to_string()), FieldType::Class("GraphJson".to_string()).as_list(), FieldType::Class("Error".to_string())]),
{
"vertices": [
{
"id": "stephanie_morales",
"metadata": {
"name": "Stephanie Morales",
"affiliation": "Made Space"
}
},
{
"id": null,
"metadata": {
}
}
],
"edges": [
]
}
);

test_partial_deserializer!(
test_partial_choppy_union_2,
CHOPPY_BAML_FILE,
TRIMMED_CHOPPY_RESULT,
FieldType::union(vec![FieldType::Class("GraphJson".to_string()), FieldType::Class("ErrorBasic".to_string())]),
{
"vertices": [
{
"id": "stephanie_morales",
"metadata": {
"name": "Stephanie Morales",
"affiliation": "Made Space"
}
},
{
"id": null,
"metadata": {
}
}
],
"edges": [
]
}
);

0 comments on commit 8aa9c00

Please sign in to comment.