Skip to content

Commit

Permalink
feat: baml-cli serve support client registry (#1000)
Browse files Browse the repository at this point in the history
## Copying this PR from #986

This adds the option to pass client registry configuration when serving
BAML as HTTP, by adding an optional __baml_options__: BamlOptions field
to the request body of the POST /call/:msg endpoint. For now,
BamlOptions only has one field, client_registry.

TODO:

- [x] Support sync endpoint
- [x] Support streaming endpoint
- [x] Make __baml_options__ take a list of clients instead of HashMap
with redundant name
- [x] __baml_options -> __baml_options__
- Add to generated OpenAPI spec (openapi.rs:364)
  - [x] Define schemas for BamlOptions
  - [x] Add field to request body schemas
- [x] Document the feature

To test:

1. cd engine
1. Create an example baml_src using baml-cli init
1. Start the HTTP server: cargo run --bin baml-runtime -- serve
--preview
1. Save the below body payload to body.json, adding a valid OpenAI token
1. In a separate terminal, run a completion: `curl -X POST
http://localhost:2024/call/ExtractResume -H 'Content-Type:
application/json' -d @body.json`
```
{
    "resume": "Vaibhav Gupta",
    "__baml_options__": {
        "client_registry": {
            "clients": [
                {
                    "name": "OpenAI",
                    "provider": "openai",
                    "retry_policy": null,
                    "options": {
                        "model": "gpt-4o-mini",
                        "api_key": "sk-FILL-IN-VALID-KEY-HERE"
                    }
                }
            ],
            "primary": "OpenAI"
        }
    }
}
```
<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Adds support for client registry configuration in BAML HTTP serving
with updates to server logic, OpenAPI spec, and documentation.
> 
>   - **Behavior**:
> - Adds `BamlOptions` struct in `serve/mod.rs` to support client
registry configuration in HTTP requests.
> - Updates `baml_call` and `baml_stream` functions to handle
`BamlOptions`.
>     - Supports both sync and streaming endpoints for client registry.
>   - **OpenAPI**:
> - Updates `openapi.rs` to include `BamlOptions` and `ClientProperty`
schemas.
>     - Adds `__baml_options__` to request body in OpenAPI spec.
>   - **Client Registry**:
> - Modifies `ClientRegistry` in `client_registry/mod.rs` to deserialize
from a list of clients.
> - Adds `deserialize_clients` function for custom deserialization
logic.
>   - **Documentation**:
> - Updates `client-registry.mdx` to document new client registry
feature in HTTP requests.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 1633303. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->

---------

Co-authored-by: Lorenz Ohly <lorenz.ohly@gmail.com>
  • Loading branch information
imalsogreg and lorenzoh authored Oct 1, 2024
1 parent da1a5e8 commit abe70bf
Show file tree
Hide file tree
Showing 6 changed files with 452 additions and 39 deletions.
34 changes: 30 additions & 4 deletions docs/docs/calling-baml/client-registry.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,38 @@ run
</Tab>

<Tab title="OpenAPI">
Dynamic types are not yet supported when used via OpenAPI.

Please let us know if you want this feature, either via [Discord] or [GitHub][openapi-feedback-github-issue].
The API supports passing client registry as a field on `__baml_options__` in the request body.

Example request body:

```json
{
"resume": "Vaibhav Gupta",
"__baml_options__": {
"client_registry": {
"clients": [
{
"name": "OpenAI",
"provider": "openai",
"retry_policy": null,
"options": {
"model": "gpt-4o-mini",
"api_key": "sk-..."
}
}
],
"primary": "OpenAI"
}
}
}
```

```sh
curl -X POST http://localhost:2024/call/ExtractResume \
-H 'Content-Type: application/json' -d @body.json
```

[Discord]: https://discord.gg/BTNBeXGuaS
[openapi-feedback-github-issue]: https://github.com/BoundaryML/baml/issues/892
</Tab>

</Tabs>
Expand Down
68 changes: 57 additions & 11 deletions engine/baml-runtime/src/cli/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ use axum_extra::{
use baml_types::BamlValue;
use core::pin::Pin;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{path::PathBuf, sync::Arc, task::Poll};
use tokio::{net::TcpListener, sync::RwLock};
use tokio_stream::StreamExt;

use crate::{
internal::llm_client::LLMResponse, BamlRuntime, FunctionResult, RuntimeContextManager,
client_registry::ClientRegistry, internal::llm_client::LLMResponse, BamlRuntime,
FunctionResult, RuntimeContextManager,
};

#[derive(clap::Args, Clone, Debug)]
Expand All @@ -50,6 +52,11 @@ pub struct ServeArgs {
no_version_check: bool,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct BamlOptions {
pub client_registry: Option<ClientRegistry>,
}

impl ServeArgs {
pub fn run(&self) -> Result<()> {
if !self.preview {
Expand Down Expand Up @@ -326,17 +333,23 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping`
Ok(())
}

async fn baml_call(self: Arc<Self>, b_fn: String, b_args: serde_json::Value) -> Response {
async fn baml_call(
self: Arc<Self>,
b_fn: String,
b_args: serde_json::Value,
b_options: Option<BamlOptions>,
) -> Response {
let args = match parse_args(&b_fn, b_args) {
Ok(args) => args,
Err(e) => return e.into_response(),
};

let ctx_mgr = RuntimeContextManager::new_from_env_vars(std::env::vars().collect(), None);
let client_registry = b_options.and_then(|options| options.client_registry);

let locked = self.b.read().await;
let (result, _trace_id) = locked
.call_function(b_fn, &args, &ctx_mgr, None, None)
.call_function(b_fn, &args, &ctx_mgr, None, client_registry.as_ref())
.await;

match result {
Expand Down Expand Up @@ -367,26 +380,47 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping`
extract::Path(b_fn): extract::Path<String>,
extract::Json(b_args): extract::Json<serde_json::Value>,
) -> Response {
self.baml_call(b_fn, b_args).await
let mut b_options = None;
if let Some(options_value) = b_args.get("__baml_options__") {
match serde_json::from_value::<BamlOptions>(options_value.clone()) {
Ok(opts) => b_options = Some(opts),
Err(_) => {
return BamlError::InvalidArgument(
"Failed to parse __baml_options__".to_string(),
)
.into_response()
}
}
}
self.baml_call(b_fn, b_args, b_options).await
}

fn baml_stream(self: Arc<Self>, b_fn: String, b_args: serde_json::Value) -> Response {
fn baml_stream(
self: Arc<Self>,
b_fn: String,
b_args: serde_json::Value,
b_options: Option<BamlOptions>,
) -> Response {
let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();

let args = match parse_args(&b_fn, b_args) {
Ok(args) => args,
Err(e) => return e.into_response(),
};

let client_registry = b_options.and_then(|options| options.client_registry);

tokio::spawn(async move {
let ctx_mgr =
RuntimeContextManager::new_from_env_vars(std::env::vars().collect(), None);

let result_stream = self
.b
.read()
.await
.stream_function(b_fn, &args, &ctx_mgr, None, None);
let result_stream = self.b.read().await.stream_function(
b_fn,
&args,
&ctx_mgr,
None,
client_registry.as_ref(),
);

match result_stream {
Ok(mut result_stream) => {
Expand Down Expand Up @@ -457,7 +491,19 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping`
extract::Path(path): extract::Path<String>,
extract::Json(body): extract::Json<serde_json::Value>,
) -> Response {
self.baml_stream(path, body)
let mut b_options = None;
if let Some(options_value) = body.get("__baml_options__") {
match serde_json::from_value::<BamlOptions>(options_value.clone()) {
Ok(opts) => b_options = Some(opts),
Err(_) => {
return BamlError::InvalidArgument(
"Failed to parse __baml_options__".to_string(),
)
.into_response()
}
}
}
self.baml_stream(path, body, b_options)
}
}

Expand Down
17 changes: 14 additions & 3 deletions engine/baml-runtime/src/client_registry/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::HashMap;
use std::sync::Arc;

use baml_types::{BamlMap, BamlValue};
use serde::Serialize;
use serde::{Deserialize, Deserializer, Serialize};

use crate::{internal::llm_client::llm_provider::LLMProvider, RuntimeContext};

Expand All @@ -16,16 +16,17 @@ pub enum PrimitiveClient {
Vertex,
}

#[derive(Serialize, Clone)]
#[derive(Serialize, Clone, Deserialize, Debug)]
pub struct ClientProperty {
pub name: String,
pub provider: String,
pub retry_policy: Option<String>,
pub options: BamlMap<String, BamlValue>,
}

#[derive(Clone)]
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ClientRegistry {
#[serde(deserialize_with = "deserialize_clients")]
clients: HashMap<String, ClientProperty>,
primary: Option<String>,
}
Expand Down Expand Up @@ -60,3 +61,13 @@ impl ClientRegistry {
Ok((self.primary.clone(), clients))
}
}

fn deserialize_clients<'de, D>(deserializer: D) -> Result<HashMap<String, ClientProperty>, D::Error>
where
D: Deserializer<'de>,
{
Ok(Vec::deserialize(deserializer)?
.into_iter()
.map(|client: ClientProperty| (client.name.clone(), client))
.collect())
}
92 changes: 77 additions & 15 deletions engine/language_client_codegen/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl LanguageFeatures for OpenApiLanguageFeatures {
#
# Welcome to Baml! To use this generated code, please run the following:
#
# $ openapi-generator generate -c openapi.yaml -g <language> -o <output_dir>
# $ openapi-generator generate -i openapi.yaml -g <language> -o <output_dir>
#
###############################################################################
Expand Down Expand Up @@ -223,6 +223,55 @@ impl Serialize for OpenApiSchema<'_> {
],
}),
),
(
"BamlOptions",
json!({
"type": "object",
"nullable": false,
"properties": {
"client_registry": {
"type": "object",
"nullable": false,
"properties": {
"clients": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ClientProperty"
}
},
"primary": {
"type": "string",
"nullable": false
}
},
"required": ["clients"]
}
}
})
),
(
"ClientProperty",
json!({
"type": "object",
"properties": {
"name": {
"type": "string"
},
"provider": {
"type": "string"
},
"retry_policy": {
"type": "string",
"nullable": false
},
"options": {
"type": "object",
"additionalProperties": true
}
},
"required": ["name", "provider", "options"]
})
)
]
.into_iter()
.chain(schemas.into_iter())
Expand Down Expand Up @@ -329,6 +378,32 @@ impl<'ir> TryFrom<Walker<'ir, &'ir Node<Function>>> for OpenApiMethodDef<'ir> {

fn try_from(value: Walker<'ir, &'ir Node<Function>>) -> Result<Self> {
let function_name = value.item.elem.name();
let mut properties: IndexMap<String, TypeSpecWithMeta> = value
.item
.elem
.inputs()
.iter()
.map(|(name, t)| {
Ok((
name.to_string(),
t.to_type_spec(value.db).context(format!(
"Failed to convert arg {name} (for function {function_name}) to OpenAPI type",
))?,
))
})
.collect::<Result<_>>()?;
properties.insert(
"__baml_options__".to_string(),
TypeSpecWithMeta {
meta: TypeMetadata {
title: None,
r#enum: None,
r#const: None,
nullable: true,
},
type_spec: TypeSpec::Ref { r#ref: "#/components/schemas/BamlOptions".into() }
}
);
Ok(Self {
function_name,
request_body: TypeSpecWithMeta {
Expand All @@ -348,20 +423,7 @@ impl<'ir> TryFrom<Walker<'ir, &'ir Node<Function>>> for OpenApiMethodDef<'ir> {
nullable: false,
},
type_spec: TypeSpec::Inline(TypeDef::Class {
properties: value
.item
.elem
.inputs()
.iter()
.map(|(name, t)| {
Ok((
name.to_string(),
t.to_type_spec(value.db).context(format!(
"Failed to convert arg {name} (for function {function_name}) to OpenAPI type",
))?,
))
})
.collect::<Result<_>>()?,
properties,
required: value
.item
.elem
Expand Down
2 changes: 1 addition & 1 deletion integ-tests/openapi/baml_client/.openapi-generator-ignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# Welcome to Baml! To use this generated code, please run the following:
#
# $ openapi-generator generate -c openapi.yaml -g <language> -o <output_dir>
# $ openapi-generator generate -i openapi.yaml -g <language> -o <output_dir>
#
###############################################################################

Expand Down
Loading

0 comments on commit abe70bf

Please sign in to comment.