Skip to content

Commit

Permalink
Fix azure client
Browse files Browse the repository at this point in the history
Add new client paramters: allowed_roles, default_role, finish_reason_allow_list, finish_reason_deny_list

TODO: add more tests, but screenshots show it working
  • Loading branch information
hellovai committed Dec 3, 2024
1 parent 40edbc2 commit 8b20396
Show file tree
Hide file tree
Showing 38 changed files with 870 additions and 411 deletions.
10 changes: 10 additions & 0 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

67 changes: 27 additions & 40 deletions engine/baml-lib/llm-client/src/clients/anthropic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashSet;

use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata};
use crate::{AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection};
use anyhow::Result;

use baml_types::{EvaluationContext, StringOr, UnresolvedValue};
Expand All @@ -12,48 +12,60 @@ use super::helpers::{Error, PropertyHandler, UnresolvedUrl};
pub struct UnresolvedAnthropic<Meta> {
base_url: UnresolvedUrl,
api_key: StringOr,
allowed_roles: Vec<StringOr>,
default_role: Option<StringOr>,
role_selection: UnresolvedRolesSelection,
allowed_metadata: UnresolvedAllowedRoleMetadata,
supported_request_modes: SupportedRequestModes,
headers: IndexMap<String, StringOr>,
properties: IndexMap<String, (Meta, UnresolvedValue<Meta>)>,
finish_reason_filter: UnresolvedFinishReasonFilter,
}

impl<Meta> UnresolvedAnthropic<Meta> {
pub fn without_meta(&self) -> UnresolvedAnthropic<()> {
UnresolvedAnthropic {
base_url: self.base_url.clone(),
api_key: self.api_key.clone(),
allowed_roles: self.allowed_roles.clone(),
default_role: self.default_role.clone(),
role_selection: self.role_selection.clone(),
allowed_metadata: self.allowed_metadata.clone(),
supported_request_modes: self.supported_request_modes.clone(),
headers: self.headers.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
properties: self.properties.iter().map(|(k, (_, v))| (k.clone(), ((), v.without_meta()))).collect(),
finish_reason_filter: self.finish_reason_filter.clone(),
}
}
}

pub struct ResolvedAnthropic {
pub base_url: String,
pub api_key: String,
pub allowed_roles: Vec<String>,
pub default_role: String,
role_selection: RolesSelection,
pub allowed_metadata: AllowedRoleMetadata,
pub supported_request_modes: SupportedRequestModes,
pub headers: IndexMap<String, String>,
pub properties: IndexMap<String, serde_json::Value>,
pub proxy_url: Option<String>,
pub finish_reason_filter: FinishReasonFilter,
}

impl ResolvedAnthropic {
pub fn allowed_roles(&self) -> Vec<String> {
self.role_selection.allowed_or_else(|| {
vec!["system".to_string(), "user".to_string(), "assistant".to_string()]
})
}

pub fn default_role(&self) -> String {
self.role_selection.default_or_else(|| "user".to_string())
}
}


