Skip to content

Commit

Permalink
Ruby codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
imalsogreg committed Oct 17, 2024
1 parent 528057a commit b36660d
Show file tree
Hide file tree
Showing 18 changed files with 403 additions and 98 deletions.
32 changes: 32 additions & 0 deletions engine/baml-lib/baml-types/src/baml_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,38 @@ impl<T> BamlValueWithMeta<T> {
}
}

pub fn meta_mut(&mut self) -> &mut T {
match self {
BamlValueWithMeta::String(_, m) => m,
BamlValueWithMeta::Int(_, m) => m,
BamlValueWithMeta::Float(_, m) => m,
BamlValueWithMeta::Bool(_, m) => m,
BamlValueWithMeta::Map(_, m) => m,
BamlValueWithMeta::List(_, m) => m,
BamlValueWithMeta::Media(_, m) => m,
BamlValueWithMeta::Enum(_, _, m) => m,
BamlValueWithMeta::Class(_, _, m) => m,
BamlValueWithMeta::Null(m) => m,
}
}

pub fn with_default_meta(value: &BamlValue) -> BamlValueWithMeta<T> where T: Default {
use BamlValueWithMeta::*;
match value {
BamlValue::String(s) => String(s.clone(), T::default()),
BamlValue::Int(i) => Int(*i, T::default()),
BamlValue::Float(f) => Float(*f, T::default()),
BamlValue::Bool(b) => Bool(*b, T::default()),
BamlValue::Map(entries) => BamlValueWithMeta::Map(entries.iter().map(|(k,v)| (k.clone(), Self::with_default_meta(v))).collect(), T::default()),
BamlValue::List(items) => List(items.iter().map(|i| Self::with_default_meta(i)).collect(), T::default()),
BamlValue::Media(m) => Media(m.clone(), T::default()),
BamlValue::Enum(n,v) => Enum(n.clone(), v.clone(), T::default()),
BamlValue::Class(n, items) => Map(items.iter().map(|(k,v)| (k.clone(), Self::with_default_meta(v))).collect(), T::default()),
BamlValue::Null => Null(T::default()),
_ => unimplemented!()
}
}

pub fn map_meta<F, U>(self, f: F) -> BamlValueWithMeta<U>
where
F: Fn(T) -> U + Copy,
Expand Down
15 changes: 14 additions & 1 deletion engine/language_client_codegen/src/ruby/field_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::collections::HashSet;

use baml_types::{BamlMediaType, FieldType, LiteralValue, TypeValue};

use crate::{field_type_attributes, ruby::generate_types::type_name_for_checks};

use super::ruby_language_features::ToRuby;

impl ToRuby for FieldType {
Expand Down Expand Up @@ -47,7 +49,18 @@ impl ToRuby for FieldType {
.join(", ")
),
FieldType::Optional(inner) => format!("T.nilable({})", inner.to_ruby()),
FieldType::Constrained{base,..} => base.to_ruby(),
FieldType::Constrained{base,..} => {
match field_type_attributes(self) {
Some(checks) => {
let base_type_ref = base.to_ruby();
let checks_type_ref = type_name_for_checks(&checks);
format!("Baml::Checked[{base_type_ref}, {checks_type_ref}]")
}
None => {
base.to_ruby()
}
}
}
}
}
}
45 changes: 40 additions & 5 deletions engine/language_client_codegen/src/ruby/generate_types.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::borrow::Cow;
use std::collections::HashSet;

use anyhow::Result;

use crate::{field_type_attributes, type_check_attributes, TypeCheckAttributes};

use super::ruby_language_features::ToRuby;
use internal_baml_core::ir::{repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType};

Expand All @@ -10,6 +13,7 @@ use internal_baml_core::ir::{repr::IntermediateRepr, ClassWalker, EnumWalker, Fi
pub(crate) struct RubyTypes<'ir> {
enums: Vec<RubyEnum<'ir>>,
classes: Vec<RubyStruct<'ir>>,
checks_classes: Vec<RubyStruct<'ir>>,
}

