Skip to content

Commit d432e6b

Browse files
committed
Implement complete close handshake
I changed my mind after #103 as browsers include a wasClean event to indicate whether the connection was closed cleanly. From my tests, if a server using this library prior to this commit initiates the close handshake, wasClean will be false for the browser as the connection was closed before it could respond with a close frame. Thus, I believe it's necessary to fully implement the close handshake. @stephenyama You'll enjoy this.
1 parent e795e46 commit d432e6b

7 files changed

+130
-58
lines changed

conn.go

+81-23
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ import (
1616
"sync"
1717
"sync/atomic"
1818
"time"
19+
20+
"golang.org/x/xerrors"
21+
22+
"nhooyr.io/websocket/internal/bpool"
1923
)
2024

2125
// Conn represents a WebSocket connection.
@@ -62,13 +66,15 @@ type Conn struct {
6266
writeMsgOpcode opcode
6367
writeMsgCtx context.Context
6468
readMsgLeft int64
69+
readCloseFrame CloseError
6570

6671
// Used to ensure the previous reader is read till EOF before allowing
6772
// a new one.
6873
activeReader *messageReader
6974
// readFrameLock is acquired to read from bw.
7075
readFrameLock chan struct{}
7176
isReadClosed *atomicInt64
77+
isCloseHandshake *atomicInt64
7278
readHeaderBuf []byte
7379
controlPayloadBuf []byte
7480

@@ -96,6 +102,7 @@ func (c *Conn) init() {
96102
c.writeFrameLock = make(chan struct{}, 1)
97103

98104
c.readFrameLock = make(chan struct{}, 1)
105+
c.isCloseHandshake = &atomicInt64{}
99106

100107
c.setReadTimeout = make(chan context.Context)
101108
c.setWriteTimeout = make(chan context.Context)
@@ -230,7 +237,7 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
230237
}
231238

232239
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
233-
err := c.acquireLock(context.Background(), c.readFrameLock)
240+
err := c.acquireLock(ctx, c.readFrameLock)
234241
if err != nil {
235242
return header{}, err
236243
}
@@ -308,11 +315,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
308315
c.Close(StatusProtocolError, err.Error())
309316
return c.closeErr
310317
}
318+
311319
// This ensures the closeErr of the Conn is always the received CloseError
312320
// in case the echo close frame write fails.
313321
// See https://github.com/nhooyr/websocket/issues/109
314322
c.setCloseErr(fmt.Errorf("received close frame: %w", ce))
315-
c.writeClose(b, nil)
323+
324+
c.readCloseFrame = ce
325+
326+
func() {
327+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
328+
defer cancel()
329+
c.writeControl(ctx, opClose, b)
330+
}()
331+
332+
// We close with nil since the error is already set above.
333+
c.close(nil)
316334
return c.closeErr
317335
default:
318336
panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
@@ -347,6 +365,15 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
347365
return 0, nil, fmt.Errorf("websocket connection read closed")
348366
}
349367

