Skip to content

Commit fbd323c

Browse files
authored
Merge pull request #188 from nhooyr/fix-negotiations
Fix deflate extension parameter negotiation
2 parents 94f9b71 + 95bfb8f commit fbd323c

File tree

6 files changed

+23
-69
lines changed

6 files changed

+23
-69
lines changed

accept.go

+2-27
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM
209209

210210
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
211211
copts := mode.opts()
212-
copts.serverMaxWindowBits = 8
213212

214213
for _, p := range ext.params {
215214
switch p {
@@ -222,26 +221,7 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
222221
}
223222

224223
if strings.HasPrefix(p, "client_max_window_bits") {
225-
continue
226-
227-
// bits, ok := parseExtensionParameter(p, 15)
228-
// if !ok || bits < 8 || bits > 16 {
229-
// err := fmt.Errorf("invalid client_max_window_bits: %q", p)
230-
// http.Error(w, err.Error(), http.StatusBadRequest)
231-
// return nil, err
232-
// }
233-
// copts.clientMaxWindowBits = bits
234-
// continue
235-
}
236-
237-
if false && strings.HasPrefix(p, "server_max_window_bits") {
238-
// We always send back 8 but make sure to validate.
239-
bits, ok := parseExtensionParameter(p, 0)
240-
if !ok || bits < 8 || bits > 16 {
241-
err := fmt.Errorf("invalid server_max_window_bits: %q", p)
242-
http.Error(w, err.Error(), http.StatusBadRequest)
243-
return nil, err
244-
}
224+
// We cannot adjust the read sliding window so cannot make use of this.
245225
continue
246226
}
247227

@@ -256,14 +236,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
256236
}
257237

258238
// parseExtensionParameter parses the value in the extension parameter p.
259-
// It falls back to defaultVal if there is no value.
260-
// If defaultVal == 0, then ok == false if there is no value.
261-
func parseExtensionParameter(p string, defaultVal int) (int, bool) {
239+
func parseExtensionParameter(p string) (int, bool) {
262240
ps := strings.Split(p, "=")
263241
if len(ps) == 1 {
264-
if defaultVal > 0 {
265-
return defaultVal, true
266-
}
267242
return 0, false
268243
}
269244
i, e := strconv.Atoi(strings.Trim(ps[1], `"`))

accept_test.go

-1
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,6 @@ func Test_acceptCompression(t *testing.T) {
327327
expCopts: &compressionOptions{
328328
clientNoContextTakeover: true,
329329
serverNoContextTakeover: true,
330-
serverMaxWindowBits: 8,
331330
},
332331
},
333332
{

autobahn_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ var excludedAutobahnCases = []string{
2828

2929
// We skip the tests related to requestMaxWindowBits as that is unimplemented due
3030
// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
31+
// Same with klauspost/compress which doesn't allow adjusting the sliding window size.
3132
"13.3.*", "13.4.*", "13.5.*", "13.6.*",
3233
}
3334

compress_notjs.go

+4-10
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
package websocket
44

55
import (
6-
"fmt"
76
"io"
87
"net/http"
98
"sync"
@@ -20,10 +19,7 @@ func (m CompressionMode) opts() *compressionOptions {
2019

2120
type compressionOptions struct {
2221
clientNoContextTakeover bool
23-
clientMaxWindowBits int
24-
2522
serverNoContextTakeover bool
26-
serverMaxWindowBits int
2723
}
2824

2925
func (copts *compressionOptions) setHeader(h http.Header) {
@@ -34,12 +30,6 @@ func (copts *compressionOptions) setHeader(h http.Header) {
3430
if copts.serverNoContextTakeover {
3531
s += "; server_no_context_takeover"
3632
}
37-
if false && copts.serverMaxWindowBits > 0 {
38-
s += fmt.Sprintf("; server_max_window_bits=%v", copts.serverMaxWindowBits)
39-
}
40-
if false && copts.clientMaxWindowBits > 0 {
41-
s += fmt.Sprintf("; client_max_window_bits=%v", copts.clientMaxWindowBits)
42-
}
4333
h.Set("Sec-WebSocket-Extensions", s)
4434
}
4535

@@ -147,6 +137,10 @@ func (sw *slidingWindow) init(n int) {
147137
return
148138
}
149139

140+
if n == 0 {
141+
n = 32768
142+
}
143+
150144
p := slidingWindowPool(n)
151145
buf, ok := p.Get().([]byte)
152146
if ok {

dial.go

+15-30
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
8282
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
8383
}
8484

85-
resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey)
85+
var copts *compressionOptions
86+
if opts.CompressionMode != CompressionDisabled {
87+
copts = opts.CompressionMode.opts()
88+
}
89+
90+
resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
8691
if err != nil {
8792
return nil, resp, err
8893
}
@@ -104,7 +109,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
104109
}
105110
}()
106111

107-
copts, err := verifyServerResponse(opts, secWebSocketKey, resp)
112+
copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
108113
if err != nil {
109114
return nil, resp, err
110115
}
@@ -125,7 +130,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
125130
}), resp, nil
126131
}
127132

