Skip to content

Commit 224ef23

Browse files
committed
Cleanup close handshake implementation
1 parent a5af693 commit 224ef23

File tree

4 files changed

+92
-124
lines changed

4 files changed

+92
-124
lines changed

conn.go

+73-84
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ import (
1717
"sync/atomic"
1818
"time"
1919

20-
"golang.org/x/xerrors"
21-
2220
"nhooyr.io/websocket/internal/bpool"
2321
)
2422

@@ -66,15 +64,13 @@ type Conn struct {
6664
writeMsgOpcode opcode
6765
writeMsgCtx context.Context
6866
readMsgLeft int64
69-
readCloseFrame CloseError
7067

7168
// Used to ensure the previous reader is read till EOF before allowing
7269
// a new one.
7370
activeReader *messageReader
7471
// readFrameLock is acquired to read from bw.
7572
readFrameLock chan struct{}
7673
isReadClosed *atomicInt64
77-
isCloseHandshake *atomicInt64
7874
readHeaderBuf []byte
7975
controlPayloadBuf []byte
8076

@@ -102,7 +98,6 @@ func (c *Conn) init() {
10298
c.writeFrameLock = make(chan struct{}, 1)
10399

104100
c.readFrameLock = make(chan struct{}, 1)
105-
c.isCloseHandshake = &atomicInt64{}
106101

107102
c.setReadTimeout = make(chan context.Context)
108103
c.setWriteTimeout = make(chan context.Context)
@@ -206,20 +201,20 @@ func (c *Conn) releaseLock(lock chan struct{}) {
206201
}
207202
}
208203

209-
func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
204+
func (c *Conn) readTillMsg(ctx context.Context, lock bool) (header, error) {
210205
for {
211-
h, err := c.readFrameHeader(ctx)
206+
h, err := c.readFrameHeader(ctx, lock)
212207
if err != nil {
213208
return header{}, err
214209
}
215210

216211
if h.rsv1 || h.rsv2 || h.rsv3 {
217-
c.Close(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3))
212+
c.writeClose(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3), false)
218213
return header{}, c.closeErr
219214
}
220215

221216
if h.opcode.controlOp() {
222-
err = c.handleControl(ctx, h)
217+
err = c.handleControl(ctx, h, lock)
223218
if err != nil {
224219
return header{}, fmt.Errorf("failed to handle control frame: %w", err)
225220
}
@@ -230,18 +225,20 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
230225
case opBinary, opText, opContinuation:
231226
return h, nil
232227
default:
233-
c.Close(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode))
228+
c.writeClose(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode), false)
234229
return header{}, c.closeErr
235230
}
236231
}
237232
}
238233

