Skip to content

Commit

Permalink
Fix proxy not returning errors if an image couldnt be fetched
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronvg committed Dec 11, 2024
1 parent 7384ba8 commit f117ad9
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 15 deletions.
2 changes: 1 addition & 1 deletion engine/baml-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ tokio-stream = "0.1.15"
uuid = { version = "1.8.0", features = ["v4", "serde"] }
web-time.workspace = true
static_assertions.workspace = true
mime_guess = "2.0.4"
mime_guess = "=2.0.5"
mime = "0.3.17"

# For tracing
Expand Down
38 changes: 26 additions & 12 deletions engine/baml-runtime/src/internal/llm_client/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,18 +621,26 @@ async fn to_base64_with_inferred_mime_type(
Ok(response) => response,
Err(e) => return Err(anyhow::anyhow!("Failed to fetch media: {e:?}")),
};
let bytes = match response.bytes().await {
Ok(bytes) => bytes,
Err(e) => return Err(anyhow::anyhow!("Failed to fetch media bytes: {e:?}")),
};
let base64 = BASE64_STANDARD.encode(&bytes);
// TODO: infer based on file extension?
let mime_type = match infer::get(&bytes) {
Some(t) => t.mime_type(),
None => "application/octet-stream",
if response.status().is_success() {
let bytes = match response.bytes().await {
Ok(bytes) => bytes,
Err(e) => return Err(anyhow::anyhow!("Failed to fetch media bytes: {e:?}")),
};
let base64 = BASE64_STANDARD.encode(&bytes);
// TODO: infer based on file extension?
let mime_type = match infer::get(&bytes) {
Some(t) => t.mime_type(),
None => "application/octet-stream",
}
.to_string();
Ok((base64, mime_type))
} else {
Err(anyhow::anyhow!(
"Failed to fetch media: {}, {}",
response.status(),
response.text().await.unwrap_or_default()
))
}
.to_string();
Ok((base64, mime_type))
}

/// A naive implementation of the data URL parser, returning the (mime_type, base64)
Expand All @@ -658,7 +666,13 @@ async fn fetch_with_proxy(

let request = if let Some(proxy) = proxy_url {
client
.get(format!("{}/{}", proxy, url))
.get(format!(
"{}{}",
proxy,
url.parse::<url::Url>()
.map_err(|e| anyhow::anyhow!("Failed to parse URL: {}", e))?
.path()
))
.header("baml-original-url", url)
} else {
client.get(url)
Expand Down
34 changes: 32 additions & 2 deletions typescript/vscode-ext/packages/vscode/src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ export function activate(context: vscode.ExtensionContext) {
if (path.endsWith('/')) {
return path.slice(0, -1)
}
console.log('pathRewrite', path, req)
return path
},
router: (req) => {
Expand All @@ -211,15 +212,44 @@ export function activate(context: vscode.ExtensionContext) {
if (originalUrl.endsWith('/')) {
originalUrl = originalUrl.slice(0, -1)
}
return originalUrl
console.log('returning original url', originalUrl)
return new URL(originalUrl).origin
} else {
console.log('baml-original-url header is missing or invalid')
throw new Error('baml-original-url header is missing or invalid')
}
},
logger: console,
on: {
proxyReq: (proxyReq, req, res) => {
console.debug('Proxying an LLM request (to bypass CORS)', { proxyReq, req, res })
console.log('proxying request')

try {
const bamlOriginalUrl = req.headers['baml-original-url']
if (bamlOriginalUrl === undefined) {
return
}
const targetUrl = new URL(bamlOriginalUrl)
// proxyReq.path = targetUrl.pathname
// proxyReq.p
// It is very important that we ONLY resolve against API_KEY_INJECTION_ALLOWED
// by using the URL origin! (i.e. NOT using str.startsWith - the latter can still
// leak API keys to malicious subdomains e.g. https://api.openai.com.evil.com)
// const headers = API_KEY_INJECTION_ALLOWED[proxyOrigin]
// if (headers === undefined) {
// return
// }
// for (const [header, value] of Object.entries(headers)) {
// proxyReq.setHeader(header, value)
// }
// proxyReq.removeHeader('origin')
// proxyReq.setHeader('Origin', targetUrl.origin)
console.info('Proxying an LLM request (to bypass CORS)', { proxyReq, req, res })

} catch (err) {
// This is not console.warn because it's not important
console.log('baml-original-url is not parsable', err)
}
},
proxyRes: (proxyRes, req, res) => {
proxyRes.headers['Access-Control-Allow-Origin'] = '*'
Expand Down

0 comments on commit f117ad9

Please sign in to comment.