Skip to content

Commit

Permalink
fixup! feat: support langchain serialization
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Dvorak <toomas2d@gmail.com>
  • Loading branch information
Tomas2D committed Jun 3, 2024
1 parent 047f82f commit c5985f1
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 18 deletions.
14 changes: 11 additions & 3 deletions src/langchain/llm-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export type GenAIChatModelOptions = BaseChatModelCallOptions &
Partial<Omit<GenAIChatModelParams, 'client' | 'configuration'>>;

export class GenAIChatModel extends BaseChatModel<GenAIChatModelOptions> {
protected readonly client: Client;
public readonly client: Client;

public readonly modelId: GenAIChatModelParams['model_id'];
public readonly promptId: GenAIChatModelParams['prompt_id'];
Expand Down Expand Up @@ -236,16 +236,24 @@ export class GenAIChatModel extends BaseChatModel<GenAIChatModelOptions> {
useConversationParameters: undefined,
parentId: undefined,
trimMethod: undefined,
client: undefined,
};

static async fromJSON(value: string | Serialized) {
const input = typeof value === 'string' ? value : JSON.stringify(value);
get lc_secrets() {
return { ...super.lc_secrets, client: 'client' };
}

static async fromJSON(value: string | Serialized, client?: Client) {
const input = typeof value !== 'string' ? JSON.stringify(value) : value;
return await load(input, {
optionalImportsMap: {
'@ibm-generative-ai/node-sdk/langchain/llm-chat': {
GenAIModel: GenAIChatModel,
},
},
secretsMap: {
client,
},
});
}

Expand Down
12 changes: 10 additions & 2 deletions src/langchain/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export type GenAIModelOptions = BaseLLMCallOptions &
Partial<Omit<GenAIModelParams, 'client' | 'configuration'>>;

export class GenAIModel extends BaseLLM<GenAIModelOptions> {
protected readonly client: Client;
public readonly client: Client;

public readonly modelId: GenAIModelParams['model_id'];
public readonly promptId: GenAIModelParams['prompt_id'];
Expand Down Expand Up @@ -185,14 +185,17 @@ export class GenAIModel extends BaseLLM<GenAIModelOptions> {
return result.results.at(0)?.token_count ?? 0;
}

static async fromJSON(value: string | Serialized) {
static async fromJSON(value: string | Serialized, client?: Client) {
const input = typeof value === 'string' ? value : JSON.stringify(value);
return await load(input, {
optionalImportsMap: {
'@ibm-generative-ai/node-sdk/langchain/llm': {
GenAIModel: GenAIModel,
},
},
secretsMap: {
client,
},
});
}

Expand All @@ -216,5 +219,10 @@ export class GenAIModel extends BaseLLM<GenAIModelOptions> {
promptId: undefined,
parameters: undefined,
moderations: undefined,
client: undefined,
};

get lc_secrets() {
return { ...super.lc_secrets, client: 'client' };
}
}
5 changes: 4 additions & 1 deletion tests/e2e/langchain/llm-chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ describe('LangChain Chat', () => {
it('Serializes', async () => {
const model = makeModel();
const serialized = model.toJSON();
const deserialized = await GenAIChatModel.fromJSON(serialized);
const deserialized = await GenAIChatModel.fromJSON(
serialized,
model.client,
);
expect(deserialized).toBeInstanceOf(GenAIChatModel);
});
});
24 changes: 12 additions & 12 deletions tests/e2e/langchain/llm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { GenAIModel } from '../../../src/langchain/llm.js';
import { Client } from '../../../src/client.js';

describe('Langchain', () => {
const makeClient = (modelId: string) =>
const makeModel = (modelId: string) =>
new GenAIModel({
model_id: modelId,
client: new Client({
Expand All @@ -26,7 +26,7 @@ describe('Langchain', () => {

describe('tokenization', () => {
it('should correctly calculate tokens', async () => {
const client = makeClient('google/flan-ul2');
const client = makeModel('google/flan-ul2');
const tokensCount = await client.getNumTokens(
'What is the biggest building on this planet?',
);
Expand All @@ -35,30 +35,30 @@ describe('Langchain', () => {
});

it('Serializes', async () => {
const client = makeClient('google/flan-ul2');
const serialized = client.toJSON();
const deserialized = await GenAIModel.fromJSON(serialized);
const model = makeModel('google/flan-ul2');
const serialized = model.toJSON();
const deserialized = await GenAIModel.fromJSON(serialized, model.client);
expect(deserialized).toBeInstanceOf(GenAIModel);
});

describe('generate', () => {
// TODO: enable once we will set default model for the test account
test.skip('should handle empty modelId', async () => {
const client = makeClient('google/flan-ul2');
const client = makeModel('google/flan-ul2');

const data = await client.invoke('Who are you?');
expectIsString(data);
}, 15_000);

test('should return correct response for a single input', async () => {
const client = makeClient('google/flan-ul2');
const client = makeModel('google/flan-ul2');

const data = await client.invoke('Hello, World');
expectIsString(data);
}, 15_000);

test('should return correct response for each input', async () => {
const client = makeClient('google/flan-ul2');
const client = makeModel('google/flan-ul2');

const inputs = ['Hello, World', 'Hello again'];

Expand All @@ -81,7 +81,7 @@ describe('Langchain', () => {
}, 20_000);

test('should reject with ERR_CANCELED when aborted', async () => {
const model = makeClient('google/flan-ul2');
const model = makeModel('google/flan-ul2');

const controller = new AbortController();
const generatePromise = model.generate(['Hello, World'], {
Expand All @@ -99,15 +99,15 @@ describe('Langchain', () => {
});

test('should reject with ETIMEDOUT when timed out', async () => {
const model = makeClient('google/flan-ul2');
const model = makeModel('google/flan-ul2');

await expect(
model.invoke('Hello, World', { timeout: 10 }),
).rejects.toThrow();
});

test('streaming', async () => {
const client = makeClient('google/flan-t5-xl');
const client = makeModel('google/flan-t5-xl');

const tokens: string[] = [];
const handleText = vi.fn((token: string) => {
Expand All @@ -132,7 +132,7 @@ describe('Langchain', () => {
});

describe('chaining', () => {
const model = makeClient('google/flan-t5-xl');
const model = makeModel('google/flan-t5-xl');

test('chaining', async () => {
const prompt = new PromptTemplate({
Expand Down

0 comments on commit c5985f1

Please sign in to comment.