@@ -16,6 +16,10 @@ import (
16
16
"sync"
17
17
"sync/atomic"
18
18
"time"
19
+
20
+ "golang.org/x/xerrors"
21
+
22
+ "nhooyr.io/websocket/internal/bpool"
19
23
)
20
24
21
25
// Conn represents a WebSocket connection.
@@ -62,13 +66,15 @@ type Conn struct {
62
66
writeMsgOpcode opcode
63
67
writeMsgCtx context.Context
64
68
readMsgLeft int64
69
+ readCloseFrame CloseError
65
70
66
71
// Used to ensure the previous reader is read till EOF before allowing
67
72
// a new one.
68
73
activeReader * messageReader
69
74
// readFrameLock is acquired to read from bw.
70
75
readFrameLock chan struct {}
71
76
isReadClosed * atomicInt64
77
+ isCloseHandshake * atomicInt64
72
78
readHeaderBuf []byte
73
79
controlPayloadBuf []byte
74
80
@@ -96,6 +102,7 @@ func (c *Conn) init() {
96
102
c .writeFrameLock = make (chan struct {}, 1 )
97
103
98
104
c .readFrameLock = make (chan struct {}, 1 )
105
+ c .isCloseHandshake = & atomicInt64 {}
99
106
100
107
c .setReadTimeout = make (chan context.Context )
101
108
c .setWriteTimeout = make (chan context.Context )
@@ -230,7 +237,7 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
230
237
}
231
238
232
239
func (c * Conn ) readFrameHeader (ctx context.Context ) (header , error ) {
233
- err := c .acquireLock (context . Background () , c .readFrameLock )
240
+ err := c .acquireLock (ctx , c .readFrameLock )
234
241
if err != nil {
235
242
return header {}, err
236
243
}
@@ -308,11 +315,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
308
315
c .Close (StatusProtocolError , err .Error ())
309
316
return c .closeErr
310
317
}
318
+
311
319
// This ensures the closeErr of the Conn is always the received CloseError
312
320
// in case the echo close frame write fails.
313
321
// See https://github.com/nhooyr/websocket/issues/109
314
322
c .setCloseErr (fmt .Errorf ("received close frame: %w" , ce ))
315
- c .writeClose (b , nil )
323
+
324
+ c .readCloseFrame = ce
325
+
326
+ func () {
327
+ ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
328
+ defer cancel ()
329
+ c .writeControl (ctx , opClose , b )
330
+ }()
331
+
332
+ // We close with nil since the error is already set above.
333
+ c .close (nil )
316
334
return c .closeErr
317
335
default :
318
336
panic (fmt .Sprintf ("websocket: unexpected control opcode: %#v" , h ))
@@ -347,6 +365,15 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
347
365
return 0 , nil , fmt .Errorf ("websocket connection read closed" )
348
366
}
349
367
368
+ if c .isCloseHandshake .Load () == 1 {
369
+ select {
370
+ case <- ctx .Done ():
371
+ return 0 , nil , fmt .Errorf ("failed to get reader: %w" , ctx .Err ())
372
+ case <- c .closed :
373
+ return 0 , nil , fmt .Errorf ("failed to get reader: %w" , c .closeErr )
374
+ }
375
+ }
376
+
350
377
typ , r , err := c .reader (ctx )
351
378
if err != nil {
352
379
return 0 , nil , fmt .Errorf ("failed to get reader: %w" , err )
@@ -772,27 +799,28 @@ func (c *Conn) writePong(p []byte) error {
772
799
773
800
// Close closes the WebSocket connection with the given status code and reason.
774
801
//
775
- // It will write a WebSocket close frame with a timeout of 5 seconds.
802
+ // It will write a WebSocket close frame and then wait for the peer to respond
803
+ // with its own close frame. The entire process must complete within 10 seconds.
804
+ // Thus, it implements the full WebSocket close handshake.
805
+ //
776
806
// The connection can only be closed once. Additional calls to Close
777
807
// are no-ops.
778
808
//
779
- // This does not perform a WebSocket close handshake.
780
- // See https://github.com/nhooyr/websocket/issues/103 for details on why.
781
- //
782
809
// The maximum length of reason must be 125 bytes otherwise an internal
783
810
// error will be sent to the peer. For this reason, you should avoid
784
811
// sending a dynamic reason.
785
812
//
786
- // Close will unblock all goroutines interacting with the connection.
813
+ // Close will unblock all goroutines interacting with the connection once
814
+ // complete.
787
815
func (c * Conn ) Close (code StatusCode , reason string ) error {
788
- err := c .exportedClose (code , reason )
816
+ err := c .closeHandshake (code , reason )
789
817
if err != nil {
790
818
return fmt .Errorf ("failed to close websocket connection: %w" , err )
791
819
}
792
820
return nil
793
821
}
794
822
795
- func (c * Conn ) exportedClose (code StatusCode , reason string ) error {
823
+ func (c * Conn ) closeHandshake (code StatusCode , reason string ) error {
796
824
ce := CloseError {
797
825
Code : code ,
798
826
Reason : reason ,
@@ -810,34 +838,64 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
810
838
p , _ = ce .bytes ()
811
839
}
812
840
841
+ ctx , cancel := context .WithTimeout (context .Background (), time .Second * 10 )
842
+ defer cancel ()
843
+
844
+ // Ensures the connection is closed if everything below succeeds.
845
+ // Up here because we must release the read lock first.
846
+ // nil because of the setCloseErr call below.
847
+ defer c .close (nil )
848
+
813
849
// CloseErrors sent are made opaque to prevent applications from thinking
814
850
// they received a given status.
815
851
sentErr := fmt .Errorf ("sent close frame: %v" , ce )
816
- err = c .writeClose (p , sentErr )
852
+ // Other connections should only see this error.
853
+ c .setCloseErr (sentErr )
854
+
855
+ err = c .writeControl (ctx , opClose , p )
817
856
if err != nil {
818
857
return err
819
858
}
820
859
821
- if ! errors .Is (c .closeErr , sentErr ) {
822
- return c .closeErr
860
+ // Wait for close frame from peer.
861
+ err = c .waitClose (ctx )
862
+ // We didn't read a close frame.
863
+ if c .readCloseFrame == (CloseError {}) {
864
+ if ctx .Err () != nil {
865
+ return xerrors .Errorf ("failed to wait for peer close frame: %w" , ctx .Err ())
866
+ }
867
+ // We need to make the err returned from c.waitClose accurate.
868
+ return xerrors .Errorf ("failed to read peer close frame for unknown reason" )
823
869
}
824
-
825
870
return nil
826
871
}
827
872
828
- func (c * Conn ) writeClose (p []byte , cerr error ) error {
829
- ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
830
- defer cancel ()
873
+ func (c * Conn ) waitClose (ctx context.Context ) error {
874
+ b := bpool .Get ()
875
+ buf := b .Bytes ()
876
+ buf = buf [:cap (buf )]
877
+ defer bpool .Put (b )
831
878
832
- // If this fails, the connection had to have died.
833
- err := c .writeControl (ctx , opClose , p )
834
- if err != nil {
835
- return err
836
- }
879
+ // Prevent reads from user code as we are going to be
880
+ // discarding all messages so they cannot rely on any ordering.
881
+ c .isCloseHandshake .Store (1 )
837
882
838
- c .close (cerr )
883
+ // From this point forward, any reader we receive means we are
884
+ // now the sole readers of the connection and so it is safe
885
+ // to discard all payloads.
839
886
840
- return nil
887
+ for {
888
+ _ , r , err := c .reader (ctx )
889
+ if err != nil {
890
+ return err
891
+ }
892
+
893
+ // Discard all payloads.
894
+ _ , err = io .CopyBuffer (ioutil .Discard , r , buf )
895
+ if err != nil {
896
+ return err
897
+ }
898
+ }
841
899
}
842
900
843
901
// Ping sends a ping to the peer and waits for a pong.
0 commit comments