Skip to content

Commit

Permalink
passes runtime tests
Browse files Browse the repository at this point in the history
  • Loading branch information
anish-palakurthi committed Aug 7, 2024
1 parent c436368 commit da9343a
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 54 deletions.
9 changes: 6 additions & 3 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,13 @@ fn type_with_arity(t: FieldType, arity: &FieldArity) -> FieldType {
impl WithRepr<FieldType> for ast::FieldType {
fn repr(&self, db: &ParserDatabase) -> Result<FieldType> {
Ok(match self {
ast::FieldType::Primitive(_, typeval, ..) => {
ast::FieldType::Primitive(arity, typeval, ..) => {
let repr = FieldType::Primitive(typeval.clone());

repr
if arity.is_optional() {
FieldType::Optional(Box::new(repr))
} else {
repr
}
}
ast::FieldType::Symbol(arity, idn, ..) => type_with_arity(
match db.find_type(idn) {
Expand Down
16 changes: 13 additions & 3 deletions engine/baml-lib/baml-types/src/field_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,19 @@ impl FieldType {
pub fn is_optional(&self) -> bool {
match self {
FieldType::Optional(_) => true,
FieldType::Primitive(TypeValue::Null) => true,
FieldType::Union(types) => types.iter().any(FieldType::is_optional),
_ => false,
FieldType::Primitive(TypeValue::Null) => {
println!("found a null in is_optional");
true
}

FieldType::Union(types) => {
println!("found a union in is_optional");
types.iter().any(FieldType::is_optional)
}
_ => {
// println!("non-optional in is_optional: {:#?}", self);
false
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ impl TypeCoercer for Class {
);
let (optional, required): (Vec<_>, Vec<_>) =
self.fields.iter().partition(|f| f.1.is_optional());

let mut optional_values = optional
.iter()
.map(|(f, ..)| (f.real_name().to_string(), None))
Expand Down
5 changes: 2 additions & 3 deletions engine/baml-lib/parser-database/src/walkers/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,13 @@ impl<'db> FieldWalker<'db> {

/// The field's default attributes.
pub fn get_default_attributes(&self) -> Option<&'db ToStringAttributes> {
println!("Field is triggered");

let result = self
.db
.types
.class_attributes
.get(&self.id.0)
.and_then(|f| f.field_serilizers.get(&self.id.1));
println!("Result: {:?}", result);

result
}
}
Expand Down Expand Up @@ -119,6 +117,7 @@ impl<'db> WithSerializeableContent for (&ParserDatabase, &FieldType) {
"int" => "int",
"float" => "float",
"bool" => "bool",
// "null" => "null",
_ => "unknown",
},
"optional": arity.is_optional(),
Expand Down
40 changes: 26 additions & 14 deletions engine/baml-lib/parser-database/src/walkers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod field;
mod function;
mod template_string;

use baml_types::TypeValue;
pub use client::*;
pub use configuration::*;
use either::Either;
Expand Down Expand Up @@ -245,12 +246,19 @@ impl<'db> crate::ParserDatabase {
/// Convert a field type to a `Type`.
pub fn to_jinja_type(&self, ft: &FieldType) -> internal_baml_jinja::Type {
use internal_baml_jinja::Type;
match ft {
FieldType::Symbol(arity, idn, ..) => match self.find_type(idn) {
None => Type::Undefined,
Some(Either::Left(_)) => Type::ClassRef(idn.to_string()),
Some(Either::Right(_)) => Type::String,
},

let r = match ft {
FieldType::Symbol(arity, idn, ..) => {
let mut t = match self.find_type(idn) {
None => Type::Undefined,
Some(Either::Left(_)) => Type::ClassRef(idn.to_string()),
Some(Either::Right(_)) => Type::String,
};
if arity.is_optional() {
t = Type::None | t;
}
t
}
FieldType::List(inner, dims, ..) => {
let mut t = self.to_jinja_type(inner);
for _ in 0..*dims {
Expand All @@ -277,18 +285,22 @@ impl<'db> crate::ParserDatabase {
Box::new(self.to_jinja_type(&kv.1)),
),
FieldType::Primitive(arity, t, ..) => {
let mut t = match t.to_string().as_str() {
"string" => Type::String,
"int" => Type::Int,
"float" => Type::Float,
"bool" => Type::Bool,
_ => Type::Unknown,
let mut t = match &t {
TypeValue::String => Type::String,
TypeValue::Int => Type::Int,
TypeValue::Float => Type::Float,
TypeValue::Bool => Type::Bool,
TypeValue::Null => Type::None,
TypeValue::Image => Type::Unknown,
TypeValue::Audio => Type::Unknown,
};
if arity.is_optional() {
if arity.is_optional() || matches!(t, Type::None) {
t = Type::None | t;
}
t
}
}
};

r
}
}
7 changes: 0 additions & 7 deletions engine/baml-lib/schema-ast/src/ast/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,6 @@ pub enum FieldType {
Map(Box<(FieldType, FieldType)>, Span, Option<Vec<Attribute>>),
}

fn arity_suffix(arity: &FieldArity) -> &'static str {
match arity {
FieldArity::Required => "",
FieldArity::Optional => "?",
}
}

impl FieldType {
pub fn name(&self) -> String {
match self {
Expand Down
5 changes: 1 addition & 4 deletions engine/baml-lib/schema-ast/src/parser/parse_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,7 @@ fn parse_field_type_with_attr(pair: Pair<'_>, diagnostics: &mut Diagnostics) ->

Some(ft) // Return the field type with attributes
}
None => {
log::info!("field_type is None");
None
}
None => None,
}
}
fn combine_field_types(types: Vec<FieldType>) -> Option<FieldType> {
Expand Down
16 changes: 10 additions & 6 deletions engine/baml-lib/schema-ast/src/parser/parse_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,12 @@ pub fn parse_field_type(pair: Pair<'_>, diagnostics: &mut Diagnostics) -> Option
Some(ftype) => {
if arity.is_optional() {
match ftype.to_nullable() {
Ok(ftype) => return Some(ftype),
Err(e) => {
diagnostics.push_error(e);
return None;
}
Ok(ftype) => Some(ftype),
Err(_) => None,
}
} else {
Some(ftype)
}
Some(ftype)
}
None => {
unreachable!("Ftype should always be defined")
Expand Down Expand Up @@ -153,6 +151,12 @@ fn parse_base_type(pair: Pair<'_>, diagnostics: &mut Diagnostics) -> Option<Fiel
None,
)
}
"null" => FieldType::Primitive(
FieldArity::Optional,
TypeValue::Null,
diagnostics.span(current.as_span()),
None,
),
_ => FieldType::Symbol(
FieldArity::Required,
Identifier::Local(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,54 @@
function AudioInput(aud: audio) -> string{
client Gemini
prompt #"
{{ _.role("user") }}
// function AudioInput(aud: audio) -> string{
// client Gemini
// prompt #"
// {{ _.role("user") }}

Does this sound like a roar? Yes or no? One word no other characters.
// Does this sound like a roar? Yes or no? One word no other characters.

{{ aud }}
"#
// {{ aud }}
// "#
// }



// test TestURLAudioInput{
// functions [AudioInput]
// args {
// aud{
// url https://actions.google.com/sounds/v1/emergency/beeper_emergency_call.ogg
// }
// }
// }


enum CatA {
A
}

enum CatB {
C
D
}

class CatAPicker {
cat CatA
}

test TestURLAudioInput{
functions [AudioInput]
args {
aud{
url https://actions.google.com/sounds/v1/emergency/beeper_emergency_call.ogg
}
}
class CatBPicker {
cat CatB
item int
}

enum CatC {
E
F
G
H
I
}

class CatCPicker {
cat CatC
item int | string | null
data int?
}

0 comments on commit da9343a

Please sign in to comment.