@@ -17,8 +17,6 @@ import (
17
17
"sync/atomic"
18
18
"time"
19
19
20
- "golang.org/x/xerrors"
21
-
22
20
"nhooyr.io/websocket/internal/bpool"
23
21
)
24
22
@@ -66,15 +64,13 @@ type Conn struct {
66
64
writeMsgOpcode opcode
67
65
writeMsgCtx context.Context
68
66
readMsgLeft int64
69
- readCloseFrame CloseError
70
67
71
68
// Used to ensure the previous reader is read till EOF before allowing
72
69
// a new one.
73
70
activeReader * messageReader
74
71
// readFrameLock is acquired to read from bw.
75
72
readFrameLock chan struct {}
76
73
isReadClosed * atomicInt64
77
- isCloseHandshake * atomicInt64
78
74
readHeaderBuf []byte
79
75
controlPayloadBuf []byte
80
76
@@ -102,7 +98,6 @@ func (c *Conn) init() {
102
98
c .writeFrameLock = make (chan struct {}, 1 )
103
99
104
100
c .readFrameLock = make (chan struct {}, 1 )
105
- c .isCloseHandshake = & atomicInt64 {}
106
101
107
102
c .setReadTimeout = make (chan context.Context )
108
103
c .setWriteTimeout = make (chan context.Context )
@@ -206,20 +201,20 @@ func (c *Conn) releaseLock(lock chan struct{}) {
206
201
}
207
202
}
208
203
209
- func (c * Conn ) readTillMsg (ctx context.Context ) (header , error ) {
204
+ func (c * Conn ) readTillMsg (ctx context.Context , lock bool ) (header , error ) {
210
205
for {
211
- h , err := c .readFrameHeader (ctx )
206
+ h , err := c .readFrameHeader (ctx , lock )
212
207
if err != nil {
213
208
return header {}, err
214
209
}
215
210
216
211
if h .rsv1 || h .rsv2 || h .rsv3 {
217
- c .Close (StatusProtocolError , fmt .Sprintf ("received header with rsv bits set: %v:%v:%v" , h .rsv1 , h .rsv2 , h .rsv3 ))
212
+ c .writeClose (StatusProtocolError , fmt .Sprintf ("received header with rsv bits set: %v:%v:%v" , h .rsv1 , h .rsv2 , h .rsv3 ), false )
218
213
return header {}, c .closeErr
219
214
}
220
215
221
216
if h .opcode .controlOp () {
222
- err = c .handleControl (ctx , h )
217
+ err = c .handleControl (ctx , h , lock )
223
218
if err != nil {
224
219
return header {}, fmt .Errorf ("failed to handle control frame: %w" , err )
225
220
}
@@ -230,18 +225,20 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
230
225
case opBinary , opText , opContinuation :
231
226
return h , nil
232
227
default :
233
- c .Close (StatusProtocolError , fmt .Sprintf ("received unknown opcode %v" , h .opcode ))
228
+ c .writeClose (StatusProtocolError , fmt .Sprintf ("received unknown opcode %v" , h .opcode ), false )
234
229
return header {}, c .closeErr
235
230
}
236
231
}
237
232
}
238
233
239
- func (c * Conn ) readFrameHeader (ctx context.Context ) (header , error ) {
240
- err := c .acquireLock (ctx , c .readFrameLock )
241
- if err != nil {
242
- return header {}, err
234
+ func (c * Conn ) readFrameHeader (ctx context.Context , lock bool ) (header , error ) {
235
+ if lock {
236
+ err := c .acquireLock (ctx , c .readFrameLock )
237
+ if err != nil {
238
+ return header {}, err
239
+ }
240
+ defer c .releaseLock (c .readFrameLock )
243
241
}
244
- defer c .releaseLock (c .readFrameLock )
245
242
246
243
select {
247
244
case <- c .closed :
@@ -273,22 +270,22 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
273
270
return h , nil
274
271
}
275
272
276
- func (c * Conn ) handleControl (ctx context.Context , h header ) error {
273
+ func (c * Conn ) handleControl (ctx context.Context , h header , lock bool ) error {
277
274
if h .payloadLength > maxControlFramePayload {
278
- c .Close (StatusProtocolError , fmt .Sprintf ("control frame too large at %v bytes" , h .payloadLength ))
275
+ c .writeClose (StatusProtocolError , fmt .Sprintf ("control frame too large at %v bytes" , h .payloadLength ), false )
279
276
return c .closeErr
280
277
}
281
278
282
279
if ! h .fin {
283
- c .Close (StatusProtocolError , "received fragmented control frame" )
280
+ c .writeClose (StatusProtocolError , "received fragmented control frame" , false )
284
281
return c .closeErr
285
282
}
286
283
287
284
ctx , cancel := context .WithTimeout (ctx , time .Second * 5 )
288
285
defer cancel ()
289
286
290
287
b := c .controlPayloadBuf [:h .payloadLength ]
291
- _ , err := c .readFramePayload (ctx , b )
288
+ _ , err := c .readFramePayload (ctx , b , lock )
292
289
if err != nil {
293
290
return err
294
291
}
@@ -312,23 +309,24 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
312
309
ce , err := parseClosePayload (b )
313
310
if err != nil {
314
311
err = fmt .Errorf ("received invalid close payload: %w" , err )
315
- c .Close (StatusProtocolError , err .Error ())
312
+ c .writeClose (StatusProtocolError , err .Error (), false )
316
313
return c .closeErr
317
314
}
318
315
319
316
// This ensures the closeErr of the Conn is always the received CloseError
320
317
// in case the echo close frame write fails.
321
318
// See https://github.com/nhooyr/websocket/issues/109
322
- c .setCloseErr (fmt .Errorf ("received close frame: %w" , ce ))
323
-
324
- c .readCloseFrame = ce
319
+ c .setCloseErr (ce )
325
320
326
321
func () {
327
322
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
328
323
defer cancel ()
329
324
c .writeControl (ctx , opClose , b )
330
325
}()
331
326
327
+ if ! lock {
328
+ c .releaseLock (c .readFrameLock )
329
+ }
332
330
// We close with nil since the error is already set above.
333
331
c .close (nil )
334
332
return c .closeErr
@@ -362,16 +360,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
362
360
// Most users should not need this.
363
361
func (c * Conn ) Reader (ctx context.Context ) (MessageType , io.Reader , error ) {
364
362
if c .isReadClosed .Load () == 1 {
365
- return 0 , nil , fmt .Errorf ("websocket connection read closed" )
366
- }
367
-
368
- if c .isCloseHandshake .Load () == 1 {
369
- select {
370
- case <- ctx .Done ():
371
- return 0 , nil , fmt .Errorf ("failed to get reader: %w" , ctx .Err ())
372
- case <- c .closed :
373
- return 0 , nil , fmt .Errorf ("failed to get reader: %w" , c .closeErr )
374
- }
363
+ return 0 , nil , errors .New ("websocket connection read closed" )
375
364
}
376
365
377
366
typ , r , err := c .reader (ctx )
@@ -381,23 +370,23 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
381
370
return typ , r , nil
382
371
}
383
372
384
- func (c * Conn ) reader (ctx context.Context ) (MessageType , io.Reader , error ) {
373
+ func (c * Conn ) reader (ctx context.Context ) (_ MessageType , _ io.Reader , err error ) {
385
374
if c .activeReader != nil && ! c .readerFrameEOF {
386
375
// The only way we know for sure the previous reader is not yet complete is
387
376
// if there is an active frame not yet fully read.
388
377
// Otherwise, a user may have read the last byte but not the EOF if the EOF
389
378
// is in the next frame so we check for that below.
390
- return 0 , nil , fmt . Errorf ("previous message not read to completion" )
379
+ return 0 , nil , errors . New ("previous message not read to completion" )
391
380
}
392
381
393
- h , err := c .readTillMsg (ctx )
382
+ h , err := c .readTillMsg (ctx , true )
394
383
if err != nil {
395
384
return 0 , nil , err
396
385
}
397
386
398
387
if c .activeReader != nil && ! c .activeReader .eof () {
399
388
if h .opcode != opContinuation {
400
- c .Close (StatusProtocolError , "received new data message without finishing the previous message" )
389
+ c .writeClose (StatusProtocolError , "received new data message without finishing the previous message" , false )
401
390
return 0 , nil , c .closeErr
402
391
}
403
392
@@ -407,12 +396,12 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
407
396
408
397
c .activeReader = nil
409
398
410
- h , err = c .readTillMsg (ctx )
399
+ h , err = c .readTillMsg (ctx , true )
411
400
if err != nil {
412
401
return 0 , nil , err
413
402
}
414
403
} else if h .opcode == opContinuation {
415
- c .Close (StatusProtocolError , "received continuation frame not after data or text frame" )
404
+ c .writeClose (StatusProtocolError , "received continuation frame not after data or text frame" , false )
416
405
return 0 , nil , c .closeErr
417
406
}
418
407
@@ -458,7 +447,7 @@ func (r *messageReader) read(p []byte) (int, error) {
458
447
}
459
448
460
449
if r .c .readMsgLeft <= 0 {
461
- r .c .Close (StatusMessageTooBig , fmt .Sprintf ("read limited at %v bytes" , r .c .msgReadLimit ))
450
+ r .c .writeClose (StatusMessageTooBig , fmt .Sprintf ("read limited at %v bytes" , r .c .msgReadLimit ), false )
462
451
return 0 , r .c .closeErr
463
452
}
464
453
@@ -467,13 +456,13 @@ func (r *messageReader) read(p []byte) (int, error) {
467
456
}
468
457
469
458
if r .c .readerFrameEOF {
470
- h , err := r .c .readTillMsg (r .c .readerMsgCtx )
459
+ h , err := r .c .readTillMsg (r .c .readerMsgCtx , true )
471
460
if err != nil {
472
461
return 0 , err
473
462
}
474
463
475
464
if h .opcode != opContinuation {
476
- r .c .Close (StatusProtocolError , "received new data message without finishing the previous message" )
465
+ r .c .writeClose (StatusProtocolError , "received new data message without finishing the previous message" , false )
477
466
return 0 , r .c .closeErr
478
467
}
479
468
@@ -487,7 +476,7 @@ func (r *messageReader) read(p []byte) (int, error) {
487
476
p = p [:h .payloadLength ]
488
477
}
489
478
490
- n , err := r .c .readFramePayload (r .c .readerMsgCtx , p )
479
+ n , err := r .c .readFramePayload (r .c .readerMsgCtx , p , true )
491
480
492
481
h .payloadLength -= int64 (n )
493
482
r .c .readMsgLeft -= int64 (n )
@@ -512,12 +501,14 @@ func (r *messageReader) read(p []byte) (int, error) {
512
501
return n , nil
513
502
}
514
503
515
- func (c * Conn ) readFramePayload (ctx context.Context , p []byte ) (int , error ) {
516
- err := c .acquireLock (ctx , c .readFrameLock )
517
- if err != nil {
518
- return 0 , err
504
+ func (c * Conn ) readFramePayload (ctx context.Context , p []byte , lock bool ) (int , error ) {
505
+ if lock {
506
+ err := c .acquireLock (ctx , c .readFrameLock )
507
+ if err != nil {
508
+ return 0 , err
509
+ }
510
+ defer c .releaseLock (c .readFrameLock )
519
511
}
520
- defer c .releaseLock (c .readFrameLock )
521
512
522
513
select {
523
514
case <- c .closed :
@@ -813,14 +804,14 @@ func (c *Conn) writePong(p []byte) error {
813
804
// Close will unblock all goroutines interacting with the connection once
814
805
// complete.
815
806
func (c * Conn ) Close (code StatusCode , reason string ) error {
816
- err := c .closeHandshake (code , reason )
807
+ err := c .writeClose (code , reason , true )
817
808
if err != nil {
818
809
return fmt .Errorf ("failed to close websocket connection: %w" , err )
819
810
}
820
811
return nil
821
812
}
822
813
823
- func (c * Conn ) closeHandshake (code StatusCode , reason string ) error {
814
+ func (c * Conn ) writeClose (code StatusCode , reason string , handshake bool ) error {
824
815
ce := CloseError {
825
816
Code : code ,
826
817
Reason : reason ,
@@ -838,60 +829,58 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) error {
838
829
p , _ = ce .bytes ()
839
830
}
840
831
832
+ // Give the handshake 10 seconds.
841
833
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 10 )
842
834
defer cancel ()
843
835
844
- // Ensures the connection is closed if everything below succeeds.
845
- // Up here because we must release the read lock first.
846
- // nil because of the setCloseErr call below.
847
- defer c .close (nil )
848
-
849
- // CloseErrors sent are made opaque to prevent applications from thinking
850
- // they received a given status.
851
- sentErr := fmt .Errorf ("sent close frame: %v" , ce )
852
- // Other connections should only see this error.
853
- c .setCloseErr (sentErr )
854
-
855
836
err = c .writeControl (ctx , opClose , p )
856
837
if err != nil {
857
838
return err
858
839
}
840
+ c .setCloseErr (ce )
841
+ defer c .close (nil )
859
842
860
- // Wait for close frame from peer.
861
- err = c .waitClose (ctx )
862
- // We didn't read a close frame.
863
- if c .readCloseFrame == (CloseError {}) {
864
- if ctx .Err () != nil {
865
- return xerrors .Errorf ("failed to wait for peer close frame: %w" , ctx .Err ())
866
- }
867
- // We need to make the err returned from c.waitClose accurate.
868
- return xerrors .Errorf ("failed to read peer close frame for unknown reason" )
843
+ if handshake {
844
+ // Try to wait for close frame peer but don't complain
845
+ // if one is not received since we already decided the
846
+ // close status of the connection above.
847
+ c .waitClose (ctx )
869
848
}
849
+
870
850
return nil
871
851
}
872
852
873
853
func (c * Conn ) waitClose (ctx context.Context ) error {
854
+ err := c .acquireLock (ctx , c .readFrameLock )
855
+ if err != nil {
856
+ return err
857
+ }
858
+ defer c .releaseLock (c .readFrameLock )
859
+
874
860
b := bpool .Get ()
875
- buf := b .Bytes ()
876
- buf = buf [:cap (buf )]
877
861
defer bpool .Put (b )
878
862
879
- // Prevent reads from user code as we are going to be
880
- // discarding all messages so they cannot rely on any ordering.
881
- c .isCloseHandshake .Store (1 )
882
-
883
- // From this point forward, any reader we receive means we are
884
- // now the sole readers of the connection and so it is safe
885
- // to discard all payloads.
863
+ var h header
864
+ if c .activeReader != nil && ! c .readerFrameEOF {
865
+ h = c .readerMsgHeader
866
+ }
886
867
887
868
for {
888
- _ , r , err := c .reader (ctx )
889
- if err != nil {
890
- return err
869
+ for h .payloadLength > 0 {
870
+ buf := b .Bytes ()
871
+ if int64 (cap (buf )) > h .payloadLength {
872
+ buf = buf [:h .payloadLength ]
873
+ } else {
874
+ buf = buf [:cap (buf )]
875
+ }
876
+ n , err := c .readFramePayload (ctx , buf , false )
877
+ if err != nil {
878
+ return err
879
+ }
880
+ h .payloadLength -= int64 (n )
891
881
}
892
882
893
- // Discard all payloads.
894
- _ , err = io .CopyBuffer (ioutil .Discard , r , buf )
883
+ h , err = c .readTillMsg (ctx , false )
895
884
if err != nil {
896
885
return err
897
886
}
0 commit comments