Skip to content

Commit 8bf8877

Browse files
authored
Merge pull request #165 from nhooyr/164
Fix concurrent read with close
2 parents 50dd426 + 4f014d2 commit 8bf8877

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

conn.go

+14-5
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@ type Conn struct {
4242
closer io.Closer
4343
client bool
4444

45-
closeOnce sync.Once
46-
closeErrOnce sync.Once
47-
closeErr error
48-
closed chan struct{}
49-
closing *atomicInt64
45+
closeOnce sync.Once
46+
closeErrOnce sync.Once
47+
closeErr error
48+
closed chan struct{}
49+
closing *atomicInt64
50+
closeReceived error
5051

5152
// messageWriter state.
5253
// writeMsgLock is acquired to write a data message.
@@ -339,10 +340,12 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
339340
if err != nil {
340341
err = fmt.Errorf("received invalid close payload: %w", err)
341342
c.exportedClose(StatusProtocolError, err.Error(), false)
343+
c.closeReceived = err
342344
return err
343345
}
344346

345347
err = fmt.Errorf("received close: %w", ce)
348+
c.closeReceived = err
346349
c.writeClose(b, err, false)
347350

348351
if ctx.Err() != nil {
@@ -941,6 +944,12 @@ func (c *Conn) waitClose() error {
941944
return err
942945
}
943946
defer c.releaseLock(c.readLock)
947+
948+
if c.closeReceived != nil {
949+
// goroutine reading just received the close.
950+
return c.closeReceived
951+
}
952+
944953
c.readerShouldLock = false
945954

946955
b := bpool.Get()

conn_test.go

+23
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,29 @@ func TestConn(t *testing.T) {
868868
return c.Close(websocket.StatusNormalClosure, "")
869869
},
870870
},
871+
{
872+
// Issue #164
873+
name: "closeHandshake_concurrentRead",
874+
server: func(ctx context.Context, c *websocket.Conn) error {
875+
_, _, err := c.Read(ctx)
876+
return assertCloseStatus(err, websocket.StatusNormalClosure)
877+
},
878+
client: func(ctx context.Context, c *websocket.Conn) error {
879+
errc := make(chan error, 1)
880+
go func() {
881+
_, _, err := c.Read(ctx)
882+
errc <- err
883+
}()
884+
885+
err := c.Close(websocket.StatusNormalClosure, "")
886+
if err != nil {
887+
return err
888+
}
889+
890+
err = <-errc
891+
return assertCloseStatus(err, websocket.StatusNormalClosure)
892+
},
893+
},
871894
}
872895
for _, tc := range testCases {
873896
tc := tc

0 commit comments

Comments
 (0)