Skip to content

Commit c781bdf

Browse files
authored
Merge pull request #171 from nhooyr/fast-xor
Optimize masking with math/bits
2 parents 0fc34f9 + 15d0a18 commit c781bdf

File tree

4 files changed

+129
-111
lines changed

4 files changed

+129
-111
lines changed

conn.go

+9-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"bufio"
77
"context"
88
"crypto/rand"
9+
"encoding/binary"
910
"errors"
1011
"fmt"
1112
"io"
@@ -81,7 +82,7 @@ type Conn struct {
8182
readerMsgCtx context.Context
8283
readerMsgHeader header
8384
readerFrameEOF bool
84-
readerMaskPos int
85+
readerMaskKey uint32
8586

8687
setReadTimeout chan context.Context
8788
setWriteTimeout chan context.Context
@@ -324,7 +325,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
324325
}
325326

326327
if h.masked {
327-
fastXOR(h.maskKey, 0, b)
328+
mask(h.maskKey, b)
328329
}
329330

330331
switch h.opcode {
@@ -446,7 +447,7 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
446447
c.readerMsgCtx = ctx
447448
c.readerMsgHeader = h
448449
c.readerFrameEOF = false
449-
c.readerMaskPos = 0
450+
c.readerMaskKey = h.maskKey
450451
c.readMsgLeft = c.msgReadLimit.Load()
451452

452453
r := &messageReader{
@@ -532,7 +533,7 @@ func (r *messageReader) read(p []byte, lock bool) (int, error) {
532533

533534
r.c.readerMsgHeader = h
534535
r.c.readerFrameEOF = false
535-
r.c.readerMaskPos = 0
536+
r.c.readerMaskKey = h.maskKey
536537
}
537538

538539
h := r.c.readerMsgHeader
@@ -545,7 +546,7 @@ func (r *messageReader) read(p []byte, lock bool) (int, error) {
545546
h.payloadLength -= int64(n)
546547
r.c.readMsgLeft -= int64(n)
547548
if h.masked {
548-
r.c.readerMaskPos = fastXOR(h.maskKey, r.c.readerMaskPos, p)
549+
r.c.readerMaskKey = mask(r.c.readerMaskKey, p)
549550
}
550551
r.c.readerMsgHeader = h
551552

@@ -761,7 +762,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
761762
c.writeHeader.payloadLength = int64(len(p))
762763

763764
if c.client {
764-
_, err := io.ReadFull(rand.Reader, c.writeHeader.maskKey[:])
765+
err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey)
765766
if err != nil {
766767
return 0, fmt.Errorf("failed to generate masking key: %w", err)
767768
}
@@ -809,7 +810,7 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e
809810
}
810811

811812
if c.client {
812-
var keypos int
813+
maskKey := h.maskKey
813814
for len(p) > 0 {
814815
if c.bw.Available() == 0 {
815816
err = c.bw.Flush()
@@ -831,7 +832,7 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e
831832
return n, err
832833
}
833834

834-
keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2])
835+
maskKey = mask(maskKey, c.writeBuf[i:i+n2])
835836

836837
p = p[n2:]
837838
n += n2

conn_export_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) {
3737
return 0, nil, err
3838
}
3939
if h.masked {
40-
fastXOR(h.maskKey, 0, b)
40+
mask(h.maskKey, b)
4141
}
4242
return OpCode(h.opcode), b, nil
4343
}

frame.go

+84-81
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"math"
9+
"math/bits"
910
)
1011

1112
//go:generate stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go
@@ -69,7 +70,7 @@ type header struct {
6970
payloadLength int64
7071

7172
masked bool
72-
maskKey [4]byte
73+
maskKey uint32
7374
}
7475