239-
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
240-
err := c.acquireLock(ctx, c.readFrameLock)
241-
if err != nil {
242-
return header{}, err
234+
func (c *Conn) readFrameHeader(ctx context.Context, lock bool) (header, error) {
235+
if lock {
236+
err := c.acquireLock(ctx, c.readFrameLock)
237+
if err != nil {
238+
return header{}, err
239+
}
240+
defer c.releaseLock(c.readFrameLock)
243241
}
244-
defer c.releaseLock(c.readFrameLock)
245242

246243
select {
247244
case <-c.closed:
@@ -273,22 +270,22 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
273270
return h, nil
274271
}
275272

276-
func (c *Conn) handleControl(ctx context.Context, h header) error {
273+
func (c *Conn) handleControl(ctx context.Context, h header, lock bool) error {
277274
if h.payloadLength > maxControlFramePayload {
278-
c.Close(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength))
275+
c.writeClose(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength), false)
279276
return c.closeErr
280277
}
281278

282279
if !h.fin {
283-
c.Close(StatusProtocolError, "received fragmented control frame")
280+
c.writeClose(StatusProtocolError, "received fragmented control frame", false)
284281
return c.closeErr
285282
}
286283

287284
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
288285
defer cancel()
289286

290287
b := c.controlPayloadBuf[:h.payloadLength]
291-
_, err := c.readFramePayload(ctx, b)
288+
_, err := c.readFramePayload(ctx, b, lock)
292289
if err != nil {
293290
return err
294291
}
@@ -312,23 +309,24 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
312309
ce, err := parseClosePayload(b)
313310
if err != nil {
314311
err = fmt.Errorf("received invalid close payload: %w", err)
315-
c.Close(StatusProtocolError, err.Error())
312+
c.writeClose(StatusProtocolError, err.Error(), false)
316313
return c.closeErr
317314
}
318315

319316
// This ensures the closeErr of the Conn is always the received CloseError
320317
// in case the echo close frame write fails.
321318
// See https://github.com/nhooyr/websocket/issues/109
322-
c.setCloseErr(fmt.Errorf("received close frame: %w", ce))
323-
324-
c.readCloseFrame = ce
319+
c.setCloseErr(ce)
325320

326321
func() {
327322
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
328323
defer cancel()
329324
c.writeControl(ctx, opClose, b)
330325
}()
331326

327+
if !lock {
328+
c.releaseLock(c.readFrameLock)
329+
}
332330
// We close with nil since the error is already set above.
333331
c.close(nil)
334332
return c.closeErr
@@ -362,16 +360,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
362360
// Most users should not need this.
363361
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
364362
if c.isReadClosed.Load() == 1 {
365-
return 0, nil, fmt.Errorf("websocket connection read closed")
366-
}
367-
368-
if c.isCloseHandshake.Load() == 1 {
369-
select {
370-
case <-ctx.Done():
371-
return 0, nil, fmt.Errorf("failed to get reader: %w", ctx.Err())
372-
case <-c.closed:
373-
return 0, nil, fmt.Errorf("failed to get reader: %w", c.closeErr)
374-
}
363+
return 0, nil, errors.New("websocket connection read closed")
375364
}
376365

377366
typ, r, err := c.reader(ctx)
@@ -381,23 +370,23 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
381370
return typ, r, nil
382371
}
383372

384-
func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
373+
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
385374
if c.activeReader != nil && !c.readerFrameEOF {
386375
// The only way we know for sure the previous reader is not yet complete is
387376
// if there is an active frame not yet fully read.
388377
// Otherwise, a user may have read the last byte but not the EOF if the EOF
389378
// is in the next frame so we check for that below.
390-
return 0, nil, fmt.Errorf("previous message not read to completion")
379+
return 0, nil, errors.New("previous message not read to completion")
391380
}
392381

393-
h, err := c.readTillMsg(ctx)
382+
h, err := c.readTillMsg(ctx, true)
394383
if err != nil {
395384
return 0, nil, err
396385
}
397386

398387
if c.activeReader != nil && !c.activeReader.eof() {
399388
if h.opcode != opContinuation {
400-
c.Close(StatusProtocolError, "received new data message without finishing the previous message")
389+
c.writeClose(StatusProtocolError, "received new data message without finishing the previous message", false)
401390
return 0, nil, c.closeErr
402391
}
403392

@@ -407,12 +396,12 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
407396

408397
c.activeReader = nil
409398

410-
h, err = c.readTillMsg(ctx)
399+
h, err = c.readTillMsg(ctx, true)
411400
if err != nil {
412401
return 0, nil, err
413402
}
414403
} else if h.opcode == opContinuation {
415-
c.Close(StatusProtocolError, "received continuation frame not after data or text frame")
404+
c.writeClose(StatusProtocolError, "received continuation frame not after data or text frame", false)
416405
return 0, nil, c.closeErr
417406
}
418407

@@ -458,7 +447,7 @@ func (r *messageReader) read(p []byte) (int, error) {
458447
}
459448

460449
if r.c.readMsgLeft <= 0 {
461-
r.c.Close(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit))
450+
r.c.writeClose(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit), false)
462451
return 0, r.c.closeErr
463452
}
464453

@@ -467,13 +456,13 @@ func (r *messageReader) read(p []byte) (int, error) {
467456
}
468457

