Skip to content

Commit

Permalink
Client orchestration graph (#801)
Browse files Browse the repository at this point in the history
- Added dynamic Client Orchestration Graph visualizer
- allows users to see the execution path of their fallback, retry, and
round robin strategies, including the render prompt and raw cURL per
node
- uses a DP-based node identifier algorithm to ID the groups that nodes
belong to
- recursively generates sizes and positioning of nodes for React-flow
component

<img width="1265" alt="Screenshot 2024-07-18 at 6 02 37 PM"
src="https://github.com/user-attachments/assets/8123ace9-88eb-4cda-b7f3-5beaaca35c8d">

---------

Co-authored-by: hellovai <vbv@boundaryml.com>
  • Loading branch information
anish-palakurthi and hellovai authored Aug 8, 2024
1 parent ec9b66c commit 24b5895
Show file tree
Hide file tree
Showing 24 changed files with 1,086 additions and 97 deletions.
2 changes: 1 addition & 1 deletion docs/docs/snippets/prompt-syntax/what-is-jinja.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ BAML Prompt strings are essentially [Jinja](https://jinja.palletsprojects.com/en

When in doubt -- use the BAML VSCode Playground preview. It will show you the fully rendered prompt, even when it has complex logic.

### Basic Syntax
### Basic Syntax

- `{% ... %}`: Use for executing statements such as for-loops or conditionals.
- `{{ ... }}`: Use for outputting expressions or variables.
Expand Down
2 changes: 1 addition & 1 deletion engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ indexmap = { version = "2.1.0", features = ["serde"] }
indoc = "2.0.1"
regex = "1.10.4"
serde_json = { version = "1", features = ["float_roundtrip", "preserve_order"] }
serde = { version = "1", features = ["derive"] }
serde = { version = "1", features = ["derive", "rc"] }
static_assertions = "1.1.0"
strum = { version = "0.26.2", features = ["derive"] }
strum_macros = "0.26.2"
Expand Down
7 changes: 7 additions & 0 deletions engine/baml-lib/baml-types/src/baml_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ impl BamlValue {
}
}

pub fn as_int(&self) -> Option<i64> {
match self {
BamlValue::Int(i) => Some(*i),
_ => None,
}
}

pub fn as_str(&self) -> Option<&str> {
match self {
BamlValue::String(s) => Some(s),
Expand Down
33 changes: 11 additions & 22 deletions engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
mod call;
mod stream;

use anyhow::Result;
use baml_types::BamlValue;

use internal_baml_core::ir::repr::IntermediateRepr;
use internal_baml_jinja::RenderedChatMessage;
use internal_baml_jinja::RenderedPrompt;
use std::{collections::HashMap, sync::Arc};
use web_time::Duration;
use web_time::Duration; // Add this line

use crate::{
internal::prompt_renderer::PromptRenderer, runtime_interface::InternalClientLookup,
Expand All @@ -26,6 +19,13 @@ pub use super::primitive::LLMPrimitiveProvider;
pub use call::orchestrate as orchestrate_call;
pub use stream::orchestrate_stream;

use anyhow::Result;
use baml_types::BamlValue;
use internal_baml_core::ir::repr::IntermediateRepr;
use internal_baml_jinja::RenderedChatMessage;
use internal_baml_jinja::RenderedPrompt;
use serde::Serialize;
use std::{collections::HashMap, sync::Arc};
pub struct OrchestratorNode {
pub scope: OrchestrationScope,
pub provider: Arc<LLMPrimitiveProvider>,
Expand Down Expand Up @@ -82,9 +82,9 @@ impl OrchestratorNode {
}
}

#[derive(Default, Clone)]
#[derive(Default, Clone, Serialize)]
pub struct OrchestrationScope {
scope: Vec<ExecutionScope>,
pub scope: Vec<ExecutionScope>,
}

impl From<ExecutionScope> for OrchestrationScope {
Expand Down Expand Up @@ -128,20 +128,9 @@ impl OrchestrationScope {
.collect(),
}
}

// pub fn extend_scopes(&self, scope: Vec<ExecutionScope>) -> OrchestrationScope {
// OrchestrationScope {
// scope: self
// .scope
// .clone()
// .into_iter()
// .chain(scope.into_iter())
// .collect(),
// }
// }
}

#[derive(Clone)]
#[derive(Clone, Serialize)]
pub enum ExecutionScope {
Direct(String),
// PolicyName, RetryCount, RetryDelayMs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ use crate::internal::llm_client::{
};

use super::properties::{
self, resolve_azure_properties, resolve_ollama_properties, resolve_openai_properties,
resolve_azure_properties, resolve_ollama_properties, resolve_openai_properties,
PostRequestProperities,
};
use super::types::{ChatCompletionResponse, ChatCompletionResponseDelta, FinishReason};
Expand All @@ -216,7 +216,6 @@ impl RequestBuilder for OpenAIClient {
&self,
prompt: either::Either<&String, &Vec<RenderedChatMessage>>,
allow_proxy: bool,

stream: bool,
) -> Result<reqwest::RequestBuilder> {
let destination_url = if allow_proxy {
Expand All @@ -227,6 +226,7 @@ impl RequestBuilder for OpenAIClient {
} else {
&self.properties.base_url
};

let mut req = self.client.post(if prompt.is_left() {
format!("{}/completions", destination_url)
} else {
Expand All @@ -241,14 +241,15 @@ impl RequestBuilder for OpenAIClient {
req = req.header(key, value);
}
if let Some(key) = &self.properties.api_key {
req = req.bearer_auth(key)
req = req.bearer_auth(key);
}

if allow_proxy {
req = req.header("baml-original-url", self.properties.base_url.as_str());
}

let mut body = json!(self.properties.properties);

let body_obj = body.as_object_mut().unwrap();
match prompt {
either::Either::Left(prompt) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ pub async fn make_request(
stream: bool,
) -> Result<(Response, web_time::SystemTime, web_time::Instant), LLMResponse> {
let (system_now, instant_now) = (web_time::SystemTime::now(), web_time::Instant::now());
log::debug!("Making request using client {}", client.context().name);

let req = match client
.build_request(prompt, true, stream)
Expand All @@ -52,7 +51,7 @@ pub async fn make_request(
start_time: system_now,
request_options: client.request_options().clone(),
latency: instant_now.elapsed(),
message: format!("{:?}", e),
message: format!("{:#?}", e),
code: ErrorCode::Other(2),
}));
}
Expand All @@ -68,14 +67,12 @@ pub async fn make_request(
start_time: system_now,
request_options: client.request_options().clone(),
latency: instant_now.elapsed(),
message: format!("{:?}", e),
message: format!("{:#?}", e),
code: ErrorCode::Other(2),
}));
}
};

log::debug!("LLM request: {:?} body: {:?}", req, req.body());

let response = match client.http_client().execute(req).await {
Ok(response) => response,
Err(e) => {
Expand Down
24 changes: 15 additions & 9 deletions engine/baml-runtime/src/internal/llm_client/strategy/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,21 @@ impl IterOrchestrator for FallbackStrategy {
.clients
.iter()
.enumerate()
.flat_map(|(idx, client)| {
let client = client_lookup.get_llm_provider(client, ctx).unwrap().clone();
client.iter_orchestrator(
state,
ExecutionScope::Fallback(self.name.clone(), idx).into(),
ctx,
client_lookup,
)
})
.filter_map(
|(idx, client)| match client_lookup.get_llm_provider(client, ctx) {
Ok(client) => {
let client = client.clone();
Some(client.iter_orchestrator(
state,
ExecutionScope::Fallback(self.name.clone(), idx).into(),
ctx,
client_lookup,
))
}
Err(_) => None,
},
)
.flatten()
.collect::<Vec<_>>();

items
Expand Down
21 changes: 17 additions & 4 deletions engine/baml-runtime/src/internal/llm_client/strategy/roundrobin.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use anyhow::{Context, Result};
use std::{
collections::HashMap,
sync::{atomic::AtomicUsize, Arc},
fmt::Debug,
{
collections::HashMap,
sync::{atomic::AtomicUsize, Arc},
},
};

use internal_baml_core::ir::ClientWalker;

use crate::{
client_registry::ClientProperty,
internal::llm_client::orchestrator::{
Expand All @@ -15,15 +16,27 @@ use crate::{
runtime_interface::InternalClientLookup,
RuntimeContext,
};
use internal_baml_core::ir::ClientWalker;
use serde::Serialize;
use serde::Serializer;

#[derive(Serialize, Debug)]
pub struct RoundRobinStrategy {
pub name: String,
pub(super) retry_policy: Option<String>,
// TODO: We can add conditions to each client
clients: Vec<String>,
#[serde(serialize_with = "serialize_atomic")]
current_index: AtomicUsize,
}

fn serialize_atomic<S>(value: &AtomicUsize, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u64(value.load(std::sync::atomic::Ordering::Relaxed) as u64)
}

impl RoundRobinStrategy {
pub fn current_index(&self) -> usize {
self.current_index
Expand Down
Loading

0 comments on commit 24b5895

Please sign in to comment.