@@ -82,7 +82,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
82
82
return nil , nil , fmt .Errorf ("failed to generate Sec-WebSocket-Key: %w" , err )
83
83
}
84
84
85
- resp , err := handshakeRequest (ctx , urls , opts , secWebSocketKey )
85
+ var copts * compressionOptions
86
+ if opts .CompressionMode != CompressionDisabled {
87
+ copts = opts .CompressionMode .opts ()
88
+ }
89
+
90
+ resp , err := handshakeRequest (ctx , urls , opts , copts , secWebSocketKey )
86
91
if err != nil {
87
92
return nil , resp , err
88
93
}
@@ -104,7 +109,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
104
109
}
105
110
}()
106
111
107
- copts , err : = verifyServerResponse (opts , secWebSocketKey , resp )
112
+ copts , err = verifyServerResponse (opts , copts , secWebSocketKey , resp )
108
113
if err != nil {
109
114
return nil , resp , err
110
115
}
@@ -125,7 +130,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
125
130
}), resp , nil
126
131
}
127
132
128
- func handshakeRequest (ctx context.Context , urls string , opts * DialOptions , secWebSocketKey string ) (* http.Response , error ) {
133
+ func handshakeRequest (ctx context.Context , urls string , opts * DialOptions , copts * compressionOptions , secWebSocketKey string ) (* http.Response , error ) {
129
134
if opts .HTTPClient .Timeout > 0 {
130
135
return nil , errors .New ("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67" )
131
136
}
@@ -153,9 +158,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
153
158
if len (opts .Subprotocols ) > 0 {
154
159
req .Header .Set ("Sec-WebSocket-Protocol" , strings .Join (opts .Subprotocols , "," ))
155
160
}
156
- if opts .CompressionMode != CompressionDisabled {
157
- copts := opts .CompressionMode .opts ()
158
- copts .clientMaxWindowBits = 8
161
+ if copts != nil {
159
162
copts .setHeader (req .Header )
160
163
}
161
164
@@ -178,7 +181,7 @@ func secWebSocketKey(rr io.Reader) (string, error) {
178
181
return base64 .StdEncoding .EncodeToString (b ), nil
179
182
}
180
183
181
- func verifyServerResponse (opts * DialOptions , secWebSocketKey string , resp * http.Response ) (* compressionOptions , error ) {
184
+ func verifyServerResponse (opts * DialOptions , copts * compressionOptions , secWebSocketKey string , resp * http.Response ) (* compressionOptions , error ) {
182
185
if resp .StatusCode != http .StatusSwitchingProtocols {
183
186
return nil , fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
184
187
}
@@ -203,7 +206,7 @@ func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.
203
206
return nil , err
204
207
}
205
208
206
- return verifyServerExtensions (resp .Header )
209
+ return verifyServerExtensions (copts , resp .Header )
207
210
}
208
211
209
212
func verifySubprotocol (subprotos []string , resp * http.Response ) error {
@@ -221,19 +224,19 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
221
224
return fmt .Errorf ("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q" , proto )
222
225
}
223
226
224
- func verifyServerExtensions (h http.Header ) (* compressionOptions , error ) {
227
+ func verifyServerExtensions (copts * compressionOptions , h http.Header ) (* compressionOptions , error ) {
225
228
exts := websocketExtensions (h )
226
229
if len (exts ) == 0 {
227
230
return nil , nil
228
231
}
229
232
230
233
ext := exts [0 ]
231
- if ext .name != "permessage-deflate" || len (exts ) > 1 {
234
+ if ext .name != "permessage-deflate" || len (exts ) > 1 || copts == nil {
232
235
return nil , fmt .Errorf ("WebSocket protcol violation: unsupported extensions from server: %+v" , exts [1 :])
233
236
}
234
237
235
- copts : = & compressionOptions {}
236
- copts . clientMaxWindowBits = 8
238
+ copts = & * copts
239
+
237
240
for _ , p := range ext .params {
238
241
switch p {
239
242
case "client_no_context_takeover" :
@@ -244,24 +247,6 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
244
247
continue
245
248
}
246
249
247
- if false && strings .HasPrefix (p , "server_max_window_bits" ) {
248
- bits , ok := parseExtensionParameter (p , 0 )
249
- if ! ok || bits < 8 || bits > 16 {
250
- return nil , fmt .Errorf ("invalid server_max_window_bits: %q" , p )
251
- }
252
- copts .serverMaxWindowBits = bits
253
- continue
254
- }
255
-
256
- if false && strings .HasPrefix (p , "client_max_window_bits" ) {
257
- bits , ok := parseExtensionParameter (p , 0 )
258
- if ! ok || bits < 8 || bits > 16 {
259
- return nil , fmt .Errorf ("invalid client_max_window_bits: %q" , p )
260
- }
261
- copts .clientMaxWindowBits = 8
262
- continue
263
- }
264
-
265
250
return nil , fmt .Errorf ("unsupported permessage-deflate parameter: %q" , p )
266
251
}
267
252
0 commit comments