Skip to content

Commit

Permalink
blanketed anyhow
Browse files Browse the repository at this point in the history
  • Loading branch information
anish-palakurthi committed Aug 7, 2024
1 parent 421efc0 commit c436368
Show file tree
Hide file tree
Showing 17 changed files with 169 additions and 108 deletions.
7 changes: 1 addition & 6 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,12 +535,7 @@ impl WithRepr<Field> for FieldWalker<'_> {
Ok(Field {
name: self.name().to_string(),
r#type: Node {
elem: self
.ast_field()
.expr
.clone()
.ok_or(anyhow!("Field type is None"))?
.repr(db)?,
elem: self.ast_field().expr.clone().ok_or(anyhow!(""))?.repr(db)?,
attributes: self.attributes(db),
},
})
Expand Down
74 changes: 74 additions & 0 deletions engine/baml-lib/parser-database/src/walkers/arg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use super::{ClassWalker, ClientWalker, ConfigurationWalker, EnumWalker, Walker};
use crate::{
ast::{self, WithName},
printer::{serialize_with_printer, WithSerializeableContent},
types::FunctionType,
ParserDatabase, WithSerialize,
};
use either::Either;
use internal_baml_schema_ast::ast::{ArgumentId, Identifier, ValExpId, WithIdentifier, WithSpan};

pub type ArgWalker<'db> = super::Walker<'db, (ValExpId, bool, ArgumentId)>;

impl<'db> ArgWalker<'db> {
/// The ID of the function in the db
pub fn function_id(self) -> ast::ValExpId {
self.id.0
}

/// The AST node.
pub fn ast_function(self) -> &'db ast::ValueExprBlock {
&self.db.ast[self.id.0]
}

/// The AST node.
pub fn ast_arg(self) -> (Option<&'db Identifier>, &'db ast::BlockArg) {
match self.id.1 {
true => {
let args = self.ast_function().input();
let res = &args.expect("Expected input args")[self.id.2];
(Some(&res.0), &res.1)
}

false => {
let output = self.ast_function().output();
let res = output.expect("Error: Output is undefined for function ID");
(None, res)
}
}
}

/// The name of the type.
pub fn field_type(self) -> &'db ast::FieldType {
&self.ast_arg().1.field_type
}

/// The name of the function.
pub fn is_optional(self) -> bool {
self.field_type().is_nullable()
}

/// The name of the function.
pub fn required_enums(self) -> impl Iterator<Item = EnumWalker<'db>> {
let (input, output) = &self.db.types.function[&self.function_id()].dependencies;
if self.id.1 { input } else { output }
.iter()
.filter_map(|f| match self.db.find_type_by_str(f) {
Some(Either::Left(_cls)) => None,
Some(Either::Right(walker)) => Some(walker),
None => None,
})
}

