diff --git a/index.js b/index.js index 14f439a2334..58ab2e7bf9c 100644 --- a/index.js +++ b/index.js @@ -50,7 +50,8 @@ module.exports.interceptors = { dns: require('./lib/interceptor/dns'), cache: require('./lib/interceptor/cache'), decompress: require('./lib/interceptor/decompress'), - deduplicate: require('./lib/interceptor/deduplicate') + deduplicate: require('./lib/interceptor/deduplicate'), + signal: require('./lib/interceptor/signal') } module.exports.cacheStores = { diff --git a/lib/interceptor/signal.js b/lib/interceptor/signal.js new file mode 100644 index 00000000000..d8e586bb350 --- /dev/null +++ b/lib/interceptor/signal.js @@ -0,0 +1,93 @@ +'use strict' + +const { RequestAbortedError } = require('../core/errors') +const DecoratorHandler = require('../handler/decorator-handler') + +class SignalHandler extends DecoratorHandler { + #signal + #listener + #controller + #aborted = false + + constructor ({ signal }, { handler }) { + super(handler) + this.#signal = signal + this.#listener = null + this.#controller = null + } + + onRequestStart (controller, context) { + this.#controller = controller + + if (!this.#signal) { + return super.onRequestStart(controller, context) + } + + if (this.#signal.aborted) { + this.#abort() + return + } + + this.#listener = () => { + this.#abort() + } + + if ('addEventListener' in this.#signal) { + this.#signal.addEventListener('abort', this.#listener) + } else { + this.#signal.on('abort', this.#listener) + } + + return super.onRequestStart(controller, context) + } + + #abort () { + if (this.#aborted) return + this.#aborted = true + + const reason = this.#signal?.reason ?? new RequestAbortedError() + if (this.#controller) { + this.#controller.abort(reason) + } + this.#removeSignal() + } + + #removeSignal () { + if (!this.#signal || !this.#listener) { + return + } + + if ('removeEventListener' in this.#signal) { + this.#signal.removeEventListener('abort', this.#listener) + } else { + this.#signal.removeListener('abort', this.#listener) + } + + this.#signal = null + this.#listener = null + } + + onResponseEnd (controller, trailers) { + this.#removeSignal() + return super.onResponseEnd(controller, trailers) + } + + onResponseError (controller, err) { + this.#removeSignal() + return super.onResponseError(controller, err) + } +} + +module.exports = () => { + return (dispatch) => { + return function signalInterceptor (opts, handler) { + const { signal } = opts + + if (!signal) { + return dispatch(opts, handler) + } + + return dispatch(opts, new SignalHandler({ signal }, { handler })) + } + } +} diff --git a/test/interceptors/signal.js b/test/interceptors/signal.js new file mode 100644 index 00000000000..4001e7a3b69 --- /dev/null +++ b/test/interceptors/signal.js @@ -0,0 +1,207 @@ +'use strict' + +const { createServer } = require('node:http') +const { test, after } = require('node:test') +const { once } = require('node:events') +const { tspl } = require('@matteo.collina/tspl') +const { Client, interceptors } = require('../..') +const { signal } = interceptors + +test('should abort request when signal is already aborted', async (t) => { + t = tspl(t, { plan: 1 }) + + const server = createServer({ joinDuplicateHeaders: true }, (req, res) => { + res.end('asd') + }) + + server.listen(0) + await once(server, 'listening') + + const ac = new AbortController() + const _err = new Error('Custom abort reason') + ac.abort(_err) + + const client = new Client( + `http://localhost:${server.address().port}` + ).compose(signal()) + + after(async () => { + await client.close() + server.close() + await once(server, 'close') + }) + + try { + await client.request({ method: 'GET', path: '/', signal: ac.signal }) + } catch (err) { + t.equal(err, _err) + } + + await t.completed +}) + +test('should abort request when signal is aborted after request starts', async (t) => { + t = tspl(t, { plan: 1 }) + + const server = createServer({ joinDuplicateHeaders: true }, (req, res) => { + res.end('asd') + }) + + server.listen(0) + await once(server, 'listening') + + const ac = new AbortController() + + const client = new Client( + `http://localhost:${server.address().port}` + ).compose(signal()) + + after(async () => { + await client.close() + server.close() + await once(server, 'close') + }) + + const ures = await client.request({ method: 'GET', path: '/', signal: ac.signal }) + ac.abort() + + try { + /* eslint-disable-next-line no-unused-vars */ + for await (const chunk of ures.body) { + // Do nothing... + } + } catch (err) { + t.equal(err.name, 'AbortError') + } + + await t.completed +}) + +test('should abort request with custom reason when signal is aborted', async (t) => { + t = tspl(t, { plan: 1 }) + + const server = createServer({ joinDuplicateHeaders: true }, (req, res) => { + res.end('asd') + }) + + server.listen(0) + await once(server, 'listening') + + const ac = new AbortController() + const _err = new Error('Custom abort reason') + + const client = new Client( + `http://localhost:${server.address().port}` + ).compose(signal()) + + after(async () => { + await client.close() + server.close() + await once(server, 'close') + }) + + const ures = await client.request({ method: 'GET', path: '/', signal: ac.signal }) + ac.abort(_err) + try { + /* eslint-disable-next-line no-unused-vars */ + for await (const chunk of ures.body) { + // Do nothing... + } + } catch (err) { + t.equal(err, _err) + } + + await t.completed +}) + +test('should not interfere when signal is not provided', async (t) => { + t = tspl(t, { plan: 1 }) + + const server = createServer({ joinDuplicateHeaders: true }, (req, res) => { + res.end('hello') + }) + + server.listen(0) + await once(server, 'listening') + + const client = new Client( + `http://localhost:${server.address().port}` + ).compose(signal()) + + after(async () => { + await client.close() + server.close() + await once(server, 'close') + }) + + const response = await client.request({ method: 'GET', path: '/' }) + const body = await response.body.text() + t.equal(body, 'hello') + + await t.completed +}) + +test('should cleanup abort listener on response end', async (t) => { + t = tspl(t, { plan: 1 }) + + const server = createServer({ joinDuplicateHeaders: true }, (req, res) => { + res.end('hello') + }) + + server.listen(0) + await once(server, 'listening') + + const ac = new AbortController() + + const client = new Client( + `http://localhost:${server.address().port}` + ).compose(signal()) + + after(async () => { + await client.close() + server.close() + await once(server, 'close') + }) + + const response = await client.request({ method: 'GET', path: '/', signal: ac.signal }) + await response.body.text() + + ac.abort() + t.ok(true, 'Cleanup successful') + + await t.completed +}) + +test('should cleanup abort listener on response error', async (t) => { + t = tspl(t, { plan: 2 }) + + const server = createServer({ joinDuplicateHeaders: true }, (req, res) => { + res.destroy() + }) + + server.listen(0) + await once(server, 'listening') + + const ac = new AbortController() + + const client = new Client( + `http://localhost:${server.address().port}` + ).compose(signal()) + + after(async () => { + await client.close() + server.close() + await once(server, 'close') + }) + + try { + await client.request({ method: 'GET', path: '/', signal: ac.signal }) + } catch (err) { + t.ok(err) + } + + ac.abort() + t.ok(true, 'Cleanup successful') + + await t.completed +})