Skip to content

Commit

Permalink
Add API key to gRPC server and client (#394)
Browse files Browse the repository at this point in the history
* Add API key to gRPC server and client

* Address feedback

* Address more comments
  • Loading branch information
plameniv authored May 30, 2024
1 parent 5f8e2ca commit 9e54d65
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 34 deletions.
2 changes: 1 addition & 1 deletion packages/grpc/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@walmartlabs/cookie-cutter-grpc",
"version": "1.6.0-beta.2",
"version": "1.6.0-beta.3",
"license": "Apache-2.0",
"main": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down
86 changes: 73 additions & 13 deletions packages/grpc/src/__test__/grpc.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ import {
GrpcMetadata,
grpcSource,
IGrpcClientConfiguration,
IGrpcClientOptions,
IGrpcConfiguration,
IGrpcServerOptions,
IResponseStream,
} from "..";
import { sample } from "./Sample";

const apiKey = "token";
let nextPort = 56011;

export interface ISampleService {
Expand Down Expand Up @@ -78,16 +81,23 @@ export const SampleServiceDefinition = {
},
};

function testApp(handler: any, host?: string): CancelablePromise<void> {
function testApp(
handler: any,
host?: string,
options?: IGrpcServerOptions
): CancelablePromise<void> {
return Application.create()
.input()
.add(
grpcSource({
port: nextPort,
host,
definitions: [SampleServiceDefinition],
skipNoStreamingValidation: true,
})
grpcSource(
{
port: nextPort,
host,
definitions: [SampleServiceDefinition],
skipNoStreamingValidation: true,
},
options
)
)
.done()
.dispatch(handler)
Expand All @@ -96,13 +106,17 @@ function testApp(handler: any, host?: string): CancelablePromise<void> {

async function createClient(
host?: string,
config?: Partial<IGrpcClientConfiguration & IGrpcConfiguration>
config?: Partial<IGrpcClientConfiguration & IGrpcConfiguration>,
options?: string | IGrpcClientOptions
): Promise<ISampleService & IRequireInitialization & IDisposable> {
const client = grpcClient<ISampleService & IRequireInitialization & IDisposable>({
endpoint: `${host || "localhost"}:${nextPort++}`,
definition: SampleServiceDefinition,
...config,
});
const client = grpcClient<ISampleService & IRequireInitialization & IDisposable>(
{
endpoint: `${host || "localhost"}:${nextPort++}`,
definition: SampleServiceDefinition,
...config,
},
options
);
return client;
}

Expand All @@ -126,6 +140,29 @@ describe("gRPC source", () => {
}
});

it("serves requests with api key validation", async () => {
const app = testApp(
{
onNoStreaming: async (
request: sample.ISampleRequest,
_: IDispatchContext
): Promise<sample.ISampleResponse> => {
return { name: request.id.toString() };
},
},
undefined,
{ apiKey }
);
try {
const client = await createClient(undefined, undefined, { apiKey });
const response = await client.NoStreaming({ id: 15 });
expect(response).toMatchObject({ name: "15" });
} finally {
app.cancel();
await app;
}
});

it("serves response streams", async () => {
const app = testApp({
onStreamingOut: async (
Expand Down Expand Up @@ -235,6 +272,29 @@ describe("gRPC source", () => {
}
});

it("throws error for missing/invalid api key", async () => {
const app = testApp(
{
onNoStreaming: async (
request: sample.ISampleRequest,
_: IDispatchContext
): Promise<sample.ISampleResponse> => {
return { name: request.id.toString() };
},
},
undefined,
{ apiKey }
);
try {
const client = await createClient();
const response = client.NoStreaming({ id: 15 });
await expect(response).rejects.toThrowError(/Invalid API Key/);
} finally {
app.cancel();
await app;
}
});

it("validates that no streaming operations are exposed", () => {
const a = () =>
grpcSource({
Expand Down
21 changes: 17 additions & 4 deletions packages/grpc/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ export interface IGrpcServerConfiguration {
readonly skipNoStreamingValidation?: boolean;
}

export interface IGrpcServerOptions {
readonly apiKey?: string;
}

export interface IGrpcClientConfiguration {
readonly endpoint: string;
readonly definition: IGrpcServiceDefinition;
Expand All @@ -66,6 +70,11 @@ export interface IGrpcClientConfiguration {
readonly behavior?: Required<IComponentRuntimeBehavior>;
}

export interface IGrpcClientOptions {
readonly certPath?: string;
readonly apiKey?: string;
}

export enum GrpcMetadata {
OperationPath = "grpc.OperationPath",
ResponseStream = "grpc.ResponseStream",
Expand All @@ -80,7 +89,8 @@ export interface IResponseStream<TResponse> {
}

export function grpcSource(
configuration: IGrpcServerConfiguration & IGrpcConfiguration
configuration: IGrpcServerConfiguration & IGrpcConfiguration,
options?: IGrpcServerOptions
): IInputSource & IRequireInitialization {
configuration = config.parse<IGrpcServerConfiguration & IGrpcConfiguration>(
GrpcSourceConfiguration,
Expand All @@ -90,7 +100,7 @@ export function grpcSource(
allocator: Buffer,
}
);
return new GrpcInputSource(configuration);
return new GrpcInputSource(configuration, options);
}

export function grpcMsg(operation: IGrpcServiceMethod, request: any): IMessage {
Expand All @@ -102,7 +112,7 @@ export function grpcMsg(operation: IGrpcServiceMethod, request: any): IMessage {

export function grpcClient<T>(
configuration: IGrpcClientConfiguration & IGrpcConfiguration,
certPath?: string
certPathOrOptions?: string | IGrpcClientOptions
): T & IRequireInitialization & IDisposable {
configuration = config.parse<IGrpcClientConfiguration & IGrpcConfiguration>(
GrpcClientConfiguration,
Expand All @@ -122,5 +132,8 @@ export function grpcClient<T>(
},
}
);
return createGrpcClient<T>(configuration, certPath);
if (typeof certPathOrOptions === "string") {
certPathOrOptions = { certPath: certPathOrOptions };
}
return createGrpcClient<T>(configuration, certPathOrOptions);
}
29 changes: 25 additions & 4 deletions packages/grpc/src/internal/GrpcInputSource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
OpenTracingTagKeys,
} from "@walmartlabs/cookie-cutter-core";
import {
Metadata,
sendUnaryData,
Server,
ServerCredentials,
Expand All @@ -38,7 +39,7 @@ import {
GrpcResponseStream,
GrpcStreamHandler,
} from ".";
import { GrpcMetadata, IGrpcConfiguration, IGrpcServerConfiguration } from "..";
import { GrpcMetadata, IGrpcConfiguration, IGrpcServerConfiguration, IGrpcServerOptions } from "..";
import { GrpcOpenTracingTagKeys } from "./helper";

enum GrpcMetrics {
Expand All @@ -59,7 +60,10 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization {
private tracer: Tracer;
private metrics: IMetrics;

constructor(private readonly config: IGrpcServerConfiguration & IGrpcConfiguration) {
constructor(
private readonly config: IGrpcServerConfiguration & IGrpcConfiguration,
private readonly options?: IGrpcServerOptions
) {
if (!config.skipNoStreamingValidation) {
for (const def of config.definitions) {
for (const key of Object.keys(def)) {
Expand Down Expand Up @@ -180,7 +184,11 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization {
if (value !== undefined) {
callback(undefined, value);
} else if (error !== undefined) {
callback(this.createError(error), null);
if ((error as ServerErrorResponse).code !== undefined) {
callback(error, null);
} else {
callback(this.createError(error), null);
}
} else {
callback(
this.createError("not implemented", status.UNIMPLEMENTED),
Expand All @@ -198,7 +206,15 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization {
path: method.path,
});
});

if (this.options?.apiKey) {
if (!this.isApiKeyValid(call.metadata)) {
await msgRef.release(
undefined,
this.createError("Invalid API Key", status.UNAUTHENTICATED)
);
return;
}
}
if (!(await this.queue.enqueue(msgRef))) {
await msgRef.release(undefined, new Error("service unavailable"));
}
Expand Down Expand Up @@ -239,4 +255,9 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization {
message: error.toString(),
};
}

private isApiKeyValid(meta: Metadata) {
const headerValue = meta.get("authorization");
return headerValue?.[0]?.toString() === this.options.apiKey;
}
}
36 changes: 24 additions & 12 deletions packages/grpc/src/internal/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import {
import { FORMAT_HTTP_HEADERS, Span, SpanContext, Tags, Tracer } from "opentracing";
import { performance } from "perf_hooks";
import { createGrpcConfiguration, createServiceDefinition } from ".";
import { IGrpcClientConfiguration, IGrpcConfiguration } from "..";
import { IGrpcClientConfiguration, IGrpcClientOptions, IGrpcConfiguration } from "..";

enum GrpcMetrics {
RequestSent = "cookie_cutter.grpc_client.request_sent",
Expand Down Expand Up @@ -75,23 +75,27 @@ class ClientBase implements IRequireInitialization, IDisposable {

export function createGrpcClient<T>(
config: IGrpcClientConfiguration & IGrpcConfiguration,
certPath?: string
options?: IGrpcClientOptions
): T & IDisposable & IRequireInitialization {
const serviceDef = createServiceDefinition(config.definition);
let client: Client;
const ClientType = makeGenericClientConstructor(serviceDef, undefined, undefined);
const certPath = options?.certPath;
const apiKey = options?.apiKey;
if (certPath) {
const rootCert = readFileSync(certPath);
const channelCreds = credentials.createSsl(rootCert);

const metaCallback = (_params: any, callback: (arg0: null, arg1: Metadata) => void) => {
const meta = new Metadata();
meta.add("custom-auth-header", "token");
callback(null, meta);
};

const callCreds = credentials.createFromMetadataGenerator(metaCallback);
const combCreds = credentials.combineChannelCredentials(channelCreds, callCreds);
let combCreds = channelCreds;
if (apiKey) {
const metaCallback = (_params: any, callback: (arg0: null, arg1: Metadata) => void) => {
const meta = new Metadata();
meta.add("authorization", apiKey);
callback(null, meta);
};
const callCreds = credentials.createFromMetadataGenerator(metaCallback);
combCreds = credentials.combineChannelCredentials(channelCreds, callCreds);
}
client = new ClientType(config.endpoint, combCreds, createGrpcConfiguration(config));
} else {
client = new ClientType(
Expand Down Expand Up @@ -165,12 +169,16 @@ export function createGrpcClient<T>(

const stream = await retrier.retry((bail) => {
try {
const meta = createTracingMetadata(wrapper.tracer, span);
if (!certPath && apiKey) {
meta.set("authorization", apiKey);
}
return client.makeServerStreamRequest(
method.path,
method.requestSerialize,
method.responseDeserialize,
request,
createTracingMetadata(wrapper.tracer, span),
meta,
callOptions()
);
} catch (e) {
Expand Down Expand Up @@ -239,12 +247,16 @@ export function createGrpcClient<T>(
return await retrier.retry(async (bail) => {
try {
return await new Promise((resolve, reject) => {
const meta = createTracingMetadata(wrapper.tracer, span);
if (!certPath && apiKey) {
meta.set("authorization", apiKey);
}
client.makeUnaryRequest(
method.path,
method.requestSerialize,
method.responseDeserialize,
request,
createTracingMetadata(wrapper.tracer, span),
meta,
callOptions(),
(error, value) => {
this.metrics.increment(GrpcMetrics.RequestProcessed, {
Expand Down

0 comments on commit 9e54d65

Please sign in to comment.