impl<Meta: Clone> UnresolvedAnthropic<Meta> {
pub fn required_env_vars(&self) -> HashSet<String> {
let mut env_vars = HashSet::new();
env_vars.extend(self.base_url.required_env_vars());
env_vars.extend(self.api_key.required_env_vars());
env_vars.extend(self.allowed_roles.iter().map(|r| r.required_env_vars()).flatten());
self.default_role.as_ref().map(|r| env_vars.extend(r.required_env_vars()));
env_vars.extend(self.role_selection.required_env_vars());
env_vars.extend(self.allowed_metadata.required_env_vars());
env_vars.extend(self.supported_request_modes.required_env_vars());
env_vars.extend(self.headers.values().map(|v| v.required_env_vars()).flatten());
Expand All @@ -63,25 +75,6 @@ impl<Meta: Clone> UnresolvedAnthropic<Meta> {
}

pub fn resolve(&self, ctx: &EvaluationContext<'_>) -> Result<ResolvedAnthropic> {
let allowed_roles = self
.allowed_roles
.iter()
.map(|role| role.resolve(ctx))
.collect::<Result<Vec<_>>>()?;

let Some(default_role) = self.default_role.as_ref() else {
return Err(anyhow::anyhow!("default_role must be provided"));
};
let default_role = default_role.resolve(ctx)?;

if !allowed_roles.contains(&default_role) {
return Err(anyhow::anyhow!(
"default_role must be in allowed_roles: {} not in {:?}",
default_role,
allowed_roles
));
}

let base_url = self.base_url.resolve(ctx)?;

let mut headers = self
Expand Down Expand Up @@ -112,13 +105,13 @@ impl<Meta: Clone> UnresolvedAnthropic<Meta> {
Ok(ResolvedAnthropic {
base_url,
api_key: self.api_key.resolve(ctx)?,
allowed_roles,
default_role,
role_selection: self.role_selection.resolve(ctx)?,
allowed_metadata: self.allowed_metadata.resolve(ctx)?,
supported_request_modes: self.supported_request_modes.clone(),
headers,
properties,
proxy_url: super::helpers::get_proxy_url(ctx),
finish_reason_filter: self.finish_reason_filter.resolve(ctx)?,
})
}

Expand All @@ -131,17 +124,11 @@ impl<Meta: Clone> UnresolvedAnthropic<Meta> {
.map(|(_, v, _)| v.clone())
.unwrap_or(StringOr::EnvVar("ANTHROPIC_API_KEY".to_string()));

let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![
StringOr::Value("system".to_string()),
StringOr::Value("user".to_string()),
StringOr::Value("assistant".to_string()),
]);

let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1);
let role_selection = properties.ensure_roles_selection();
let allowed_metadata = properties.ensure_allowed_metadata();
let supported_request_modes = properties.ensure_supported_request_modes();
let headers = properties.ensure_headers().unwrap_or_default();

let finish_reason_filter = properties.ensure_finish_reason_filter();
let (properties, errors) = properties.finalize();
if !errors.is_empty() {
return Err(errors);
Expand All @@ -150,12 +137,12 @@ impl<Meta: Clone> UnresolvedAnthropic<Meta> {
Ok(Self {
base_url,
api_key,
allowed_roles,
default_role,
role_selection,
allowed_metadata,
supported_request_modes,
headers,
properties,
finish_reason_filter,
})
}
}
67 changes: 25 additions & 42 deletions engine/baml-lib/llm-client/src/clients/aws_bedrock.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashSet;

use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata};
use crate::{AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection};
use anyhow::Result;

use baml_types::{EvaluationContext, StringOr};
Expand All @@ -13,11 +13,11 @@ pub struct UnresolvedAwsBedrock {
region: StringOr,
access_key_id: StringOr,
secret_access_key: StringOr,
allowed_roles: Vec<StringOr>,
default_role: Option<StringOr>,
role_selection: UnresolvedRolesSelection,
allowed_role_metadata: UnresolvedAllowedRoleMetadata,
supported_request_modes: SupportedRequestModes,
inference_config: Option<UnresolvedInferenceConfiguration>,
finish_reason_filter: UnresolvedFinishReasonFilter,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -64,10 +64,22 @@ pub struct ResolvedAwsBedrock {
pub access_key_id: Option<String>,
pub secret_access_key: Option<String>,
pub inference_config: Option<InferenceConfiguration>,
pub allowed_roles: Vec<String>,
pub default_role: String,
role_selection: RolesSelection,
pub allowed_role_metadata: AllowedRoleMetadata,
pub supported_request_modes: SupportedRequestModes,
pub finish_reason_filter: FinishReasonFilter,
}

impl ResolvedAwsBedrock {
pub fn allowed_roles(&self) -> Vec<String> {
self.role_selection.allowed_or_else(|| {
vec!["system".to_string(), "user".to_string(), "assistant".to_string()]
})
}

pub fn default_role(&self) -> String {
self.role_selection.default_or_else(|| "user".to_string())
}
}

impl UnresolvedAwsBedrock {
Expand All @@ -79,15 +91,7 @@ impl UnresolvedAwsBedrock {
env_vars.extend(self.region.required_env_vars());
env_vars.extend(self.access_key_id.required_env_vars());
env_vars.extend(self.secret_access_key.required_env_vars());
env_vars.extend(
self.allowed_roles
.iter()
.map(|r| r.required_env_vars())
.flatten(),
);
self.default_role
.as_ref()
.map(|r| env_vars.extend(r.required_env_vars()));
env_vars.extend(self.role_selection.required_env_vars());
env_vars.extend(self.allowed_role_metadata.required_env_vars());
env_vars.extend(self.supported_request_modes.required_env_vars());
self.inference_config
Expand All @@ -101,39 +105,22 @@ impl UnresolvedAwsBedrock {
return Err(anyhow::anyhow!("model must be provided"));
};

let allowed_roles = self
.allowed_roles
.iter()
.map(|role| role.resolve(ctx))
.collect::<Result<Vec<_>>>()?;

let Some(default_role) = self.default_role.as_ref() else {
return Err(anyhow::anyhow!("default_role must be provided"));
};
let default_role = default_role.resolve(ctx)?;

if !allowed_roles.contains(&default_role) {
return Err(anyhow::anyhow!(
"default_role must be in allowed_roles: {} not in {:?}",
default_role,
allowed_roles
));
}
let role_selection = self.role_selection.resolve(ctx)?;

Ok(ResolvedAwsBedrock {
model: model.resolve(ctx)?,
region: self.region.resolve(ctx).ok(),
access_key_id: self.access_key_id.resolve(ctx).ok(),
secret_access_key: self.secret_access_key.resolve(ctx).ok(),
allowed_roles,
default_role,
role_selection,
allowed_role_metadata: self.allowed_role_metadata.resolve(ctx)?,
supported_request_modes: self.supported_request_modes.clone(),
inference_config: self
.inference_config
.as_ref()
.map(|c| c.resolve(ctx))
.transpose()?,
finish_reason_filter: self.finish_reason_filter.resolve(ctx)?,
})
}

Expand Down Expand Up @@ -176,12 +163,7 @@ impl UnresolvedAwsBedrock {
.map(|(_, v, _)| v.clone())
.unwrap_or_else(|| baml_types::StringOr::EnvVar("AWS_SECRET_ACCESS_KEY".to_string()));

let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![
StringOr::Value("system".to_string()),
StringOr::Value("user".to_string()),
StringOr::Value("assistant".to_string()),
]);
let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1);
let role_selection = properties.ensure_roles_selection();
let allowed_metadata = properties.ensure_allowed_metadata();
let supported_request_modes = properties.ensure_supported_request_modes();

Expand Down Expand Up @@ -242,6 +224,7 @@ impl UnresolvedAwsBedrock {
}
Some(inference_config)
};
let finish_reason_filter = properties.ensure_finish_reason_filter();

// TODO: Handle inference_configuration
let errors = properties.finalize_empty();
Expand All @@ -254,11 +237,11 @@ impl UnresolvedAwsBedrock {
region,
access_key_id,
secret_access_key,
allowed_roles,
default_role,
role_selection,
allowed_role_metadata: allowed_metadata,
supported_request_modes,
inference_config,
finish_reason_filter,
})
}
}
Loading

0 comments on commit 8b20396

Please sign in to comment.