/// The name of the function.
pub fn required_classes(self) -> impl Iterator<Item = ClassWalker<'db>> {
let (input, output) = &self.db.types.function[&self.function_id()].dependencies;
if self.id.1 { input } else { output }
.iter()
.filter_map(|f| match self.db.find_type_by_str(f) {
Some(Either::Left(walker)) => Some(walker),
Some(Either::Right(_enm)) => None,
None => None,
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ function AudioInput(aud: audio) -> string{
}



test TestURLAudioInput{
functions [AudioInput]
args {
Expand Down

This file was deleted.

12 changes: 6 additions & 6 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ async def ExtractReceiptInfo(

async def ExtractResume(
self,
resume: str,img: Optional[baml_py.Image],
resume: str,img: baml_py.Image,
baml_options: BamlCallOptions = {},
) -> types.Resume:
__tb__ = baml_options.get("tb", None)
Expand Down Expand Up @@ -613,7 +613,7 @@ async def FnEnumOutput(

async def FnNamedArgsSingleStringOptional(
self,
myString: Optional[str],
myString: str,
baml_options: BamlCallOptions = {},
) -> str:
__tb__ = baml_options.get("tb", None)
Expand Down Expand Up @@ -1717,7 +1717,7 @@ async def TestVertex(

async def UnionTest_Function(
self,
input: Union[str, bool],
input: Union[Union[str], Union[bool]],
baml_options: BamlCallOptions = {},
) -> types.UnionTest_ReturnType:
__tb__ = baml_options.get("tb", None)
Expand Down Expand Up @@ -2315,7 +2315,7 @@ def ExtractReceiptInfo(

def ExtractResume(
self,
resume: str,img: Optional[baml_py.Image],
resume: str,img: baml_py.Image,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[partial_types.Resume, types.Resume]:
__tb__ = baml_options.get("tb", None)
Expand Down Expand Up @@ -2514,7 +2514,7 @@ def FnEnumOutput(

def FnNamedArgsSingleStringOptional(
self,
myString: Optional[str],
myString: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Optional[str], str]:
__tb__ = baml_options.get("tb", None)
Expand Down Expand Up @@ -4030,7 +4030,7 @@ def TestVertex(

def UnionTest_Function(
self,
input: Union[str, bool],
input: Union[Union[str], Union[bool]],
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[partial_types.UnionTest_ReturnType, types.UnionTest_ReturnType]:
__tb__ = baml_options.get("tb", None)
Expand Down
3 changes: 1 addition & 2 deletions integ-tests/python/baml_client/inlinedbaml.py

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions integ-tests/python/baml_client/partial_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ class Person(BaseModel):
class Quantity(BaseModel):


amount: Optional[Union[Optional[int], Optional[float]]] = None
amount: Optional[Union[Optional[Union[Optional[int]]], Optional[Union[Optional[float]]]]] = None
unit: Optional[str] = None

class RaysData(BaseModel):


dataType: Optional[types.DataType] = None
value: Optional[Union["Resume", "Event"]] = None
value: Optional[Union[Optional[Union["Resume"]], Optional[Union["Event"]]]] = None

class ReceiptInfo(BaseModel):

Expand Down Expand Up @@ -209,7 +209,7 @@ class SearchParams(BaseModel):
jobTitle: Optional["WithReasoning"] = None
company: Optional["WithReasoning"] = None
description: List["WithReasoning"]
tags: List[Optional[Union[Optional[types.Tag], Optional[str]]]]
tags: List[Optional[Union[Optional[Union[Optional[types.Tag]]], Optional[Union[Optional[str]]]]]]

class SomeClassNestedDynamic(BaseModel):

Expand Down Expand Up @@ -252,9 +252,9 @@ class TestOutputClass(BaseModel):
class UnionTest_ReturnType(BaseModel):


prop1: Optional[Union[Optional[str], Optional[bool]]] = None
prop2: List[Optional[Union[Optional[float], Optional[bool]]]]
prop3: Optional[Union[List[Optional[bool]], List[Optional[int]]]] = None
prop1: Optional[Union[Optional[Union[Optional[str]]], Optional[Union[Optional[bool]]]]] = None
prop2: List[Optional[Union[Optional[Union[Optional[float]]], Optional[Union[Optional[bool]]]]]]
prop3: Optional[Union[Optional[Union[List[Optional[bool]]]], Optional[Union[List[Optional[int]]]]]] = None

class WithReasoning(BaseModel):

Expand Down
12 changes: 6 additions & 6 deletions integ-tests/python/baml_client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def ExtractReceiptInfo(

def ExtractResume(
self,
resume: str,img: Optional[baml_py.Image],
resume: str,img: baml_py.Image,
baml_options: BamlCallOptions = {},
) -> types.Resume:
__tb__ = baml_options.get("tb", None)
Expand Down Expand Up @@ -611,7 +611,7 @@ def FnEnumOutput(

def FnNamedArgsSingleStringOptional(
self,
myString: Optional[str],
myString: str,
baml_options: BamlCallOptions = {},
) -> str:
__tb__ = baml_options.get("tb", None)
Expand Down Expand Up @@ -1715,7 +1715,7 @@ def TestVertex(

def UnionTest_Function(
self,
input: Union[str, bool],
input: Union[Union[str], Union[bool]],
baml_options: BamlCallOptions = {},
) -> types.UnionTest_ReturnType:
__tb__ = baml_options.get("tb", None)
Expand Down Expand Up @@ -2314,7 +2314,7 @@ def ExtractReceiptInfo(

def ExtractResume(
self,
resume: str,img: Optional[baml_py.Image],
resume: str,img: baml_py.Image,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlSyncStream[partial_types.Resume, types.Resume]:
__tb__ = baml_options.get("tb", None)
Expand Down Expand Up @@ -2513,7 +2513,7 @@ def FnEnumOutput(

def FnNamedArgsSingleStringOptional(
self,
myString: Optional[str],
myString: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlSyncStream[Optional[str], str]:
__tb__ = baml_options.get("tb", None)
Expand Down Expand Up @@ -4029,7 +4029,7 @@ def TestVertex(

def UnionTest_Function(
self,
input: Union[str, bool],
input: Union[Union[str], Union[bool]],
baml_options: BamlCallOptions = {},
) -> baml_py.BamlSyncStream[partial_types.UnionTest_ReturnType, types.UnionTest_ReturnType]:
__tb__ = baml_options.get("tb", None)
Expand Down
36 changes: 18 additions & 18 deletions integ-tests/python/baml_client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class TestEnum(str, Enum):
class Blah(BaseModel):


prop4: Optional[str] = None
prop4: str

class ClassOptionalOutput(BaseModel):

Expand All @@ -133,8 +133,8 @@ class ClassOptionalOutput(BaseModel):
class ClassOptionalOutput2(BaseModel):


prop1: Optional[str] = None
prop2: Optional[str] = None
prop1: str
prop2: str
prop3: Optional["Blah"] = None

class ClassWithImage(BaseModel):
Expand Down Expand Up @@ -182,7 +182,7 @@ class Education(BaseModel):
location: str
degree: str
major: List[str]
graduation_date: Optional[str] = None
graduation_date: str

class Email(BaseModel):

Expand Down Expand Up @@ -234,46 +234,46 @@ class OptionalTest_ReturnType(BaseModel):


omega_1: Optional["OptionalTest_Prop1"] = None
omega_2: Optional[str] = None
omega_2: str
omega_3: List[Optional["OptionalTest_CategoryType"]]

class OrderInfo(BaseModel):


order_status: "OrderStatus"
tracking_number: Optional[str] = None
estimated_arrival_date: Optional[str] = None
tracking_number: str
estimated_arrival_date: str

class Person(BaseModel):

model_config = ConfigDict(extra='allow')

name: Optional[str] = None
name: str
hair_color: Optional[Union["Color", str]] = None

class Quantity(BaseModel):


amount: Union[int, float]
unit: Optional[str] = None
amount: Union[Union[int], Union[float]]
unit: str

class RaysData(BaseModel):


dataType: "DataType"
value: Union["Resume", "Event"]
value: Union[Union["Resume"], Union["Event"]]

class ReceiptInfo(BaseModel):


items: List["ReceiptItem"]
total_cost: Optional[float] = None
total_cost: float

class ReceiptItem(BaseModel):


name: str
description: Optional[str] = None
description: str
quantity: int
price: float

Expand All @@ -295,12 +295,12 @@ class Resume(BaseModel):
class SearchParams(BaseModel):


dateRange: Optional[int] = None
dateRange: int
location: List[str]
jobTitle: Optional["WithReasoning"] = None
company: Optional["WithReasoning"] = None
description: List["WithReasoning"]
tags: List[Union["Tag", str]]
tags: List[Union[Union["Tag"], Union[str]]]

class SomeClassNestedDynamic(BaseModel):

Expand Down Expand Up @@ -343,9 +343,9 @@ class TestOutputClass(BaseModel):
class UnionTest_ReturnType(BaseModel):


prop1: Union[str, bool]
prop2: List[Union[float, bool]]
prop3: Union[List[bool], List[int]]
prop1: Union[Union[str], Union[bool]]
prop2: List[Union[Union[float], Union[bool]]]
prop3: Union[Union[List[bool]], Union[List[int]]]

class WithReasoning(BaseModel):

Expand Down
Loading

0 comments on commit c436368

Please sign in to comment.