From 7ba8a42424eeccfd02804afa08913e00245d68b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Dvo=C5=99=C3=A1k?= Date: Fri, 10 Nov 2023 12:37:19 +0100 Subject: [PATCH] chore(langchain): updates (#63) --- README.md | 7 ++-- examples/langchain/llm-chat.ts | 6 +-- examples/langchain/llm.ts | 49 ++++++++++++++++++++++++ src/langchain/llm-chat.ts | 2 +- src/tests/e2e/langchain/llm-chat.test.ts | 26 ++++++------- 5 files changed, 70 insertions(+), 20 deletions(-) create mode 100644 examples/langchain/llm.ts diff --git a/README.md b/README.md index d5def39..da4058c 100644 --- a/README.md +++ b/README.md @@ -243,6 +243,7 @@ await model.call('Tell me a joke.', undefined, [ ```typescript import { GenAIChatModel } from '@ibm-generative-ai/node-sdk/langchain'; +import { SystemMessage, HumanMessage } from 'langchain/schema'; const client = new GenAIChatModel({ modelId: 'eleutherai/gpt-neox-20b', @@ -268,13 +269,13 @@ const client = new GenAIChatModel({ }); const response = await client.call([ - new SystemChatMessage( + new SystemMessage( 'You are a helpful assistant that translates English to Spanish.', ), - new HumanChatMessage('I love programming.'), + new HumanMessage('I love programming.'), ]); -console.info(response.text); // "Me encanta la programaciĆ³n." +console.info(response.content); // "Me encanta la programaciĆ³n." ``` #### Prompt Templates (GenAI x LangChain) diff --git a/examples/langchain/llm-chat.ts b/examples/langchain/llm-chat.ts index 1e5ebe5..b7b68e3 100644 --- a/examples/langchain/llm-chat.ts +++ b/examples/langchain/llm-chat.ts @@ -1,4 +1,4 @@ -import { HumanChatMessage } from 'langchain/schema'; +import { HumanMessage } from 'langchain/schema'; import { GenAIChatModel } from '../../src/langchain/llm-chat.js'; @@ -31,7 +31,7 @@ const makeClient = (stream?: boolean) => const chat = makeClient(); const response = await chat.call([ - new HumanChatMessage( + new HumanMessage( 'What is a good name for a company that makes colorful socks?', ), ]); @@ -43,7 +43,7 @@ const makeClient = (stream?: boolean) => // Streaming const chat = makeClient(true); - await chat.call([new HumanChatMessage('Tell me a joke.')], undefined, [ + await chat.call([new HumanMessage('Tell me a joke.')], undefined, [ { handleLLMNewToken(token) { console.log(token); diff --git a/examples/langchain/llm.ts b/examples/langchain/llm.ts new file mode 100644 index 0000000..cb46ef8 --- /dev/null +++ b/examples/langchain/llm.ts @@ -0,0 +1,49 @@ +import { GenAIModel } from '../../src/langchain/index.js'; + +const makeClient = (stream?: boolean) => + new GenAIModel({ + modelId: 'google/flan-t5-xl', + stream, + configuration: { + endpoint: process.env.ENDPOINT, + apiKey: process.env.API_KEY, + }, + parameters: { + decoding_method: 'greedy', + min_new_tokens: 5, + max_new_tokens: 25, + repetition_penalty: 1.5, + }, + }); + +{ + // Basic + console.info('---Single Input Example---'); + const model = makeClient(); + + const prompt = 'What is a good name for a company that makes colorful socks?'; + console.info(`Request: ${prompt}`); + const response = await model.call(prompt); + console.log(`Response: ${response}`); +} + +{ + console.info('---Multiple Inputs Example---'); + const model = makeClient(); + + const prompts = ['What is IBM?', 'What is WatsonX?']; + console.info('Request prompts:', prompts); + const response = await model.generate(prompts); + console.info('Response:', response); +} + +{ + console.info('---Streaming Example---'); + const chat = makeClient(true); + + const prompt = 'What is a molecule?'; + console.info(`Request: ${prompt}`); + for await (const token of await chat.stream(prompt)) { + console.info(`Received token: ${token}`); + } +} diff --git a/src/langchain/llm-chat.ts b/src/langchain/llm-chat.ts index 9fb0758..8073fcb 100644 --- a/src/langchain/llm-chat.ts +++ b/src/langchain/llm-chat.ts @@ -74,7 +74,7 @@ export class GenAIChatModel extends BaseChatModel { `Unsupported message type "${msg._getType()}"`, ); } - return `${type.stopSequence}${msg.text}`; + return `${type.stopSequence}${msg.content}`; }) .join('\n') .concat(this.#rolesMapping.system.stopSequence); diff --git a/src/tests/e2e/langchain/llm-chat.test.ts b/src/tests/e2e/langchain/llm-chat.test.ts index 2309f4c..48e74b3 100644 --- a/src/tests/e2e/langchain/llm-chat.test.ts +++ b/src/tests/e2e/langchain/llm-chat.test.ts @@ -1,4 +1,4 @@ -import { HumanChatMessage, SystemChatMessage } from 'langchain/schema'; +import { HumanMessage, SystemMessage } from 'langchain/schema'; import { GenAIChatModel } from '../../../langchain/index.js'; import { describeIf } from '../../utils.js'; @@ -47,21 +47,21 @@ describeIf(process.env.RUN_LANGCHAIN_CHAT_TESTS === 'true')( const chat = makeClient(); const response = await chat.call([ - new HumanChatMessage( + new HumanMessage( 'What is a good name for a company that makes colorful socks?', ), ]); - expectIsNonEmptyString(response.text); + expectIsNonEmptyString(response.content); }); test('should handle question with additional hint', async () => { const chat = makeClient(); const response = await chat.call([ - new SystemChatMessage(SYSTEM_MESSAGE), - new HumanChatMessage('I love programming.'), + new SystemMessage(SYSTEM_MESSAGE), + new HumanMessage('I love programming.'), ]); - expectIsNonEmptyString(response.text); + expectIsNonEmptyString(response.content); }); test('should handle multiple questions', async () => { @@ -69,12 +69,12 @@ describeIf(process.env.RUN_LANGCHAIN_CHAT_TESTS === 'true')( const response = await chat.generate([ [ - new SystemChatMessage(SYSTEM_MESSAGE), - new HumanChatMessage('I love programming.'), + new SystemMessage(SYSTEM_MESSAGE), + new HumanMessage('I love programming.'), ], [ - new SystemChatMessage(SYSTEM_MESSAGE), - new HumanChatMessage('I love artificial intelligence.'), + new SystemMessage(SYSTEM_MESSAGE), + new HumanMessage('I love artificial intelligence.'), ], ]); @@ -95,7 +95,7 @@ describeIf(process.env.RUN_LANGCHAIN_CHAT_TESTS === 'true')( }); const output = await chat.call( - [new HumanChatMessage('Tell me a joke.')], + [new HumanMessage('Tell me a joke.')], undefined, [ { @@ -105,8 +105,8 @@ describeIf(process.env.RUN_LANGCHAIN_CHAT_TESTS === 'true')( ); expect(handleNewToken).toHaveBeenCalled(); - expectIsNonEmptyString(output.text); - expect(tokens.join('')).toStrictEqual(output.text); + expectIsNonEmptyString(output.content); + expect(tokens.join('')).toStrictEqual(output.content); }); }); },