Skip to content

Commit

Permalink
Fix more typecheck issues
Browse files Browse the repository at this point in the history
  • Loading branch information
imalsogreg committed Jan 6, 2025
1 parent 6c02bed commit 2466014
Show file tree
Hide file tree
Showing 19 changed files with 126 additions and 72 deletions.
10 changes: 9 additions & 1 deletion engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,15 @@ where
FieldType::Literal(_) => None,
FieldType::Optional(base) => map_types(ir, base.as_ref()),
FieldType::Tuple(_) => None,
FieldType::Union(_) => None,
FieldType::Union(variants) => {
// When encountering a union, we return the key/value types of the
// first map we find inside the union.
// TODO: Give more thought to what `map_types` should return for
// unions, because the current logic is faulty for unions containing
// multiple maps.
let mut variant_map_types = variants.into_iter().filter_map(|variant| map_types(ir, variant));
variant_map_types.next()
},
FieldType::Class(_) => None,
FieldType::WithMetadata { .. } => {
unreachable!("distribute_metadata never returns this variant")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl TypeCoercer for Class {
}
Some(crate::jsonish::Value::Object(obj, completion)) => {
// match keys, if that fails, then do something fancy later.
dbg!(&obj);
// dbg!(&obj);
let mut extra_keys = vec![];
let mut found_keys = false;
obj.iter().for_each(|(key, v)| {
Expand Down
4 changes: 2 additions & 2 deletions engine/baml-lib/jsonish/src/jsonish/parser/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result<Value> {
Ok(items) => match items.len() {
0 => {}
1 => {
eprintln!("MULTI_JSON: {items:?}");
// eprintln!("MULTI_JSON: {items:?}");
let ret =
Value::AnyOf(
vec![Value::FixedJson(
Expand All @@ -154,7 +154,7 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result<Value> {
)],
str.to_string(),
);
eprintln!("ret: {ret:?}");
// eprintln!("ret: {ret:?}");
return Ok(ret);
}
n => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub fn parse(str: &str, options: &ParseOptions) -> Result<Vec<MarkdownResult>> {

match res {
Ok(v) => {
eprintln!("Pushing value {v:?}");
// eprintln!("Pushing value {v:?}");
// TODO: Add any more additional strings here.
values.push(MarkdownResult::CodeBlock(
if tag.len() > 3 {
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-lib/jsonish/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub fn from_str(
// When the schema is just a string, i should really just return the raw_string w/o parsing it.
let value = jsonish::parse(raw_string, jsonish::ParseOptions::default())?;
// let schema = deserializer::schema::from_jsonish_value(&value, None);
eprintln!("value: {value:?}");
// eprintln!("value: {value:?}");

// See Note [Streaming Number Invalidation]
if allow_partials {
Expand Down
16 changes: 10 additions & 6 deletions engine/baml-runtime/src/internal/llm_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ pub fn parsed_value_to_response(
field_type: &FieldType,
allow_partials: bool,
) -> Result<ResponseBamlValue> {
dbg!(&baml_value);
// dbg!(&baml_value);
let meta_flags: BamlValueWithMeta<Vec<Flag>> = baml_value.clone().into();
let baml_value_with_meta: BamlValueWithMeta<Vec<(String, JinjaExpression, bool)>> =
baml_value.clone().into();
dbg!(&baml_value_with_meta);
// dbg!(&baml_value_with_meta);

let value_with_response_checks: BamlValueWithMeta<Vec<ResponseCheck>> = baml_value_with_meta
.map_meta(|cs| {
Expand All @@ -54,15 +54,19 @@ pub fn parsed_value_to_response(
.collect()
});

dbg!(&value_with_response_checks);
// dbg!(&value_with_response_checks);
let baml_value_with_streaming = validate_streaming_state(ir, &baml_value, field_type, allow_partials)
.map_err(|s| anyhow::anyhow!("{s:?}"))?;
dbg!(&baml_value_with_streaming);
// dbg!(&baml_value_with_streaming);

// Combine the baml_value, its types, the parser flags, and the streaming state
// into a final value.
// Node that we set the StreamState to `None` unless `allow_partials`.
let response_value = baml_value_with_streaming
.zip_meta(value_with_response_checks)?
.zip_meta(meta_flags)?
.map_meta(|((x, y), z)| (z.clone(), y.clone(), x.clone()));
dbg!(&response_value);
.map_meta(|((x, y), z)| (z.clone(), y.clone(), if allow_partials { x.clone() } else { None } ));
// dbg!(&response_value);
Ok(ResponseBamlValue(response_value))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from .types import Checked, Check
T = TypeVar('T')
class StreamState(BaseModel, Generic[T]):
value: T
completion_state: Literal["Pending", "Incomplete", "Complete"]
state: Literal["Pending", "Incomplete", "Complete"]

{# Partial classes (used for streaming) -#}
{% for cls in partial_classes %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def get_checks(checks: Dict[CheckName, Check]) -> List[Check]:
def all_succeeded(checks: Dict[CheckName, Check]) -> bool:
return all(check.status == "succeeded" for check in get_checks(checks))


{# Enums -#}
{% for enum in enums %}
class {{enum.name}}(str, Enum):
Expand Down
109 changes: 74 additions & 35 deletions engine/language_client_python/src/types/function_results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use baml_types::{BamlValueWithMeta, ResponseCheck};
use jsonish::ResponseBamlValue;
use pyo3::prelude::{pymethods, PyResult};
use pyo3::types::{PyAnyMethods, PyDict, PyModule, PyTuple, PyType};
use pyo3::{Bound, IntoPyObject, IntoPyObjectExt, PyAny, PyObject, Python};
use pyo3::{Bound, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr, PyObject, Python};

use crate::errors::BamlError;

Expand Down Expand Up @@ -39,9 +39,10 @@ impl FunctionResult {
.result_with_constraints_content()
.map_err(BamlError::from_anyhow)?;

let parsed = pythonize_strict(py, parsed.clone(), &enum_module, &cls_module)?;
let parsed = pythonize_strict(py, parsed.clone(), &enum_module, &cls_module);
eprintln!("parsed result: {:?}", parsed);

Ok(parsed)
Ok(parsed?)
}
}

Expand Down Expand Up @@ -79,6 +80,7 @@ fn pythonize_strict(
enum_module: &Bound<'_, PyModule>,
cls_module: &Bound<'_, PyModule>,
) -> PyResult<PyObject> {
eprintln!("pythonize_strict parsed: {:?}", parsed);
let meta = parsed.0.meta().clone();
let py_value_without_constraints = match parsed.0 {
BamlValueWithMeta::String(val, _) => val.into_py_any(py),
Expand Down Expand Up @@ -190,49 +192,86 @@ fn pythonize_strict(

let (_, checks, completion_state) = meta;
if checks.is_empty() && completion_state.is_none() {
eprintln!("ret1: {:?}", py_value_without_constraints);
Ok(py_value_without_constraints)
} else {
// Generate the Python checks
let python_checks = pythonize_checks(py, cls_module, &checks)?;

// Get the type of the original value
let value_type = py_value_without_constraints.bind(py).get_type();

// Import the necessary modules and objects
let typing = py.import("typing")?;
let literal = typing.getattr("Literal")?;
let typing = py.import("typing").expect("typing");
let literal = typing.getattr("Literal").expect("Literal");
let value_with_possible_checks = if !checks.is_empty() {

// Collect check names as &str and turn them into a Python tuple
let check_names: Vec<&str> = checks.iter().map(|check| check.name.as_str()).collect();
let literal_args = PyTuple::new_bound(py, check_names);
// Generate the Python checks
let python_checks = pythonize_checks(py, cls_module, &checks).expect("pythonize_checks");

// Call Literal[...] dynamically
let literal_check_names = literal.get_item(literal_args)?;
// Get the type of the original value
let value_type = py_value_without_constraints.bind(py).get_type();

// Prepare the properties dictionary
let properties_dict = pyo3::types::PyDict::new(py);
properties_dict.set_item("value", py_value_without_constraints)?;
if !checks.is_empty() {
properties_dict.set_item("checks", python_checks)?;
}
if let Some(completion_state) = completion_state {
let completion_json = serde_json::to_string(&completion_state).expect("Serializing CompletionState is safe.");
properties_dict.set_item("completion_state", completion_json)?;
}

let class_checked_type_constructor = cls_module.getattr("Checked")?;
// Collect check names as &str and turn them into a Python tuple
let check_names: Vec<&str> = checks.iter().map(|check| check.name.as_str()).collect();
let literal_args = PyTuple::new_bound(py, check_names);

// Call Literal[...] dynamically
let literal_check_names = literal.get_item(literal_args).expect("get_item");


let class_checked_type_constructor = cls_module.getattr("Checked").expect("getattr(Checked)");

// Prepare type parameters for Checked[...]
let type_parameters_tuple = PyTuple::new(py, [value_type.as_ref(), &literal_check_names]).expect("PyTuple::new");

// Create the Checked type using __class_getitem__
let class_checked_type: Bound<'_, PyAny> = class_checked_type_constructor
.call_method1("__class_getitem__", (type_parameters_tuple,)).expect("__class_getitem__");

// Prepare the properties dictionary
let properties_dict = pyo3::types::PyDict::new(py);
properties_dict.set_item("value", py_value_without_constraints.clone_ref(py))?;
if !checks.is_empty() {
properties_dict.set_item("checks", python_checks)?;
}

// Validate the model with the constructed type
let checked_instance =
class_checked_type.call_method("model_validate", (properties_dict.clone(),), None).expect("model_validate");

eprintln!("ret2: {:?}", checked_instance);

Ok::<Py<PyAny>, PyErr>(checked_instance.into())
} else {
Ok(py_value_without_constraints.clone_ref(py))
}?;

let value_with_possible_completion_state = if let Some(state) = completion_state {
let value_type = py_value_without_constraints.bind(py).get_type();
eprintln!("value_type: {:?}", value_type);

// Prepare the properties dictionary
let properties_dict = pyo3::types::PyDict::new(py);
properties_dict.set_item("value", value_with_possible_checks)?;
properties_dict.set_item("state", format!("{:?}", state))?;

// Prepare type parameters for StreamingState[...]
let type_parameters_tuple = PyTuple::new(py, [value_type.as_ref()]).expect("PyTuple::new");
dbg!(&type_parameters_tuple);

// Prepare type parameters for Checked[...]
let type_parameters_tuple = PyTuple::new(py, [value_type.as_ref(), &literal_check_names])?;
let class_streaming_state_type_constructor = cls_module.getattr("StreamState").expect("getattr(StreamState)");
let class_completion_state_type: Bound<'_, PyAny> = class_streaming_state_type_constructor
.call_method1("__class_getitem__", (type_parameters_tuple,))
.expect("__class_getitem__ for streaming");
dbg!(&class_completion_state_type);

// Create the Checked type using __class_getitem__
let class_checked_type: Bound<'_, PyAny> = class_checked_type_constructor
.call_method1("__class_getitem__", (type_parameters_tuple,))?;
eprintln!("properties dict: {:?}", properties_dict);
let streaming_state_instance = class_completion_state_type
.call_method("model_validate", (properties_dict.clone(),), None)
.expect("model_validate for streaming");

// Validate the model with the constructed type
let checked_instance =
class_checked_type.call_method("model_validate", (properties_dict.clone(),), None)?;
Ok::<Py<PyAny>, PyErr>(streaming_state_instance.into())
} else {
Ok(value_with_possible_checks)
}?;

Ok(checked_instance.into())
Ok(value_with_possible_completion_state)
}
}
20 changes: 10 additions & 10 deletions fern/01-guide/04-baml-basics/streaming.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -366,28 +366,28 @@ class Message {
</Tab>
<Tab title="Python">
```python
class StreamingState(BaseModel, Generic[T]):
class StreamState(BaseModel, Generic[T]):
value: T,
state: "incomplete" | "complete"

class Message(BaseModel):
message_type: Union["greeting", "within-convo", "farewell"]
gesture: Option[Union["gesticulate", "wave", "shake-hands", "hug"]]
message: StreamingState[String]
message: StreamState[String]
```
</Tab>

<Tab title="Typescript">
```typescript
interface StreamingState<T> {
interface StreamState<T> {
value: T,
state: "incomplete" | "complete"
}

interface Message {
message_type: "greeting" | "within-convo" | "farewell",
gesture: ("gesticulate" | "wave" | "shake-hands" | "hug")?,
message: StreamingState<string>,
message: StreamState<string>,
}
```
</Tab>
Expand Down Expand Up @@ -436,14 +436,14 @@ the generated code:

- `Recommendation` does not have any partial field because it was marked
`@stream.done`.
- The `Message.message` `string` is wrapped in `StreamingState`, allowing
- The `Message.message` `string` is wrapped in `StreamState`, allowing
runtime checking of its completion status. This status could be used
to render a spinner as the message streams in.
- The `Message.message_type` field may not be `null`, because it was marked
as `@stream.not_null`.

```python
class StreamingState(BaseModel, Generic[T]):
class StreamState(BaseModel, Generic[T]):
value: T,
state: Union[Literal["incomplete"] | Literal[]]

Expand All @@ -460,7 +460,7 @@ class Recommendation(BaseClass):

class Message(BaseClass):
message_type: Union[Literal["gretting"], Literal["conversation"], Literal["farewell"]]
message: StreamingState[string]
message: StreamState[string]
```
</Tab>

Expand All @@ -471,14 +471,14 @@ the generated code:

- `Recommendation` does not have any partial field because it was marked
`@stream.done`.
- The `Message.message` `string` is wrapped in `StreamingState`, allowing
- The `Message.message` `string` is wrapped in `StreamState`, allowing
runtime checking of its completion status. This status could be used
to render a spinner as the message streams in.
- The `Message.message_type` field may not be `null`, because it was marked
as `@stream.not_null`.

```typescript
export interface StreamingState<T> {
export interface StreamState<T> {
value: T,
state: "incomplete" | "complete"
}
Expand All @@ -498,7 +498,7 @@ export interface Recommendation {

export interface Message {
message_type: "gretting" | "conversation" | "farewell"
message: StreamingState<string>
message: StreamState<string>
}
```
</Tab>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
class TestOutputClass {
prop1 string @description("A long string with about 200 words") @stream.done @stream.with_state
prop2 int @stream.with_state
prop2 int
}

function FnOutputClass(input: string) -> TestOutputClass {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class ClassWithDone {

class ClassWithoutDone {
i_16_digits int
s_20_words string
s_20_words string @stream.with_state
}

class ClassWithBlockDone {
Expand Down
Loading

0 comments on commit 2466014

Please sign in to comment.