Skip to content

Commit

Permalink
feat(adapters): add embedding support for Bedrock (#253)
Browse files Browse the repository at this point in the history
Ref: #176
  • Loading branch information
Tomas2D authored Jan 3, 2025
1 parent e9d5d4f commit c989e27
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
29 changes: 25 additions & 4 deletions src/adapters/bedrock/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import { Emitter } from "@/emitter/emitter.js";
import type { AwsCredentialIdentity, Provider } from "@aws-sdk/types";
import {
BedrockRuntimeClient as Client,
InvokeModelCommand,
ConverseCommand,
ConverseCommandOutput,
ConverseStreamCommand,
Expand All @@ -42,10 +43,14 @@ import {
} from "@aws-sdk/client-bedrock-runtime";
import { GetRunContext } from "@/context.js";
import { Serializer } from "@/serializer/serializer.js";
import { NotImplementedError } from "@/errors.js";
import { omitUndefined } from "@/internals/helpers/object.js";

type Response = ContentBlockDeltaEvent | ConverseCommandOutput;

export interface BedrockEmbeddingOptions extends EmbeddingOptions {
body?: Record<string, any>;
}

export class ChatBedrockOutput extends ChatLLMOutput {
public readonly responses: Response[];

Expand Down Expand Up @@ -204,9 +209,25 @@ export class BedrockChatLLM extends ChatLLM<ChatBedrockOutput> {
};
}

// eslint-disable-next-line unused-imports/no-unused-vars
async embed(input: BaseMessage[][], options?: EmbeddingOptions): Promise<EmbeddingOutput> {
throw new NotImplementedError();
async embed(
input: BaseMessage[][],
options: BedrockEmbeddingOptions = {},
): Promise<EmbeddingOutput> {
const command = new InvokeModelCommand({
modelId: this.modelId,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(
omitUndefined({
inputText: input.flat().map((msg) => msg.text),
...options?.body,
}),
),
});

const response = await this.client.send(command, { abortSignal: options?.signal });
const jsonString = new TextDecoder().decode(response.body);
return JSON.parse(jsonString);
}

async tokenize(input: BaseMessage[]): Promise<BaseLLMTokenizeOutput> {
Expand Down
35 changes: 35 additions & 0 deletions tests/e2e/adapters/bedrock/chat.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Copyright 2025 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { BedrockChatLLM } from "@/adapters/bedrock/chat.js";
import { BaseMessage } from "@/llms/primitives/message.js";

describe.runIf([process.env.AWS_REGION].every((env) => Boolean(env)))("Bedrock Chat LLM", () => {
it("Embeds", async () => {
const llm = new BedrockChatLLM({
region: process.env.AWS_REGION,
modelId: "amazon.titan-embed-text-v1",
});

const response = await llm.embed([
[BaseMessage.of({ role: "user", text: `Hello world!` })],
[BaseMessage.of({ role: "user", text: `Hello family!` })],
]);
expect(response.embeddings.length).toBe(2);
expect(response.embeddings[0].length).toBe(512);
expect(response.embeddings[1].length).toBe(512);
});
});

0 comments on commit c989e27

Please sign in to comment.