Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve BAML Parser #785

Merged
merged 2 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ fn coerce_int(
}
}
crate::jsonish::Value::String(s) => {
let s = s.trim();
// Trim trailing commas
let s = s.trim_end_matches(',');
if let Ok(n) = s.parse::<i64>() {
Ok(BamlValueWithFlags::Int(n.into()))
} else if let Ok(n) = s.parse::<u64>() {
Expand All @@ -106,6 +109,10 @@ fn coerce_int(
Ok(BamlValueWithFlags::Int(
((n.round() as i64), Flag::FloatToInt(n)).into(),
))
} else if let Some(frac) = float_from_maybe_fraction(s) {
Ok(BamlValueWithFlags::Int(
((frac.round() as i64), Flag::FloatToInt(frac)).into(),
))
} else {
Err(ctx.error_unexpected_type(target, value))
}
Expand All @@ -122,6 +129,20 @@ fn coerce_int(
}
}

fn float_from_maybe_fraction(value: &str) -> Option<f64> {
if let Some((numerator, denominator)) = value.split_once('/') {
match (
numerator.trim().parse::<f64>(),
denominator.trim().parse::<f64>(),
) {
(Ok(num), Ok(denom)) if denom != 0.0 => Some(num / denom),
_ => None,
}
} else {
None
}
}

fn coerce_float(
ctx: &ParsingContext,
target: &FieldType,
Expand All @@ -141,12 +162,17 @@ fn coerce_float(
}
}
crate::jsonish::Value::String(s) => {
let s = s.trim();
// Trim trailing commas
let s = s.trim_end_matches(',');
if let Ok(n) = s.parse::<f64>() {
Ok(BamlValueWithFlags::Float(n.into()))
} else if let Ok(n) = s.parse::<i64>() {
Ok(BamlValueWithFlags::Float((n as f64).into()))
} else if let Ok(n) = s.parse::<u64>() {
Ok(BamlValueWithFlags::Float((n as f64).into()))
} else if let Some(frac) = float_from_maybe_fraction(s) {
Ok(BamlValueWithFlags::Float(frac.into()))
} else {
Err(ctx.error_unexpected_type(target, value))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@

use anyhow::Result;
use baml_types::BamlMap;
use internal_baml_core::{
ir::{FieldType},
};
use internal_baml_core::ir::FieldType;
use internal_baml_jinja::types::{Class, Name};

use crate::deserializer::{
Expand Down Expand Up @@ -51,6 +48,8 @@ impl TypeCoercer for Class {
}
Some(crate::jsonish::Value::Object(obj)) => {
// match keys, if that fails, then do something fancy later.
let mut extra_keys = vec![];
let mut found_keys = false;
obj.iter().for_each(|(key, v)| {
if let Some(field) = self
.fields
Expand All @@ -60,10 +59,45 @@ impl TypeCoercer for Class {
let scope = ctx.enter_scope(field.0.real_name());
let parsed = field.1.coerce(&scope, &field.1, Some(v));
update_map(&mut required_values, &mut optional_values, field, parsed);
found_keys = true;
} else {
flags.add_flag(Flag::ExtraKey(key.clone(), v.clone()));
extra_keys.push((key, v));
}
});

if !found_keys && !extra_keys.is_empty() && self.fields.len() == 1 {
// Try to coerce the object into the single field
let field = &self.fields[0];
let scope = ctx.enter_scope(&format!("<implied:{}>", field.0.real_name()));
let parsed = field
.1
.coerce(
&scope,
&field.1,
Some(&crate::jsonish::Value::Object(obj.clone())),
)
.map(|mut v| {
v.add_flag(Flag::ImpliedKey(field.0.real_name().into()));
v
});

if let Ok(parsed_value) = parsed {
update_map(
&mut required_values,
&mut optional_values,
field,
Ok(parsed_value),
);
} else {
extra_keys.into_iter().for_each(|(key, v)| {
flags.add_flag(Flag::ExtraKey(key.to_string(), v.clone()));
});
}
} else {
extra_keys.into_iter().for_each(|(key, v)| {
flags.add_flag(Flag::ExtraKey(key.to_string(), v.clone()));
});
}
}
Some(crate::jsonish::Value::Array(items)) => {
if self.fields.len() == 1 {
Expand Down Expand Up @@ -97,6 +131,7 @@ impl TypeCoercer for Class {
let parsed = match field.1.coerce(&scope, &field.1, Some(x)) {
Ok(mut v) => {
v.add_flag(Flag::ImpliedKey(field.0.real_name().into()));
flags.add_flag(Flag::InferedObject(x.clone()));
Ok(v)
}
Err(e) => Err(e),
Expand Down
8 changes: 8 additions & 0 deletions engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub enum Flag {

JsonToString(crate::jsonish::Value),
ImpliedKey(String),
InferedObject(crate::jsonish::Value),

// Values here are all the possible matches.
FirstMatch(usize, Vec<Result<BamlValueWithFlags, ParsingError>>),
Expand Down Expand Up @@ -68,6 +69,9 @@ impl std::fmt::Display for DeserializerConditions {
impl std::fmt::Display for Flag {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Flag::InferedObject(value) => {
write!(f, "Infered object from: {}", value.r#type())?;
}
Flag::OptionalDefaultFromNoValue => {
write!(f, "Optional Default value")?;
}
Expand Down Expand Up @@ -175,6 +179,10 @@ impl DeserializerConditions {
pub fn new() -> Self {
Self { flags: Vec::new() }
}

pub fn flags(&self) -> &Vec<Flag> {
&self.flags
}
}

impl Default for DeserializerConditions {
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-lib/jsonish/src/deserializer/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub mod coercer;
mod deserialize_flags;
pub mod deserialize_flags;
// pub mod schema;
mod score;
pub mod types;
1 change: 1 addition & 0 deletions engine/baml-lib/jsonish/src/deserializer/score.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ impl WithScore for BamlValueWithFlags {
impl WithScore for Flag {
fn score(&self) -> i32 {
match self {
Flag::InferedObject(_) => 0, // Dont penalize for this but instead handle it at the top level
Flag::OptionalDefaultFromNoValue => 1,
Flag::DefaultFromNoValue => 100,
Flag::DefaultButHadValue(_) => 110,
Expand Down
15 changes: 15 additions & 0 deletions engine/baml-lib/jsonish/src/deserializer/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ impl BamlValueWithFlags {
BamlValueWithFlags::Image(f) => f.score(),
}
}

pub fn conditions(&self) -> &DeserializerConditions {
match self {
BamlValueWithFlags::String(v) => &v.flags,
BamlValueWithFlags::Int(v) => &v.flags,
BamlValueWithFlags::Float(v) => &v.flags,
BamlValueWithFlags::Bool(v) => &v.flags,
BamlValueWithFlags::List(v, _) => &v,
BamlValueWithFlags::Map(v, _) => &v,
BamlValueWithFlags::Enum(_, v) => &v.flags,
BamlValueWithFlags::Class(_, v, _) => &v,
BamlValueWithFlags::Null(v) => &v,
BamlValueWithFlags::Image(v) => &v.flags,
}
}
}

#[derive(Debug, Clone)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,22 @@ impl JsonParseState {
log::debug!("Testing for comment after space + comma");
// If after the space we have "//" or "/*" or the beginning of a key, we'll close the string
let mut buffer = ",".to_string();
let mut anything_but_whitespace = false;
while let Some((_, next_next_c)) = next.next() {
anything_but_whitespace = anything_but_whitespace
|| !next_next_c.is_whitespace();
buffer.push(next_next_c);
match next_next_c {
' ' | '\n' => {}
' ' => {}
'\n' => {
if anything_but_whitespace {
} else {
// Likely end of the key as the LLM generated a (', ' token by mistake)
// so drop the comma
log::debug!("Closing due to: newline after comma + space");
return Some(idx);
}
}
'/' => match next.peek() {
Some((_, '/')) => {
// This is likely a comment
Expand Down
14 changes: 13 additions & 1 deletion engine/baml-lib/jsonish/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub use deserializer::types::BamlValueWithFlags;
use internal_baml_core::ir::TypeValue;
use internal_baml_jinja::types::OutputFormatContent;

use deserializer::deserialize_flags::Flag;

pub fn from_str(
of: &OutputFormatContent,
target: &FieldType,
Expand Down Expand Up @@ -42,7 +44,17 @@ pub fn from_str(

// Lets try to now coerce the value into the expected schema.
match target.coerce(&ctx, target, Some(&value)) {
Ok(v) => Ok(v),
Ok(v) => {
if v.conditions()
.flags()
.iter()
.any(|f| matches!(f, Flag::InferedObject(jsonish::Value::String(_))))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spelling: Typo in Flag::InferedObject. Should be Flag::InferredObject.

Suggested change
.any(|f| matches!(f, Flag::InferedObject(jsonish::Value::String(_))))
.any(|f| matches!(f, Flag::InferredObject(jsonish::Value::String(_))))

{
anyhow::bail!("Failed to coerce value: {:?}", v.conditions().flags());
}

Ok(v)
}
Err(e) => anyhow::bail!("Failed to coerce value: {}", e),
}
}
6 changes: 5 additions & 1 deletion engine/baml-lib/jsonish/src/tests/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ macro_rules! test_failing_deserializer {

let result = from_str(&target, &$target_type, $raw_string, false);

assert!(result.is_err(), "Failed to parse: {:?}", result);
assert!(
result.is_err(),
"Failed not to parse: {:?}",
result.unwrap()
);
}
};
}
Expand Down
37 changes: 37 additions & 0 deletions engine/baml-lib/jsonish/src/tests/test_basics.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
use super::*;

test_deserializer!(test_null, EMPTY_FILE, "null", FieldType::null(), null);
test_deserializer!(
test_null_1,
EMPTY_FILE,
"null",
FieldType::string().as_optional(),
null
);
test_deserializer!(
test_null_2,
EMPTY_FILE,
"Null",
FieldType::string().as_optional(),
// This is a string, not null
"Null"
);

test_deserializer!(
test_null_3,
EMPTY_FILE,
"None",
FieldType::string().as_optional(),
// This is a string, not null
"None"
);

test_deserializer!(test_number, EMPTY_FILE, "12111", FieldType::int(), 12111);

Expand All @@ -13,6 +37,19 @@ test_deserializer!(
);

test_deserializer!(test_bool, EMPTY_FILE, "true", FieldType::bool(), true);
test_deserializer!(test_bool_2, EMPTY_FILE, "True", FieldType::bool(), true);
test_deserializer!(test_bool_3, EMPTY_FILE, "false", FieldType::bool(), false);
test_deserializer!(test_bool_4, EMPTY_FILE, "False", FieldType::bool(), false);

test_deserializer!(
test_float,
EMPTY_FILE,
"12111.123",
FieldType::float(),
12111.123
);

test_deserializer!(test_float_1, EMPTY_FILE, "1/5", FieldType::float(), 0.2);

test_deserializer!(
test_array,
Expand Down
Loading
Loading