From 8b203968fec0c4abfcfc1bf771e6b1740df1b368 Mon Sep 17 00:00:00 2001 From: Vaibhav Gupta Date: Mon, 2 Dec 2024 19:44:49 -0800 Subject: [PATCH] Fix azure client Add new client paramters: allowed_roles, default_role, finish_reason_allow_list, finish_reason_deny_list TODO: add more tests, but screenshots show it working --- engine/Cargo.lock | 10 ++ .../llm-client/src/clients/anthropic.rs | 67 +++----- .../llm-client/src/clients/aws_bedrock.rs | 67 +++----- .../llm-client/src/clients/google_ai.rs | 62 +++---- .../llm-client/src/clients/helpers.rs | 74 ++++++++- .../baml-lib/llm-client/src/clients/openai.rs | 112 +++++++------ .../baml-lib/llm-client/src/clients/vertex.rs | 73 ++++---- engine/baml-lib/llm-client/src/clientspec.rs | 157 ++++++++++++++++++ engine/baml-runtime/Cargo.toml | 3 +- engine/baml-runtime/src/cli/serve/error.rs | 19 +++ engine/baml-runtime/src/cli/serve/mod.rs | 2 +- engine/baml-runtime/src/errors.rs | 21 +++ .../internal/llm_client/orchestrator/call.rs | 35 +++- .../internal/llm_client/orchestrator/mod.rs | 22 +++ .../llm_client/orchestrator/stream.rs | 27 ++- .../primitive/anthropic/anthropic_client.rs | 22 +-- .../llm_client/primitive/aws/aws_client.rs | 22 +-- .../primitive/google/googleai_client.rs | 39 +++-- .../llm_client/primitive/google/types.rs | 2 +- .../src/internal/llm_client/primitive/mod.rs | 9 + .../primitive/openai/openai_client.rs | 66 +++----- .../llm_client/primitive/openai/types.rs | 21 +-- .../primitive/vertex/vertex_client.rs | 22 +-- .../src/internal/llm_client/traits/chat.rs | 19 ++- .../src/internal/llm_client/traits/mod.rs | 8 +- engine/baml-runtime/src/types/response.rs | 14 +- engine/baml-runtime/tests/harness.rs | 2 +- engine/baml-runtime/tests/test_cli.rs | 5 +- engine/baml-runtime/tests/test_runtime.rs | 20 +-- engine/baml-schema-wasm/src/lib.rs | 2 + .../baml-schema-wasm/src/runtime_wasm/mod.rs | 23 ++- .../tests/test_file_manager.rs | 1 + .../baml_py/internal_monkeypatch.py | 16 +- engine/language_client_python/src/errors.rs | 19 +++ .../language_client_typescript/src/errors.rs | 17 ++ .../typescript_src/index.ts | 113 +++++++++---- .../src/baml_wasm_web/test_uis/testHooks.ts | 35 +++- .../baml_wasm_web/test_uis/test_result.tsx | 33 +++- 38 files changed, 870 insertions(+), 411 deletions(-) diff --git a/engine/Cargo.lock b/engine/Cargo.lock index 633e6eccf..c34ae3515 100644 --- a/engine/Cargo.lock +++ b/engine/Cargo.lock @@ -966,6 +966,7 @@ dependencies = [ "jsonish", "jsonwebtoken", "log", + "log-once", "mime", "mime_guess", "minijinja", @@ -2966,6 +2967,15 @@ dependencies = [ "value-bag", ] +[[package]] +name = "log-once" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d8a05e3879b317b1b6dbf353e5bba7062bedcc59815267bb23eaa0c576cebf0" +dependencies = [ + "log", +] + [[package]] name = "magnus" version = "0.7.1" diff --git a/engine/baml-lib/llm-client/src/clients/anthropic.rs b/engine/baml-lib/llm-client/src/clients/anthropic.rs index dcdaa391a..5b85f0560 100644 --- a/engine/baml-lib/llm-client/src/clients/anthropic.rs +++ b/engine/baml-lib/llm-client/src/clients/anthropic.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata}; +use crate::{AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection}; use anyhow::Result; use baml_types::{EvaluationContext, StringOr, UnresolvedValue}; @@ -12,12 +12,12 @@ use super::helpers::{Error, PropertyHandler, UnresolvedUrl}; pub struct UnresolvedAnthropic { base_url: UnresolvedUrl, api_key: StringOr, - allowed_roles: Vec, - default_role: Option, + role_selection: UnresolvedRolesSelection, allowed_metadata: UnresolvedAllowedRoleMetadata, supported_request_modes: SupportedRequestModes, headers: IndexMap, properties: IndexMap)>, + finish_reason_filter: UnresolvedFinishReasonFilter, } impl UnresolvedAnthropic { @@ -25,12 +25,12 @@ impl UnresolvedAnthropic { UnresolvedAnthropic { base_url: self.base_url.clone(), api_key: self.api_key.clone(), - allowed_roles: self.allowed_roles.clone(), - default_role: self.default_role.clone(), + role_selection: self.role_selection.clone(), allowed_metadata: self.allowed_metadata.clone(), supported_request_modes: self.supported_request_modes.clone(), headers: self.headers.iter().map(|(k, v)| (k.clone(), v.clone())).collect(), properties: self.properties.iter().map(|(k, (_, v))| (k.clone(), ((), v.without_meta()))).collect(), + finish_reason_filter: self.finish_reason_filter.clone(), } } } @@ -38,22 +38,34 @@ impl UnresolvedAnthropic { pub struct ResolvedAnthropic { pub base_url: String, pub api_key: String, - pub allowed_roles: Vec, - pub default_role: String, + role_selection: RolesSelection, pub allowed_metadata: AllowedRoleMetadata, pub supported_request_modes: SupportedRequestModes, pub headers: IndexMap, pub properties: IndexMap, pub proxy_url: Option, + pub finish_reason_filter: FinishReasonFilter, } +impl ResolvedAnthropic { + pub fn allowed_roles(&self) -> Vec { + self.role_selection.allowed_or_else(|| { + vec!["system".to_string(), "user".to_string(), "assistant".to_string()] + }) + } + + pub fn default_role(&self) -> String { + self.role_selection.default_or_else(|| "user".to_string()) + } +} + + impl UnresolvedAnthropic { pub fn required_env_vars(&self) -> HashSet { let mut env_vars = HashSet::new(); env_vars.extend(self.base_url.required_env_vars()); env_vars.extend(self.api_key.required_env_vars()); - env_vars.extend(self.allowed_roles.iter().map(|r| r.required_env_vars()).flatten()); - self.default_role.as_ref().map(|r| env_vars.extend(r.required_env_vars())); + env_vars.extend(self.role_selection.required_env_vars()); env_vars.extend(self.allowed_metadata.required_env_vars()); env_vars.extend(self.supported_request_modes.required_env_vars()); env_vars.extend(self.headers.values().map(|v| v.required_env_vars()).flatten()); @@ -63,25 +75,6 @@ impl UnresolvedAnthropic { } pub fn resolve(&self, ctx: &EvaluationContext<'_>) -> Result { - let allowed_roles = self - .allowed_roles - .iter() - .map(|role| role.resolve(ctx)) - .collect::>>()?; - - let Some(default_role) = self.default_role.as_ref() else { - return Err(anyhow::anyhow!("default_role must be provided")); - }; - let default_role = default_role.resolve(ctx)?; - - if !allowed_roles.contains(&default_role) { - return Err(anyhow::anyhow!( - "default_role must be in allowed_roles: {} not in {:?}", - default_role, - allowed_roles - )); - } - let base_url = self.base_url.resolve(ctx)?; let mut headers = self @@ -112,13 +105,13 @@ impl UnresolvedAnthropic { Ok(ResolvedAnthropic { base_url, api_key: self.api_key.resolve(ctx)?, - allowed_roles, - default_role, + role_selection: self.role_selection.resolve(ctx)?, allowed_metadata: self.allowed_metadata.resolve(ctx)?, supported_request_modes: self.supported_request_modes.clone(), headers, properties, proxy_url: super::helpers::get_proxy_url(ctx), + finish_reason_filter: self.finish_reason_filter.resolve(ctx)?, }) } @@ -131,17 +124,11 @@ impl UnresolvedAnthropic { .map(|(_, v, _)| v.clone()) .unwrap_or(StringOr::EnvVar("ANTHROPIC_API_KEY".to_string())); - let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![ - StringOr::Value("system".to_string()), - StringOr::Value("user".to_string()), - StringOr::Value("assistant".to_string()), - ]); - - let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1); + let role_selection = properties.ensure_roles_selection(); let allowed_metadata = properties.ensure_allowed_metadata(); let supported_request_modes = properties.ensure_supported_request_modes(); let headers = properties.ensure_headers().unwrap_or_default(); - + let finish_reason_filter = properties.ensure_finish_reason_filter(); let (properties, errors) = properties.finalize(); if !errors.is_empty() { return Err(errors); @@ -150,12 +137,12 @@ impl UnresolvedAnthropic { Ok(Self { base_url, api_key, - allowed_roles, - default_role, + role_selection, allowed_metadata, supported_request_modes, headers, properties, + finish_reason_filter, }) } } diff --git a/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs b/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs index 34deada2a..1ad21a2bb 100644 --- a/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs +++ b/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata}; +use crate::{AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection}; use anyhow::Result; use baml_types::{EvaluationContext, StringOr}; @@ -13,11 +13,11 @@ pub struct UnresolvedAwsBedrock { region: StringOr, access_key_id: StringOr, secret_access_key: StringOr, - allowed_roles: Vec, - default_role: Option, + role_selection: UnresolvedRolesSelection, allowed_role_metadata: UnresolvedAllowedRoleMetadata, supported_request_modes: SupportedRequestModes, inference_config: Option, + finish_reason_filter: UnresolvedFinishReasonFilter, } #[derive(Debug, Clone)] @@ -64,10 +64,22 @@ pub struct ResolvedAwsBedrock { pub access_key_id: Option, pub secret_access_key: Option, pub inference_config: Option, - pub allowed_roles: Vec, - pub default_role: String, + role_selection: RolesSelection, pub allowed_role_metadata: AllowedRoleMetadata, pub supported_request_modes: SupportedRequestModes, + pub finish_reason_filter: FinishReasonFilter, +} + +impl ResolvedAwsBedrock { + pub fn allowed_roles(&self) -> Vec { + self.role_selection.allowed_or_else(|| { + vec!["system".to_string(), "user".to_string(), "assistant".to_string()] + }) + } + + pub fn default_role(&self) -> String { + self.role_selection.default_or_else(|| "user".to_string()) + } } impl UnresolvedAwsBedrock { @@ -79,15 +91,7 @@ impl UnresolvedAwsBedrock { env_vars.extend(self.region.required_env_vars()); env_vars.extend(self.access_key_id.required_env_vars()); env_vars.extend(self.secret_access_key.required_env_vars()); - env_vars.extend( - self.allowed_roles - .iter() - .map(|r| r.required_env_vars()) - .flatten(), - ); - self.default_role - .as_ref() - .map(|r| env_vars.extend(r.required_env_vars())); + env_vars.extend(self.role_selection.required_env_vars()); env_vars.extend(self.allowed_role_metadata.required_env_vars()); env_vars.extend(self.supported_request_modes.required_env_vars()); self.inference_config @@ -101,32 +105,14 @@ impl UnresolvedAwsBedrock { return Err(anyhow::anyhow!("model must be provided")); }; - let allowed_roles = self - .allowed_roles - .iter() - .map(|role| role.resolve(ctx)) - .collect::>>()?; - - let Some(default_role) = self.default_role.as_ref() else { - return Err(anyhow::anyhow!("default_role must be provided")); - }; - let default_role = default_role.resolve(ctx)?; - - if !allowed_roles.contains(&default_role) { - return Err(anyhow::anyhow!( - "default_role must be in allowed_roles: {} not in {:?}", - default_role, - allowed_roles - )); - } + let role_selection = self.role_selection.resolve(ctx)?; Ok(ResolvedAwsBedrock { model: model.resolve(ctx)?, region: self.region.resolve(ctx).ok(), access_key_id: self.access_key_id.resolve(ctx).ok(), secret_access_key: self.secret_access_key.resolve(ctx).ok(), - allowed_roles, - default_role, + role_selection, allowed_role_metadata: self.allowed_role_metadata.resolve(ctx)?, supported_request_modes: self.supported_request_modes.clone(), inference_config: self @@ -134,6 +120,7 @@ impl UnresolvedAwsBedrock { .as_ref() .map(|c| c.resolve(ctx)) .transpose()?, + finish_reason_filter: self.finish_reason_filter.resolve(ctx)?, }) } @@ -176,12 +163,7 @@ impl UnresolvedAwsBedrock { .map(|(_, v, _)| v.clone()) .unwrap_or_else(|| baml_types::StringOr::EnvVar("AWS_SECRET_ACCESS_KEY".to_string())); - let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![ - StringOr::Value("system".to_string()), - StringOr::Value("user".to_string()), - StringOr::Value("assistant".to_string()), - ]); - let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1); + let role_selection = properties.ensure_roles_selection(); let allowed_metadata = properties.ensure_allowed_metadata(); let supported_request_modes = properties.ensure_supported_request_modes(); @@ -242,6 +224,7 @@ impl UnresolvedAwsBedrock { } Some(inference_config) }; + let finish_reason_filter = properties.ensure_finish_reason_filter(); // TODO: Handle inference_configuration let errors = properties.finalize_empty(); @@ -254,11 +237,11 @@ impl UnresolvedAwsBedrock { region, access_key_id, secret_access_key, - allowed_roles, - default_role, + role_selection, allowed_role_metadata: allowed_metadata, supported_request_modes, inference_config, + finish_reason_filter, }) } } diff --git a/engine/baml-lib/llm-client/src/clients/google_ai.rs b/engine/baml-lib/llm-client/src/clients/google_ai.rs index ab8b34d42..de6f11ab7 100644 --- a/engine/baml-lib/llm-client/src/clients/google_ai.rs +++ b/engine/baml-lib/llm-client/src/clients/google_ai.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use anyhow::Result; use crate::{ - AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata, + AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection }; use baml_types::{EvaluationContext, StringOr, UnresolvedValue}; @@ -15,11 +15,11 @@ pub struct UnresolvedGoogleAI { api_key: StringOr, base_url: UnresolvedUrl, headers: IndexMap, - allowed_roles: Vec, - default_role: Option, + role_selection: UnresolvedRolesSelection, model: Option, allowed_metadata: UnresolvedAllowedRoleMetadata, supported_request_modes: SupportedRequestModes, + finish_reason_filter: UnresolvedFinishReasonFilter, properties: IndexMap)>, } @@ -27,8 +27,7 @@ pub struct UnresolvedGoogleAI { impl UnresolvedGoogleAI { pub fn without_meta(&self) -> UnresolvedGoogleAI<()> { UnresolvedGoogleAI { - allowed_roles: self.allowed_roles.clone(), - default_role: self.default_role.clone(), + role_selection: self.role_selection.clone(), api_key: self.api_key.clone(), model: self.model.clone(), base_url: self.base_url.clone(), @@ -36,13 +35,13 @@ impl UnresolvedGoogleAI { allowed_metadata: self.allowed_metadata.clone(), supported_request_modes: self.supported_request_modes.clone(), properties: self.properties.iter().map(|(k, (_, v))| (k.clone(), ((), v.without_meta()))).collect::>(), + finish_reason_filter: self.finish_reason_filter.clone(), } } } pub struct ResolvedGoogleAI { - pub allowed_roles: Vec, - pub default_role: String, + role_selection: RolesSelection, pub api_key: String, pub model: String, pub base_url: String, @@ -51,6 +50,19 @@ pub struct ResolvedGoogleAI { pub supported_request_modes: SupportedRequestModes, pub properties: IndexMap, pub proxy_url: Option, + pub finish_reason_filter: FinishReasonFilter, +} + +impl ResolvedGoogleAI { + pub fn allowed_roles(&self) -> Vec { + self.role_selection.allowed_or_else(|| { + vec!["system".to_string(), "user".to_string(), "assistant".to_string()] + }) + } + + pub fn default_role(&self) -> String { + self.role_selection.default_or_else(|| "user".to_string()) + } } impl UnresolvedGoogleAI { @@ -59,8 +71,7 @@ impl UnresolvedGoogleAI { env_vars.extend(self.api_key.required_env_vars()); env_vars.extend(self.base_url.required_env_vars()); env_vars.extend(self.headers.values().map(|v| v.required_env_vars()).flatten()); - env_vars.extend(self.allowed_roles.iter().map(|r| r.required_env_vars()).flatten()); - self.default_role.as_ref().map(|r| env_vars.extend(r.required_env_vars())); + env_vars.extend(self.role_selection.required_env_vars()); self.model.as_ref().map(|m| env_vars.extend(m.required_env_vars())); env_vars.extend(self.allowed_metadata.required_env_vars()); env_vars.extend(self.supported_request_modes.required_env_vars()); @@ -70,20 +81,7 @@ impl UnresolvedGoogleAI { pub fn resolve(&self, ctx: &EvaluationContext<'_>) -> Result { let api_key = self.api_key.resolve(ctx)?; - let Some(default_role) = self.default_role.as_ref() else { - return Err(anyhow::anyhow!("default_role must be provided")); - }; - let default_role = default_role.resolve(ctx)?; - - let allowed_roles = self.allowed_roles.iter().map(|r| r.resolve(ctx)).collect::>>()?; - if !allowed_roles.contains(&default_role) { - return Err(anyhow::anyhow!( - "default_role must be in allowed_roles: {} not in {:?}", - default_role, - allowed_roles - )); - } - + let role_selection = self.role_selection.resolve(ctx)?; let model = self .model @@ -101,12 +99,11 @@ impl UnresolvedGoogleAI { .collect::>>()?; Ok(ResolvedGoogleAI { - default_role, + role_selection, api_key, model, base_url, headers, - allowed_roles, allowed_metadata: self.allowed_metadata.resolve(ctx)?, supported_request_modes: self.supported_request_modes.clone(), properties: self @@ -115,17 +112,12 @@ impl UnresolvedGoogleAI { .map(|(k, (_, v))| Ok((k.clone(), v.resolve_serde::(ctx)?))) .collect::>>()?, proxy_url: super::helpers::get_proxy_url(ctx), + finish_reason_filter: self.finish_reason_filter.resolve(ctx)?, }) } pub fn create_from(mut properties: PropertyHandler) -> Result>> { - let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![ - StringOr::Value("system".to_string()), - StringOr::Value("user".to_string()), - StringOr::Value("assistant".to_string()), - ]); - let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1); - + let role_selection = properties.ensure_roles_selection(); let api_key = properties.ensure_api_key().map(|v| v.clone()).unwrap_or(StringOr::EnvVar("GOOGLE_API_KEY".to_string())); let model = properties.ensure_string("model", false).map(|(_, v, _)| v.clone()); @@ -135,7 +127,7 @@ impl UnresolvedGoogleAI { let allowed_metadata = properties.ensure_allowed_metadata(); let supported_request_modes = properties.ensure_supported_request_modes(); let headers = properties.ensure_headers().unwrap_or_default(); - + let finish_reason_filter = properties.ensure_finish_reason_filter(); let (properties, errors) = properties.finalize(); if !errors.is_empty() { @@ -143,8 +135,7 @@ impl UnresolvedGoogleAI { } Ok(Self { - allowed_roles, - default_role, + role_selection, api_key, model, base_url, @@ -152,6 +143,7 @@ impl UnresolvedGoogleAI { allowed_metadata, supported_request_modes, properties, + finish_reason_filter, }) } } diff --git a/engine/baml-lib/llm-client/src/clients/helpers.rs b/engine/baml-lib/llm-client/src/clients/helpers.rs index c45c5ea51..6e553009d 100644 --- a/engine/baml-lib/llm-client/src/clients/helpers.rs +++ b/engine/baml-lib/llm-client/src/clients/helpers.rs @@ -3,7 +3,7 @@ use std::{borrow::Cow, collections::HashSet}; use baml_types::{GetEnvVar, StringOr, UnresolvedValue}; use indexmap::IndexMap; -use crate::{SupportedRequestModes, UnresolvedAllowedRoleMetadata}; +use crate::{SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection}; #[derive(Debug, Clone)] pub struct UnresolvedUrl(StringOr); @@ -172,9 +172,13 @@ impl PropertyHandler { final_result } - pub fn ensure_allowed_roles(&mut self) -> Option> { + fn ensure_allowed_roles(&mut self) -> Option> { self.ensure_array("allowed_roles", false) - .map(|(_, value, _)| { + .map(|(_, value, value_span)| { + if value.is_empty() { + self.push_error("allowed_roles must not be empty", value_span); + } + value .into_iter() .filter_map(|v| match v.as_str() { @@ -194,10 +198,19 @@ impl PropertyHandler { }) } - pub fn ensure_default_role( + pub fn ensure_roles_selection(&mut self) -> UnresolvedRolesSelection { + let allowed_roles = self.ensure_allowed_roles(); + let default_role = self.ensure_default_role(allowed_roles.as_ref().unwrap_or(&vec![ + StringOr::Value("user".to_string()), + StringOr::Value("assistant".to_string()), + StringOr::Value("system".to_string()), + ])); + UnresolvedRolesSelection::new(allowed_roles, default_role) + } + + fn ensure_default_role( &mut self, allowed_roles: &[StringOr], - default_role_index: usize, ) -> Option { self.ensure_string("default_role", false) .and_then(|(_, value, span)| { @@ -211,7 +224,7 @@ impl PropertyHandler { .join(", "); self.push_error( format!( - "default_role must be one of {}. Got: {}", + "default_role must be one of {}. Got: {}. To support different default roles, add allowed_roles [\"user\", \"assistant\", \"system\", ...]", allowed_roles_str, value ), span, @@ -219,7 +232,6 @@ impl PropertyHandler { None } }) - .or_else(|| allowed_roles.get(default_role_index).cloned()) } pub fn ensure_api_key(&mut self) -> Option { @@ -248,6 +260,54 @@ impl PropertyHandler { } } + pub fn ensure_finish_reason_filter(&mut self) -> UnresolvedFinishReasonFilter { + let allow_list = self.ensure_array("finish_reason_allow_list", false); + let deny_list = self.ensure_array("finish_reason_deny_list", false); + + match (allow_list, deny_list) { + (Some(allow), Some(deny)) => { + self.push_error( + "finish_reason_allow_list and finish_reason_deny_list cannot be used together", + allow.0 + ); + self.push_error( + "finish_reason_allow_list and finish_reason_deny_list cannot be used together", + deny.0, + ); + UnresolvedFinishReasonFilter::All + }, + (Some((_, allow, _)), None) => { + UnresolvedFinishReasonFilter::AllowList(allow.into_iter().filter_map(|v| match v.as_str() { + Some(s) => Some(s.clone()), + None => { + self.push_error( + "values in finish_reason_allow_list must be strings.", + v.meta().clone(), + ); + None + } + }) + .collect() + ) + } + (None, Some((_, deny, _))) => { + UnresolvedFinishReasonFilter::DenyList(deny.into_iter().filter_map(|v| match v.to_str() { + Ok(s) => Some(s.0), + Err(other) => { + self.push_error( + "values in finish_reason_deny_list must be strings.", + other.meta().clone() + ); + None + } + }) + .collect() + ) + } + (None, None) => UnresolvedFinishReasonFilter::All, + } + } + pub fn ensure_any(&mut self, key: &str) -> Option<(Meta, UnresolvedValue)> { self.options.shift_remove(key) } diff --git a/engine/baml-lib/llm-client/src/clients/openai.rs b/engine/baml-lib/llm-client/src/clients/openai.rs index e9609e4cd..6619ea260 100644 --- a/engine/baml-lib/llm-client/src/clients/openai.rs +++ b/engine/baml-lib/llm-client/src/clients/openai.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata}; +use crate::{AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection}; use anyhow::Result; use baml_types::{GetEnvVar, StringOr, UnresolvedValue}; @@ -12,13 +12,13 @@ use super::helpers::{Error, PropertyHandler, UnresolvedUrl}; pub struct UnresolvedOpenAI { base_url: Option>, api_key: Option, - allowed_roles: Vec, - default_role: Option, + role_selection: UnresolvedRolesSelection, allowed_role_metadata: UnresolvedAllowedRoleMetadata, supported_request_modes: SupportedRequestModes, headers: IndexMap, properties: IndexMap)>, query_params: IndexMap, + finish_reason_filter: UnresolvedFinishReasonFilter, } impl UnresolvedOpenAI { @@ -26,8 +26,7 @@ impl UnresolvedOpenAI { UnresolvedOpenAI { base_url: self.base_url.clone(), api_key: self.api_key.clone(), - allowed_roles: self.allowed_roles.clone(), - default_role: self.default_role.clone(), + role_selection: self.role_selection.clone(), allowed_role_metadata: self.allowed_role_metadata.clone(), supported_request_modes: self.supported_request_modes.clone(), headers: self @@ -45,6 +44,7 @@ impl UnresolvedOpenAI { .iter() .map(|(k, v)| (k.clone(), v.clone())) .collect(), + finish_reason_filter: self.finish_reason_filter.clone(), } } } @@ -52,14 +52,46 @@ impl UnresolvedOpenAI { pub struct ResolvedOpenAI { pub base_url: String, pub api_key: Option, - pub allowed_roles: Vec, - pub default_role: String, + role_selection: RolesSelection, pub allowed_metadata: AllowedRoleMetadata, - pub supported_request_modes: SupportedRequestModes, + supported_request_modes: SupportedRequestModes, pub headers: IndexMap, pub properties: IndexMap, pub query_params: IndexMap, pub proxy_url: Option, + pub finish_reason_filter: FinishReasonFilter, +} + +impl ResolvedOpenAI { + fn is_o1_model(&self) -> bool { + self.properties.get("model").is_some_and(|model| model.as_str().map(|s| s.starts_with("o1-")).unwrap_or(false)) + } + + pub fn supports_streaming(&self) -> bool { + match self.supported_request_modes.stream { + Some(v) => v, + None => !self.is_o1_model(), + } + } + + pub fn allowed_roles(&self) -> Vec { + self.role_selection.allowed_or_else(|| { + if self.is_o1_model() { + vec!["user".to_string(), "assistant".to_string()] + } else { + vec!["system".to_string(), "user".to_string(), "assistant".to_string()] + } + }) + } + + pub fn default_role(&self) -> String { + self.role_selection.default_or_else(|| { + // TODO: guard against empty allowed_roles + // The compiler should already guarantee that this is non-empty + self.allowed_roles().remove(0) + + }) + } } impl UnresolvedOpenAI { @@ -78,12 +110,7 @@ impl UnresolvedOpenAI { self.api_key .as_ref() .map(|key| env_vars.extend(key.required_env_vars())); - self.allowed_roles - .iter() - .for_each(|role| env_vars.extend(role.required_env_vars())); - self.default_role - .as_ref() - .map(|role| env_vars.extend(role.required_env_vars())); + env_vars.extend(self.role_selection.required_env_vars()); env_vars.extend(self.allowed_role_metadata.required_env_vars()); env_vars.extend(self.supported_request_modes.required_env_vars()); self.headers @@ -126,24 +153,7 @@ impl UnresolvedOpenAI { .map(|key| key.resolve(ctx)) .transpose()?; - let allowed_roles = self - .allowed_roles - .iter() - .map(|role| role.resolve(ctx)) - .collect::>>()?; - - let Some(default_role) = self.default_role.as_ref() else { - return Err(anyhow::anyhow!("default_role must be provided")); - }; - let default_role = default_role.resolve(ctx)?; - - if !allowed_roles.contains(&default_role) { - return Err(anyhow::anyhow!( - "default_role must be in allowed_roles: {} not in {:?}", - default_role, - allowed_roles - )); - } + let role_selection = self.role_selection.resolve(ctx)?; let headers = self .headers @@ -176,14 +186,14 @@ impl UnresolvedOpenAI { Ok(ResolvedOpenAI { base_url, api_key, - allowed_roles, - default_role, + role_selection, allowed_metadata: self.allowed_role_metadata.resolve(ctx)?, supported_request_modes: self.supported_request_modes.clone(), headers, properties, query_params, proxy_url: super::helpers::get_proxy_url(ctx), + finish_reason_filter: self.finish_reason_filter.resolve(ctx)?, }) } @@ -247,20 +257,19 @@ impl UnresolvedOpenAI { } }; - let api_key = Some( - properties - .ensure_api_key() - .map(|v| v.clone()) - .unwrap_or_else(|| StringOr::EnvVar("AZURE_OPENAI_API_KEY".to_string())), - ); + let api_key = properties + .ensure_api_key() + .map(|v| v.clone()) + .unwrap_or_else(|| StringOr::EnvVar("AZURE_OPENAI_API_KEY".to_string())); let mut query_params = IndexMap::new(); if let Some((_, v, _)) = properties.ensure_string("api_version", false) { query_params.insert("api-version".to_string(), v.clone()); } - let mut instance = Self::create_common(properties, base_url, api_key)?; + let mut instance = Self::create_common(properties, base_url, None)?; instance.query_params = query_params; + instance.headers.entry("api-key".to_string()).or_insert(api_key); Ok(instance) } @@ -282,7 +291,13 @@ impl UnresolvedOpenAI { let api_key = properties.ensure_api_key().map(|v| v.clone()); - Self::create_common(properties, Some(either::Either::Left(base_url)), api_key) + let mut instance = Self::create_common(properties, Some(either::Either::Left(base_url)), api_key)?; + // Ollama uses smaller models many of which prefer the user role + if instance.role_selection.default.is_none() { + instance.role_selection.default = Some(StringOr::Value("user".to_string())); + } + + Ok(instance) } fn create_common( @@ -290,16 +305,11 @@ impl UnresolvedOpenAI { base_url: Option>, api_key: Option, ) -> Result>> { - let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![ - StringOr::Value("system".to_string()), - StringOr::Value("user".to_string()), - StringOr::Value("assistant".to_string()), - ]); - - let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1); + let role_selection = properties.ensure_roles_selection(); let allowed_metadata = properties.ensure_allowed_metadata(); let supported_request_modes = properties.ensure_supported_request_modes(); let headers = properties.ensure_headers().unwrap_or_default(); + let finish_reason_filter = properties.ensure_finish_reason_filter(); let (properties, errors) = properties.finalize(); if !errors.is_empty() { @@ -309,13 +319,13 @@ impl UnresolvedOpenAI { Ok(Self { base_url, api_key, - allowed_roles, - default_role, + role_selection, allowed_role_metadata: allowed_metadata, supported_request_modes, headers, properties, query_params: IndexMap::new(), + finish_reason_filter, }) } } diff --git a/engine/baml-lib/llm-client/src/clients/vertex.rs b/engine/baml-lib/llm-client/src/clients/vertex.rs index dc5eaf7e3..977b60301 100644 --- a/engine/baml-lib/llm-client/src/clients/vertex.rs +++ b/engine/baml-lib/llm-client/src/clients/vertex.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata}; +use crate::{AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection}; use anyhow::{Context, Result}; use baml_types::{GetEnvVar, StringOr, UnresolvedValue}; @@ -123,10 +123,10 @@ pub struct UnresolvedVertex { authorization: UnresolvedServiceAccountDetails, model: StringOr, headers: IndexMap, - allowed_roles: Vec, - default_role: Option, + role_selection: UnresolvedRolesSelection, allowed_role_metadata: UnresolvedAllowedRoleMetadata, supported_request_modes: SupportedRequestModes, + finish_reason_filter: UnresolvedFinishReasonFilter, properties: IndexMap)>, } @@ -135,12 +135,24 @@ pub struct ResolvedVertex { pub authorization: ResolvedServiceAccountDetails, pub model: String, pub headers: IndexMap, - pub allowed_roles: Vec, - pub default_role: String, + role_selection: RolesSelection, pub allowed_metadata: AllowedRoleMetadata, pub supported_request_modes: SupportedRequestModes, pub properties: IndexMap, pub proxy_url: Option, + pub finish_reason_filter: FinishReasonFilter, +} + +impl ResolvedVertex { + pub fn allowed_roles(&self) -> Vec { + self.role_selection.allowed_or_else(|| { + vec!["system".to_string(), "user".to_string(), "assistant".to_string()] + }) + } + + pub fn default_role(&self) -> String { + self.role_selection.default_or_else(|| "user".to_string()) + } } impl UnresolvedVertex { @@ -161,15 +173,7 @@ impl UnresolvedVertex { .map(|v| v.required_env_vars()) .flatten(), ); - env_vars.extend( - self.allowed_roles - .iter() - .map(|r| r.required_env_vars()) - .flatten(), - ); - self.default_role - .as_ref() - .map(|r| env_vars.extend(r.required_env_vars())); + env_vars.extend(self.role_selection.required_env_vars()); env_vars.extend(self.allowed_role_metadata.required_env_vars()); env_vars.extend(self.supported_request_modes.required_env_vars()); env_vars.extend( @@ -189,8 +193,7 @@ impl UnresolvedVertex { authorization: self.authorization.without_meta(), model: self.model.clone(), headers: self.headers.clone(), - allowed_roles: self.allowed_roles.clone(), - default_role: self.default_role.clone(), + role_selection: self.role_selection.clone(), allowed_role_metadata: self.allowed_role_metadata.clone(), supported_request_modes: self.supported_request_modes.clone(), properties: self @@ -198,6 +201,7 @@ impl UnresolvedVertex { .iter() .map(|(k, (_, v))| (k.clone(), ((), v.without_meta()))) .collect(), + finish_reason_filter: self.finish_reason_filter.clone(), } } @@ -231,24 +235,7 @@ impl UnresolvedVertex { let model = self.model.resolve(ctx)?; - let allowed_roles = self - .allowed_roles - .iter() - .map(|role| role.resolve(ctx)) - .collect::>>()?; - - let Some(default_role) = self.default_role.as_ref() else { - return Err(anyhow::anyhow!("default_role must be provided")); - }; - let default_role = default_role.resolve(ctx)?; - - if !allowed_roles.contains(&default_role) { - return Err(anyhow::anyhow!( - "default_role must be in allowed_roles: {} not in {:?}", - default_role, - allowed_roles - )); - } + let role_selection = self.role_selection.resolve(ctx)?; let headers = self .headers @@ -261,8 +248,7 @@ impl UnresolvedVertex { authorization, model, headers, - allowed_roles, - default_role, + role_selection, allowed_metadata: self.allowed_role_metadata.resolve(ctx)?, supported_request_modes: self.supported_request_modes.clone(), properties: self @@ -271,6 +257,7 @@ impl UnresolvedVertex { .map(|(k, (_, v))| Ok((k.clone(), v.resolve_serde::(ctx)?))) .collect::>>()?, proxy_url: super::helpers::get_proxy_url(ctx), + finish_reason_filter: self.finish_reason_filter.resolve(ctx)?, }) } @@ -358,16 +345,12 @@ impl UnresolvedVertex { .ensure_string("project_id", false) .map(|(_, v, _)| v); - let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![ - StringOr::Value("system".to_string()), - StringOr::Value("user".to_string()), - StringOr::Value("assistant".to_string()), - ]); - - let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1); + let role_selection = properties.ensure_roles_selection(); let allowed_metadata = properties.ensure_allowed_metadata(); let supported_request_modes = properties.ensure_supported_request_modes(); let headers = properties.ensure_headers().unwrap_or_default(); + let finish_reason_filter = properties.ensure_finish_reason_filter(); + let (properties, errors) = properties.finalize(); if !errors.is_empty() { return Err(errors); @@ -383,11 +366,11 @@ impl UnresolvedVertex { authorization, model, headers, - allowed_roles, - default_role, + role_selection, allowed_role_metadata: allowed_metadata, supported_request_modes, properties, + finish_reason_filter, }) } } diff --git a/engine/baml-lib/llm-client/src/clientspec.rs b/engine/baml-lib/llm-client/src/clientspec.rs index 862c6b2cc..d42219358 100644 --- a/engine/baml-lib/llm-client/src/clientspec.rs +++ b/engine/baml-lib/llm-client/src/clientspec.rs @@ -197,6 +197,163 @@ impl SupportedRequestModes { } } +#[derive(Clone, Debug)] +pub enum UnresolvedFinishReasonFilter { + All, + AllowList(HashSet), + DenyList(HashSet), +} + +#[derive(Clone, Debug)] +pub enum FinishReasonFilter { + All, + AllowList(HashSet), + DenyList(HashSet), +} + +impl UnresolvedFinishReasonFilter { + pub fn required_env_vars(&self) -> HashSet { + match self { + Self::AllowList(allow) => allow + .iter() + .map(|s| s.required_env_vars()) + .flatten() + .collect(), + Self::DenyList(deny) => deny + .iter() + .map(|s| s.required_env_vars()) + .flatten() + .collect(), + _ => HashSet::new(), + } + } + + pub fn resolve(&self, ctx: &impl GetEnvVar) -> Result { + match self { + Self::AllowList(allow) => Ok(FinishReasonFilter::AllowList( + allow + .iter() + .map(|s| s.resolve(ctx)) + .collect::>>()?, + )), + Self::DenyList(deny) => Ok(FinishReasonFilter::DenyList( + deny.iter() + .map(|s| s.resolve(ctx)) + .collect::>>()?, + )), + Self::All => Ok(FinishReasonFilter::All), + } + } +} + +impl FinishReasonFilter { + pub fn is_allowed(&self, reason: Option>) -> bool { + log::warn!( + "debug is_allowed: {:?} {}", + self, + reason + .as_ref() + .map(|r| r.as_ref().to_string()) + .unwrap_or("".into()) + ); + match self { + Self::AllowList(allow) => { + let Some(reason) = reason.map(|r| r.as_ref().to_string()) else { + return false; + }; + allow.contains(&reason) + } + Self::DenyList(deny) => { + let Some(reason) = reason.map(|r| r.as_ref().to_string()) else { + return true; + }; + !deny.contains(&reason) + } + Self::All => true, + } + } +} + +#[derive(Clone, Debug)] +pub(crate) struct UnresolvedRolesSelection { + pub allowed: Option>, + pub default: Option, +} + +impl UnresolvedRolesSelection { + pub fn new(allowed: Option>, default: Option) -> Self { + Self { allowed, default } + } + + pub fn required_env_vars(&self) -> HashSet { + let mut env_vars = HashSet::new(); + if let Some(allowed) = &self.allowed { + env_vars.extend(allowed.iter().map(|s| s.required_env_vars()).flatten()); + } + if let Some(default) = &self.default { + env_vars.extend(default.required_env_vars()); + } + env_vars + } + + pub fn resolve(&self, ctx: &impl GetEnvVar) -> Result { + let allowed = self + .allowed + .as_ref() + .map(|allowed| { + allowed + .iter() + .map(|s| s.resolve(ctx)) + .collect::>>() + }) + .transpose()?; + + let default = self + .default + .as_ref() + .map(|default| default.resolve(ctx)) + .transpose()?; + + match (&allowed, &default) { + (Some(allowed), Some(default)) => { + if !allowed.contains(&default) { + return Err(anyhow::anyhow!("default_role must be in allowed_roles: {}. Not found in {:?}", default, allowed)); + } + } + (None, Some(default)) => { + match default.as_str() { + "system" | "user" | "assistant" => {} + _ => return Err(anyhow::anyhow!("default_role must be one of 'system', 'user' or 'assistant': {}. Please specify \"allowed_roles\" if you want to use other custom default role.", default)), + } + } + _ => {} + } + Ok(RolesSelection { allowed, default }) + } +} + +#[derive(Clone, Debug)] +pub(crate) struct RolesSelection { + allowed: Option>, + default: Option, +} + +impl RolesSelection { + pub fn allowed_or_else(&self, f: impl FnOnce() -> Vec) -> Vec { + match self.allowed.as_ref() { + Some(allowed) => allowed.clone(), + None => f(), + } + } + + pub fn default_or_else(&self, f: impl FnOnce() -> String) -> String { + match self.default.as_ref() { + Some(default) => default.clone(), + None => f(), + } + } +} + #[derive(Clone, Debug)] pub enum UnresolvedAllowedRoleMetadata { Value(StringOr), diff --git a/engine/baml-runtime/Cargo.toml b/engine/baml-runtime/Cargo.toml index 5db9034e6..679d0c272 100644 --- a/engine/baml-runtime/Cargo.toml +++ b/engine/baml-runtime/Cargo.toml @@ -91,6 +91,7 @@ valuable = { version = "0.1.0", features = ["derive"] } tracing = { version = "0.1.40", features = ["valuable"] } tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter","valuable"] } thiserror = "2.0.1" +log-once = "0.4.1" [target.'cfg(target_arch = "wasm32")'.dependencies] @@ -141,7 +142,7 @@ which = "6.0.3" [features] -defaults = [] +defaults = ["skip-integ-tests"] internal = [] skip-integ-tests = [] diff --git a/engine/baml-runtime/src/cli/serve/error.rs b/engine/baml-runtime/src/cli/serve/error.rs index 236e4015f..cb30b7709 100644 --- a/engine/baml-runtime/src/cli/serve/error.rs +++ b/engine/baml-runtime/src/cli/serve/error.rs @@ -27,6 +27,13 @@ pub enum BamlError { raw_output: String, message: String, }, + #[serde(rename_all = "snake_case")] + FinishReasonError { + prompt: String, + raw_output: String, + message: String, + finish_reason: Option, + }, /// This is the only variant not documented at the aforementioned link: /// this is the catch-all for unclassified errors. #[serde(rename_all = "snake_case")] @@ -46,6 +53,17 @@ impl BamlError { raw_output: raw_output.to_string(), message: message.to_string(), }, + ExposedError::FinishReasonError { + prompt, + raw_output, + message, + finish_reason, + } => Self::FinishReasonError { + prompt: prompt.to_string(), + raw_output: raw_output.to_string(), + message: message.to_string(), + finish_reason: finish_reason.clone(), + }, } } else if let Some(er) = err.downcast_ref::() { Self::InvalidArgument { @@ -93,6 +111,7 @@ impl IntoResponse for BamlError { match &self { BamlError::InvalidArgument { .. } => StatusCode::BAD_REQUEST, BamlError::ClientError { .. } => StatusCode::BAD_GATEWAY, + BamlError::FinishReasonError { .. } => StatusCode::INTERNAL_SERVER_ERROR, // ??? - FIXME BamlError::ValidationFailure { .. } => StatusCode::INTERNAL_SERVER_ERROR, // ??? - FIXME BamlError::InternalError { .. } => StatusCode::INTERNAL_SERVER_ERROR, }, diff --git a/engine/baml-runtime/src/cli/serve/mod.rs b/engine/baml-runtime/src/cli/serve/mod.rs index 92b63e659..06e7961ea 100644 --- a/engine/baml-runtime/src/cli/serve/mod.rs +++ b/engine/baml-runtime/src/cli/serve/mod.rs @@ -226,7 +226,7 @@ impl Server { baml_api_key: Option<&XBamlApiKey>, ) -> AuthEnforcementMode { let Ok(password) = std::env::var("BAML_PASSWORD") else { - log::warn!("BAML_PASSWORD not set, skipping auth check"); + log_once::warn_once!("BAML_PASSWORD not set, skipping auth check"); return AuthEnforcementMode::NoEnforcement; }; diff --git a/engine/baml-runtime/src/errors.rs b/engine/baml-runtime/src/errors.rs index 324a7d26d..4dffec6d7 100644 --- a/engine/baml-runtime/src/errors.rs +++ b/engine/baml-runtime/src/errors.rs @@ -5,6 +5,12 @@ pub enum ExposedError { raw_output: String, message: String, }, + FinishReasonError { + prompt: String, + raw_output: String, + message: String, + finish_reason: Option, + }, } impl std::error::Error for ExposedError {} @@ -23,6 +29,21 @@ impl std::fmt::Display for ExposedError { message, prompt, raw_output ) } + ExposedError::FinishReasonError { + prompt, + raw_output, + message, + finish_reason, + } => { + write!( + f, + "Finish reason error: {}\nPrompt: {}\nRaw Response: {}\nFinish Reason: {}", + message, + prompt, + raw_output, + finish_reason.as_ref().map_or("", |f| f.as_str()) + ) + } } } } diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs index c8c7dd74c..089833682 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs @@ -7,7 +7,9 @@ use web_time::Duration; use crate::{ internal::{ llm_client::{ - parsed_value_to_response, traits::{WithPrompt, WithSingleCallable}, LLMResponse, ResponseBamlValue + parsed_value_to_response, + traits::{WithClientProperties, WithPrompt, WithSingleCallable}, + LLMResponse, ResponseBamlValue, }, prompt_renderer::PromptRenderer, }, @@ -50,17 +52,36 @@ pub async fn orchestrate( }; let response = node.single_call(&ctx, &prompt).await; let parsed_response = match &response { - LLMResponse::Success(s) => Some(parse_fn(&s.content)), + LLMResponse::Success(s) => { + if !node + .finish_reason_filter() + .is_allowed(s.metadata.finish_reason.as_ref()) + { + Some(Err(anyhow::anyhow!(crate::errors::ExposedError::FinishReasonError { + prompt: prompt.to_string(), + raw_output: s.content.clone(), + message: "Finish reason not allowed".to_string(), + finish_reason: s.metadata.finish_reason.clone(), + }))) + } else { + Some(parse_fn(&s.content)) + } + }, _ => None, }; let sleep_duration = node.error_sleep_duration().cloned(); let (parsed_response, response_with_constraints) = match parsed_response { - Some(Ok(v)) => (Some(Ok(v.clone())), Some(Ok(parsed_value_to_response(&v)))), - Some(Err(e)) => (None, Some(Err(e))), - None => (None, None), - }; - results.push((node.scope, response, parsed_response, response_with_constraints)); + Some(Ok(v)) => (Some(Ok(v.clone())), Some(Ok(parsed_value_to_response(&v)))), + Some(Err(e)) => (None, Some(Err(e))), + None => (None, None), + }; + results.push(( + node.scope, + response, + parsed_response, + response_with_constraints, + )); // Currently, we break out of the loop if an LLM responded, even if we couldn't parse the result. if results diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs index dc079b7fc..34e59d24e 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs @@ -222,3 +222,25 @@ impl WithStreamable for OrchestratorNode { self.provider.stream(ctx, prompt).await } } + +impl WithClientProperties for OrchestratorNode { + fn default_role(&self) -> String { + self.provider.default_role() + } + + fn allowed_metadata(&self) -> &internal_llm_client::AllowedRoleMetadata { + self.provider.allowed_metadata() + } + + fn supports_streaming(&self) -> bool { + self.provider.supports_streaming() + } + + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + self.provider.finish_reason_filter() + } + + fn allowed_roles(&self) -> Vec { + self.provider.allowed_roles() + } +} diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs index 74750a8bd..b7e6fae52 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs @@ -8,7 +8,9 @@ use web_time::Duration; use crate::{ internal::{ llm_client::{ - parsed_value_to_response, traits::{WithPrompt, WithStreamable}, LLMErrorResponse, LLMResponse, ResponseBamlValue + parsed_value_to_response, + traits::{WithClientProperties, WithPrompt, WithStreamable}, + LLMErrorResponse, LLMResponse, ResponseBamlValue, }, prompt_renderer::PromptRenderer, }, @@ -66,7 +68,10 @@ where LLMResponse::Success(s) => { let parsed = partial_parse_fn(&s.content); let (parsed, response_value) = match parsed { - Ok(v) => (Some(Ok(v.clone())), Some(Ok(parsed_value_to_response(&v)))), + Ok(v) => ( + Some(Ok(v.clone())), + Some(Ok(parsed_value_to_response(&v))), + ), Err(e) => (None, Some(Err(e))), }; on_event(FunctionResult::new( @@ -99,7 +104,21 @@ where }; let parsed_response = match &final_response { - LLMResponse::Success(s) => Some(parse_fn(&s.content)), + LLMResponse::Success(s) => { + if !node + .finish_reason_filter() + .is_allowed(s.metadata.finish_reason.as_ref()) + { + Some(Err(anyhow::anyhow!(crate::errors::ExposedError::FinishReasonError { + prompt: s.prompt.to_string(), + raw_output: s.content.clone(), + message: "Finish reason not allowed".to_string(), + finish_reason: s.metadata.finish_reason.clone(), + }))) + } else { + Some(parse_fn(&s.content)) + } + }, _ => None, }; let (parsed_response, response_value) = match parsed_response { @@ -107,7 +126,7 @@ where Some(Err(e)) => (None, Some(Err(e))), None => (None, None), }; - // parsed_response.map(|r| r.and_then(|v| parsed_value_to_response(v))); + // parsed_response.map(|r| r.and_then(|v| parsed_value_to_response(v))); let sleep_duration = node.error_sleep_duration().cloned(); results.push((node.scope, final_response, parsed_response, response_value)); diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs index 837470a5a..105e320b9 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs @@ -80,6 +80,15 @@ impl WithClientProperties for AnthropicClient { fn supports_streaming(&self) -> bool { self.properties.supported_request_modes.stream.unwrap_or(true) } + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + &self.properties.finish_reason_filter + } + fn default_role(&self) -> String { + self.properties.default_role() + } + fn allowed_roles(&self) -> Vec { + self.properties.allowed_roles() + } } impl WithClient for AnthropicClient { @@ -251,13 +260,12 @@ impl WithStreamChat for AnthropicClient { impl AnthropicClient { pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.provider, &client.unresolved_options()?, ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name.clone(), context: RenderContext_Client { name: client.name.clone(), provider: client.provider.to_string(), - default_role, + default_role: properties.default_role(), }, features: ModelFeatures { chat: true, @@ -274,13 +282,12 @@ impl AnthropicClient { pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.elem().provider, &client.options(), ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name().into(), context: RenderContext_Client { name: client.name().into(), provider: client.elem().provider.to_string(), - default_role, + default_role: properties.default_role(), }, features: ModelFeatures { chat: true, @@ -359,13 +366,6 @@ impl RequestBuilder for AnthropicClient { } impl WithChat for AnthropicClient { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - Ok(internal_baml_jinja::ChatOptions::new( - self.properties.default_role.clone(), - None, - )) - } - async fn chat(&self, _ctx: &RuntimeContext, prompt: &Vec) -> LLMResponse { let (response, system_now, instant_now) = match make_parsed_request::< AnthropicMessageResponse, diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs index bb6ffd7d0..3a9448889 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs @@ -62,14 +62,13 @@ impl AwsClient { pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.provider, &client.unresolved_options()?, ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name.clone(), context: RenderContext_Client { name: client.name.clone(), provider: client.provider.to_string(), - default_role, + default_role: properties.default_role(), }, features: ModelFeatures { chat: true, @@ -88,14 +87,13 @@ impl AwsClient { pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.elem().provider, &client.options(), ctx)?; - let default_role = properties.default_role.clone(); // clone before moving Ok(Self { name: client.name().into(), context: RenderContext_Client { name: client.name().into(), provider: client.elem().provider.to_string(), - default_role, + default_role: properties.default_role(), }, features: ModelFeatures { chat: true, @@ -272,6 +270,15 @@ impl WithClientProperties for AwsClient { fn supports_streaming(&self) -> bool { self.properties.supported_request_modes.stream.unwrap_or(true) } + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + &self.properties.finish_reason_filter + } + fn default_role(&self) -> String { + self.properties.default_role() + } + fn allowed_roles(&self) -> Vec { + self.properties.allowed_roles() + } } impl WithClient for AwsClient { @@ -581,13 +588,6 @@ impl AwsClient { } impl WithChat for AwsClient { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - Ok(internal_baml_jinja::ChatOptions::new( - self.properties.default_role.clone(), - None, - )) - } - async fn chat( &self, _ctx: &RuntimeContext, diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs index 7f2d78c3d..21a09fce3 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs @@ -75,6 +75,15 @@ impl WithClientProperties for GoogleAIClient { .stream .unwrap_or(true) } + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + &self.properties.finish_reason_filter + } + fn default_role(&self) -> String { + self.properties.default_role() + } + fn allowed_roles(&self) -> Vec { + self.properties.allowed_roles() + } } impl WithClient for GoogleAIClient { @@ -158,7 +167,7 @@ impl SseResponseTrait for GoogleAIClient { }; if let Some(choice) = event.candidates.get(0) { - if let Some(content) = choice.content.parts.get(0) { + if let Some(content) = choice.content.as_ref().and_then(|c| c.parts.get(0)) { inner.content += &content.text; } match choice.finish_reason.as_ref() { @@ -198,13 +207,12 @@ impl WithStreamChat for GoogleAIClient { impl GoogleAIClient { pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.elem().provider, &client.options(), ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name().into(), context: RenderContext_Client { name: client.name().into(), provider: client.elem().provider.to_string(), - default_role, + default_role: properties.default_role(), }, features: ModelFeatures { chat: true, @@ -225,14 +233,13 @@ impl GoogleAIClient { pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.provider, &client.unresolved_options()?, ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name.clone(), context: RenderContext_Client { name: client.name.clone(), provider: client.provider.to_string(), - default_role, + default_role: properties.default_role(), }, features: ModelFeatures { chat: true, @@ -305,13 +312,6 @@ impl RequestBuilder for GoogleAIClient { } impl WithChat for GoogleAIClient { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - Ok(internal_baml_jinja::ChatOptions::new( - self.properties.default_role.clone(), - None, - )) - } - async fn chat(&self, _ctx: &RuntimeContext, prompt: &Vec) -> LLMResponse { //non-streaming, complete response is returned let (response, system_now, instant_now) = @@ -338,10 +338,23 @@ impl WithChat for GoogleAIClient { }); } + let Some(content) = response.candidates[0].content.as_ref() else { + return LLMResponse::LLMFailure(LLMErrorResponse { + client: self.context.name.to_string(), + model: None, + prompt: internal_baml_jinja::RenderedPrompt::Chat(prompt.clone()), + start_time: system_now, + request_options: self.properties.properties.clone(), + latency: instant_now.elapsed(), + message: "No content returned".to_string(), + code: ErrorCode::Other(200), + }); + }; + LLMResponse::Success(LLMCompleteResponse { client: self.context.name.to_string(), prompt: internal_baml_jinja::RenderedPrompt::Chat(prompt.clone()), - content: response.candidates[0].content.parts[0].text.clone(), + content: content.parts[0].text.clone(), start_time: system_now, latency: instant_now.elapsed(), request_options: self.properties.properties.clone(), diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/google/types.rs b/engine/baml-runtime/src/internal/llm_client/primitive/google/types.rs index d75858703..0875e1d72 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/google/types.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/google/types.rs @@ -221,7 +221,7 @@ pub enum HarmSeverity { #[serde(rename_all = "camelCase")] pub struct Candidate { pub index: Option, - pub content: Content, + pub content: Option, pub finish_reason: Option, pub safety_ratings: Option>, // pub citation_metadata: Option, diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/mod.rs b/engine/baml-runtime/src/internal/llm_client/primitive/mod.rs index a49bbe4b1..268e2c738 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/mod.rs @@ -94,6 +94,15 @@ impl WithClientProperties for LLMPrimitiveProvider { fn supports_streaming(&self) -> bool { match_llm_provider!(self, supports_streaming) } + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + match_llm_provider!(self, finish_reason_filter) + } + fn default_role(&self) -> String { + match_llm_provider!(self, default_role) + } + fn allowed_roles(&self) -> Vec { + match_llm_provider!(self, allowed_roles) + } } impl TryFrom<(&ClientProperty, &RuntimeContext)> for LLMPrimitiveProvider { diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs index 6736413b5..7180465fb 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs @@ -6,7 +6,7 @@ use baml_types::{BamlMap, BamlMedia, BamlMediaContent, BamlMediaType}; use internal_baml_core::ir::ClientWalker; use internal_baml_jinja::{ChatMessagePart, RenderContext_Client, RenderedChatMessage}; use internal_llm_client::openai::ResolvedOpenAI; -use internal_llm_client::AllowedRoleMetadata; +use internal_llm_client::{AllowedRoleMetadata, FinishReasonFilter}; use serde_json::json; use crate::internal::llm_client::{ @@ -14,7 +14,7 @@ use crate::internal::llm_client::{ }; use super::properties; -use super::types::{ChatCompletionResponse, ChatCompletionResponseDelta, FinishReason}; +use super::types::{ChatCompletionResponse, ChatCompletionResponseDelta}; use crate::client_registry::ClientProperty; use crate::internal::llm_client::primitive::request::{ @@ -56,19 +56,21 @@ impl WithClientProperties for OpenAIClient { fn allowed_metadata(&self) -> &AllowedRoleMetadata { &self.properties.allowed_metadata } + + fn finish_reason_filter(&self) -> &FinishReasonFilter { + &self.properties.finish_reason_filter + } + + fn allowed_roles(&self) -> Vec { + self.properties.allowed_roles() + } + + fn default_role(&self) -> String { + self.properties.default_role() + } + fn supports_streaming(&self) -> bool { - match self.properties.supported_request_modes.stream { - Some(v) => v, - None => { - match self.properties.properties.get("model") { - Some(serde_json::Value::String(model)) => { - // OpenAI's streaming is not available for o1-* models - !model.starts_with("o1-") - } - _ => true, - } - } - } + self.properties.supports_streaming() } } @@ -155,13 +157,6 @@ impl WithNoCompletion for OpenAIClient {} // } impl WithChat for OpenAIClient { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - Ok(internal_baml_jinja::ChatOptions::new( - self.properties.default_role.clone(), - None, - )) - } - async fn chat(&self, _ctx: &RuntimeContext, prompt: &Vec) -> LLMResponse { let (response, system_start, instant_start) = match make_parsed_request::( @@ -208,18 +203,11 @@ impl WithChat for OpenAIClient { request_options: self.properties.properties.clone(), metadata: LLMCompleteResponseMetadata { baml_is_complete: match response.choices.get(0) { - Some(c) => match c.finish_reason { - Some(FinishReason::Stop) => true, - _ => false, - }, + Some(c) => c.finish_reason.as_ref().is_some_and(|f| f == "stop"), None => false, }, finish_reason: match response.choices.get(0) { - Some(c) => match c.finish_reason { - Some(FinishReason::Stop) => Some(FinishReason::Stop.to_string()), - Some(other) => Some(other.to_string()), - _ => None, - }, + Some(c) => c.finish_reason.clone(), None => None, }, prompt_tokens: usage.map(|u| u.prompt_tokens), @@ -377,18 +365,8 @@ impl SseResponseTrait for OpenAIClient { inner.content += content.as_str(); } inner.model = event.model; - match choice.finish_reason.as_ref() { - Some(FinishReason::Stop) => { - inner.metadata.baml_is_complete = true; - inner.metadata.finish_reason = - Some(FinishReason::Stop.to_string()); - } - finish_reason => { - inner.metadata.baml_is_complete = false; - inner.metadata.finish_reason = - finish_reason.as_ref().map(|r| r.to_string()); - } - } + inner.metadata.finish_reason = choice.finish_reason.clone(); + inner.metadata.baml_is_complete = choice.finish_reason.as_ref().is_some_and(|s| s == "stop"); } inner.latency = instant_start.elapsed(); if let Some(usage) = event.usage.as_ref() { @@ -427,7 +405,7 @@ macro_rules! make_openai_client { context: RenderContext_Client { name: $client.name.clone(), provider: $client.provider.to_string(), - default_role: $properties.default_role.clone(), + default_role: $properties.default_role(), }, features: ModelFeatures { chat: true, @@ -448,7 +426,7 @@ macro_rules! make_openai_client { context: RenderContext_Client { name: $client.name().into(), provider: $client.elem().provider.to_string(), - default_role: $properties.default_role.clone(), + default_role: $properties.default_role(), }, features: ModelFeatures { chat: true, diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/openai/types.rs b/engine/baml-runtime/src/internal/llm_client/primitive/openai/types.rs index d457b2dea..377ca6402 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/openai/types.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/openai/types.rs @@ -28,7 +28,7 @@ pub struct ChatCompletionGeneric { #[derive(Debug, Deserialize, Clone, PartialEq)] pub struct CompletionChoice { - pub finish_reason: Option, + pub finish_reason: Option, pub index: u32, pub text: String, } @@ -42,7 +42,7 @@ pub struct ChatCompletionChoice { /// `length` if the maximum number of tokens specified in the request was reached, /// `content_filter` if content was omitted due to a flag from our content filters, /// `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. - pub finish_reason: Option, + pub finish_reason: Option, /// Log probability information for the choice. pub logprobs: Option, } @@ -78,7 +78,7 @@ pub struct ChatCompletionResponseMessage { #[derive(Deserialize, Clone, Debug)] pub struct ChatCompletionChoiceDelta { pub index: u64, - pub finish_reason: Option, + pub finish_reason: Option, pub delta: ChatCompletionMessageDelta, } @@ -99,7 +99,7 @@ pub struct ChatCompletionMessageDelta { // pub function_call: Option, } -#[derive(Debug, Deserialize, Clone, Copy, Default, PartialEq)] +#[derive(Debug, Deserialize, Clone, Default, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ChatCompletionMessageRole { System, @@ -110,19 +110,6 @@ pub enum ChatCompletionMessageRole { Function, } -#[derive(Debug, Deserialize, strum_macros::Display, Clone, Copy, PartialEq, Serialize)] -#[serde(rename_all = "snake_case")] -#[strum(serialize_all = "snake_case")] -pub enum FinishReason { - Stop, - Length, - ToolCalls, - ContentFilter, - FunctionCall, - #[serde(other)] - Unknown, -} - #[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ChatChoiceLogprobs { /// A list of message content tokens with log probability information. diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs index be64f13d7..9789d0311 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs @@ -113,6 +113,15 @@ impl WithClientProperties for VertexClient { .stream .unwrap_or(true) } + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + &self.properties.finish_reason_filter + } + fn default_role(&self) -> String { + self.properties.default_role() + } + fn allowed_roles(&self) -> Vec { + self.properties.allowed_roles() + } } impl WithClient for VertexClient { @@ -240,13 +249,12 @@ impl WithStreamChat for VertexClient { impl VertexClient { pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.elem().provider, client.options(), ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name().into(), context: RenderContext_Client { name: client.name().into(), provider: client.elem().provider.to_string(), - default_role, + default_role: properties.default_role(), }, features: ModelFeatures { chat: true, @@ -267,14 +275,13 @@ impl VertexClient { pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.provider, &client.unresolved_options()?, ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name.clone(), context: RenderContext_Client { name: client.name.clone(), provider: client.provider.to_string(), - default_role, + default_role: properties.default_role(), }, features: ModelFeatures { chat: true, @@ -394,13 +401,6 @@ impl RequestBuilder for VertexClient { } impl WithChat for VertexClient { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - Ok(internal_baml_jinja::ChatOptions::new( - self.properties.default_role.clone(), - None, - )) - } - async fn chat(&self, _ctx: &RuntimeContext, prompt: &Vec) -> LLMResponse { //non-streaming, complete response is returned let (response, system_now, instant_now) = diff --git a/engine/baml-runtime/src/internal/llm_client/traits/chat.rs b/engine/baml-runtime/src/internal/llm_client/traits/chat.rs index 5c26ae016..7d826ee74 100644 --- a/engine/baml-runtime/src/internal/llm_client/traits/chat.rs +++ b/engine/baml-runtime/src/internal/llm_client/traits/chat.rs @@ -5,9 +5,20 @@ use crate::{internal::llm_client::LLMResponse, RuntimeContext}; use super::StreamResponse; -pub trait WithChat: Sync + Send { +pub trait WithChatOptions { fn chat_options(&self, ctx: &RuntimeContext) -> Result; +} +impl WithChatOptions for T +where + T: super::WithClientProperties, +{ + fn chat_options(&self, ctx: &RuntimeContext) -> Result { + Ok(ChatOptions::new(self.default_role(), Some(self.allowed_roles()))) + } +} + +pub trait WithChat: Sync + Send + WithChatOptions { #[allow(async_fn_in_trait)] async fn chat(&self, ctx: &RuntimeContext, prompt: &Vec) -> LLMResponse; } @@ -25,12 +36,8 @@ pub trait WithNoChat {} impl WithChat for T where - T: WithNoChat + Send + Sync, + T: WithNoChat + Send + Sync + WithChatOptions, { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - anyhow::bail!("Chat prompts are not supported by this provider") - } - #[allow(async_fn_in_trait)] async fn chat(&self, _: &RuntimeContext, _: &Vec) -> LLMResponse { LLMResponse::InternalFailure("Chat prompts are not supported by this provider".to_string()) diff --git a/engine/baml-runtime/src/internal/llm_client/traits/mod.rs b/engine/baml-runtime/src/internal/llm_client/traits/mod.rs index 7bc160d44..65240668c 100644 --- a/engine/baml-runtime/src/internal/llm_client/traits/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/traits/mod.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, path::PathBuf, pin::Pin}; use anyhow::{Context, Result}; use aws_smithy_types::byte_stream::error::Error; -use internal_llm_client::AllowedRoleMetadata; +use internal_llm_client::{AllowedRoleMetadata, FinishReasonFilter}; use serde_json::{json, Map}; mod chat; @@ -36,6 +36,9 @@ pub trait WithRetryPolicy { pub trait WithClientProperties { fn allowed_metadata(&self) -> &AllowedRoleMetadata; fn supports_streaming(&self) -> bool; + fn finish_reason_filter(&self) -> &FinishReasonFilter; + fn default_role(&self) -> String; + fn allowed_roles(&self) -> Vec; } pub trait WithSingleCallable { @@ -143,10 +146,11 @@ pub trait WithRenderRawCurl { impl WithSingleCallable for T where - T: WithClient + WithChat + WithCompletion, + T: WithClient + WithChat + WithCompletion + WithClientProperties, { #[allow(async_fn_in_trait)] async fn single_call(&self, ctx: &RuntimeContext, prompt: &RenderedPrompt) -> LLMResponse { + log::warn!("debug single_call start: {:?}", prompt); if let RenderedPrompt::Chat(chat) = &prompt { match process_media_urls( self.model_features().resolve_media_urls, diff --git a/engine/baml-runtime/src/types/response.rs b/engine/baml-runtime/src/types/response.rs index f0581091d..b609f0cde 100644 --- a/engine/baml-runtime/src/types/response.rs +++ b/engine/baml-runtime/src/types/response.rs @@ -220,6 +220,7 @@ pub enum TestFailReason<'a> { TestUnspecified(anyhow::Error), TestLLMFailure(&'a LLMResponse), TestParseFailure(&'a anyhow::Error), + TestFinishReasonFailed(&'a anyhow::Error), TestConstraintsFailure { checks: Vec<(String, bool)>, failed_assert: Option, @@ -234,6 +235,9 @@ impl PartialEq for TestFailReason<'_> { (Self::TestParseFailure(a), Self::TestParseFailure(b)) => { a.to_string() == b.to_string() } + (Self::TestFinishReasonFailed(a), Self::TestFinishReasonFailed(b)) => { + a.to_string() == b.to_string() + } _ => false, } } @@ -267,9 +271,13 @@ impl TestResponse { } } } else { - TestStatus::Fail(TestFailReason::TestParseFailure( - parsed.as_ref().unwrap_err(), - )) + let err = parsed.as_ref().unwrap_err(); + match err.downcast_ref::() { + Some(ExposedError::FinishReasonError { .. }) => { + TestStatus::Fail(TestFailReason::TestFinishReasonFailed(&err)) + } + _ => TestStatus::Fail(TestFailReason::TestParseFailure(&err)), + } } } else { TestStatus::Fail(TestFailReason::TestLLMFailure(func_res.llm_response())) diff --git a/engine/baml-runtime/tests/harness.rs b/engine/baml-runtime/tests/harness.rs index d76899547..a1bf91b13 100644 --- a/engine/baml-runtime/tests/harness.rs +++ b/engine/baml-runtime/tests/harness.rs @@ -45,7 +45,7 @@ impl Harness { cmd.args(args.split_ascii_whitespace()); cmd.current_dir(&self.test_dir); // cmd.env("RUST_BACKTRACE", "1"); - cmd.env("BAML_LOG", "debug,jsonish=info"); + // cmd.env("BAML_LOG", "debug,jsonish=info"); Ok(cmd) } diff --git a/engine/baml-runtime/tests/test_cli.rs b/engine/baml-runtime/tests/test_cli.rs index b550a0c20..857e3903e 100644 --- a/engine/baml-runtime/tests/test_cli.rs +++ b/engine/baml-runtime/tests/test_cli.rs @@ -17,10 +17,7 @@ use serde_json::json; // Run this with cargo test --features internal // run the CLI using debug build using: engine/target/debug/baml-runtime dev -#[cfg(all( - not(feature = "skip-integ-tests"), - any(feature = "OPENAI_API_KEY", env = "OPENAI_API_KEY") -))] +#[cfg(not(feature = "skip-integ-tests"))] mod test_cli { use super::*; use pretty_assertions::assert_eq; diff --git a/engine/baml-runtime/tests/test_runtime.rs b/engine/baml-runtime/tests/test_runtime.rs index c3323def9..0fc564829 100644 --- a/engine/baml-runtime/tests/test_runtime.rs +++ b/engine/baml-runtime/tests/test_runtime.rs @@ -9,7 +9,7 @@ mod internal_tests { use baml_runtime::BamlRuntime; use std::sync::Once; - use baml_runtime::internal::llm_client::orchestrator::OrchestrationScope; + // use baml_runtime::internal::llm_client::orchestrator::OrchestrationScope; use baml_runtime::InternalRuntimeInterface; use baml_types::BamlValue; @@ -180,7 +180,7 @@ mod internal_tests { let ctx = runtime .create_ctx_manager(BamlValue::String("test".to_string()), None) - .create_ctx_with_default(missing_env_vars.iter()); + .create_ctx_with_default(); let params = runtime.get_test_params(function_name, test_name, &ctx, true)?; @@ -256,11 +256,9 @@ mod internal_tests { )?; log::info!("Runtime:"); - let missing_env_vars = runtime.internal().ir().required_env_vars(); - let ctx = runtime .create_ctx_manager(BamlValue::String("test".to_string()), None) - .create_ctx_with_default(missing_env_vars.iter()); + .create_ctx_with_default(); let params = runtime.get_test_params(function_name, test_name, &ctx, true)?; @@ -337,11 +335,9 @@ test ImageReceiptTest { "##, )?; - let missing_env_vars = runtime.internal().ir().required_env_vars(); - let ctx = runtime .create_ctx_manager(BamlValue::String("test".to_string()), None) - .create_ctx_with_default(missing_env_vars.iter()); + .create_ctx_with_default(); let function_name = "ExtractReceipt"; let test_name = "ImageReceiptTest"; @@ -420,11 +416,9 @@ test TestName { "##, )?; - let missing_env_vars = runtime.internal().ir().required_env_vars(); - let ctx = runtime .create_ctx_manager(BamlValue::String("test".to_string()), None) - .create_ctx_with_default(missing_env_vars.iter()); + .create_ctx_with_default(); let function_name = "Bot"; let test_name = "TestName"; @@ -490,11 +484,9 @@ test TestTree { "##, )?; - let missing_env_vars = runtime.internal().ir().required_env_vars(); - let ctx = runtime .create_ctx_manager(BamlValue::String("test".to_string()), None) - .create_ctx_with_default(missing_env_vars.iter()); + .create_ctx_with_default(); let function_name = "BuildTree"; let test_name = "TestTree"; diff --git a/engine/baml-schema-wasm/src/lib.rs b/engine/baml-schema-wasm/src/lib.rs index dca44e5a7..243f94848 100644 --- a/engine/baml-schema-wasm/src/lib.rs +++ b/engine/baml-schema-wasm/src/lib.rs @@ -1,4 +1,6 @@ +#[cfg(target_arch = "wasm32")] pub mod runtime_wasm; + use std::env; use wasm_bindgen::prelude::*; diff --git a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs index 3553f296a..aafa61af1 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs @@ -422,6 +422,7 @@ pub enum TestStatus { Passed, LLMFailure, ParseFailure, + FinishReasonFailed, ConstraintsFailed, AssertFailed, UnableToRun, @@ -597,6 +598,9 @@ impl WasmTestResponse { baml_runtime::TestFailReason::TestUnspecified(_) => TestStatus::UnableToRun, baml_runtime::TestFailReason::TestLLMFailure(_) => TestStatus::LLMFailure, baml_runtime::TestFailReason::TestParseFailure(_) => TestStatus::ParseFailure, + baml_runtime::TestFailReason::TestFinishReasonFailed(_) => { + TestStatus::FinishReasonFailed + } baml_runtime::TestFailReason::TestConstraintsFailure { failed_assert, .. } => { @@ -784,7 +788,20 @@ impl WithRenderError for baml_runtime::TestFailReason<'_> { match &self { baml_runtime::TestFailReason::TestUnspecified(e) => Some(format!("{e:#}")), baml_runtime::TestFailReason::TestLLMFailure(f) => f.render_error(), - baml_runtime::TestFailReason::TestParseFailure(e) => Some(format!("{e:#}")), + baml_runtime::TestFailReason::TestParseFailure(e) + | baml_runtime::TestFailReason::TestFinishReasonFailed(e) => { + match e.downcast_ref::() { + Some(exposed_error) => match exposed_error { + baml_runtime::errors::ExposedError::ValidationError { message, .. } => { + Some(message.clone()) + } + baml_runtime::errors::ExposedError::FinishReasonError { + message, .. + } => Some(message.clone()), + }, + None => Some(format!("{e:#}")), + } + } baml_runtime::TestFailReason::TestConstraintsFailure { checks, failed_assert, @@ -847,10 +864,10 @@ fn get_dummy_value( TypeValue::Bool => "true".to_string(), TypeValue::Null => "null".to_string(), TypeValue::Media(BamlMediaType::Image) => { - "{ url \"https://imgs.xkcd.com/comics/standards.png\"}".to_string() + "{ url \"https://imgs.xkcd.com/comics/standards.png\" }".to_string() } TypeValue::Media(BamlMediaType::Audio) => { - "{ url \"https://actions.google.com/sounds/v1/emergency/beeper_emergency_call.ogg\"}".to_string() + "{ url \"https://actions.google.com/sounds/v1/emergency/beeper_emergency_call.ogg\" }".to_string() } }; diff --git a/engine/baml-schema-wasm/tests/test_file_manager.rs b/engine/baml-schema-wasm/tests/test_file_manager.rs index 489a93431..dfed4997a 100644 --- a/engine/baml-schema-wasm/tests/test_file_manager.rs +++ b/engine/baml-schema-wasm/tests/test_file_manager.rs @@ -1,5 +1,6 @@ // Run from the baml-schema-wasm folder with: // wasm-pack test --node +#[cfg(target_arch = "wasm32")] #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/engine/language_client_python/python_src/baml_py/internal_monkeypatch.py b/engine/language_client_python/python_src/baml_py/internal_monkeypatch.py index 1e6ae231d..2743ab978 100644 --- a/engine/language_client_python/python_src/baml_py/internal_monkeypatch.py +++ b/engine/language_client_python/python_src/baml_py/internal_monkeypatch.py @@ -1,5 +1,5 @@ from .baml_py import BamlError - +from typing import Optional # Define the BamlValidationError exception with additional fields # note on custom exceptions https://github.com/PyO3/pyo3/issues/295 @@ -16,3 +16,17 @@ def __str__(self): def __repr__(self): return f"BamlValidationError(message={self.message}, raw_output={self.raw_output}, prompt={self.prompt})" + +class BamlClientFinishReasonError(BamlError): + def __init__(self, prompt: str, message: str, raw_output: str, finish_reason: Optional[str]): + super().__init__(message) + self.prompt = prompt + self.message = message + self.raw_output = raw_output + self.finish_reason = finish_reason + + def __str__(self): + return f"BamlClientFinishReasonError(message={self.message}, raw_output={self.raw_output}, prompt={self.prompt}, finish_reason={self.finish_reason})" + + def __repr__(self): + return f"BamlClientFinishReasonError(message={self.message}, raw_output={self.raw_output}, prompt={self.prompt}, finish_reason={self.finish_reason})" diff --git a/engine/language_client_python/src/errors.rs b/engine/language_client_python/src/errors.rs index b07a26f7b..de2c33fd2 100644 --- a/engine/language_client_python/src/errors.rs +++ b/engine/language_client_python/src/errors.rs @@ -25,6 +25,17 @@ fn raise_baml_validation_error(prompt: String, message: String, raw_output: Stri }) } +#[allow(non_snake_case)] +fn raise_baml_client_finish_reason_error(prompt: String, raw_output: String, message: String, finish_reason: Option) -> PyErr { + Python::with_gil(|py| { + let internal_monkeypatch = py.import("baml_py.internal_monkeypatch").unwrap(); + let exception = internal_monkeypatch.getattr("BamlClientFinishReasonError").unwrap(); + let args = (prompt, message, raw_output, finish_reason); + let inst = exception.call1(args).unwrap(); + PyErr::from_value(inst) + }) +} + /// Defines the errors module with the BamlValidationError exception. /// IIRC the name of this function is the name of the module that pyo3 generates (errors.py) #[pymodule] @@ -64,6 +75,14 @@ impl BamlError { // If not, you may need to adjust this part based on the actual structure of ValidationError raise_baml_validation_error(prompt.clone(), message.clone(), raw_output.clone()) } + ExposedError::FinishReasonError { + prompt, + raw_output, + message, + finish_reason, + } => { + raise_baml_client_finish_reason_error(prompt.clone(), raw_output.clone(), message.clone(), finish_reason.clone()) + } } } else if let Some(er) = err.downcast_ref::() { PyErr::new::(format!("Invalid argument: {}", er)) diff --git a/engine/language_client_typescript/src/errors.rs b/engine/language_client_typescript/src/errors.rs index 4af13f6ea..ce9ef1821 100644 --- a/engine/language_client_typescript/src/errors.rs +++ b/engine/language_client_typescript/src/errors.rs @@ -20,6 +20,12 @@ pub fn from_anyhow_error(err: anyhow::Error) -> napi::Error { message, raw_output: raw_response, } => throw_baml_validation_error(prompt, raw_response, message), + ExposedError::FinishReasonError { + prompt, + message, + raw_output: raw_response, + finish_reason, + } => throw_baml_client_finish_reason_error(prompt, raw_response, message, finish_reason), } } else if let Some(er) = err.downcast_ref::() { invalid_argument_error(&format!("{}", er)) @@ -79,3 +85,14 @@ pub fn throw_baml_validation_error(prompt: &str, raw_output: &str, message: &str }); napi::Error::new(napi::Status::GenericFailure, error_json.to_string()) } + +pub fn throw_baml_client_finish_reason_error(prompt: &str, raw_output: &str, message: &str, finish_reason: Option<&str>) -> napi::Error { + let error_json = serde_json::json!({ + "type": "BamlClientFinishReasonError", + "prompt": prompt, + "raw_output": raw_output, + "message": format!("BamlClientFinishReasonError: {}", message), + "finish_reason": finish_reason, + }); + napi::Error::new(napi::Status::GenericFailure, error_json.to_string()) +} diff --git a/engine/language_client_typescript/typescript_src/index.ts b/engine/language_client_typescript/typescript_src/index.ts index 27bd94611..d0aff086d 100644 --- a/engine/language_client_typescript/typescript_src/index.ts +++ b/engine/language_client_typescript/typescript_src/index.ts @@ -8,62 +8,119 @@ export { invoke_runtime_cli, ClientRegistry, BamlLogEvent, -} from './native' -export { BamlStream } from './stream' -export { BamlCtxManager } from './async_context_vars' +} from "./native"; +export { BamlStream } from "./stream"; +export { BamlCtxManager } from "./async_context_vars"; -export class BamlValidationError extends Error { - prompt: string - raw_output: string +export class BamlClientFinishReasonError extends Error { + prompt: string; + raw_output: string; constructor(prompt: string, raw_output: string, message: string) { - super(message) - this.name = 'BamlValidationError' - this.prompt = prompt - this.raw_output = raw_output + super(message); + this.name = "BamlClientFinishReasonError"; + this.prompt = prompt; + this.raw_output = raw_output; - Object.setPrototypeOf(this, BamlValidationError.prototype) + Object.setPrototypeOf(this, BamlClientFinishReasonError.prototype); } - static from(error: Error): BamlValidationError | Error { - if (error.message.includes('BamlValidationError')) { + toJSON(): string { + return JSON.stringify( + { + name: this.name, + message: this.message, + raw_output: this.raw_output, + prompt: this.prompt, + }, + null, + 2 + ); + } + + static from(error: Error): BamlClientFinishReasonError | undefined { + if (error.message.includes("BamlClientFinishReasonError")) { try { - const errorData = JSON.parse(error.message) - if (errorData.type === 'BamlValidationError') { - return new BamlValidationError( - errorData.prompt || '', - errorData.raw_output || '', - errorData.message || error.message, - ) + const errorData = JSON.parse(error.message); + if (errorData.type === "BamlClientFinishReasonError") { + return new BamlClientFinishReasonError( + errorData.prompt || "", + errorData.raw_output || "", + errorData.message || error.message + ); } else { - console.warn('Not a BamlValidationError:', error) + console.warn("Not a BamlClientFinishReasonError:", error); } } catch (parseError) { // If JSON parsing fails, fall back to the original error - console.warn('Failed to parse BamlValidationError:', parseError) + console.warn("Failed to parse BamlClientFinishReasonError:", parseError); } } + return undefined; + } +} + +export class BamlValidationError extends Error { + prompt: string; + raw_output: string; + + constructor(prompt: string, raw_output: string, message: string) { + super(message); + this.name = "BamlValidationError"; + this.prompt = prompt; + this.raw_output = raw_output; - // If it's not a BamlValidationError or parsing failed, return the original error - return error + Object.setPrototypeOf(this, BamlValidationError.prototype); } toJSON(): string { return JSON.stringify( { + name: this.name, message: this.message, raw_output: this.raw_output, prompt: this.prompt, }, null, - 2, - ) + 2 + ); + } + + static from(error: Error): BamlValidationError | undefined { + if (error.message.includes("BamlValidationError")) { + try { + const errorData = JSON.parse(error.message); + if (errorData.type === "BamlValidationError") { + return new BamlValidationError( + errorData.prompt || "", + errorData.raw_output || "", + errorData.message || error.message + ); + } + } catch (parseError) { + console.warn("Failed to parse BamlValidationError:", parseError); + } + } + return undefined; } } // Helper function to safely create a BamlValidationError -export function createBamlValidationError(error: Error): BamlValidationError | Error { - return BamlValidationError.from(error) +export function createBamlValidationError( + error: Error +): BamlValidationError | BamlClientFinishReasonError | Error { + const bamlValidationError = BamlValidationError.from(error); + if (bamlValidationError) { + return bamlValidationError; + } + + const bamlClientFinishReasonError = BamlClientFinishReasonError.from(error); + if (bamlClientFinishReasonError) { + return bamlClientFinishReasonError; + } + + // otherwise return the original error + return error; } // No need for a separate throwBamlValidationError function in TypeScript diff --git a/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts b/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts index a534e7888..e9c202661 100644 --- a/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts +++ b/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts @@ -10,7 +10,13 @@ export const showTestsAtom = atom(false) export const showClientGraphAtom = atom(false) export type TestStatusType = 'queued' | 'running' | 'done' | 'error' -export type DoneTestStatusType = 'passed' | 'llm_failed' | 'parse_failed' | 'constraints_failed' | 'error' +export type DoneTestStatusType = + | 'passed' + | 'llm_failed' + | 'finish_reason_failed' + | 'parse_failed' + | 'constraints_failed' + | 'error' export type TestState = | { status: 'queued' @@ -37,12 +43,26 @@ export const testStatusAtom = atomFamily( (a, b) => a === b, ) export const runningTestsAtom = atom([]) + +// Match the Rust enum +// engine/baml-schema-wasm/src/runtime_wasm/mod.rs +enum RustTestStatus { + Passed, + LLMFailure, + ParseFailure, + FinishReasonFailed, + ConstraintsFailed, + AssertFailed, + UnableToRun, +} + export const statusCountAtom = atom({ queued: 0, running: 0, done: { passed: 0, llm_failed: 0, + finish_reason_failed: 0, parse_failed: 0, constraints_failed: 0, error: 0, @@ -135,6 +155,7 @@ export const useRunHooks = () => { done: { passed: 0, llm_failed: 0, + finish_reason_failed: 0, parse_failed: 0, constraints_failed: 0, error: 0, @@ -189,15 +210,17 @@ export const useRunHooks = () => { const { res, elapsed } = result.value // console.log('result', i, result.value.res.llm_response(), 'batch[i]', batch[i]) - let status: Number = res.status() + let status: RustTestStatus = res.status() as unknown as RustTestStatus let response_status: DoneTestStatusType = 'error' - if (status === 0) { + if (status === RustTestStatus.Passed) { response_status = 'passed' - } else if (status === 1) { + } else if (status === RustTestStatus.LLMFailure) { response_status = 'llm_failed' - } else if (status === 2) { + } else if (status === RustTestStatus.ParseFailure) { response_status = 'parse_failed' - } else if (status === 3 || status === 4) { + } else if (status === RustTestStatus.FinishReasonFailed) { + response_status = 'finish_reason_failed' + } else if (status === RustTestStatus.ConstraintsFailed || status === RustTestStatus.AssertFailed) { response_status = 'constraints_failed' } else { response_status = 'error' diff --git a/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx b/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx index bda2d47a8..7570b1ad2 100644 --- a/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx +++ b/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx @@ -58,6 +58,8 @@ const TestStatusMessage: React.FC<{ testStatus: DoneTestStatusType }> = ({ testS return
LLM Failed
case 'parse_failed': return
Parse Failed
+ case 'finish_reason_failed': + return
Finish Reason Failed
case 'constraints_failed': return
Constraints Failed
case 'error': @@ -101,9 +103,25 @@ const TestStatusIcon: React.FC<{ ) } -type FilterValues = 'queued' | 'running' | 'error' | 'llm_failed' | 'parse_failed' | 'constraints_failed' | 'passed' +type FilterValues = + | 'queued' + | 'running' + | 'error' + | 'llm_failed' + | 'parse_failed' + | 'constraints_failed' + | 'passed' + | 'finish_reason_failed' const filterAtom = atom( - new Set(['running', 'error', 'llm_failed', 'parse_failed', 'constraints_failed', 'passed']), + new Set([ + 'running', + 'error', + 'llm_failed', + 'parse_failed', + 'constraints_failed', + 'passed', + 'finish_reason_failed', + ]), ) const checkFilter = (filter: Set, status: TestStatusType, test_status?: DoneTestStatusType) => { @@ -150,6 +168,17 @@ const ParsedTestResult: React.FC<{ doneStatus: string; parsed?: WasmParsedTestRe } }, [parsed, hasClosedIntroToChecksDialog, setShowIntroToChecksDialog]) + if (doneStatus === 'finish_reason_failed') { + return ( +
+
Pre-parse Error
+
+ {failure &&
{failure}
} +
+
+ ) + } + if (doneStatus === 'parse_failed' || parsed !== undefined) { return (