128-
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) {
133+
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
129134
if opts.HTTPClient.Timeout > 0 {
130135
return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
131136
}
@@ -153,9 +158,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
153158
if len(opts.Subprotocols) > 0 {
154159
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
155160
}
156-
if opts.CompressionMode != CompressionDisabled {
157-
copts := opts.CompressionMode.opts()
158-
copts.clientMaxWindowBits = 8
161+
if copts != nil {
159162
copts.setHeader(req.Header)
160163
}
161164

@@ -178,7 +181,7 @@ func secWebSocketKey(rr io.Reader) (string, error) {
178181
return base64.StdEncoding.EncodeToString(b), nil
179182
}
180183

181-
func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
184+
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
182185
if resp.StatusCode != http.StatusSwitchingProtocols {
183186
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
184187
}
@@ -203,7 +206,7 @@ func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.
203206
return nil, err
204207
}
205208

206-
return verifyServerExtensions(resp.Header)
209+
return verifyServerExtensions(copts, resp.Header)
207210
}
208211

209212
func verifySubprotocol(subprotos []string, resp *http.Response) error {
@@ -221,19 +224,19 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
221224
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
222225
}
223226

224-
func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
227+
func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
225228
exts := websocketExtensions(h)
226229
if len(exts) == 0 {
227230
return nil, nil
228231
}
229232

230233
ext := exts[0]
231-
if ext.name != "permessage-deflate" || len(exts) > 1 {
234+
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
232235
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
233236
}
234237

235-
copts := &compressionOptions{}
236-
copts.clientMaxWindowBits = 8
238+
copts = &*copts
239+
237240
for _, p := range ext.params {
238241
switch p {
239242
case "client_no_context_takeover":
@@ -244,24 +247,6 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
244247
continue
245248
}
246249

247-
if false && strings.HasPrefix(p, "server_max_window_bits") {
248-
bits, ok := parseExtensionParameter(p, 0)
249-
if !ok || bits < 8 || bits > 16 {
250-
return nil, fmt.Errorf("invalid server_max_window_bits: %q", p)
251-
}
252-
copts.serverMaxWindowBits = bits
253-
continue
254-
}
255-
256-
if false && strings.HasPrefix(p, "client_max_window_bits") {
257-
bits, ok := parseExtensionParameter(p, 0)
258-
if !ok || bits < 8 || bits > 16 {
259-
return nil, fmt.Errorf("invalid client_max_window_bits: %q", p)
260-
}
261-
copts.clientMaxWindowBits = 8
262-
continue
263-
}
264-
265250
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
266251
}
267252

dial_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ func Test_verifyServerHandshake(t *testing.T) {
221221
opts := &DialOptions{
222222
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
223223
}
224-
_, err = verifyServerResponse(opts, key, resp)
224+
_, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
225225
if tc.success {
226226
assert.Success(t, err)
227227
} else {

0 commit comments

Comments
 (0)