From c5985f140e680b3d36afb7bab820b81058ccd26a Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Mon, 3 Jun 2024 10:26:36 +0200 Subject: [PATCH] fixup! feat: support langchain serialization Signed-off-by: Tomas Dvorak --- src/langchain/llm-chat.ts | 14 +++++++++++--- src/langchain/llm.ts | 12 ++++++++++-- tests/e2e/langchain/llm-chat.test.ts | 5 ++++- tests/e2e/langchain/llm.test.ts | 24 ++++++++++++------------ 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/src/langchain/llm-chat.ts b/src/langchain/llm-chat.ts index 00b9cf7..7073eda 100644 --- a/src/langchain/llm-chat.ts +++ b/src/langchain/llm-chat.ts @@ -31,7 +31,7 @@ export type GenAIChatModelOptions = BaseChatModelCallOptions & Partial>; export class GenAIChatModel extends BaseChatModel { - protected readonly client: Client; + public readonly client: Client; public readonly modelId: GenAIChatModelParams['model_id']; public readonly promptId: GenAIChatModelParams['prompt_id']; @@ -236,16 +236,24 @@ export class GenAIChatModel extends BaseChatModel { useConversationParameters: undefined, parentId: undefined, trimMethod: undefined, + client: undefined, }; - static async fromJSON(value: string | Serialized) { - const input = typeof value === 'string' ? value : JSON.stringify(value); + get lc_secrets() { + return { ...super.lc_secrets, client: 'client' }; + } + + static async fromJSON(value: string | Serialized, client?: Client) { + const input = typeof value !== 'string' ? JSON.stringify(value) : value; return await load(input, { optionalImportsMap: { '@ibm-generative-ai/node-sdk/langchain/llm-chat': { GenAIModel: GenAIChatModel, }, }, + secretsMap: { + client, + }, }); } diff --git a/src/langchain/llm.ts b/src/langchain/llm.ts index e935b83..3eb4d57 100644 --- a/src/langchain/llm.ts +++ b/src/langchain/llm.ts @@ -34,7 +34,7 @@ export type GenAIModelOptions = BaseLLMCallOptions & Partial>; export class GenAIModel extends BaseLLM { - protected readonly client: Client; + public readonly client: Client; public readonly modelId: GenAIModelParams['model_id']; public readonly promptId: GenAIModelParams['prompt_id']; @@ -185,7 +185,7 @@ export class GenAIModel extends BaseLLM { return result.results.at(0)?.token_count ?? 0; } - static async fromJSON(value: string | Serialized) { + static async fromJSON(value: string | Serialized, client?: Client) { const input = typeof value === 'string' ? value : JSON.stringify(value); return await load(input, { optionalImportsMap: { @@ -193,6 +193,9 @@ export class GenAIModel extends BaseLLM { GenAIModel: GenAIModel, }, }, + secretsMap: { + client, + }, }); } @@ -216,5 +219,10 @@ export class GenAIModel extends BaseLLM { promptId: undefined, parameters: undefined, moderations: undefined, + client: undefined, }; + + get lc_secrets() { + return { ...super.lc_secrets, client: 'client' }; + } } diff --git a/tests/e2e/langchain/llm-chat.test.ts b/tests/e2e/langchain/llm-chat.test.ts index 4f030bf..26a0417 100644 --- a/tests/e2e/langchain/llm-chat.test.ts +++ b/tests/e2e/langchain/llm-chat.test.ts @@ -126,7 +126,10 @@ describe('LangChain Chat', () => { it('Serializes', async () => { const model = makeModel(); const serialized = model.toJSON(); - const deserialized = await GenAIChatModel.fromJSON(serialized); + const deserialized = await GenAIChatModel.fromJSON( + serialized, + model.client, + ); expect(deserialized).toBeInstanceOf(GenAIChatModel); }); }); diff --git a/tests/e2e/langchain/llm.test.ts b/tests/e2e/langchain/llm.test.ts index a37d466..24194d5 100644 --- a/tests/e2e/langchain/llm.test.ts +++ b/tests/e2e/langchain/llm.test.ts @@ -5,7 +5,7 @@ import { GenAIModel } from '../../../src/langchain/llm.js'; import { Client } from '../../../src/client.js'; describe('Langchain', () => { - const makeClient = (modelId: string) => + const makeModel = (modelId: string) => new GenAIModel({ model_id: modelId, client: new Client({ @@ -26,7 +26,7 @@ describe('Langchain', () => { describe('tokenization', () => { it('should correctly calculate tokens', async () => { - const client = makeClient('google/flan-ul2'); + const client = makeModel('google/flan-ul2'); const tokensCount = await client.getNumTokens( 'What is the biggest building on this planet?', ); @@ -35,30 +35,30 @@ describe('Langchain', () => { }); it('Serializes', async () => { - const client = makeClient('google/flan-ul2'); - const serialized = client.toJSON(); - const deserialized = await GenAIModel.fromJSON(serialized); + const model = makeModel('google/flan-ul2'); + const serialized = model.toJSON(); + const deserialized = await GenAIModel.fromJSON(serialized, model.client); expect(deserialized).toBeInstanceOf(GenAIModel); }); describe('generate', () => { // TODO: enable once we will set default model for the test account test.skip('should handle empty modelId', async () => { - const client = makeClient('google/flan-ul2'); + const client = makeModel('google/flan-ul2'); const data = await client.invoke('Who are you?'); expectIsString(data); }, 15_000); test('should return correct response for a single input', async () => { - const client = makeClient('google/flan-ul2'); + const client = makeModel('google/flan-ul2'); const data = await client.invoke('Hello, World'); expectIsString(data); }, 15_000); test('should return correct response for each input', async () => { - const client = makeClient('google/flan-ul2'); + const client = makeModel('google/flan-ul2'); const inputs = ['Hello, World', 'Hello again']; @@ -81,7 +81,7 @@ describe('Langchain', () => { }, 20_000); test('should reject with ERR_CANCELED when aborted', async () => { - const model = makeClient('google/flan-ul2'); + const model = makeModel('google/flan-ul2'); const controller = new AbortController(); const generatePromise = model.generate(['Hello, World'], { @@ -99,7 +99,7 @@ describe('Langchain', () => { }); test('should reject with ETIMEDOUT when timed out', async () => { - const model = makeClient('google/flan-ul2'); + const model = makeModel('google/flan-ul2'); await expect( model.invoke('Hello, World', { timeout: 10 }), @@ -107,7 +107,7 @@ describe('Langchain', () => { }); test('streaming', async () => { - const client = makeClient('google/flan-t5-xl'); + const client = makeModel('google/flan-t5-xl'); const tokens: string[] = []; const handleText = vi.fn((token: string) => { @@ -132,7 +132,7 @@ describe('Langchain', () => { }); describe('chaining', () => { - const model = makeClient('google/flan-t5-xl'); + const model = makeModel('google/flan-t5-xl'); test('chaining', async () => { const prompt = new PromptTemplate({