struct RubyEnum<'ir> {
Expand All @@ -19,8 +23,8 @@ struct RubyEnum<'ir> {
}

struct RubyStruct<'ir> {
name: &'ir str,
fields: Vec<(&'ir str, String)>,
name: Cow<'ir, str>,
fields: Vec<(Cow<'ir, str>, String)>,
dynamic: bool,
}

Expand Down Expand Up @@ -51,6 +55,7 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir crate::GeneratorArgs)> for RubyTy
Ok(RubyTypes {
enums: ir.walk_enums().map(|e| e.into()).collect(),
classes: ir.walk_classes().map(|c| c.into()).collect(),
checks_classes: type_check_attributes(ir).into_iter().map(|checks| type_def_for_checks(checks)).collect::<Vec<_>>()
})
}
}
Expand All @@ -74,14 +79,14 @@ impl<'ir> From<EnumWalker<'ir>> for RubyEnum<'ir> {
impl<'ir> From<ClassWalker<'ir>> for RubyStruct<'ir> {
fn from(c: ClassWalker<'ir>) -> RubyStruct<'ir> {
RubyStruct {
name: c.name(),
name: Cow::Borrowed(c.name()),
dynamic: c.item.attributes.get("dynamic_type").is_some(),
fields: c
.item
.elem
.static_fields
.iter()
.map(|f| (f.elem.name.as_str(), f.elem.r#type.elem.to_type_ref()))
.map(|f| (Cow::Borrowed(f.elem.name.as_str()), f.elem.r#type.elem.to_type_ref()))
.collect(),
}
}
Expand Down Expand Up @@ -163,7 +168,18 @@ impl ToTypeReferenceInTypeDefinition for FieldType {
.join(", ")
),
FieldType::Optional(inner) => inner.to_partial_type_ref(),
FieldType::Constrained{base,..} => base.to_partial_type_ref(),
FieldType::Constrained{base,..} => {
match field_type_attributes(self) {
Some(checks) => {
let base_type_ref = base.to_partial_type_ref();
let checks_type_ref = type_name_for_checks(&checks);
format!("Baml::Checked[{base_type_ref}, {checks_type_ref}]")
}
None => {
base.to_partial_type_ref()
}
}
},
}
}
}
Expand All @@ -180,3 +196,22 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for TypeReg
})
}
}

pub fn type_name_for_checks(checks: &TypeCheckAttributes) -> String {
let mut name = "Checks".to_string();
let mut names: Vec<&String> = checks.0.iter().collect();
names.sort();
for check_name in names.iter() {
name.push_str("__");
name.push_str(check_name);
}
name
}

fn type_def_for_checks(checks: TypeCheckAttributes) -> RubyStruct<'static> {
RubyStruct {
name: Cow::Owned(type_name_for_checks(&checks)),
fields: checks.0.into_iter().map(|check_name| (Cow::Owned(check_name), "Baml::Check".to_string())).collect(),
dynamic: false
}
}
29 changes: 28 additions & 1 deletion engine/language_client_codegen/src/ruby/templates/types.rb.j2
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ module Baml
class {{cls.name}} < T::Struct; end
{%- endfor %}

