Skip to content

Commit

Permalink
fix: ensure redirects handled correctly with dispatchFetch() (#5191)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrbbot authored Mar 8, 2024
1 parent 8e9faf2 commit 27fb22b
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 40 deletions.
11 changes: 11 additions & 0 deletions .changeset/purple-lemons-yell.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
"miniflare": patch
---

fix: ensure redirect responses handled correctly with `dispatchFetch()`

Previously, if your Worker returned a redirect response, calling `dispatchFetch(url)` would send another request to the original `url` rather than the redirect. This change ensures redirects are followed correctly.

- If your Worker returns a relative redirect or an absolute redirect with the same origin as the original `url`, the request will be sent to the Worker.
- If your Worker instead returns an absolute redirect with a different origin, the request will be sent to the Internet.
- If a redirected request to a different origin returns an absolute redirect with the same origin as the original `url`, the request will also be sent to the Worker.
142 changes: 127 additions & 15 deletions packages/miniflare/src/http/fetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,11 @@ import http from "http";
import { IncomingRequestCfProperties } from "@cloudflare/workers-types/experimental";
import { Dispatcher, Headers, fetch as baseFetch } from "undici";
import NodeWebSocket from "ws";
import { DeferredPromise } from "../workers";
import { CoreHeaders, DeferredPromise } from "../workers";
import { Request, RequestInfo, RequestInit } from "./request";
import { Response } from "./response";
import { WebSocketPair, coupleWebSocket } from "./websocket";

// `Dispatcher`s don't expose whether they had `rejectUnauthorized` set when
// constructed, but we need to know whether to pass this when constructing
// WebSockets. Instead, we add all known `rejectUnauthorized` dispatchers to
// a weak map, and check that before constructing WebSockets.
const allowUnauthorizedDispatchers = new WeakSet<Dispatcher>();
export function registerAllowUnauthorizedDispatcher(dispatcher: Dispatcher) {
allowUnauthorizedDispatchers.add(dispatcher);
}

const ignored = ["transfer-encoding", "connection", "keep-alive", "expect"];
function headersFromIncomingRequest(req: http.IncomingMessage): Headers {
const entries = Object.entries(req.headers).filter(
Expand Down Expand Up @@ -59,11 +50,11 @@ export async function fetch(
}
}

const rejectUnauthorized =
requestInit?.dispatcher !== undefined &&
allowUnauthorizedDispatchers.has(requestInit?.dispatcher)
? { rejectUnauthorized: false }
: {};
let rejectUnauthorized: { rejectUnauthorized: false } | undefined;
if (requestInit.dispatcher instanceof DispatchFetchDispatcher) {
requestInit.dispatcher.addHeaders(headers, url.pathname + url.search);
rejectUnauthorized = { rejectUnauthorized: false };
}

// Establish web socket connection
const ws = new NodeWebSocket(url, protocols, {
Expand Down Expand Up @@ -106,3 +97,124 @@ export type DispatchFetch = (
input: RequestInfo,
init?: RequestInit<Partial<IncomingRequestCfProperties>>
) => Promise<Response>;

export type AnyHeaders = http.IncomingHttpHeaders | string[];
function addHeader(/* mut */ headers: AnyHeaders, key: string, value: string) {
if (Array.isArray(headers)) headers.push(key, value);
else headers[key] = value;
}

/**
* Dispatcher created for each `dispatchFetch()` call. Ensures request origin
* in Worker matches that passed to `dispatchFetch()`, not the address the
* `workerd` server is listening on. Handles cases where `fetch()` redirects to
* same origin and different external origins.
*/
export class DispatchFetchDispatcher extends Dispatcher {
private readonly cfBlobJson?: string;

/**
* @param globalDispatcher Dispatcher to use for all non-runtime requests
* (rejects unauthorised certificates)
* @param runtimeDispatcher Dispatcher to use for runtime requests
* (permits unauthorised certificates)
* @param actualRuntimeOrigin Origin to send all runtime requests to
* @param userRuntimeOrigin Origin to treat as runtime request
* (initial URL passed by user to `dispatchFetch()`)
* @param cfBlob `request.cf` blob override for runtime requests
*/
constructor(
private readonly globalDispatcher: Dispatcher,
private readonly runtimeDispatcher: Dispatcher,
private readonly actualRuntimeOrigin: string,
private readonly userRuntimeOrigin: string,
cfBlob?: IncomingRequestCfProperties
) {
super();
if (cfBlob !== undefined) this.cfBlobJson = JSON.stringify(cfBlob);
}

addHeaders(
/* mut */ headers: AnyHeaders,
path: string // Including query parameters
) {
// Reconstruct URL using runtime origin specified with `dispatchFetch()`
const originalURL = this.userRuntimeOrigin + path;
addHeader(headers, CoreHeaders.ORIGINAL_URL, originalURL);
addHeader(headers, CoreHeaders.DISABLE_PRETTY_ERROR, "true");
if (this.cfBlobJson !== undefined) {
// Only add this header if a `cf` override was set
addHeader(headers, CoreHeaders.CF_BLOB, this.cfBlobJson);
}
}

dispatch(
/* mut */ options: Dispatcher.DispatchOptions,
handler: Dispatcher.DispatchHandlers
): boolean {
let origin = String(options.origin);
// The first request in a redirect chain will always match the user origin
if (origin === this.userRuntimeOrigin) origin = this.actualRuntimeOrigin;
if (origin === this.actualRuntimeOrigin) {
// If this is now a request to the runtime, rewrite dispatching origin to
// the runtime's
options.origin = origin;

let path = options.path;
if (options.query !== undefined) {
// `options.path` may include query parameters, so we need to parse it
const url = new URL(path, "http://placeholder/");
for (const [key, value] of Object.entries(options.query)) {
url.searchParams.append(key, value);
}
path = url.pathname + url.search;
}

// ...and add special Miniflare headers for runtime requests
options.headers ??= {};
this.addHeaders(options.headers, path);

// Dispatch with runtime dispatcher to avoid certificate errors if using
// self-signed certificate
return this.runtimeDispatcher.dispatch(options, handler);
} else {
// If this wasn't a request to the runtime (e.g. redirect to somewhere
// else), use the regular global dispatcher, without special headers
return this.globalDispatcher.dispatch(options, handler);
}
}

close(): Promise<void>;
close(callback: () => void): void;
async close(callback?: () => void): Promise<void> {
await Promise.all([
this.globalDispatcher.close(),
this.runtimeDispatcher.close(),
]);
callback?.();
}

destroy(): Promise<void>;
destroy(err: Error | null): Promise<void>;
destroy(callback: () => void): void;
destroy(err: Error | null, callback: () => void): void;
async destroy(
errCallback?: Error | null | (() => void),
callback?: () => void
): Promise<void> {
let err: Error | null = null;
if (typeof errCallback === "function") callback = errCallback;
if (errCallback instanceof Error) err = errCallback;

await Promise.all([
this.globalDispatcher.destroy(err),
this.runtimeDispatcher.destroy(err),
]);
callback?.();
}

get isMockActive(): boolean {
// @ts-expect-error missing type on `MockAgent`, but exists at runtime
return this.globalDispatcher.isMockActive ?? false;
}
}
6 changes: 3 additions & 3 deletions packages/miniflare/src/http/server.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import fs from "fs/promises";
import { z } from "zod";
import { CORE_PLUGIN, HEADER_CF_BLOB } from "../plugins";
import { CORE_PLUGIN } from "../plugins";
import { HttpOptions, Socket_Https } from "../runtime";
import { Awaitable } from "../workers";
import { Awaitable, CoreHeaders } from "../workers";
import { CERT, KEY } from "./cert";

export const ENTRY_SOCKET_HTTP_OPTIONS: HttpOptions = {
// Even though we inject a `cf` object in the entry worker, allow it to
// be customised via `dispatchFetch`
cfBlobHeader: HEADER_CF_BLOB,
cfBlobHeader: CoreHeaders.CF_BLOB,
};

export async function getEntrySocketHttpOptions(
Expand Down
32 changes: 19 additions & 13 deletions packages/miniflare/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ import type {
import exitHook from "exit-hook";
import { $ as colors$ } from "kleur/colors";
import stoppable from "stoppable";
import { Dispatcher, Pool } from "undici";
import { Dispatcher, Pool, getGlobalDispatcher } from "undici";
import SCRIPT_MINIFLARE_SHARED from "worker:shared/index";
import SCRIPT_MINIFLARE_ZOD from "worker:shared/zod";
import { WebSocketServer } from "ws";
import { z } from "zod";
import { fallbackCf, setupCf } from "./cf";
import {
DispatchFetch,
DispatchFetchDispatcher,
ENTRY_SOCKET_HTTP_OPTIONS,
Headers,
Request,
Expand All @@ -39,13 +40,11 @@ import {
fetch,
getAccessibleHosts,
getEntrySocketHttpOptions,
registerAllowUnauthorizedDispatcher,
} from "./http";
import {
D1_PLUGIN_NAME,
DURABLE_OBJECTS_PLUGIN_NAME,
DurableObjectClassNames,
HEADER_CF_BLOB,
KV_PLUGIN_NAME,
PLUGIN_ENTRIES,
PluginServicesOptions,
Expand Down Expand Up @@ -816,8 +815,8 @@ export class Miniflare {
}

// Extract cf blob (if any) from headers
const cfBlob = headers.get(HEADER_CF_BLOB);
headers.delete(HEADER_CF_BLOB);
const cfBlob = headers.get(CoreHeaders.CF_BLOB);
headers.delete(CoreHeaders.CF_BLOB);
assert(!Array.isArray(cfBlob)); // Only `Set-Cookie` headers are arrays
const cf = cfBlob ? JSON.parse(cfBlob) : undefined;

Expand Down Expand Up @@ -1336,7 +1335,6 @@ export class Miniflare {
this.#runtimeDispatcher = new Pool(this.#runtimeEntryURL, {
connect: { rejectUnauthorized: false },
});
registerAllowUnauthorizedDispatcher(this.#runtimeDispatcher);
}
if (this.#proxyClient === undefined) {
this.#proxyClient = new ProxyClient(
Expand Down Expand Up @@ -1508,14 +1506,13 @@ export class Miniflare {

const forward = new Request(input, init);
const url = new URL(forward.url);
forward.headers.set(CoreHeaders.ORIGINAL_URL, url.toString());
forward.headers.set(CoreHeaders.DISABLE_PRETTY_ERROR, "true");
const actualRuntimeOrigin = this.#runtimeEntryURL.origin;
const userRuntimeOrigin = url.origin;

// Rewrite URL for WebSocket requests which won't use `DispatchFetchDispatcher`
url.protocol = this.#runtimeEntryURL.protocol;
url.host = this.#runtimeEntryURL.host;
if (forward.cf) {
const cf = { ...fallbackCf, ...forward.cf };
forward.headers.set(HEADER_CF_BLOB, JSON.stringify(cf));
}

// Remove `Content-Length: 0` headers from requests when a body is set to
// avoid `RequestContentLengthMismatch` errors
if (
Expand All @@ -1525,8 +1522,17 @@ export class Miniflare {
forward.headers.delete("Content-Length");
}

const cfBlob = forward.cf ? { ...fallbackCf, ...forward.cf } : undefined;
const dispatcher = new DispatchFetchDispatcher(
getGlobalDispatcher(),
this.#runtimeDispatcher,
actualRuntimeOrigin,
userRuntimeOrigin,
cfBlob
);

const forwardInit = forward as RequestInit;
forwardInit.dispatcher = this.#runtimeDispatcher;
forwardInit.dispatcher = dispatcher;
const response = await fetch(url, forwardInit);

// If the Worker threw an uncaught exception, propagate it to the caller
Expand Down
3 changes: 1 addition & 2 deletions packages/miniflare/src/plugins/core/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import {
import { getCacheServiceName } from "../cache";
import { DURABLE_OBJECTS_STORAGE_SERVICE_NAME } from "../do";
import {
HEADER_CF_BLOB,
Plugin,
SERVICE_LOOPBACK,
WORKER_BINDING_SERVICE_LOOPBACK,
Expand Down Expand Up @@ -716,7 +715,7 @@ export function getGlobalServices({
return [
{
name: SERVICE_LOOPBACK,
external: { http: { cfBlobHeader: HEADER_CF_BLOB } },
external: { http: { cfBlobHeader: CoreHeaders.CF_BLOB } },
},
{
name: SERVICE_ENTRY,
Expand Down
5 changes: 0 additions & 5 deletions packages/miniflare/src/plugins/shared/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ export function getDirectSocketName(workerIndex: number) {
// Service looping back to Miniflare's Node.js process (for storage, etc)
export const SERVICE_LOOPBACK = "loopback";

// Even though we inject the `cf` blob in the entry script, we still need to
// specify a header, so we receive things like `cf.cacheKey` in loopback
// requests.
export const HEADER_CF_BLOB = "MF-CF-Blob";

export const WORKER_BINDING_SERVICE_LOOPBACK: Worker_Binding = {
name: CoreBindings.SERVICE_LOOPBACK,
service: { name: SERVICE_LOOPBACK },
Expand Down
1 change: 1 addition & 0 deletions packages/miniflare/src/workers/core/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export const CoreHeaders = {
DISABLE_PRETTY_ERROR: "MF-Disable-Pretty-Error",
ERROR_STACK: "MF-Experimental-Error-Stack",
ROUTE_OVERRIDE: "MF-Route-Override",
CF_BLOB: "MF-CF-Blob",

// API Proxy
OP_SECRET: "MF-Op-Secret",
Expand Down
Loading

0 comments on commit 27fb22b

Please sign in to comment.