469458
if r.c.readerFrameEOF {
470-
h, err := r.c.readTillMsg(r.c.readerMsgCtx)
459+
h, err := r.c.readTillMsg(r.c.readerMsgCtx, true)
471460
if err != nil {
472461
return 0, err
473462
}
474463

475464
if h.opcode != opContinuation {
476-
r.c.Close(StatusProtocolError, "received new data message without finishing the previous message")
465+
r.c.writeClose(StatusProtocolError, "received new data message without finishing the previous message", false)
477466
return 0, r.c.closeErr
478467
}
479468

@@ -487,7 +476,7 @@ func (r *messageReader) read(p []byte) (int, error) {
487476
p = p[:h.payloadLength]
488477
}
489478

490-
n, err := r.c.readFramePayload(r.c.readerMsgCtx, p)
479+
n, err := r.c.readFramePayload(r.c.readerMsgCtx, p, true)
491480

492481
h.payloadLength -= int64(n)
493482
r.c.readMsgLeft -= int64(n)
@@ -512,12 +501,14 @@ func (r *messageReader) read(p []byte) (int, error) {
512501
return n, nil
513502
}
514503

515-
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
516-
err := c.acquireLock(ctx, c.readFrameLock)
517-
if err != nil {
518-
return 0, err
504+
func (c *Conn) readFramePayload(ctx context.Context, p []byte, lock bool) (int, error) {
505+
if lock {
506+
err := c.acquireLock(ctx, c.readFrameLock)
507+
if err != nil {
508+
return 0, err
509+
}
510+
defer c.releaseLock(c.readFrameLock)
519511
}
520-
defer c.releaseLock(c.readFrameLock)
521512

522513
select {
523514
case <-c.closed:
@@ -813,14 +804,14 @@ func (c *Conn) writePong(p []byte) error {
813804
// Close will unblock all goroutines interacting with the connection once
814805
// complete.
815806
func (c *Conn) Close(code StatusCode, reason string) error {
816-
err := c.closeHandshake(code, reason)
807+
err := c.writeClose(code, reason, true)
817808
if err != nil {
818809
return fmt.Errorf("failed to close websocket connection: %w", err)
819810
}
820811
return nil
821812
}
822813

823-
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
814+
func (c *Conn) writeClose(code StatusCode, reason string, handshake bool) error {
824815
ce := CloseError{
825816
Code: code,
826817
Reason: reason,
@@ -838,60 +829,58 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) error {
838829
p, _ = ce.bytes()
839830
}
840831

832+
// Give the handshake 10 seconds.
841833
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
842834
defer cancel()
843835

844-
// Ensures the connection is closed if everything below succeeds.
845-
// Up here because we must release the read lock first.
846-
// nil because of the setCloseErr call below.
847-
defer c.close(nil)
848-
849-
// CloseErrors sent are made opaque to prevent applications from thinking
850-
// they received a given status.
851-
sentErr := fmt.Errorf("sent close frame: %v", ce)
852-
// Other connections should only see this error.
853-
c.setCloseErr(sentErr)
854-
855836
err = c.writeControl(ctx, opClose, p)
856837
if err != nil {
857838
return err
858839
}
840+
c.setCloseErr(ce)
841+
defer c.close(nil)
859842

860-
// Wait for close frame from peer.
861-
err = c.waitClose(ctx)
862-
// We didn't read a close frame.
863-
if c.readCloseFrame == (CloseError{}) {
864-
if ctx.Err() != nil {
865-
return xerrors.Errorf("failed to wait for peer close frame: %w", ctx.Err())
866-
}
867-
// We need to make the err returned from c.waitClose accurate.
868-
return xerrors.Errorf("failed to read peer close frame for unknown reason")
843+
if handshake {
844+
// Try to wait for close frame peer but don't complain
845+
// if one is not received since we already decided the
846+
// close status of the connection above.
847+
c.waitClose(ctx)
869848
}
849+
870850
return nil
871851
}
872852

873853
func (c *Conn) waitClose(ctx context.Context) error {
854+
err := c.acquireLock(ctx, c.readFrameLock)
855+
if err != nil {
856+
return err
857+
}
858+
defer c.releaseLock(c.readFrameLock)
859+
874860
b := bpool.Get()
875-
buf := b.Bytes()
876-
buf = buf[:cap(buf)]
877861
defer bpool.Put(b)
878862

879-
// Prevent reads from user code as we are going to be
880-
// discarding all messages so they cannot rely on any ordering.
881-
c.isCloseHandshake.Store(1)
882-
883-
// From this point forward, any reader we receive means we are
884-
// now the sole readers of the connection and so it is safe
885-
// to discard all payloads.
863+
var h header
864+
if c.activeReader != nil && !c.readerFrameEOF {
865+
h = c.readerMsgHeader
866+
}
886867

887868
for {
888-
_, r, err := c.reader(ctx)
889-
if err != nil {
890-
return err
869+
for h.payloadLength > 0 {
870+
buf := b.Bytes()
871+
if int64(cap(buf)) > h.payloadLength {
872+
buf = buf[:h.payloadLength]
873+
} else {
874+
buf = buf[:cap(buf)]
875+
}
876+
n, err := c.readFramePayload(ctx, buf, false)
877+
if err != nil {
878+
return err
879+
}
880+
h.payloadLength -= int64(n)
891881
}
892882

893-
// Discard all payloads.
894-
_, err = io.CopyBuffer(ioutil.Discard, r, buf)
883+
h, err = c.readTillMsg(ctx, false)
895884
if err != nil {
896885
return err
897886
}

conn_export_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ const (
2323
)
2424

2525
func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) {
26-
h, err := c.readFrameHeader(ctx)
26+
h, err := c.readFrameHeader(ctx, true)
2727
if err != nil {
2828
return 0, nil, err
2929
}
3030
b := make([]byte, h.payloadLength)
31-
_, err = c.readFramePayload(ctx, b)
31+
_, err = c.readFramePayload(ctx, b, true)
3232
if err != nil {
3333
return 0, nil, err
3434
}

go.mod

-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,5 @@ require (
2121
golang.org/x/sys v0.0.0-20190927073244-c990c680b611 // indirect
2222
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
2323
golang.org/x/tools v0.0.0-20190920225731-5eefd052ad72
24-
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
2524
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
2625
)

0 commit comments

Comments
 (0)