diff --git a/engine/baml-runtime/Cargo.toml b/engine/baml-runtime/Cargo.toml index 679d0c272..fdd11b554 100644 --- a/engine/baml-runtime/Cargo.toml +++ b/engine/baml-runtime/Cargo.toml @@ -63,7 +63,7 @@ tokio-stream = "0.1.15" uuid = { version = "1.8.0", features = ["v4", "serde"] } web-time.workspace = true static_assertions.workspace = true -mime_guess = "2.0.4" +mime_guess = "=2.0.5" mime = "0.3.17" # For tracing diff --git a/engine/baml-runtime/src/internal/llm_client/traits/mod.rs b/engine/baml-runtime/src/internal/llm_client/traits/mod.rs index ef72f9b40..13945309c 100644 --- a/engine/baml-runtime/src/internal/llm_client/traits/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/traits/mod.rs @@ -621,18 +621,26 @@ async fn to_base64_with_inferred_mime_type( Ok(response) => response, Err(e) => return Err(anyhow::anyhow!("Failed to fetch media: {e:?}")), }; - let bytes = match response.bytes().await { - Ok(bytes) => bytes, - Err(e) => return Err(anyhow::anyhow!("Failed to fetch media bytes: {e:?}")), - }; - let base64 = BASE64_STANDARD.encode(&bytes); - // TODO: infer based on file extension? - let mime_type = match infer::get(&bytes) { - Some(t) => t.mime_type(), - None => "application/octet-stream", + if response.status().is_success() { + let bytes = match response.bytes().await { + Ok(bytes) => bytes, + Err(e) => return Err(anyhow::anyhow!("Failed to fetch media bytes: {e:?}")), + }; + let base64 = BASE64_STANDARD.encode(&bytes); + // TODO: infer based on file extension? + let mime_type = match infer::get(&bytes) { + Some(t) => t.mime_type(), + None => "application/octet-stream", + } + .to_string(); + Ok((base64, mime_type)) + } else { + Err(anyhow::anyhow!( + "Failed to fetch media: {}, {}", + response.status(), + response.text().await.unwrap_or_default() + )) } - .to_string(); - Ok((base64, mime_type)) } /// A naive implementation of the data URL parser, returning the (mime_type, base64) @@ -658,7 +666,13 @@ async fn fetch_with_proxy( let request = if let Some(proxy) = proxy_url { client - .get(format!("{}/{}", proxy, url)) + .get(format!( + "{}{}", + proxy, + url.parse::() + .map_err(|e| anyhow::anyhow!("Failed to parse URL: {}", e))? + .path() + )) .header("baml-original-url", url) } else { client.get(url) diff --git a/typescript/vscode-ext/packages/vscode/src/extension.ts b/typescript/vscode-ext/packages/vscode/src/extension.ts index 64b59c11e..022fa1116 100644 --- a/typescript/vscode-ext/packages/vscode/src/extension.ts +++ b/typescript/vscode-ext/packages/vscode/src/extension.ts @@ -197,6 +197,7 @@ export function activate(context: vscode.ExtensionContext) { if (path.endsWith('/')) { return path.slice(0, -1) } + console.log('pathRewrite', path, req) return path }, router: (req) => { @@ -211,15 +212,44 @@ export function activate(context: vscode.ExtensionContext) { if (originalUrl.endsWith('/')) { originalUrl = originalUrl.slice(0, -1) } - return originalUrl + console.log('returning original url', originalUrl) + return new URL(originalUrl).origin } else { + console.log('baml-original-url header is missing or invalid') throw new Error('baml-original-url header is missing or invalid') } }, logger: console, on: { proxyReq: (proxyReq, req, res) => { - console.debug('Proxying an LLM request (to bypass CORS)', { proxyReq, req, res }) + console.log('proxying request') + + try { + const bamlOriginalUrl = req.headers['baml-original-url'] + if (bamlOriginalUrl === undefined) { + return + } + const targetUrl = new URL(bamlOriginalUrl) + // proxyReq.path = targetUrl.pathname + // proxyReq.p + // It is very important that we ONLY resolve against API_KEY_INJECTION_ALLOWED + // by using the URL origin! (i.e. NOT using str.startsWith - the latter can still + // leak API keys to malicious subdomains e.g. https://api.openai.com.evil.com) + // const headers = API_KEY_INJECTION_ALLOWED[proxyOrigin] + // if (headers === undefined) { + // return + // } + // for (const [header, value] of Object.entries(headers)) { + // proxyReq.setHeader(header, value) + // } + // proxyReq.removeHeader('origin') + // proxyReq.setHeader('Origin', targetUrl.origin) + console.info('Proxying an LLM request (to bypass CORS)', { proxyReq, req, res }) + + } catch (err) { + // This is not console.warn because it's not important + console.log('baml-original-url is not parsable', err) + } }, proxyRes: (proxyRes, req, res) => { proxyRes.headers['Access-Control-Allow-Origin'] = '*'