7576
func makeWriteHeaderBuf() []byte {
@@ -119,7 +120,7 @@ func writeHeader(b []byte, h header) []byte {
119120
if h.masked {
120121
b[1] |= 1 << 7
121122
b = b[:len(b)+4]
122-
copy(b[len(b)-4:], h.maskKey[:])
123+
binary.LittleEndian.PutUint32(b[len(b)-4:], h.maskKey)
123124
}
124125

125126
return b
@@ -192,7 +193,7 @@ func readHeader(b []byte, r io.Reader) (header, error) {
192193
}
193194

194195
if h.masked {
195-
copy(h.maskKey[:], b)
196+
h.maskKey = binary.LittleEndian.Uint32(b)
196197
}
197198

198199
return h, nil
@@ -321,122 +322,124 @@ func (ce CloseError) bytes() ([]byte, error) {
321322
return buf, nil
322323
}
323324

324-
// xor applies the WebSocket masking algorithm to p
325-
// with the given key where the first 3 bits of pos
326-
// are the starting position in the key.
325+
// fastMask applies the WebSocket masking algorithm to p
326+
// with the given key.
327327
// See https://tools.ietf.org/html/rfc6455#section-5.3
328328
//
329-
// The returned value is the position of the next byte
330-
// to be used for masking in the key. This is so that
331-
// unmasking can be performed without the entire frame.
332-
func fastXOR(key [4]byte, keyPos int, b []byte) int {
333-
// If the payload is greater than or equal to 16 bytes, then it's worth
334-
// masking 8 bytes at a time.
335-
// Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859
336-
if len(b) >= 16 {
337-
// We first create a key that is 8 bytes long
338-
// and is aligned on the position correctly.
339-
var alignedKey [8]byte
340-
for i := range alignedKey {
341-
alignedKey[i] = key[(i+keyPos)&3]
342-
}
343-
k := binary.LittleEndian.Uint64(alignedKey[:])
329+
// The returned value is the correctly rotated key to
330+
// to continue to mask/unmask the message.
331+
//
332+
// It is optimized for LittleEndian and expects the key
333+
// to be in little endian.
334+
//
335+
// See https://github.com/golang/go/issues/31586
336+
func mask(key uint32, b []byte) uint32 {
337+
if len(b) >= 8 {
338+
key64 := uint64(key)<<32 | uint64(key)
344339

345340
// At some point in the future we can clean these unrolled loops up.
346341
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
347342

348343
// Then we xor until b is less than 128 bytes.
349344
for len(b) >= 128 {
350345
v := binary.LittleEndian.Uint64(b)
351-
binary.LittleEndian.PutUint64(b, v^k)
352-
v = binary.LittleEndian.Uint64(b[8:])
353-
binary.LittleEndian.PutUint64(b[8:], v^k)
354-
v = binary.LittleEndian.Uint64(b[16:])
355-
binary.LittleEndian.PutUint64(b[16:], v^k)
356-
v = binary.LittleEndian.Uint64(b[24:])
357-
binary.LittleEndian.PutUint64(b[24:], v^k)
358-
v = binary.LittleEndian.Uint64(b[32:])
359-
binary.LittleEndian.PutUint64(b[32:], v^k)
360-
v = binary.LittleEndian.Uint64(b[40:])
361-
binary.LittleEndian.PutUint64(b[40:], v^k)
362-
v = binary.LittleEndian.Uint64(b[48:])
363-
binary.LittleEndian.PutUint64(b[48:], v^k)
364-
v = binary.LittleEndian.Uint64(b[56:])
365-
binary.LittleEndian.PutUint64(b[56:], v^k)
366-
v = binary.LittleEndian.Uint64(b[64:])
367-
binary.LittleEndian.PutUint64(b[64:], v^k)
368-
v = binary.LittleEndian.Uint64(b[72:])
369-
binary.LittleEndian.PutUint64(b[72:], v^k)
370-
v = binary.LittleEndian.Uint64(b[80:])
371-
binary.LittleEndian.PutUint64(b[80:], v^k)
372-
v = binary.LittleEndian.Uint64(b[88:])
373-
binary.LittleEndian.PutUint64(b[88:], v^k)
374-
v = binary.LittleEndian.Uint64(b[96:])
375-
binary.LittleEndian.PutUint64(b[96:], v^k)
376-
v = binary.LittleEndian.Uint64(b[104:])
377-
binary.LittleEndian.PutUint64(b[104:], v^k)
378-
v = binary.LittleEndian.Uint64(b[112:])
379-
binary.LittleEndian.PutUint64(b[112:], v^k)
380-
v = binary.LittleEndian.Uint64(b[120:])
381-
binary.LittleEndian.PutUint64(b[120:], v^k)
346+
binary.LittleEndian.PutUint64(b, v^key64)
347+
v = binary.LittleEndian.Uint64(b[8:16])
348+
binary.LittleEndian.PutUint64(b[8:16], v^key64)
349+
v = binary.LittleEndian.Uint64(b[16:24])
350+
binary.LittleEndian.PutUint64(b[16:24], v^key64)
351+
v = binary.LittleEndian.Uint64(b[24:32])
352+
binary.LittleEndian.PutUint64(b[24:32], v^key64)
353+
v = binary.LittleEndian.Uint64(b[32:40])
354+
binary.LittleEndian.PutUint64(b[32:40], v^key64)
355+
v = binary.LittleEndian.Uint64(b[40:48])
356+
binary.LittleEndian.PutUint64(b[40:48], v^key64)
357+
v = binary.LittleEndian.Uint64(b[48:56])
358+
binary.LittleEndian.PutUint64(b[48:56], v^key64)
359+
v = binary.LittleEndian.Uint64(b[56:64])
360+
binary.LittleEndian.PutUint64(b[56:64], v^key64)
361+
v = binary.LittleEndian.Uint64(b[64:72])
362+
binary.LittleEndian.PutUint64(b[64:72], v^key64)
363+
v = binary.LittleEndian.Uint64(b[72:80])
364+
binary.LittleEndian.PutUint64(b[72:80], v^key64)
365+
v = binary.LittleEndian.Uint64(b[80:88])
366+
binary.LittleEndian.PutUint64(b[80:88], v^key64)
367+
v = binary.LittleEndian.Uint64(b[88:96])
368+
binary.LittleEndian.PutUint64(b[88:96], v^key64)
369+
v = binary.LittleEndian.Uint64(b[96:104])
370+
binary.LittleEndian.PutUint64(b[96:104], v^key64)
371+
v = binary.LittleEndian.Uint64(b[104:112])
372+
binary.LittleEndian.PutUint64(b[104:112], v^key64)
373+
v = binary.LittleEndian.Uint64(b[112:120])
374+
binary.LittleEndian.PutUint64(b[112:120], v^key64)
375+
v = binary.LittleEndian.Uint64(b[120:128])
376+
binary.LittleEndian.PutUint64(b[120:128], v^key64)
382377
b = b[128:]
383378
}
384379

385380
// Then we xor until b is less than 64 bytes.
386381
for len(b) >= 64 {
387382
v := binary.LittleEndian.Uint64(b)
388-
binary.LittleEndian.PutUint64(b, v^k)
389-
v = binary.LittleEndian.Uint64(b[8:])
390-
binary.LittleEndian.PutUint64(b[8:], v^k)
391-
v = binary.LittleEndian.Uint64(b[16:])
392-
binary.LittleEndian.PutUint64(b[16:], v^k)
393-
v = binary.LittleEndian.Uint64(b[24:])
394-
binary.LittleEndian.PutUint64(b[24:], v^k)
395-
v = binary.LittleEndian.Uint64(b[32:])
396-
binary.LittleEndian.PutUint64(b[32:], v^k)
397-
v = binary.LittleEndian.Uint64(b[40:])
398-
binary.LittleEndian.PutUint64(b[40:], v^k)
399-
v = binary.LittleEndian.Uint64(b[48:])
400-
binary.LittleEndian.PutUint64(b[48:], v^k)
401-
v = binary.LittleEndian.Uint64(b[56:])
402-
binary.LittleEndian.PutUint64(b[56:], v^k)
383+
binary.LittleEndian.PutUint64(b, v^key64)
384+
v = binary.LittleEndian.Uint64(b[8:16])
385+
binary.LittleEndian.PutUint64(b[8:16], v^key64)
386+
v = binary.LittleEndian.Uint64(b[16:24])
387+
binary.LittleEndian.PutUint64(b[16:24], v^key64)
388+
v = binary.LittleEndian.Uint64(b[24:32])
389+
binary.LittleEndian.PutUint64(b[24:32], v^key64)
390+
v = binary.LittleEndian.Uint64(b[32:40])
391+
binary.LittleEndian.PutUint64(b[32:40], v^key64)
392+
v = binary.LittleEndian.Uint64(b[40:48])
393+
binary.LittleEndian.PutUint64(b[40:48], v^key64)
394+
v = binary.LittleEndian.Uint64(b[48:56])
395+
binary.LittleEndian.PutUint64(b[48:56], v^key64)
396+
v = binary.LittleEndian.Uint64(b[56:64])
397+
binary.LittleEndian.PutUint64(b[56:64], v^key64)
403398
b = b[64:]
404399
}
405400

406401
// Then we xor until b is less than 32 bytes.
407402
for len(b) >= 32 {
408403
v := binary.LittleEndian.Uint64(b)
409-
binary.LittleEndian.PutUint64(b, v^k)
410-
v = binary.LittleEndian.Uint64(b[8:])
411-
binary.LittleEndian.PutUint64(b[8:], v^k)
412-
v = binary.LittleEndian.Uint64(b[16:])
413-
binary.LittleEndian.PutUint64(b[16:], v^k)
414-
v = binary.LittleEndian.Uint64(b[24:])
415-
binary.LittleEndian.PutUint64(b[24:], v^k)
404+
binary.LittleEndian.PutUint64(b, v^key64)
405+
v = binary.LittleEndian.Uint64(b[8:16])
406+
binary.LittleEndian.PutUint64(b[8:16], v^key64)
407+
v = binary.LittleEndian.Uint64(b[16:24])
408+
binary.LittleEndian.PutUint64(b[16:24], v^key64)
409+
v = binary.LittleEndian.Uint64(b[24:32])
410+
binary.LittleEndian.PutUint64(b[24:32], v^key64)
416411
b = b[32:]
417412
}
418413

419414
// Then we xor until b is less than 16 bytes.
420415
for len(b) >= 16 {
421416
v := binary.LittleEndian.Uint64(b)
422-
binary.LittleEndian.PutUint64(b, v^k)
423-
v = binary.LittleEndian.Uint64(b[8:])
424-
binary.LittleEndian.PutUint64(b[8:], v^k)
417+
binary.LittleEndian.PutUint64(b, v^key64)
418+
v = binary.LittleEndian.Uint64(b[8:16])
419+
binary.LittleEndian.PutUint64(b[8:16], v^key64)
425420
b = b[16:]
426421
}
427422

428423
// Then we xor until b is less than 8 bytes.
429424
for len(b) >= 8 {
430425
v := binary.LittleEndian.Uint64(b)
431-
binary.LittleEndian.PutUint64(b, v^k)
426+
binary.LittleEndian.PutUint64(b, v^key64)
432427
b = b[8:]
433428
}
434429
}
435430

431+
// Then we xor until b is less than 4 bytes.
432+
for len(b) >= 4 {
433+
v := binary.LittleEndian.Uint32(b)
434+
binary.LittleEndian.PutUint32(b, v^key)
435+
b = b[4:]
436+
}
437+
436438
// xor remaining bytes.
437439
for i := range b {
438-
b[i] ^= key[keyPos&3]
439-
keyPos++
440+
b[i] ^= byte(key)
441+
key = bits.RotateLeft32(key, -8)
440442
}
441-
return keyPos & 3
443+
444+
return key
442445
}

0 commit comments

Comments
 (0)