{#- Forward declarations for checks classes #}
{%- for cls in checks_classes %}
class {{cls.name}} < T::Struct; end
{%- endfor %}

{#- https://sorbet.org/docs/tstruct #}
{%- for cls in classes %}
class {{cls.name}} < T::Struct
Expand All @@ -42,5 +47,27 @@ module Baml
end
end
{%- endfor %}

{#- https://sorbet.org/docs/tstruct #}
{%- for cls in checks_classes %}
class {{cls.name}} < T::Struct
include Baml::Sorbet::Struct

{%- for (name, type) in cls.fields %}
const :{{name}}, {{type}}
{%- endfor %}

def initialize(props)
super(
{%- for (name, _) in cls.fields %}
{{name}}: props[:{{name}}],
{%- endfor %}
)

@props = props
end
end
{%- endfor %}

end
end
end
2 changes: 1 addition & 1 deletion engine/language_client_ruby/Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
baml (0.52.1)
baml (0.60.0)

GEM
remote: https://rubygems.org/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use baml_types::BamlValue;
use magnus::{
class, exception::runtime_error, method, prelude::*, value::Value, Error, RModule, Ruby,
};
Expand Down Expand Up @@ -39,7 +38,7 @@ impl FunctionResult {
) -> Result<Value> {
match rb_self.inner.parsed_content() {
Ok(parsed) => {
ruby_to_json::RubyToJson::serialize_baml(ruby, types, &BamlValue::from(parsed))
ruby_to_json::RubyToJson::serialize_baml(ruby, types, parsed.clone())
.map_err(|e| {
magnus::Error::new(
ruby.exception_type_error(),
Expand Down
137 changes: 96 additions & 41 deletions engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use baml_types::{BamlMap, BamlValue};
use baml_types::{BamlValue, BamlMap, BamlValueWithMeta, ResponseCheck};
use indexmap::IndexMap;
use magnus::{
prelude::*, typed_data::Obj, value::Value, Error, Float, Integer, IntoValue, RArray, RClass,
prelude::*, typed_data::Obj, value::Value, class, Error, Float, Integer, IntoValue, RArray, RClass,
RHash, RModule, RString, Ruby, Symbol, TypedData,
};
use std::result::Result;
Expand All @@ -26,57 +26,112 @@ impl<'rb> RubyToJson<'rb> {
serde_magnus::serialize(&json)
}

pub fn serialize_baml(ruby: &Ruby, types: RModule, from: &BamlValue) -> crate::Result<Value> {
match from {
BamlValue::Class(class_name, class_fields) => {
let hash = ruby.hash_new();
for (k, v) in class_fields.iter() {
let k = ruby.sym_new(k.as_str());
let v = RubyToJson::serialize_baml(ruby, types, v)?;
hash.aset(k, v)?;
}
match types.const_get::<_, RClass>(class_name.as_str()) {
Ok(class_type) => class_type.funcall("new", (hash,)),
Err(_) => {
let dynamic_class_type = ruby.eval::<RClass>("Baml::DynamicStruct")?;
dynamic_class_type.funcall("new", (hash,))
pub fn type_name_for_checks(checks: &Vec<ResponseCheck>) -> String {
let mut name = "Checks".to_string();
let mut names: Vec<&String> = checks.iter().map(|ResponseCheck{name, ..}| name).collect();
names.sort();
for check_name in names.iter() {
name.push_str("__");
name.push_str(check_name);
}
name
}

/// Serialize a list of check results into some `Checked__*` instance.
pub fn serialize_response_checks(ruby: &Ruby, checks: &Vec<ResponseCheck>) -> crate::Result<Value> {

let class_name = format!("Types::{}", Self::type_name_for_checks(checks));
let checks_class = ruby.eval::<RClass>(&class_name)?;

// Create a `Check` for each check in the `Checked__*`.
let hash = ruby.hash_new();
checks.iter().try_for_each(|ResponseCheck{name, expression, status}| {
let check_class = ruby.eval::<RClass>("Baml::Checks::Check")?;
let check_hash = ruby.hash_new();
check_hash.aset(ruby.sym_new("name"), name.as_str())?;
check_hash.aset(ruby.sym_new("expr"), expression.as_str())?;
check_hash.aset(ruby.sym_new("status"), status.as_str())?;

let check: Value = check_class.funcall("new", (check_hash,))?;
hash.aset(ruby.sym_new(name.as_str()), check)?;
crate::Result::Ok(())
})?;

checks_class.funcall("new", (hash,))
}

pub fn serialize_baml(ruby: &Ruby, types: RModule, mut from: BamlValueWithMeta<Vec<ResponseCheck>>) -> crate::Result<Value> {

// If we encounter a BamlValue node with check results, serialize it as
// { value: T, checks: K }. To compute `value`, we strip the metadata
// off the node and pass it back to `serialize_baml`.
if !from.meta().is_empty() {
let meta = from.meta().clone();
let checks = Self::serialize_response_checks(ruby, &meta)?;

*from.meta_mut() = vec![];
let serialized_subvalue = Self::serialize_baml(ruby, types, from)?;

let checked_class = ruby.eval::<RClass>("Baml::Checked").expect("SHOWME");
let hash = ruby.hash_new();
hash.aset(ruby.sym_new("value"), serialized_subvalue)?;
hash.aset(ruby.sym_new("checks"), checks)?;
Ok(checked_class.funcall("new", (hash,)).expect("problem here"))
}
// Otherwise encode it directly.
else {
match from {
BamlValueWithMeta::Class(class_name, class_fields, _) => {
let hash = ruby.hash_new();
for (k, v) in class_fields.into_iter() {
let k = ruby.sym_new(k.as_str());
let v = RubyToJson::serialize_baml(ruby, types, v)?;
hash.aset(k, v)?;
}
}
}
BamlValue::Enum(enum_name, enum_value) => {
if let Ok(enum_type) = types.const_get::<_, RClass>(enum_name.as_str()) {
let enum_value = ruby.str_new(enum_value);
if let Ok(enum_instance) = enum_type.funcall("deserialize", (enum_value,)) {
return Ok(enum_instance);
match types.const_get::<_, RClass>(class_name.as_str()) {
Ok(class_type) => class_type.funcall("new", (hash,)),
Err(_) => {
let dynamic_class_type = ruby.eval::<RClass>("Baml::DynamicStruct")?;
dynamic_class_type.funcall("new", (hash,))
}
}
}
BamlValueWithMeta::Enum(enum_name, enum_value, _) => {
if let Ok(enum_type) = types.const_get::<_, RClass>(enum_name.as_str()) {
let enum_value = ruby.str_new(&enum_value);
if let Ok(enum_instance) = enum_type.funcall("deserialize", (enum_value,)) {
return Ok(enum_instance);
}
}

Ok(ruby.str_new(enum_value).into_value_with(ruby))
}
BamlValue::Map(m) => {
let hash = ruby.hash_new();
for (k, v) in m.iter() {
let k = ruby.str_new(k);
let v = RubyToJson::serialize_baml(ruby, types, v)?;
hash.aset(k, v)?;
Ok(ruby.str_new(&enum_value).into_value_with(ruby))
}
Ok(hash.into_value_with(ruby))
}
BamlValue::List(l) => {
let arr = ruby.ary_new();
for v in l.iter() {
let v = RubyToJson::serialize_baml(ruby, types, v)?;
arr.push(v)?;
BamlValueWithMeta::Map(m,_) => {
let hash = ruby.hash_new();
for (k, v) in m.into_iter() {
let k = ruby.str_new(&k);
let v = RubyToJson::serialize_baml(ruby, types, v)?;
hash.aset(k, v)?;
}
Ok(hash.into_value_with(ruby))
}
Ok(arr.into_value_with(ruby))
BamlValueWithMeta::List(l, _) => {
let arr = ruby.ary_new();
for v in l.into_iter() {
let v = RubyToJson::serialize_baml(ruby, types, v)?;
arr.push(v)?;
}
Ok(arr.into_value_with(ruby))
}
_ => serde_magnus::serialize(&from),
}
_ => serde_magnus::serialize(from),

}
}

pub fn serialize(ruby: &Ruby, types: RModule, from: Value) -> crate::Result<Value> {
let json = RubyToJson::convert(from)?;
RubyToJson::serialize_baml(ruby, types, &json)
RubyToJson::serialize_baml(ruby, types, BamlValueWithMeta::with_default_meta(&json))
}

/// Convert a Ruby object to a JSON object.
Expand Down
Loading

0 comments on commit b36660d

Please sign in to comment.