|
6 | 6 | "fmt"
|
7 | 7 | "io"
|
8 | 8 | "math"
|
| 9 | + "math/bits" |
9 | 10 | )
|
10 | 11 |
|
11 | 12 | //go:generate stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go
|
@@ -69,7 +70,7 @@ type header struct {
|
69 | 70 | payloadLength int64
|
70 | 71 |
|
71 | 72 | masked bool
|
72 |
| - maskKey [4]byte |
| 73 | + maskKey uint32 |
73 | 74 | }
|
74 | 75 |
|
75 | 76 | func makeWriteHeaderBuf() []byte {
|
@@ -119,7 +120,7 @@ func writeHeader(b []byte, h header) []byte {
|
119 | 120 | if h.masked {
|
120 | 121 | b[1] |= 1 << 7
|
121 | 122 | b = b[:len(b)+4]
|
122 |
| - copy(b[len(b)-4:], h.maskKey[:]) |
| 123 | + binary.LittleEndian.PutUint32(b[len(b)-4:], h.maskKey) |
123 | 124 | }
|
124 | 125 |
|
125 | 126 | return b
|
@@ -192,7 +193,7 @@ func readHeader(b []byte, r io.Reader) (header, error) {
|
192 | 193 | }
|
193 | 194 |
|
194 | 195 | if h.masked {
|
195 |
| - copy(h.maskKey[:], b) |
| 196 | + h.maskKey = binary.LittleEndian.Uint32(b) |
196 | 197 | }
|
197 | 198 |
|
198 | 199 | return h, nil
|
@@ -321,122 +322,124 @@ func (ce CloseError) bytes() ([]byte, error) {
|
321 | 322 | return buf, nil
|
322 | 323 | }
|
323 | 324 |
|
324 |
| -// xor applies the WebSocket masking algorithm to p |
325 |
| -// with the given key where the first 3 bits of pos |
326 |
| -// are the starting position in the key. |
| 325 | +// fastMask applies the WebSocket masking algorithm to p |
| 326 | +// with the given key. |
327 | 327 | // See https://tools.ietf.org/html/rfc6455#section-5.3
|
328 | 328 | //
|
329 |
| -// The returned value is the position of the next byte |
330 |
| -// to be used for masking in the key. This is so that |
331 |
| -// unmasking can be performed without the entire frame. |
332 |
| -func fastXOR(key [4]byte, keyPos int, b []byte) int { |
333 |
| - // If the payload is greater than or equal to 16 bytes, then it's worth |
334 |
| - // masking 8 bytes at a time. |
335 |
| - // Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859 |
336 |
| - if len(b) >= 16 { |
337 |
| - // We first create a key that is 8 bytes long |
338 |
| - // and is aligned on the position correctly. |
339 |
| - var alignedKey [8]byte |
340 |
| - for i := range alignedKey { |
341 |
| - alignedKey[i] = key[(i+keyPos)&3] |
342 |
| - } |
343 |
| - k := binary.LittleEndian.Uint64(alignedKey[:]) |
| 329 | +// The returned value is the correctly rotated key to |
| 330 | +// to continue to mask/unmask the message. |
| 331 | +// |
| 332 | +// It is optimized for LittleEndian and expects the key |
| 333 | +// to be in little endian. |
| 334 | +// |
| 335 | +// See https://github.com/golang/go/issues/31586 |
| 336 | +func mask(key uint32, b []byte) uint32 { |
| 337 | + if len(b) >= 8 { |
| 338 | + key64 := uint64(key)<<32 | uint64(key) |
344 | 339 |
|
345 | 340 | // At some point in the future we can clean these unrolled loops up.
|
346 | 341 | // See https://github.com/golang/go/issues/31586#issuecomment-487436401
|
347 | 342 |
|
348 | 343 | // Then we xor until b is less than 128 bytes.
|
349 | 344 | for len(b) >= 128 {
|
350 | 345 | v := binary.LittleEndian.Uint64(b)
|
351 |
| - binary.LittleEndian.PutUint64(b, v^k) |
352 |
| - v = binary.LittleEndian.Uint64(b[8:]) |
353 |
| - binary.LittleEndian.PutUint64(b[8:], v^k) |
354 |
| - v = binary.LittleEndian.Uint64(b[16:]) |
355 |
| - binary.LittleEndian.PutUint64(b[16:], v^k) |
356 |
| - v = binary.LittleEndian.Uint64(b[24:]) |
357 |
| - binary.LittleEndian.PutUint64(b[24:], v^k) |
358 |
| - v = binary.LittleEndian.Uint64(b[32:]) |
359 |
| - binary.LittleEndian.PutUint64(b[32:], v^k) |
360 |
| - v = binary.LittleEndian.Uint64(b[40:]) |
361 |
| - binary.LittleEndian.PutUint64(b[40:], v^k) |
362 |
| - v = binary.LittleEndian.Uint64(b[48:]) |
363 |
| - binary.LittleEndian.PutUint64(b[48:], v^k) |
364 |
| - v = binary.LittleEndian.Uint64(b[56:]) |
365 |
| - binary.LittleEndian.PutUint64(b[56:], v^k) |
366 |
| - v = binary.LittleEndian.Uint64(b[64:]) |
367 |
| - binary.LittleEndian.PutUint64(b[64:], v^k) |
368 |
| - v = binary.LittleEndian.Uint64(b[72:]) |
369 |
| - binary.LittleEndian.PutUint64(b[72:], v^k) |
370 |
| - v = binary.LittleEndian.Uint64(b[80:]) |
371 |
| - binary.LittleEndian.PutUint64(b[80:], v^k) |
372 |
| - v = binary.LittleEndian.Uint64(b[88:]) |
373 |
| - binary.LittleEndian.PutUint64(b[88:], v^k) |
374 |
| - v = binary.LittleEndian.Uint64(b[96:]) |
375 |
| - binary.LittleEndian.PutUint64(b[96:], v^k) |
376 |
| - v = binary.LittleEndian.Uint64(b[104:]) |
377 |
| - binary.LittleEndian.PutUint64(b[104:], v^k) |
378 |
| - v = binary.LittleEndian.Uint64(b[112:]) |
379 |
| - binary.LittleEndian.PutUint64(b[112:], v^k) |
380 |
| - v = binary.LittleEndian.Uint64(b[120:]) |
381 |
| - binary.LittleEndian.PutUint64(b[120:], v^k) |
| 346 | + binary.LittleEndian.PutUint64(b, v^key64) |
| 347 | + v = binary.LittleEndian.Uint64(b[8:16]) |
| 348 | + binary.LittleEndian.PutUint64(b[8:16], v^key64) |
| 349 | + v = binary.LittleEndian.Uint64(b[16:24]) |
| 350 | + binary.LittleEndian.PutUint64(b[16:24], v^key64) |
| 351 | + v = binary.LittleEndian.Uint64(b[24:32]) |
| 352 | + binary.LittleEndian.PutUint64(b[24:32], v^key64) |
| 353 | + v = binary.LittleEndian.Uint64(b[32:40]) |
| 354 | + binary.LittleEndian.PutUint64(b[32:40], v^key64) |
| 355 | + v = binary.LittleEndian.Uint64(b[40:48]) |
| 356 | + binary.LittleEndian.PutUint64(b[40:48], v^key64) |
| 357 | + v = binary.LittleEndian.Uint64(b[48:56]) |
| 358 | + binary.LittleEndian.PutUint64(b[48:56], v^key64) |
| 359 | + v = binary.LittleEndian.Uint64(b[56:64]) |
| 360 | + binary.LittleEndian.PutUint64(b[56:64], v^key64) |
| 361 | + v = binary.LittleEndian.Uint64(b[64:72]) |
| 362 | + binary.LittleEndian.PutUint64(b[64:72], v^key64) |
| 363 | + v = binary.LittleEndian.Uint64(b[72:80]) |
| 364 | + binary.LittleEndian.PutUint64(b[72:80], v^key64) |
| 365 | + v = binary.LittleEndian.Uint64(b[80:88]) |
| 366 | + binary.LittleEndian.PutUint64(b[80:88], v^key64) |
| 367 | + v = binary.LittleEndian.Uint64(b[88:96]) |
| 368 | + binary.LittleEndian.PutUint64(b[88:96], v^key64) |
| 369 | + v = binary.LittleEndian.Uint64(b[96:104]) |
| 370 | + binary.LittleEndian.PutUint64(b[96:104], v^key64) |
| 371 | + v = binary.LittleEndian.Uint64(b[104:112]) |
| 372 | + binary.LittleEndian.PutUint64(b[104:112], v^key64) |
| 373 | + v = binary.LittleEndian.Uint64(b[112:120]) |
| 374 | + binary.LittleEndian.PutUint64(b[112:120], v^key64) |
| 375 | + v = binary.LittleEndian.Uint64(b[120:128]) |
| 376 | + binary.LittleEndian.PutUint64(b[120:128], v^key64) |
382 | 377 | b = b[128:]
|
383 | 378 | }
|
384 | 379 |
|
385 | 380 | // Then we xor until b is less than 64 bytes.
|
386 | 381 | for len(b) >= 64 {
|
387 | 382 | v := binary.LittleEndian.Uint64(b)
|
388 |
| - binary.LittleEndian.PutUint64(b, v^k) |
389 |
| - v = binary.LittleEndian.Uint64(b[8:]) |
390 |
| - binary.LittleEndian.PutUint64(b[8:], v^k) |
391 |
| - v = binary.LittleEndian.Uint64(b[16:]) |
392 |
| - binary.LittleEndian.PutUint64(b[16:], v^k) |
393 |
| - v = binary.LittleEndian.Uint64(b[24:]) |
394 |
| - binary.LittleEndian.PutUint64(b[24:], v^k) |
395 |
| - v = binary.LittleEndian.Uint64(b[32:]) |
396 |
| - binary.LittleEndian.PutUint64(b[32:], v^k) |
397 |
| - v = binary.LittleEndian.Uint64(b[40:]) |
398 |
| - binary.LittleEndian.PutUint64(b[40:], v^k) |
399 |
| - v = binary.LittleEndian.Uint64(b[48:]) |
400 |
| - binary.LittleEndian.PutUint64(b[48:], v^k) |
401 |
| - v = binary.LittleEndian.Uint64(b[56:]) |
402 |
| - binary.LittleEndian.PutUint64(b[56:], v^k) |
| 383 | + binary.LittleEndian.PutUint64(b, v^key64) |
| 384 | + v = binary.LittleEndian.Uint64(b[8:16]) |
| 385 | + binary.LittleEndian.PutUint64(b[8:16], v^key64) |
| 386 | + v = binary.LittleEndian.Uint64(b[16:24]) |
| 387 | + binary.LittleEndian.PutUint64(b[16:24], v^key64) |
| 388 | + v = binary.LittleEndian.Uint64(b[24:32]) |
| 389 | + binary.LittleEndian.PutUint64(b[24:32], v^key64) |
| 390 | + v = binary.LittleEndian.Uint64(b[32:40]) |
| 391 | + binary.LittleEndian.PutUint64(b[32:40], v^key64) |
| 392 | + v = binary.LittleEndian.Uint64(b[40:48]) |
| 393 | + binary.LittleEndian.PutUint64(b[40:48], v^key64) |
| 394 | + v = binary.LittleEndian.Uint64(b[48:56]) |
| 395 | + binary.LittleEndian.PutUint64(b[48:56], v^key64) |
| 396 | + v = binary.LittleEndian.Uint64(b[56:64]) |
| 397 | + binary.LittleEndian.PutUint64(b[56:64], v^key64) |
403 | 398 | b = b[64:]
|
404 | 399 | }
|
405 | 400 |
|
406 | 401 | // Then we xor until b is less than 32 bytes.
|
407 | 402 | for len(b) >= 32 {
|
408 | 403 | v := binary.LittleEndian.Uint64(b)
|
409 |
| - binary.LittleEndian.PutUint64(b, v^k) |
410 |
| - v = binary.LittleEndian.Uint64(b[8:]) |
411 |
| - binary.LittleEndian.PutUint64(b[8:], v^k) |
412 |
| - v = binary.LittleEndian.Uint64(b[16:]) |
413 |
| - binary.LittleEndian.PutUint64(b[16:], v^k) |
414 |
| - v = binary.LittleEndian.Uint64(b[24:]) |
415 |
| - binary.LittleEndian.PutUint64(b[24:], v^k) |
| 404 | + binary.LittleEndian.PutUint64(b, v^key64) |
| 405 | + v = binary.LittleEndian.Uint64(b[8:16]) |
| 406 | + binary.LittleEndian.PutUint64(b[8:16], v^key64) |
| 407 | + v = binary.LittleEndian.Uint64(b[16:24]) |
| 408 | + binary.LittleEndian.PutUint64(b[16:24], v^key64) |
| 409 | + v = binary.LittleEndian.Uint64(b[24:32]) |
| 410 | + binary.LittleEndian.PutUint64(b[24:32], v^key64) |
416 | 411 | b = b[32:]
|
417 | 412 | }
|
418 | 413 |
|
419 | 414 | // Then we xor until b is less than 16 bytes.
|
420 | 415 | for len(b) >= 16 {
|
421 | 416 | v := binary.LittleEndian.Uint64(b)
|
422 |
| - binary.LittleEndian.PutUint64(b, v^k) |
423 |
| - v = binary.LittleEndian.Uint64(b[8:]) |
424 |
| - binary.LittleEndian.PutUint64(b[8:], v^k) |
| 417 | + binary.LittleEndian.PutUint64(b, v^key64) |
| 418 | + v = binary.LittleEndian.Uint64(b[8:16]) |
| 419 | + binary.LittleEndian.PutUint64(b[8:16], v^key64) |
425 | 420 | b = b[16:]
|
426 | 421 | }
|
427 | 422 |
|
428 | 423 | // Then we xor until b is less than 8 bytes.
|
429 | 424 | for len(b) >= 8 {
|
430 | 425 | v := binary.LittleEndian.Uint64(b)
|
431 |
| - binary.LittleEndian.PutUint64(b, v^k) |
| 426 | + binary.LittleEndian.PutUint64(b, v^key64) |
432 | 427 | b = b[8:]
|
433 | 428 | }
|
434 | 429 | }
|
435 | 430 |
|
| 431 | + // Then we xor until b is less than 4 bytes. |
| 432 | + for len(b) >= 4 { |
| 433 | + v := binary.LittleEndian.Uint32(b) |
| 434 | + binary.LittleEndian.PutUint32(b, v^key) |
| 435 | + b = b[4:] |
| 436 | + } |
| 437 | + |
436 | 438 | // xor remaining bytes.
|
437 | 439 | for i := range b {
|
438 |
| - b[i] ^= key[keyPos&3] |
439 |
| - keyPos++ |
| 440 | + b[i] ^= byte(key) |
| 441 | + key = bits.RotateLeft32(key, -8) |
440 | 442 | }
|
441 |
| - return keyPos & 3 |
| 443 | + |
| 444 | + return key |
442 | 445 | }
|
0 commit comments