Skip to content

Commit

Permalink
feat(adapters): extends Ollama embedding options
Browse files Browse the repository at this point in the history
Ref: #176
Signed-off-by: Tomas Dvorak <toomas2d@gmail.com>
  • Loading branch information
Tomas2D committed Dec 13, 2024
1 parent ca17a18 commit d3c9364
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
10 changes: 7 additions & 3 deletions src/adapters/ollama/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import {
AsyncStream,
BaseLLMTokenizeOutput,
EmbeddingOptions,
EmbeddingOutput,
ExecutionOptions,
GenerateOptions,
Expand All @@ -42,6 +41,7 @@ import {
retrieveVersion,
} from "@/adapters/ollama/shared.js";
import { getEnv } from "@/internals/env.js";
import { OllamaEmbeddingOptions } from "@/adapters/ollama/llm.js";

export class OllamaChatLLMOutput extends ChatLLMOutput {
public readonly results: ChatResponse[];
Expand Down Expand Up @@ -160,11 +160,15 @@ export class OllamaChatLLM extends ChatLLM<OllamaChatLLMOutput> {
return extractModelMeta(model);
}

// eslint-disable-next-line unused-imports/no-unused-vars
async embed(input: BaseMessage[][], options?: EmbeddingOptions): Promise<EmbeddingOutput> {
async embed(
input: BaseMessage[][],
options: OllamaEmbeddingOptions = {},
): Promise<EmbeddingOutput> {
const response = await this.client.embed({
model: this.modelId,
input: input.flatMap((messages) => messages).flatMap((msg) => msg.text),
options: options?.options,
truncate: options?.truncate,
});
return { embeddings: response.embeddings };
}
Expand Down
15 changes: 12 additions & 3 deletions src/adapters/ollama/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ interface Input {
cache?: LLMCache<OllamaLLMOutput>;
}

export interface OllamaEmbeddingOptions extends EmbeddingOptions {
options?: Partial<Parameters>;
truncate?: boolean;
}

export class OllamaLLMOutput extends BaseLLMOutput {
public readonly results: GenerateResponse[];

Expand Down Expand Up @@ -187,9 +192,13 @@ export class OllamaLLM extends LLM<OllamaLLMOutput> {
return extractModelMeta(model);
}

// eslint-disable-next-line unused-imports/no-unused-vars
async embed(input: LLMInput[], options?: EmbeddingOptions): Promise<EmbeddingOutput> {
const response = await this.client.embed({ model: this.modelId, input: input });
async embed(input: LLMInput[], options: OllamaEmbeddingOptions = {}): Promise<EmbeddingOutput> {
const response = await this.client.embed({
model: this.modelId,
input: input,
options: options?.options,
truncate: options?.truncate,
});
return { embeddings: response.embeddings };
}

Expand Down

0 comments on commit d3c9364

Please sign in to comment.