Skip to content

Commit

Permalink
feat: provide saner semantics around aws_session_token (#1295)
Browse files Browse the repository at this point in the history
Change aws credential loading logic:

- if any of `access_key_id`, `secret_access_key`, or `session_token` are
set, all 3 are loaded explicitly (either from the `.baml` client
definition or the dynamic client properties)
- if, and only if, none of the 3 are set, all 3 are loaded from,
respectively, `AWS_ACCESS_KEY_ID` `AWS_SECRET_ACCESS_KEY`
`AWS_SESSION_TOKEN`

This most closely matches the behavior of the AWS SDKs (Python, TS, and
Rust). See [slack
thread](https://gloo-global.slack.com/archives/C03KV1PJ6EM/p1736215459393779)
which is copied below:

> OK, so chris and i figured out what happened with bedrock/ethan:
> 
> in #1266, chris correctly added support for aws session token so that
if a user set it in aws.baml as properties { session_token
env.AWS_SESSION_TOKEN }, baml would respect that (prior to #1266 baml
would not)
> 
> - however, 1266 also introduced an implicit default: if
AWS_SESSION_TOKEN is set in the process' environment, but the user only
set properties { access_key_id ... ; secret_access_key ... ; } then baml
would construct the aws creds using access_key_id secret_access_key and
session_token
> - in ethan's case this is problematic, because he uses custom values
for the access_key_id secret_access_key pair from his lambda secrets,
but the aws lambda environment also sets AWS_ACCESS_KEY_ID
AWS_SECRET_ACCESS_KEY AWS_SESSION_TOKEN
> - as a result, after updating past #1266, his (access_key_id,
secret_access_key) did not agree with his AWS_SESSION_TOKEN and caused a
runtime failure
> - he also has no way to opt out of this behavior, because we do not
currently provide a way to force session_token to null: it is always
inferred from the environment by default
> 
> to solve this, we're going to use the following logic for aws creds:
> - if any of access_key_id, secret_access_key, or session_token are set
in baml client properties, we will never magically infer a value for any
of the 3 from the environment
>     -  but if none of the 3 are set, we will read all 3 from the env
> - this behavior feels most in line with how credential init in the aws
sdk normally works
> - in ts, an AwsCredentialIdentityProvider is any function that returns
{ accessKeyId: string, secretAccessKey: string, sessionToken?: string,
expiration?: Date} (docs)
> - in python, if you set any of the 3, a creds object is constructed
using the explicitly provided values of all 3 (impl callsite, impl
source code)
> - in rust, this is what constructing Credentials::new does when you
override the credentials_loader
> 
> stepping back, this is a mix of the two approaches: (1) session_token:
unset defaults to reading AWS_SESSION_TOKEN from the env, and user is
allowed to explicitly set session_token: null or (2) session_token:
unset never reads from the env, and the user must always set it
> 
> NB: this does not explain why multiple other customers are complaining
about not being able to figure out how to use aws bedrock. so @Vaibhav
Gupta we still need to see what those other complaints are
<!-- ELLIPSIS_HIDDEN -->

----

> [!IMPORTANT]
> Updates AWS credential loading logic to use explicit credentials if
set, otherwise defaults to environment variables, aligning with AWS SDK
behavior.
> 
>   - **Behavior**:
> - Updates AWS credential loading logic in `aws_bedrock.rs` and
`aws_client.rs`.
> - If any of `access_key_id`, `secret_access_key`, or `session_token`
are set, all are loaded explicitly.
> - If none are set, all are loaded from environment variables
`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN`.
>   - **Error Handling**:
> - Adds error checks for environment variable placeholders in
`aws_client.rs`.
>   - **Misc**:
> - Adjusts credential provider logic in `aws_client.rs` to use
`DefaultCredentialsChain` when no credentials are provided.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 8199507. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->

---------

Co-authored-by: Chris Watts <chris.watts.t@gmail.com>
  • Loading branch information
sxlijin and seawatts authored Jan 7, 2025
1 parent 43a0007 commit 98c6b99
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 56 deletions.
39 changes: 27 additions & 12 deletions engine/baml-lib/llm-client/src/clients/aws_bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,12 @@ impl UnresolvedAwsBedrock {

let access_key_id = match self.access_key_id.as_ref() {
Some(access_key_id) => Some(access_key_id.resolve(ctx)?),
None => match ctx.get_env_var("AWS_ACCESS_KEY_ID") {
Ok(access_key_id) if !access_key_id.is_empty() => Some(access_key_id),
_ => None,
},
None => None,
};

let secret_access_key = match self.secret_access_key.as_ref() {
Some(secret_access_key) => Some(secret_access_key.resolve(ctx)?),
None => match ctx.get_env_var("AWS_SECRET_ACCESS_KEY") {
Ok(secret_access_key) if !secret_access_key.is_empty() => Some(secret_access_key),
_ => None,
},
None => None,
};

let session_token = match self.session_token.as_ref() {
Expand All @@ -205,12 +199,33 @@ impl UnresolvedAwsBedrock {
None
}
}
None => match ctx.get_env_var("AWS_SESSION_TOKEN") {
Ok(session_token) if !session_token.is_empty() => Some(session_token),
_ => None,
},
None => None,
};

let (access_key_id, secret_access_key, session_token) =
match (access_key_id, secret_access_key, session_token) {
(None, None, None) => {
// If no credentials provided, get them all from env vars
let access_key_id = match ctx.get_env_var("AWS_ACCESS_KEY_ID") {
Ok(key) if !key.is_empty() => Some(key),
_ => None,
};
let secret_access_key = match ctx.get_env_var("AWS_SECRET_ACCESS_KEY") {
Ok(key) if !key.is_empty() => Some(key),
_ => None,
};
let session_token = match ctx.get_env_var("AWS_SESSION_TOKEN") {
Ok(token) if !token.is_empty() => Some(token),
_ => None,
};
(access_key_id, secret_access_key, session_token)
}
// If any credentials are explicitly provided, use those
(access_key_id, secret_access_key, session_token) => {
(access_key_id, secret_access_key, session_token)
}
};

#[cfg(not(target_arch = "wasm32"))]
let profile = match self.profile.as_ref() {
Some(profile) => Some(profile.resolve(ctx)?),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,60 +144,69 @@ impl AwsClient {
}

// Set region if specified
if let Some(aws_region) = self.properties.region.as_ref() {
if aws_region.starts_with("$") {
return Err(anyhow::anyhow!(
"AWS region expected, please set: env.{}",
&aws_region[1..]
));
}

loader = loader.region(Region::new(aws_region.clone()));
}

// Set credentials provider
let loader = if let (Some(aws_access_key_id), Some(aws_secret_access_key)) = (
let mut loader = match (
self.properties.access_key_id.as_ref(),
self.properties.secret_access_key.as_ref(),
self.properties.session_token.as_ref(),
) {
let aws_session_token = self.properties.session_token.clone();

if aws_access_key_id.starts_with("$") {
return Err(anyhow::anyhow!(
"AWS access key id expected, please set: env.{}",
&aws_access_key_id[1..]
));
(None, None, None) => {
// If no credentials provided, get them all from env vars
loader.credentials_provider(
aws_config::default_provider::credentials::DefaultCredentialsChain::builder()
.build()
.await,
)
}
_ => {
if let Some(aws_access_key_id) = self.properties.access_key_id.as_ref() {
if aws_access_key_id.starts_with("$") {
return Err(anyhow::anyhow!(
"AWS access key id expected, please set: env.{}",
&aws_access_key_id[1..]
));
}
}
if let Some(aws_secret_access_key) = self.properties.secret_access_key.as_ref() {
if aws_secret_access_key.starts_with("$") {
return Err(anyhow::anyhow!(
"AWS secret access key expected, please set: env.{}",
&aws_secret_access_key[1..]
));
}
}
if let Some(aws_session_token) = self.properties.session_token.as_ref() {
if aws_session_token.starts_with("$") {
return Err(anyhow::anyhow!(
"AWS session token expected, please set: env.{}",
&aws_session_token[1..]
));
}
}
loader.credentials_provider(Credentials::new(
self.properties.access_key_id.clone().unwrap_or("".into()),
self.properties
.secret_access_key
.clone()
.unwrap_or("".into()),
self.properties.session_token.clone(),
None,
"baml-runtime",
))
}
if aws_secret_access_key.starts_with("$") {
};

if let Some(aws_region) = self.properties.region.as_ref() {
if aws_region.starts_with("$") {
return Err(anyhow::anyhow!(
"AWS secret access key expected, please set: env.{}",
&aws_secret_access_key[1..]
"AWS region expected, please set: env.{}",
&aws_region[1..]
));
}
if let Some(aws_session_token) = aws_session_token.as_ref() {
if aws_session_token.starts_with("$") {
return Err(anyhow::anyhow!(
"AWS session token expected, please set: env.{}",
&aws_session_token[1..]
));
}
}

loader.credentials_provider(Credentials::new(
aws_access_key_id.clone(),
aws_secret_access_key.clone(),
aws_session_token,
None,
"baml-runtime",
))
} else {
// Use default provider chain which includes SSO, profile, environment variables, etc.
loader.credentials_provider(
aws_config::default_provider::credentials::DefaultCredentialsChain::builder()
.build()
.await,
)
};
loader = loader.region(Region::new(aws_region.clone()));
}

let config = loader.load().await;
Ok(bedrock::Client::new(&config))
Expand Down

0 comments on commit 98c6b99

Please sign in to comment.