diff --git a/javascript/selenium-webdriver/bidi/network.js b/javascript/selenium-webdriver/bidi/network.js index 8e860b33f2bdd..4f40eed6604c6 100644 --- a/javascript/selenium-webdriver/bidi/network.js +++ b/javascript/selenium-webdriver/bidi/network.js @@ -154,26 +154,26 @@ class Network { this.ws = await this.bidi.socket this.ws.on('message', (event) => { - const { params } = JSON.parse(Buffer.from(event.toString())) + const { method, params } = JSON.parse(Buffer.from(event.toString())) if (params) { let response = null - if ('initiator' in params) { - response = new BeforeRequestSent( + if ('request' in params && 'response' in params) { + response = new ResponseStarted( params.context, params.navigation, params.redirectCount, params.request, params.timestamp, - params.initiator, + params.response, ) - } else if ('response' in params) { - response = new ResponseStarted( + } else if ('initiator' in params && !('response' in params)) { + response = new BeforeRequestSent( params.context, params.navigation, params.redirectCount, params.request, params.timestamp, - params.response, + params.initiator, ) } else if ('errorText' in params) { response = new FetchError( @@ -185,7 +185,7 @@ class Network { params.errorText, ) } - this.invokeCallbacks(eventType, response) + this.invokeCallbacks(method, response) } }) return id diff --git a/javascript/selenium-webdriver/bidi/networkTypes.js b/javascript/selenium-webdriver/bidi/networkTypes.js index 8a62aaa5d3964..5a2d4ea35a736 100644 --- a/javascript/selenium-webdriver/bidi/networkTypes.js +++ b/javascript/selenium-webdriver/bidi/networkTypes.js @@ -118,6 +118,17 @@ class Header { get value() { return this._value } + + /** + * Converts the Header to a map. + * @returns {Map} A map representation of the Header. + */ + asMap() { + const map = new Map() + map.set('name', this._name) + map.set('value', Object.fromEntries(this._value.asMap())) + return map + } } /** diff --git a/javascript/selenium-webdriver/lib/http.js b/javascript/selenium-webdriver/lib/http.js index d09e4fef8bbf9..18d16f7caed19 100644 --- a/javascript/selenium-webdriver/lib/http.js +++ b/javascript/selenium-webdriver/lib/http.js @@ -88,12 +88,18 @@ class Request { * @param {string} method The HTTP method to use for the request. * @param {string} path The path on the server to send the request to. * @param {Object=} opt_data This request's non-serialized JSON payload data. + * @param {Map} [headers=new Map()] - The optional headers as a Map. */ - constructor(method, path, opt_data) { + constructor(method, path, opt_data, headers = new Map()) { this.method = /** string */ method this.path = /** string */ path this.data = /** Object */ opt_data - this.headers = /** !Map */ new Map([['Accept', 'application/json; charset=utf-8']]) + + if (headers.size > 0) { + this.headers = headers + } else { + this.headers = /** !Map */ new Map([['Accept', 'application/json; charset=utf-8']]) + } } /** @override */ diff --git a/javascript/selenium-webdriver/lib/network.js b/javascript/selenium-webdriver/lib/network.js index cfc5873804d53..bdca20c998cdc 100644 --- a/javascript/selenium-webdriver/lib/network.js +++ b/javascript/selenium-webdriver/lib/network.js @@ -18,12 +18,18 @@ const { Network: getNetwork } = require('../bidi/network') const { InterceptPhase } = require('../bidi/interceptPhase') const { AddInterceptParameters } = require('../bidi/addInterceptParameters') +const { ContinueRequestParameters } = require('../bidi/continueRequestParameters') +const { ProvideResponseParameters } = require('../bidi/provideResponseParameters') +const { Request } = require('./http') +const { BytesValue, Header } = require('../bidi/networkTypes') class Network { #callbackId = 0 #driver #network #authHandlers = new Map() + #requestHandlers = new Map() + #responseHandlers = new Map() constructor(driver) { this.#driver = driver @@ -43,6 +49,8 @@ class Network { await this.#network.addIntercept(new AddInterceptParameters(InterceptPhase.AUTH_REQUIRED)) + await this.#network.addIntercept(new AddInterceptParameters(InterceptPhase.BEFORE_REQUEST_SENT)) + await this.#network.authRequired(async (event) => { const requestId = event.request.request const uri = event.request.url @@ -54,6 +62,76 @@ class Network { await this.#network.continueWithAuthNoCredentials(requestId) }) + + await this.#network.beforeRequestSent(async (event) => { + const requestId = event.request.request + const requestData = event.request + + // Build the original request from the intercepted request details. + const originalRequest = new Request(requestData.method, requestData.url, null, new Map(requestData.headers)) + + let requestHandler = this.getRequestHandler(originalRequest) + let responseHandler = this.getResponseHandler(originalRequest) + + // Populate the headers of the original request. + // Body is not available as part of WebDriver Spec, hence we cannot add that or use that. + + const continueRequestParams = new ContinueRequestParameters(requestId) + + // If a response handler exists, we mock the response instead of modifying the outgoing request + if (responseHandler !== null) { + const modifiedResponse = await responseHandler() + + const provideResponseParams = new ProvideResponseParameters(requestId) + provideResponseParams.statusCode(modifiedResponse.status) + + // Convert headers + if (modifiedResponse.headers.size > 0) { + const headers = [] + + modifiedResponse.headers.forEach((value, key) => { + headers.push(new Header(key, new BytesValue('string', value))) + }) + provideResponseParams.headers(headers) + } + + // Convert body if available + if (modifiedResponse.body && modifiedResponse.body.length > 0) { + provideResponseParams.body(new BytesValue('string', modifiedResponse.body)) + } + + await this.#network.provideResponse(provideResponseParams) + return + } + + // If request handler exists, modify the request + if (requestHandler !== null) { + const modifiedRequest = requestHandler(originalRequest) + + continueRequestParams.method(modifiedRequest.method) + + if (originalRequest.path !== modifiedRequest.path) { + continueRequestParams.url(modifiedRequest.path) + } + + // Convert headers + if (modifiedRequest.headers.size > 0) { + const headers = [] + + modifiedRequest.headers.forEach((value, key) => { + headers.push(new Header(key, new BytesValue('string', value))) + }) + continueRequestParams.headers(headers) + } + + if (modifiedRequest.data && modifiedRequest.data.length > 0) { + continueRequestParams.body(new BytesValue('string', modifiedRequest.data)) + } + } + + // Continue with the modified or original request + await this.#network.continueRequest(continueRequestParams) + }) } getAuthCredentials(uri) { @@ -64,6 +142,27 @@ class Network { } return null } + + getRequestHandler(req) { + for (let [, value] of this.#requestHandlers) { + const filter = value.filter + if (filter(req)) { + return value.handler + } + } + return null + } + + getResponseHandler(req) { + for (let [, value] of this.#responseHandlers) { + const filter = value.filter + if (filter(req)) { + return value.handler + } + } + return null + } + async addAuthenticationHandler(username, password, uri = '//') { await this.#init() @@ -86,6 +185,82 @@ class Network { async clearAuthenticationHandlers() { this.#authHandlers.clear() } + + /** + * Adds a request handler that filters requests based on a predicate function. + * @param {Function} filter - A function that takes an HTTP request and returns true or false. + * @param {Function} handler - A function that takes an HTTP request and returns a modified request. + * @returns {number} - A unique handler ID. + * @throws {Error} - If filter is not a function or handler does not return a request. + */ + async addRequestHandler(filter, handler) { + if (typeof filter !== 'function') { + throw new Error('Filter must be a function.') + } + + if (typeof handler !== 'function') { + throw new Error('Handler must be a function.') + } + + await this.#init() + + const id = this.#callbackId++ + + this.#requestHandlers.set(id, { filter, handler }) + return id + } + + async removeRequestHandler(id) { + await this.#init() + + if (this.#requestHandlers.has(id)) { + this.#requestHandlers.delete(id) + } else { + throw Error(`Callback with id ${id} not found`) + } + } + + async clearRequestHandlers() { + this.#requestHandlers.clear() + } + + /** + * Adds a response handler that mocks responses. + * @param {Function} filter - A function that takes an HTTP request, returning a boolean. + * @param {Function} handler - A function that returns a mocked HTTP response. + * @returns {number} - A unique handler ID. + * @throws {Error} - If filter is not a function or handler is not an async function. + */ + async addResponseHandler(filter, handler) { + if (typeof filter !== 'function') { + throw new Error('Filter must be a function.') + } + + if (typeof handler !== 'function') { + throw new Error('Handler must be a function.') + } + + await this.#init() + + const id = this.#callbackId++ + + this.#responseHandlers.set(id, { filter, handler }) + return id + } + + async removeResponseHandler(id) { + await this.#init() + + if (this.#responseHandlers.has(id)) { + this.#responseHandlers.delete(id) + } else { + throw Error(`Callback with id ${id} not found`) + } + } + + async clearResponseHandlers() { + this.#responseHandlers.clear() + } } module.exports = Network diff --git a/javascript/selenium-webdriver/lib/test/fileserver.js b/javascript/selenium-webdriver/lib/test/fileserver.js index 7023cd8d6fe3b..ce280043d0a5e 100644 --- a/javascript/selenium-webdriver/lib/test/fileserver.js +++ b/javascript/selenium-webdriver/lib/test/fileserver.js @@ -45,6 +45,7 @@ const Pages = (function () { }) } + addPage('addRequestBody', 'addRequestBody') addPage('ajaxyPage', 'ajaxy_page.html') addPage('alertsPage', 'alerts.html') addPage('basicAuth', 'basicAuth') @@ -131,6 +132,7 @@ const Path = { PAGE: WEB_ROOT + '/page', SLEEP: WEB_ROOT + '/sleep', UPLOAD: WEB_ROOT + '/upload', + ADD_REQUEST_BODY: WEB_ROOT + '/addRequestBody', } var app = express() @@ -143,6 +145,7 @@ app }) .use(JS_ROOT, serveIndex(jsDirectory), express.static(jsDirectory)) .post(Path.UPLOAD, handleUpload) + .post(Path.ADD_REQUEST_BODY, addRequestBody) .use(WEB_ROOT, serveIndex(baseDirectory), express.static(baseDirectory)) .use(DATA_ROOT, serveIndex(dataDirectory), express.static(dataDirectory)) .get(Path.ECHO, sendEcho) @@ -187,6 +190,32 @@ function sendInifinitePage(request, response) { response.end(body) } +function addRequestBody(request, response) { + let requestBody = '' + + request.on('data', (chunk) => { + requestBody += chunk + }) + + request.on('end', () => { + let body = [ + '', + '', + 'Page', + '', + `

Request Body:

${requestBody}
`, + '', + '', + ].join('') + + response.writeHead(200, { + 'Content-Length': Buffer.byteLength(body, 'utf8'), + 'Content-Type': 'text/html; charset=utf-8', + }) + response.end(body) + }) +} + function sendBasicAuth(request, response) { const denyAccess = function () { response.writeHead(401, { 'WWW-Authenticate': 'Basic realm="test"' }) diff --git a/javascript/selenium-webdriver/test/lib/webdriver_network_test.js b/javascript/selenium-webdriver/test/lib/webdriver_network_test.js index ff1ce496bf038..2d9a9ad682bf3 100644 --- a/javascript/selenium-webdriver/test/lib/webdriver_network_test.js +++ b/javascript/selenium-webdriver/test/lib/webdriver_network_test.js @@ -22,6 +22,8 @@ const { Browser } = require('selenium-webdriver') const { Pages, suite } = require('../../lib/test') const until = require('selenium-webdriver/lib/until') const { By } = require('selenium-webdriver') +const { Request, Response } = require('selenium-webdriver/http') +const { Network } = require('selenium-webdriver/bidi/network') suite( function (env) { @@ -112,6 +114,206 @@ suite( assert.strictEqual(e.name, 'UnexpectedAlertOpenError') } }) + + it('can add request handler to modify method', async function () { + const filter = (req) => req.path.includes('bidi/logEntryAdded.html') + const handler = () => new Request('HEAD', Pages.logEntryAdded, null) + + await driver.network().addRequestHandler(filter, handler) + + await driver.get(Pages.logEntryAdded) + + const pageSource = await driver.getPageSource() + + assert.strictEqual(pageSource.includes('log entry added events'), false) + }) + + it('can add request handler to modify uri', async function () { + const filter = (req) => req.path.includes('bidi/logEntryAdded.html') + const handler = () => new Request('GET', Pages.blankPage, null) + + await driver.network().addRequestHandler(filter, handler) + + await driver.get(Pages.logEntryAdded) + + const pageSource = await driver.getPageSource() + + assert.strictEqual(pageSource.includes('blank'), true) + }) + + it('can add request handler to modify body', async function () { + const filter = (req) => req.path.includes('bidi/logEntryAdded.html') + const handler = () => new Request('POST', Pages.addRequestBody, 'hello world!') + + await driver.network().addRequestHandler(filter, handler) + + await driver.get(Pages.logEntryAdded) + + const pageSource = await driver.getPageSource() + + assert.strictEqual(pageSource.includes('hello world'), true) + }) + + it('can add multiple request handlers', async function () { + const filter = (req) => req.path.includes('bidi/logEntryAdded.html') + const handler = () => new Request('GET', Pages.blankPage, null) + + await driver.network().addRequestHandler(filter, handler) + + await driver.network().addRequestHandler( + (req) => req.path.includes('hello.html'), + () => new Request('GET', Pages.logEntryAdded, null), + ) + + await driver.get(Pages.logEntryAdded) + + const pageSource = await driver.getPageSource() + + assert.strictEqual(pageSource.includes('blank'), true) + }) + + it('can add multiple request handlers with same filter', async function () { + const filter = (req) => req.path.includes('bidi/logEntryAdded.html') + const handler = () => new Request('GET', Pages.blankPage, null) + + await driver.network().addRequestHandler(filter, handler) + + await driver.network().addRequestHandler(filter, handler) + + await driver.get(Pages.logEntryAdded) + + const pageSource = await driver.getPageSource() + + assert.strictEqual(pageSource.includes('blank'), true) + }) + + it('can remove request handler', async function () { + const filter = (req) => req.path.includes('bidi/logEntryAdded.html') + const handler = () => new Request('GET', Pages.blankPage, null) + + const id = await driver.network().addRequestHandler(filter, handler) + + await driver.network().removeRequestHandler(id) + + await driver.get(Pages.logEntryAdded) + + const pageSource = await driver.getPageSource() + + assert.strictEqual(pageSource.includes('entry added'), true) + }) + + it('throws an error when removing request handler that does not exist', async function () { + try { + await driver.network().removeRequestHandler(10) + assert.fail('Expected error not thrown. Non-existent handler cannot be removed') + } catch (e) { + assert.strictEqual(e.message, 'Callback with id 10 not found') + } + }) + + it('can clear request handlers', async function () { + const filter = (req) => req.path.includes('bidi/logEntryAdded.html') + const handler = () => new Request('GET', Pages.blankPage, null) + + await driver.network().addRequestHandler(filter, handler) + + await driver.network().addRequestHandler( + (req) => req.path.includes('hello.html'), + () => new Request('GET', Pages.logEntryAdded, null), + ) + + await driver.network().clearRequestHandlers() + + await driver.get(Pages.logEntryAdded) + + const pageSource = await driver.getPageSource() + + assert.strictEqual(pageSource.includes('entry added'), true) + }) + + it('can add response handler to mock complete response', async function () { + const filter = (req) => req.path.includes('bidi/logEntryAdded.html') + const handler = () => new Response(500, { test: 'header-value' }, 'Internal server error') + + const network = await Network(driver) + + let onResponseCompleted = null + + await network.responseStarted(function (event) { + if (event.response.url.includes('logEntryAdded')) { + onResponseCompleted = event.response + } + }) + + await driver.network().addResponseHandler(filter, handler) + + await driver.get(Pages.logEntryAdded) + + const pageSource = await driver.getPageSource() + + assert.strictEqual(pageSource.includes('Internal server error'), true) + + assert.strictEqual(onResponseCompleted.status, 500) + assert.strictEqual(onResponseCompleted.headers[0].name, 'test') + assert.strictEqual(onResponseCompleted.headers[0].value.value, 'header-value') + }) + + it('can add multiple response handler with same filter', async function () { + const filter = (req) => req.path.includes('bidi/logEntryAdded.html') + const handler = () => new Response(500, { test: 'header-value' }, 'Internal server error') + + const network = await Network(driver) + + let onResponseCompleted = null + + await network.responseStarted(function (event) { + if (event.response.url.includes('logEntryAdded')) { + onResponseCompleted = event.response + } + }) + + await driver.network().addResponseHandler(filter, handler) + await driver.network().addResponseHandler(filter, handler) + + await driver.get(Pages.logEntryAdded) + + const pageSource = await driver.getPageSource() + + assert.strictEqual(pageSource.includes('Internal server error'), true) + + assert.strictEqual(onResponseCompleted.status, 500) + assert.strictEqual(onResponseCompleted.headers[0].name, 'test') + assert.strictEqual(onResponseCompleted.headers[0].value.value, 'header-value') + }) + + it('throws an error when removing response handler that does not exist', async function () { + try { + await driver.network().removeResponseHandler(10) + assert.fail('Expected error not thrown. Non-existent handler cannot be removed') + } catch (e) { + assert.strictEqual(e.message, 'Callback with id 10 not found') + } + }) + + it('can clear response handlers', async function () { + const filter = (req) => req.path.includes('bidi/logEntryAdded.html') + const handler = () => new Response(200, { test: 'header' }, 'Hello!') + + await driver.network().addResponseHandler(filter, handler) + + await driver.network().addResponseHandler( + (req) => req.path.includes('hello.html'), + () => new Response(401, { test: 'header' }, 'Not found!'), + ) + + await driver.network().clearResponseHandlers() + + await driver.get(Pages.logEntryAdded) + + const pageSource = await driver.getPageSource() + + assert.strictEqual(pageSource.includes('entry added'), true) + }) }) }, { browsers: [Browser.FIREFOX] },