368+
if c.isCloseHandshake.Load() == 1 {
369+
select {
370+
case <-ctx.Done():
371+
return 0, nil, fmt.Errorf("failed to get reader: %w", ctx.Err())
372+
case <-c.closed:
373+
return 0, nil, fmt.Errorf("failed to get reader: %w", c.closeErr)
374+
}
375+
}
376+
350377
typ, r, err := c.reader(ctx)
351378
if err != nil {
352379
return 0, nil, fmt.Errorf("failed to get reader: %w", err)
@@ -772,27 +799,28 @@ func (c *Conn) writePong(p []byte) error {
772799

773800
// Close closes the WebSocket connection with the given status code and reason.
774801
//
775-
// It will write a WebSocket close frame with a timeout of 5 seconds.
802+
// It will write a WebSocket close frame and then wait for the peer to respond
803+
// with its own close frame. The entire process must complete within 10 seconds.
804+
// Thus, it implements the full WebSocket close handshake.
805+
//
776806
// The connection can only be closed once. Additional calls to Close
777807
// are no-ops.
778808
//
779-
// This does not perform a WebSocket close handshake.
780-
// See https://github.com/nhooyr/websocket/issues/103 for details on why.
781-
//
782809
// The maximum length of reason must be 125 bytes otherwise an internal
783810
// error will be sent to the peer. For this reason, you should avoid
784811
// sending a dynamic reason.
785812
//
786-
// Close will unblock all goroutines interacting with the connection.
813+
// Close will unblock all goroutines interacting with the connection once
814+
// complete.
787815
func (c *Conn) Close(code StatusCode, reason string) error {
788-
err := c.exportedClose(code, reason)
816+
err := c.closeHandshake(code, reason)
789817
if err != nil {
790818
return fmt.Errorf("failed to close websocket connection: %w", err)
791819
}
792820
return nil
793821
}
794822

795-
func (c *Conn) exportedClose(code StatusCode, reason string) error {
823+
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
796824
ce := CloseError{
797825
Code: code,
798826
Reason: reason,
@@ -810,34 +838,64 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
810838
p, _ = ce.bytes()
811839
}
812840

841+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
842+
defer cancel()
843+
844+
// Ensures the connection is closed if everything below succeeds.
845+
// Up here because we must release the read lock first.
846+
// nil because of the setCloseErr call below.
847+
defer c.close(nil)
848+
813849
// CloseErrors sent are made opaque to prevent applications from thinking
814850
// they received a given status.
815851
sentErr := fmt.Errorf("sent close frame: %v", ce)
816-
err = c.writeClose(p, sentErr)
852+
// Other connections should only see this error.
853+
c.setCloseErr(sentErr)
854+
855+
err = c.writeControl(ctx, opClose, p)
817856
if err != nil {
818857
return err
819858
}
820859

821-
if !errors.Is(c.closeErr, sentErr) {
822-
return c.closeErr
860+
// Wait for close frame from peer.
861+
err = c.waitClose(ctx)
862+
// We didn't read a close frame.
863+
if c.readCloseFrame == (CloseError{}) {
864+
if ctx.Err() != nil {
865+
return xerrors.Errorf("failed to wait for peer close frame: %w", ctx.Err())
866+
}
867+
// We need to make the err returned from c.waitClose accurate.
868+
return xerrors.Errorf("failed to read peer close frame for unknown reason")
823869
}
824-
825870
return nil
826871
}
827872

828-
func (c *Conn) writeClose(p []byte, cerr error) error {
829-
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
830-
defer cancel()
873+
func (c *Conn) waitClose(ctx context.Context) error {
874+
b := bpool.Get()
875+
buf := b.Bytes()
876+
buf = buf[:cap(buf)]
877+
defer bpool.Put(b)
831878

832-
// If this fails, the connection had to have died.
833-
err := c.writeControl(ctx, opClose, p)
834-
if err != nil {
835-
return err
836-
}
879+
// Prevent reads from user code as we are going to be
880+
// discarding all messages so they cannot rely on any ordering.
881+
c.isCloseHandshake.Store(1)
837882

838-
c.close(cerr)
883+
// From this point forward, any reader we receive means we are
884+
// now the sole readers of the connection and so it is safe
885+
// to discard all payloads.
839886

840-
return nil
887+
for {
888+
_, r, err := c.reader(ctx)
889+
if err != nil {
890+
return err
891+
}
892+
893+
// Discard all payloads.
894+
_, err = io.CopyBuffer(ioutil.Discard, r, buf)
895+
if err != nil {
896+
return err
897+
}
898+
}
841899
}
842900

843901
// Ping sends a ping to the peer and waits for a pong.

conn_common.go

+4
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,7 @@ func (v *atomicInt64) String() string {
230230
func (v *atomicInt64) Increment(delta int64) int64 {
231231
return atomic.AddInt64(&v.v, delta)
232232
}
233+
234+
func (v *atomicInt64) CAS(old, new int64) (swapped bool) {
235+
return atomic.CompareAndSwapInt64(&v.v, old, new)
236+
}

conn_test.go

+9
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,15 @@ func TestConn(t *testing.T) {
856856
return nil
857857
},
858858
},
859+
{
860+
name: "closeHandshake",
861+
server: func(ctx context.Context, c *websocket.Conn) error {
862+
return c.Close(websocket.StatusNormalClosure, "")
863+
},
864+
client: func(ctx context.Context, c *websocket.Conn) error {
865+
return c.Close(websocket.StatusNormalClosure, "")
866+
},
867+
},
859868
}
860869
for _, tc := range testCases {
861870
tc := tc

go.mod

+1
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ require (
2121
golang.org/x/sys v0.0.0-20190927073244-c990c680b611 // indirect
2222
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
2323
golang.org/x/tools v0.0.0-20190920225731-5eefd052ad72
24+
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
2425
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
2526
)

websocket_js.go

+35-13
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ type Conn struct {
3636
readBufMu sync.Mutex
3737
readBuf []wsjs.MessageEvent
3838

39-
// Only used by tests
40-
receivedCloseFrame chan struct{}
39+
closeEventCh chan wsjs.CloseEvent
4140
}
4241

4342
func (c *Conn) close(err error) {
@@ -58,10 +57,11 @@ func (c *Conn) init() {
5857

5958
c.isReadClosed = &atomicInt64{}
6059

61-
c.receivedCloseFrame = make(chan struct{})
60+
c.closeEventCh = make(chan wsjs.CloseEvent, 1)
6261

6362
c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
64-
close(c.receivedCloseFrame)
63+
c.closeEventCh <- e
64+
close(c.closeEventCh)
6565

6666
cerr := CloseError{
6767
Code: StatusCode(e.Code),
@@ -193,24 +193,46 @@ func (c *Conn) isClosed() bool {
193193
}
194194

195195
// Close closes the websocket with the given code and reason.
196+
// It will wait until the peer responds with a close frame
197+
// or the connection is closed.
198+
// It thus performs the full WebSocket close handshake.
196199
func (c *Conn) Close(code StatusCode, reason string) error {
200+
err := c.exportedClose(code, reason)
201+
if err != nil {
202+
return fmt.Errorf("failed to close websocket: %w", err)
203+
}
204+
return nil
205+
}
206+
207+
func (c *Conn) exportedClose(code StatusCode, reason string) error {
197208
if c.isClosed() {
198209
return fmt.Errorf("already closed: %w", c.closeErr)
199210
}
200211

201-
err := fmt.Errorf("sent close frame: %v", CloseError{
212+
cerr := CloseError{
202213
Code: code,
203214
Reason: reason,
204-
})
205-
206-
err2 := c.ws.Close(int(code), reason)
207-
if err2 != nil {
208-
err = err2
209215
}
210-
c.close(err)
216+
closeErr := fmt.Errorf("sent close frame: %v", cerr)
217+
c.close(closeErr)
218+
if !errors.Is(c.closeErr, closeErr) {
219+
return c.closeErr
220+
}
211221

212-
if !errors.Is(c.closeErr, err) {
213-
return fmt.Errorf("failed to close websocket: %w", err)
222+
// We're the only goroutine allowed to get this far.
223+
// The only possible error from closing the connection here
224+
// is that the connection is already closed in which case,
225+
// we do not really care.
226+
c.ws.Close(int(code), reason)
227+
228+
// Guaranteed for this channel receive to succeed since the above
229+
// if statement means we are the goroutine that closed this connection.
230+
ev := <-c.closeEventCh
231+
if !ev.WasClean {
232+
return fmt.Errorf("unclean connection close: %v", CloseError{
233+
Code: StatusCode(ev.Code),
234+
Reason: ev.Reason,
235+
})
214236
}
215237

216238
return nil

websocket_js_export_test.go

-17
This file was deleted.

websocket_js_test.go

-5
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,4 @@ func TestConn(t *testing.T) {
4949
if err != nil {
5050
t.Fatal(err)
5151
}
52-
53-
err = c.WaitCloseFrame(ctx)
54-
if err != nil {
55-
t.Fatal(err)
56-
}
5752
}

0 commit comments

Comments
 (0)