Skip to content

Commit

Permalink
feat(generate): handle moderation chunks (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas2D authored Oct 24, 2023
1 parent 666cb2f commit e9d006c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 3 deletions.
21 changes: 21 additions & 0 deletions src/api-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,26 @@ export const GenerateStopReasonSchema = z.enum([
]);
export type GenerateStopReason = z.infer<typeof GenerateStopReasonSchema>;

const GenerateModerationSchema = z
.object({
hap: z.optional(
z.array(
z
.object({
success: z.boolean(),
flagged: z.boolean(),
score: z.number().min(0).max(1),
position: z.object({
start: z.number().int().min(0),
stop: z.number().int().min(0),
}),
})
.passthrough(),
),
),
})
.passthrough();

export const GenerateResultSchema = z
.object({
generated_text: z.string(),
Expand All @@ -99,6 +119,7 @@ export const GenerateOutputSchema = z
model_id: z.string(),
created_at: z.coerce.date(),
results: z.array(GenerateResultSchema),
moderation: GenerateModerationSchema.optional(),
})
.passthrough();
export type GenerateOutput = z.infer<typeof GenerateOutputSchema>;
Expand Down
7 changes: 6 additions & 1 deletion src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,18 @@ export class Client {
stop_reason = null,
input_token_count = 0,
generated_token_count = 0,
} = chunk.results[0];
...props
} = (chunk.results || [{}])[0];

callback(null, {
generated_text,
stop_reason,
input_token_count,
generated_token_count,
...(chunk.moderation && {
moderation: chunk.moderation,
}),
...props,
} as GenerateOutput);
} catch (e) {
const err = (chunk || e) as unknown as Error;
Expand Down
4 changes: 3 additions & 1 deletion src/client/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ export const GenerateInputSchema = z.union([
}),
]);
export type GenerateInput = z.infer<typeof GenerateInputSchema>;
export type GenerateOutput = ApiTypes.GenerateOutput['results'][number];
export type GenerateOutput = ApiTypes.GenerateOutput['results'][number] & {
moderation?: ApiTypes.GenerateOutput['moderation'];
};

export const GenerateConfigInputSchema = ApiTypes.GenerateConfigInputSchema;
export type GenerateConfigInput = z.input<typeof GenerateConfigInputSchema>;
Expand Down
28 changes: 27 additions & 1 deletion src/tests/e2e/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ describe('client', () => {
}, 15_000);

describe('streaming', () => {
const makeValidStream = () =>
const makeValidStream = (parameters: Record<string, any> = {}) =>
client.generate(
{
model_id: 'google/ul2',
input: 'Hello, World',
parameters: {
max_new_tokens: 10,
...parameters,
},
},
{
Expand All @@ -77,6 +78,10 @@ describe('client', () => {
expect(chunk.generated_token_count).not.toBeNegative();
expect(chunk.input_token_count).not.toBeNegative();
expect(chunk.stop_reason).toSatisfy(isNumberOrNull);
expect(chunk.moderation).toBeOneOf([
undefined,
expect.objectContaining({ hap: expect.any(Array) }),
]);
};

test('should throw for multiple inputs', () => {
Expand All @@ -99,6 +104,27 @@ describe('client', () => {
).toThrowError('Cannot do streaming for more than one input!');
});

test('should correctly process moderation chunks during streaming', async () => {
const stream = makeValidStream({
moderations: {
min_new_tokens: 1,
max_new_tokens: 5,
hap: {
input: true,
threshold: 0.01,
},
},
});

for await (const chunk of stream) {
validateStreamChunk(chunk);
if (chunk.moderation) {
return;
}
}
throw Error('No moderation chunks has been retrieved from the API');
});

test('should return valid stream for a single input', async () => {
const stream = makeValidStream();

Expand Down

0 comments on commit e9d006c

Please sign in to comment.