diff --git a/java/src/org/openqa/selenium/remote/Network.java b/java/src/org/openqa/selenium/remote/Network.java index 6407d3de3745b..0119238a4e754 100644 --- a/java/src/org/openqa/selenium/remote/Network.java +++ b/java/src/org/openqa/selenium/remote/Network.java @@ -19,10 +19,12 @@ import java.net.URI; import java.util.function.Predicate; +import java.util.function.Supplier; import java.util.function.UnaryOperator; import org.openqa.selenium.Beta; import org.openqa.selenium.UsernameAndPassword; import org.openqa.selenium.remote.http.HttpRequest; +import org.openqa.selenium.remote.http.HttpResponse; @Beta public interface Network { @@ -35,9 +37,15 @@ public interface Network { void clearAuthenticationHandlers(); - long addRequestHandler(Predicate filter, UnaryOperator handler); + long addRequestHandler(Predicate filter, UnaryOperator handler); void removeRequestHandler(long id); void clearRequestHandlers(); + + long addResponseHandler(Predicate filter, Supplier handler); + + void removeResponseHandler(long id); + + void clearResponseHandlers(); } diff --git a/java/src/org/openqa/selenium/remote/RemoteNetwork.java b/java/src/org/openqa/selenium/remote/RemoteNetwork.java index 2848895667f1f..c8203c568617e 100644 --- a/java/src/org/openqa/selenium/remote/RemoteNetwork.java +++ b/java/src/org/openqa/selenium/remote/RemoteNetwork.java @@ -25,6 +25,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Predicate; +import java.util.function.Supplier; import java.util.function.UnaryOperator; import org.openqa.selenium.Beta; import org.openqa.selenium.UsernameAndPassword; @@ -36,10 +37,12 @@ import org.openqa.selenium.bidi.network.ContinueRequestParameters; import org.openqa.selenium.bidi.network.Header; import org.openqa.selenium.bidi.network.InterceptPhase; +import org.openqa.selenium.bidi.network.ProvideResponseParameters; import org.openqa.selenium.bidi.network.RequestData; import org.openqa.selenium.remote.http.Contents; import org.openqa.selenium.remote.http.HttpMethod; import org.openqa.selenium.remote.http.HttpRequest; +import org.openqa.selenium.remote.http.HttpResponse; @Beta class RemoteNetwork implements Network { @@ -51,6 +54,8 @@ class RemoteNetwork implements Network { private final Map requestHandlers = new ConcurrentHashMap<>(); + private final Map responseHandlers = new ConcurrentHashMap<>(); + private final AtomicLong callBackId = new AtomicLong(1); public RemoteNetwork(WebDriver driver) { @@ -97,24 +102,60 @@ private void interceptRequest() { ContinueRequestParameters continueRequestParameters = new ContinueRequestParameters(requestId); - Optional> requestHandler = getRequestHandler(uri); + RequestData interceptedRequest = beforeRequestSent.getRequest(); + + // Build the originalRequest object from the intercepted request details. + HttpRequest originalRequest = + new HttpRequest( + HttpMethod.getHttpMethod(interceptedRequest.getMethod()), + interceptedRequest.getUrl()); + + // Populate the headers of the original request. + // Body is not available as part of WebDriver Spec, hence we cannot add that or use that. + interceptedRequest + .getHeaders() + .forEach( + header -> + originalRequest.addHeader(header.getName(), header.getValue().getValue())); + + Optional> requestHandler = getRequestHandler(originalRequest); + + Optional> responseHandler = getResponseHandler(originalRequest); + + // If request and response handler both are registered for same uri, + // then the response will be mocked instead of modifying the outgoing request. + // This can be altered in the future to let the modified request go through and have a + // response mock for that modified request. + // If in future the Browsers support intercepting at "response started" phase and allow + // using provide response command in that phase. + // Currently, only intercepting in "before request sent" phase is permitted. + if (responseHandler.isPresent()) { + ProvideResponseParameters provideResponseParameters = + new ProvideResponseParameters(requestId); - if (requestHandler.isPresent()) { - RequestData interceptedRequest = beforeRequestSent.getRequest(); + HttpResponse modifiedResponse = responseHandler.get().get(); + + provideResponseParameters.statusCode(modifiedResponse.getStatus()); + + List
headerList = new ArrayList<>(); + modifiedResponse.forEachHeader( + (name, value) -> + headerList.add( + new Header(name, new BytesValue(BytesValue.Type.STRING, value)))); - // Build the originalRequest object from the intercepted request details. - HttpRequest originalRequest = - new HttpRequest( - HttpMethod.getHttpMethod(interceptedRequest.getMethod()), - interceptedRequest.getUrl()); + if (!headerList.isEmpty()) { + provideResponseParameters.headers(headerList); + } - // Populate the headers of the original request. - interceptedRequest - .getHeaders() - .forEach( - header -> - originalRequest.addHeader(header.getName(), header.getValue().getValue())); + Contents.Supplier content = modifiedResponse.getContent(); + if (content.length() > 0) { + provideResponseParameters.body( + new BytesValue(BytesValue.Type.STRING, Contents.utf8String(content))); + } + network.provideResponse(provideResponseParameters); + return; + } else if (requestHandler.isPresent()) { HttpRequest modifiedRequest = requestHandler.get().apply(originalRequest); continueRequestParameters.method(modifiedRequest.getMethod()); @@ -144,13 +185,20 @@ private void interceptRequest() { }); } - private Optional> getRequestHandler(URI uri) { + private Optional> getRequestHandler(HttpRequest request) { return requestHandlers.values().stream() - .filter(requestDetails -> requestDetails.getFilter().test(uri)) + .filter(requestDetails -> requestDetails.getFilter().test(request)) .map(RequestDetails::getHandler) .findFirst(); } + private Optional> getResponseHandler(HttpRequest request) { + return responseHandlers.values().stream() + .filter(responseDetails -> responseDetails.getFilter().test(request)) + .map(ResponseDetails::getHandler) + .findFirst(); + } + @Override public long addAuthenticationHandler(UsernameAndPassword usernameAndPassword) { return addAuthenticationHandler(url -> true, usernameAndPassword); @@ -176,7 +224,7 @@ public void clearAuthenticationHandlers() { } @Override - public long addRequestHandler(Predicate filter, UnaryOperator handler) { + public long addRequestHandler(Predicate filter, UnaryOperator handler) { long id = this.callBackId.incrementAndGet(); requestHandlers.put(id, new RequestDetails(filter, handler)); @@ -193,6 +241,25 @@ public void clearRequestHandlers() { requestHandlers.clear(); } + // Allows mocking the response body + @Override + public long addResponseHandler(Predicate filter, Supplier handler) { + long id = this.callBackId.incrementAndGet(); + + responseHandlers.put(id, new ResponseDetails(filter, handler)); + return id; + } + + @Override + public void removeResponseHandler(long id) { + responseHandlers.remove(id); + } + + @Override + public void clearResponseHandlers() { + responseHandlers.clear(); + } + private class AuthDetails { private final Predicate filter; private final UsernameAndPassword usernameAndPassword; @@ -212,15 +279,15 @@ public UsernameAndPassword getUsernameAndPassword() { } private class RequestDetails { - private final Predicate filter; + private final Predicate filter; private final UnaryOperator handler; - public RequestDetails(Predicate filter, UnaryOperator handler) { + public RequestDetails(Predicate filter, UnaryOperator handler) { this.filter = filter; this.handler = handler; } - public Predicate getFilter() { + public Predicate getFilter() { return this.filter; } @@ -228,4 +295,22 @@ public UnaryOperator getHandler() { return this.handler; } } + + private class ResponseDetails { + private final Predicate filter; + private final Supplier handler; + + public ResponseDetails(Predicate filter, Supplier handler) { + this.filter = filter; + this.handler = handler; + } + + public Predicate getFilter() { + return this.filter; + } + + public Supplier getHandler() { + return this.handler; + } + } } diff --git a/java/test/org/openqa/selenium/WebNetworkTest.java b/java/test/org/openqa/selenium/WebNetworkTest.java index fb7f3023784c6..b3012ccf47b13 100644 --- a/java/test/org/openqa/selenium/WebNetworkTest.java +++ b/java/test/org/openqa/selenium/WebNetworkTest.java @@ -17,6 +17,7 @@ package org.openqa.selenium; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; @@ -26,12 +27,14 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Predicate; import org.junit.jupiter.api.Test; import org.openqa.selenium.bidi.module.Network; import org.openqa.selenium.bidi.network.Header; import org.openqa.selenium.environment.webserver.NettyAppServer; import org.openqa.selenium.remote.RemoteWebDriver; +import org.openqa.selenium.remote.http.Contents; import org.openqa.selenium.remote.http.HttpMethod; import org.openqa.selenium.remote.http.HttpRequest; import org.openqa.selenium.remote.http.HttpResponse; @@ -181,7 +184,7 @@ void canClearAuthenticationHandlers() { @Ignore(Browser.CHROME) @Ignore(Browser.EDGE) void canAddRequestHandler() { - Predicate filter = uri -> uri.getPath().contains("logEntry"); + Predicate filter = httpRequest -> httpRequest.getUri().contains("logEntry"); page = appServer.whereIs("/bidi/logEntryAdded.html"); @@ -189,7 +192,7 @@ void canAddRequestHandler() { driver.get(page); - assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Long entry added events"); + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Log entry added events"); } @Test @@ -197,7 +200,7 @@ void canAddRequestHandler() { @Ignore(Browser.CHROME) @Ignore(Browser.EDGE) void canAddRequestHandlerToModifyMethod() { - Predicate filter = uri -> uri.getPath().contains("logEntry"); + Predicate filter = httpRequest -> httpRequest.getUri().contains("logEntry"); page = appServer.whereIs("/bidi/logEntryAdded.html"); @@ -235,7 +238,7 @@ void canAddRequestHandlerToModifyHeaders() throws InterruptedException { appServer = new NettyAppServer(route); appServer.start(); - Predicate filter = uri -> uri.getPath().contains("network"); + Predicate filter = httpRequest -> httpRequest.getUri().contains("network"); CountDownLatch latch = new CountDownLatch(1); @@ -285,7 +288,7 @@ void canAddRequestHandlerToModifyBody() throws InterruptedException { appServer = new NettyAppServer(route); appServer.start(); - Predicate filter = uri -> uri.getPath().contains("network"); + Predicate filter = httpRequest -> httpRequest.getUri().contains("network"); page = appServer.whereIs("network.html"); @@ -311,17 +314,18 @@ void canAddMultipleRequestHandlers() { ((RemoteWebDriver) driver) .network() - .addRequestHandler(uri -> uri.getPath().contains("logEntry"), httpRequest -> httpRequest); + .addRequestHandler( + httpRequest -> httpRequest.getUri().contains("logEntry"), httpRequest -> httpRequest); ((RemoteWebDriver) driver) .network() .addRequestHandler( - uri -> uri.getPath().contains("hello"), + httpRequest -> httpRequest.getUri().contains("hello"), httpRequest -> new HttpRequest(HttpMethod.HEAD, page)); driver.get(page); - assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Long entry added events"); + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Log entry added events"); } @Test @@ -331,17 +335,19 @@ void canAddMultipleRequestHandlers() { void canAddMultipleRequestHandlersWithTheSameFilter() { ((RemoteWebDriver) driver) .network() - .addRequestHandler(uri -> uri.getPath().contains("logEntry"), httpRequest -> httpRequest); + .addRequestHandler( + httpRequest -> httpRequest.getUri().contains("logEntry"), httpRequest -> httpRequest); ((RemoteWebDriver) driver) .network() - .addRequestHandler(uri -> uri.getPath().contains("logEntry"), httpRequest -> httpRequest); + .addRequestHandler( + httpRequest -> httpRequest.getUri().contains("logEntry"), httpRequest -> httpRequest); page = appServer.whereIs("/bidi/logEntryAdded.html"); driver.get(page); - assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Long entry added events"); + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Log entry added events"); } @Test @@ -368,7 +374,7 @@ void canRemoveRequestHandler() throws InterruptedException { appServer = new NettyAppServer(route); appServer.start(); - Predicate filter = uri -> uri.getPath().contains("network"); + Predicate filter = httpRequest -> httpRequest.getUri().contains("network"); CountDownLatch latch = new CountDownLatch(1); @@ -409,11 +415,11 @@ void canRemoveRequestHandler() throws InterruptedException { @Ignore(Browser.CHROME) @Ignore(Browser.EDGE) void canRemoveRequestHandlerThatDoesNotExist() { - ((RemoteWebDriver) driver).network().removeAuthenticationHandler(5); + ((RemoteWebDriver) driver).network().removeRequestHandler(5); page = appServer.whereIs("/bidi/logEntryAdded.html"); driver.get(page); - assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Long entry added events"); + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Log entry added events"); } @Test @@ -426,19 +432,177 @@ void canClearRequestHandlers() { ((RemoteWebDriver) driver) .network() .addRequestHandler( - uri -> uri.getPath().contains("logEntryAdded"), + httpRequest -> httpRequest.getUri().contains("logEntryAdded"), httpRequest -> new HttpRequest(HttpMethod.DELETE, page)); ((RemoteWebDriver) driver) .network() .addRequestHandler( - uri -> uri.getPath().contains("hello"), + httpRequest -> httpRequest.getUri().contains("hello"), httpRequest -> new HttpRequest(HttpMethod.HEAD, page)); ((RemoteWebDriver) driver).network().clearRequestHandlers(); driver.get(page); - assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Long entry added events"); + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Log entry added events"); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canAddResponseHandlerToModifyStatusCode() { + AtomicBoolean seen = new AtomicBoolean(false); + Predicate filter = httpRequest -> httpRequest.getUri().contains("logEntry"); + + page = appServer.whereIs("/bidi/logEntryAdded.html"); + + try (Network network = new Network(driver)) { + + ((RemoteWebDriver) driver) + .network() + .addResponseHandler( + filter, + () -> { + HttpResponse res = new HttpResponse(); + return res.setStatus(500); + }); + + network.onResponseCompleted( + responseDetails -> { + if (responseDetails.getResponseData().getUrl().contains("logEntryAdded")) { + assertThat(responseDetails.getResponseData().getStatus()).isEqualTo(500); + seen.set(true); + } + }); + + driver.get(page); + + assertThat(seen.get()).isTrue(); + } + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canAddResponseHandlerToModifyBody() { + Predicate filter = httpRequest -> httpRequest.getUri().contains("logEntry"); + + page = appServer.whereIs("/bidi/logEntryAdded.html"); + + ((RemoteWebDriver) driver) + .network() + .addResponseHandler( + filter, + () -> + new HttpResponse().setContent(Contents.string("

Mocked response

", UTF_8))); + + driver.get(page); + + assertThat(driver.getPageSource()).contains("Mocked response"); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canAddResponseHandlerToModifyHeaders() { + AtomicBoolean seen = new AtomicBoolean(false); + Predicate filter = httpRequest -> httpRequest.getUri().contains("logEntry"); + + page = appServer.whereIs("/bidi/logEntryAdded.html"); + + try (Network network = new Network(driver)) { + + ((RemoteWebDriver) driver) + .network() + .addResponseHandler( + filter, + () -> { + HttpResponse res = new HttpResponse(); + return res.setHeader("hello", "world"); + }); + + network.onResponseCompleted( + responseDetails -> { + if (responseDetails.getResponseData().getUrl().contains("logEntryAdded")) { + assertThat(responseDetails.getResponseData().getStatus()).isEqualTo(200); + Header header = responseDetails.getResponseData().getHeaders().get(0); + assertThat(header.getName()).isEqualTo("hello"); + assertThat(header.getValue().getValue()).isEqualTo("world"); + seen.set(true); + } + }); + + driver.get(page); + + assertThat(seen.get()).isTrue(); + } + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canAddMultipleResponseHandlersWithTheSameFilter() { + ((RemoteWebDriver) driver) + .network() + .addResponseHandler( + httpRequest -> httpRequest.getUri().contains("logEntryAdded"), + () -> + new HttpResponse().setContent(Contents.string("

Mocked response

", UTF_8))); + + ((RemoteWebDriver) driver) + .network() + .addResponseHandler( + httpRequest -> httpRequest.getUri().contains("logEntryAdded"), + () -> + new HttpResponse().setContent(Contents.string("

Mocked response

", UTF_8))); + + page = appServer.whereIs("/bidi/logEntryAdded.html"); + + driver.get(page); + + assertThat(driver.getPageSource()).contains("Mocked response"); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canRemoveResponseHandlerThatDoesNotExist() { + ((RemoteWebDriver) driver).network().removeResponseHandler(5); + page = appServer.whereIs("/bidi/logEntryAdded.html"); + driver.get(page); + + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Log entry added events"); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canClearResponseHandlers() { + page = appServer.whereIs("/bidi/logEntryAdded.html"); + + ((RemoteWebDriver) driver) + .network() + .addResponseHandler( + httpRequest -> httpRequest.getUri().contains("logEntryAdded"), HttpResponse::new); + + ((RemoteWebDriver) driver) + .network() + .addResponseHandler( + httpRequest -> httpRequest.getUri().contains("logEntryAdded"), + () -> + new HttpResponse().setContent(Contents.string("

Mocked response

", UTF_8))); + + ((RemoteWebDriver) driver).network().clearResponseHandlers(); + + driver.get(page); + + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Log entry added events"); } }