Skip to content

Commit

Permalink
Dont drop extra fields in dynamic classes when passing them as inputs…
Browse files Browse the repository at this point in the history
… to a function
  • Loading branch information
aaronvg committed Jul 19, 2024
1 parent d793603 commit 84a231c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 16 deletions.
45 changes: 30 additions & 15 deletions engine/language_client_python/src/parse_py_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use baml_types::{BamlMap, BamlMedia, BamlValue};
use pyo3::{
exceptions::{PyRuntimeError, PyTypeError},
prelude::{PyAnyMethods, PyTypeMethods},
types::{PyBool, PyBoolMethods, PyList},
types::{PyBool, PyBoolMethods, PyDict, PyList, PyString},
PyErr, PyObject, PyResult, Python, ToPyObject,
};

Expand Down Expand Up @@ -249,25 +249,40 @@ pub fn parse_py_type(
}
})
.unwrap_or("<UnnamedBaseModel>".to_string());
let fields = match t
let mut fields = HashMap::new();
// Get regular fields
if let Ok(model_fields) = t
.getattr("model_fields")?
.extract::<HashMap<String, PyObject>>()
{
Ok(fields) => fields
.keys()
.filter_map(|k| {
let v = any.getattr(py, k.as_str());
if let Ok(v) = v {
Some((k.clone(), v))
} else {
None
for (key, _) in model_fields {
if let Ok(value) = any.getattr(py, key.as_str()) {
fields.insert(key, value.to_object(py));
}
}
}

// Get extra fields (like if this is a @@dynamic class)
if let Ok(extra) = any.getattr(py, "__pydantic_extra__") {
if let Ok(extra_dict) = extra.downcast::<PyDict>(py) {
for (key, value) in extra_dict.iter() {
if let (Ok(key), value) = (key.extract::<String>(), value) {
fields.insert(key, value.to_object(py));
}
})
.collect::<HashMap<_, _>>(),
Err(_) => {
bail!("model_fields is not a dict")
}
}
};
}

// Log the fields
// log::info!("Fields of {}:", name);
// for (key, value) in &fields {
// let repr = py
// .import_bound("builtins")?
// .getattr("repr")?
// .call1((value,))?;
// let repr_str = repr.extract::<String>()?;
// log::info!(" {}: {}", key, repr_str);
// }
Ok(MappedPyType::Class(name, fields))
// use downcast only
} else if let Ok(list) = any.downcast_bound::<PyList>(py) {
Expand Down
37 changes: 36 additions & 1 deletion integ-tests/python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..baml_client.globals import (
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME,
)
from ..baml_client.types import NamedArgsSingleEnumList, NamedArgsSingleClass
from ..baml_client.types import NamedArgsSingleEnumList, NamedArgsSingleClass, DynInputOutput
from ..baml_client.tracing import trace, set_tags, flush, on_log_event
from ..baml_client.type_builder import TypeBuilder
import datetime
Expand Down Expand Up @@ -555,6 +555,41 @@ async def test_stream_dynamic_class_output():
assert final.hair_color == "black"


@pytest.mark.asyncio
async def test_dynamic_inputs_list2():
tb = TypeBuilder()
tb.DynInputOutput.add_property("new_key", tb.string().optional())
custom_class = tb.add_class("MyBlah")
custom_class.add_property("nestedKey1", tb.string())
tb.DynInputOutput.add_property("blah", custom_class.type())

res = await b.DynamicListInputOutput(
[
DynInputOutput(**{
"new_key": "hi1",
"testKey": "myTest",
"blah": {
"nestedKey1": "nestedVal",
},
}),
{
"new_key": "hi",
"testKey": "myTest",
"blah": {
"nestedKey1": "nestedVal",
},
},
],
{"tb": tb},
)
assert res[0].new_key == "hi1"
assert res[0].testKey == "myTest"
assert res[0].blah["nestedKey1"] == "nestedVal"
assert res[1].new_key == "hi"
assert res[1].testKey == "myTest"
assert res[1].blah["nestedKey1"] == "nestedVal"


@pytest.mark.asyncio
async def test_dynamic_inputs_list():
tb = TypeBuilder()
Expand Down

0 comments on commit 84a231c

Please sign in to comment.