Skip to content

Commit

Permalink
chore(langchain): updates (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas2D authored Nov 10, 2023
1 parent 912c01c commit 7ba8a42
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 20 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ await model.call('Tell me a joke.', undefined, [

```typescript
import { GenAIChatModel } from '@ibm-generative-ai/node-sdk/langchain';
import { SystemMessage, HumanMessage } from 'langchain/schema';

const client = new GenAIChatModel({
modelId: 'eleutherai/gpt-neox-20b',
Expand All @@ -268,13 +269,13 @@ const client = new GenAIChatModel({
});

const response = await client.call([
new SystemChatMessage(
new SystemMessage(
'You are a helpful assistant that translates English to Spanish.',
),
new HumanChatMessage('I love programming.'),
new HumanMessage('I love programming.'),
]);

console.info(response.text); // "Me encanta la programación."
console.info(response.content); // "Me encanta la programación."
```

#### Prompt Templates (GenAI x LangChain)
Expand Down
6 changes: 3 additions & 3 deletions examples/langchain/llm-chat.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { HumanChatMessage } from 'langchain/schema';
import { HumanMessage } from 'langchain/schema';

import { GenAIChatModel } from '../../src/langchain/llm-chat.js';

Expand Down Expand Up @@ -31,7 +31,7 @@ const makeClient = (stream?: boolean) =>
const chat = makeClient();

const response = await chat.call([
new HumanChatMessage(
new HumanMessage(
'What is a good name for a company that makes colorful socks?',
),
]);
Expand All @@ -43,7 +43,7 @@ const makeClient = (stream?: boolean) =>
// Streaming
const chat = makeClient(true);

await chat.call([new HumanChatMessage('Tell me a joke.')], undefined, [
await chat.call([new HumanMessage('Tell me a joke.')], undefined, [
{
handleLLMNewToken(token) {
console.log(token);
Expand Down
49 changes: 49 additions & 0 deletions examples/langchain/llm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { GenAIModel } from '../../src/langchain/index.js';

const makeClient = (stream?: boolean) =>
new GenAIModel({
modelId: 'google/flan-t5-xl',
stream,
configuration: {
endpoint: process.env.ENDPOINT,
apiKey: process.env.API_KEY,
},
parameters: {
decoding_method: 'greedy',
min_new_tokens: 5,
max_new_tokens: 25,
repetition_penalty: 1.5,
},
});

{
// Basic
console.info('---Single Input Example---');
const model = makeClient();

const prompt = 'What is a good name for a company that makes colorful socks?';
console.info(`Request: ${prompt}`);
const response = await model.call(prompt);
console.log(`Response: ${response}`);
}

{
console.info('---Multiple Inputs Example---');
const model = makeClient();

const prompts = ['What is IBM?', 'What is WatsonX?'];
console.info('Request prompts:', prompts);
const response = await model.generate(prompts);
console.info('Response:', response);
}

{
console.info('---Streaming Example---');
const chat = makeClient(true);

const prompt = 'What is a molecule?';
console.info(`Request: ${prompt}`);
for await (const token of await chat.stream(prompt)) {
console.info(`Received token: ${token}`);
}
}
2 changes: 1 addition & 1 deletion src/langchain/llm-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ export class GenAIChatModel extends BaseChatModel {
`Unsupported message type "${msg._getType()}"`,
);
}
return `${type.stopSequence}${msg.text}`;
return `${type.stopSequence}${msg.content}`;
})
.join('\n')
.concat(this.#rolesMapping.system.stopSequence);
Expand Down
26 changes: 13 additions & 13 deletions src/tests/e2e/langchain/llm-chat.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { HumanChatMessage, SystemChatMessage } from 'langchain/schema';
import { HumanMessage, SystemMessage } from 'langchain/schema';

import { GenAIChatModel } from '../../../langchain/index.js';
import { describeIf } from '../../utils.js';
Expand Down Expand Up @@ -47,34 +47,34 @@ describeIf(process.env.RUN_LANGCHAIN_CHAT_TESTS === 'true')(
const chat = makeClient();

const response = await chat.call([
new HumanChatMessage(
new HumanMessage(
'What is a good name for a company that makes colorful socks?',
),
]);
expectIsNonEmptyString(response.text);
expectIsNonEmptyString(response.content);
});

test('should handle question with additional hint', async () => {
const chat = makeClient();

const response = await chat.call([
new SystemChatMessage(SYSTEM_MESSAGE),
new HumanChatMessage('I love programming.'),
new SystemMessage(SYSTEM_MESSAGE),
new HumanMessage('I love programming.'),
]);
expectIsNonEmptyString(response.text);
expectIsNonEmptyString(response.content);
});

test('should handle multiple questions', async () => {
const chat = makeClient();

const response = await chat.generate([
[
new SystemChatMessage(SYSTEM_MESSAGE),
new HumanChatMessage('I love programming.'),
new SystemMessage(SYSTEM_MESSAGE),
new HumanMessage('I love programming.'),
],
[
new SystemChatMessage(SYSTEM_MESSAGE),
new HumanChatMessage('I love artificial intelligence.'),
new SystemMessage(SYSTEM_MESSAGE),
new HumanMessage('I love artificial intelligence.'),
],
]);

Expand All @@ -95,7 +95,7 @@ describeIf(process.env.RUN_LANGCHAIN_CHAT_TESTS === 'true')(
});

const output = await chat.call(
[new HumanChatMessage('Tell me a joke.')],
[new HumanMessage('Tell me a joke.')],
undefined,
[
{
Expand All @@ -105,8 +105,8 @@ describeIf(process.env.RUN_LANGCHAIN_CHAT_TESTS === 'true')(
);

expect(handleNewToken).toHaveBeenCalled();
expectIsNonEmptyString(output.text);
expect(tokens.join('')).toStrictEqual(output.text);
expectIsNonEmptyString(output.content);
expect(tokens.join('')).toStrictEqual(output.content);
});
});
},
Expand Down

0 comments on commit 7ba8a42

Please sign in to comment.