20
20
21
21
import re
22
22
from io import BytesIO
23
- from typing import Iterator , Union , Optional , Tuple , List
23
+ from typing import Generic , Iterator , Type , TypeVar , Union , Optional , Tuple , List
24
24
from urllib .parse import parse_qs
25
25
from wsgiref .headers import Headers
26
26
from collections .abc import MutableMapping as DictMixin
@@ -280,8 +280,10 @@ def parse_options_header(header, options=None, unquote=header_unquote):
280
280
_BODY = "BODY"
281
281
_COMPLETE = "END"
282
282
283
+ t_segment = TypeVar ('SegmentType' , bound = "MultipartSegment" )
284
+
285
+ class PushMultipartParser (Generic [t_segment ]):
283
286
284
- class PushMultipartParser :
285
287
def __init__ (
286
288
self ,
287
289
boundary : Union [str , bytes ],
@@ -292,6 +294,7 @@ def __init__(
292
294
max_segment_count = inf , # unlimited
293
295
header_charset = "utf8" ,
294
296
strict = False ,
297
+ segment_class : Optional [Type [t_segment ]] = None ,
295
298
):
296
299
"""A push-based (incremental, non-blocking) parser for multipart/form-data.
297
300
@@ -311,6 +314,8 @@ def __init__(
311
314
:param max_segment_count: Maximum number of segments.
312
315
:param header_charset: Charset for header names and values.
313
316
:param strict: Enables additional format and sanity checks.
317
+
318
+ :param segment_class: Class for emitted segments, defaults to `MultipartSegment`.
314
319
"""
315
320
self .boundary = to_bytes (boundary )
316
321
self .content_length = content_length
@@ -321,13 +326,17 @@ def __init__(
321
326
self .max_segment_count = max_segment_count
322
327
self .strict = strict
323
328
324
- self ._delimiter = b"\r \n --" + self .boundary
329
+ if segment_class and issubclass (self .segment_class , MultipartSegment ):
330
+ self .segment_class = segment_class
331
+ else :
332
+ self .segment_class = MultipartSegment
325
333
326
334
# Internal parser state
335
+ self ._delimiter = b"\r \n --" + self .boundary
327
336
self ._parsed = 0
328
- self ._fieldcount = 0
329
337
self ._buffer = bytearray ()
330
- self ._current = None
338
+ self ._segment_count = 0
339
+ self ._segment = None
331
340
self ._state = _PREAMBLE
332
341
333
342
#: True if the parser reached the end of the multipart stream, stopped
@@ -344,7 +353,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
344
353
345
354
def parse (
346
355
self , chunk : Union [bytes , bytearray ]
347
- ) -> Iterator [Union ["MultipartSegment" , bytearray , None ]]:
356
+ ) -> Iterator [Union [t_segment , bytearray , None ]]:
348
357
"""Parse a chunk of data and yield as many result objects as possible
349
358
with the data given.
350
359
@@ -406,7 +415,7 @@ def parse(
406
415
tail = buffer [next_start - 2 : next_start ]
407
416
408
417
if tail == b"\r \n " : # Normal delimiter found
409
- self ._current = MultipartSegment ( self )
418
+ self ._segment = self . _new_segment ( )
410
419
self ._state = _HEADER
411
420
offset = next_start
412
421
continue
@@ -433,12 +442,12 @@ def parse(
433
442
nl = buffer .find (b"\r \n " , offset )
434
443
435
444
if nl > offset : # Non-empty header line
436
- self ._current . _add_headerline (buffer [offset :nl ])
445
+ self ._segment . _on_headerline (buffer [offset :nl ])
437
446
offset = nl + 2
438
447
continue
439
448
elif nl == offset : # Empty header line -> End of header section
440
- self ._current . _close_headers ()
441
- yield self ._current
449
+ self ._segment . _on_header_complete ()
450
+ yield self ._segment
442
451
self ._state = _BODY
443
452
offset += 2
444
453
continue
@@ -463,27 +472,25 @@ def parse(
463
472
464
473
if tail == b"\r \n " or tail == b"--" :
465
474
if index > offset :
466
- self ._current ._update_size (index - offset )
467
- yield buffer [offset :index ]
475
+ yield self ._segment ._on_data (buffer [offset :index ])
468
476
469
477
offset = next_start
470
- self ._current . _mark_complete ()
478
+ self ._segment . _on_data_complete ()
471
479
yield None # End of segment
472
480
473
481
if tail == b"--" : # Last delimiter
474
482
self ._state = _COMPLETE
475
483
break
476
484
else : # Normal delimiter
477
- self ._current = MultipartSegment ( self )
485
+ self ._segment = self . _new_segment ( )
478
486
self ._state = _HEADER
479
487
continue
480
488
481
489
# Keep enough in buffer to accout for a partial delimiter at
482
490
# the end, but emiot the rest.
483
491
chunk_end = bufferlen - (d_len + 1 )
484
492
assert chunk_end > offset # Always true
485
- self ._current ._update_size (chunk_end - offset )
486
- yield buffer [offset :chunk_end ]
493
+ yield self ._segment ._on_data (buffer [offset :chunk_end ])
487
494
offset = chunk_end
488
495
break # wait for more data
489
496
@@ -501,6 +508,12 @@ def parse(
501
508
self .close (check_complete = False )
502
509
raise
503
510
511
+ def _new_segment (self ) -> t_segment :
512
+ self ._segment_count += 1
513
+ if self ._segment_count > self .max_segment_count :
514
+ raise ParserLimitReached ("Maximum segment count exceeded" )
515
+ return self .segment_class (self )
516
+
504
517
def close (self , check_complete = True ):
505
518
"""
506
519
Close this parser if not already closed.
@@ -510,7 +523,7 @@ def close(self, check_complete=True):
510
523
"""
511
524
512
525
self .closed = True
513
- self ._current = None
526
+ self ._segment = None
514
527
del self ._buffer [:]
515
528
516
529
if check_complete and self ._state is not _COMPLETE :
@@ -551,39 +564,34 @@ class MultipartSegment:
551
564
def __init__ (self , parser : PushMultipartParser ):
552
565
""" Private constructor, used by :class:`PushMultipartParser` """
553
566
self ._parser = parser
554
-
555
- if parser ._fieldcount + 1 > parser .max_segment_count :
556
- raise ParserLimitReached ("Maximum segment count exceeded" )
557
- parser ._fieldcount += 1
558
-
559
567
self .headerlist = []
560
568
self .size = 0
561
- self .complete = 0
569
+ self .complete = False
562
570
563
- self .name = None
571
+ self .name = ""
564
572
self .filename = None
565
573
self .content_type = None
566
574
self .charset = None
575
+ self ._maxlen = parser .max_segment_size
567
576
self ._clen = - 1
568
- self ._size_limit = parser .max_segment_size
569
577
570
- def _add_headerline (self , line : bytearray ):
571
- assert line and self .name is None
572
- parser = self ._parser
578
+ def _on_headerline (self , line : bytearray ):
579
+ """ Called for each raw header line in a segment. """
573
580
574
- if line [0 ] in b" \t " : # Multi-line header value
575
- if not self .headerlist or parser .strict :
581
+ if line [0 ] in b" \t " : # Continuation of last header line
582
+ if not self .headerlist or self . _parser .strict :
576
583
raise StrictParserError ("Unexpected segment header continuation" )
577
584
prev = ": " .join (self .headerlist .pop ())
578
- line = prev .encode (parser .header_charset ) + b" " + line .strip ()
585
+ line = prev .encode (self . _parser .header_charset ) + b" " + line .strip ()
579
586
580
- if len (line ) > parser .max_header_size :
587
+ if len (line ) > self . _parser .max_header_size :
581
588
raise ParserLimitReached ("Maximum segment header length exceeded" )
582
- if len (self .headerlist ) >= parser .max_header_count :
589
+
590
+ if len (self .headerlist ) >= self ._parser .max_header_count :
583
591
raise ParserLimitReached ("Maximum segment header count exceeded" )
584
592
585
593
try :
586
- name , col , value = line .decode (parser .header_charset ).partition (":" )
594
+ name , col , value = line .decode (self . _parser .header_charset ).partition (":" )
587
595
name = name .strip ()
588
596
if not col or not name :
589
597
raise ParserError ("Malformed segment header" )
@@ -594,9 +602,10 @@ def _add_headerline(self, line: bytearray):
594
602
595
603
self .headerlist .append ((name .title (), value .strip ()))
596
604
597
- def _close_headers (self ):
598
- assert self . name is None
605
+ def _on_header_complete (self ):
606
+ """ Called after the last segment header. """
599
607
608
+ dtype = False
600
609
for h ,v in self .headerlist :
601
610
if h == "Content-Disposition" :
602
611
dtype , args = parse_options_header (v , unquote = content_disposition_unquote )
@@ -611,21 +620,23 @@ def _close_headers(self):
611
620
self .charset = args .get ("charset" )
612
621
elif h == "Content-Length" and v .isdecimal ():
613
622
self ._clen = int (v )
623
+ self ._maxlen = min (self ._clen , self ._maxlen )
614
624
615
- if self . name is None :
625
+ if not dtype :
616
626
raise ParserError ("Missing Content-Disposition segment header" )
617
627
618
- def _update_size (self , bytecount : int ) :
619
- assert self . name is not None and not self . complete
620
- self .size += bytecount
621
- if self ._clen >= 0 and self . size > self ._clen :
622
- raise ParserError ( "Segment Content-Length exceeded" )
623
- if self . size > self . _size_limit :
628
+ def _on_data (self , chunk : bytearray ) -> bytearray :
629
+ """ Called for each chunk of segment data. Must return the chunk. """
630
+ self .size += len ( chunk )
631
+ if self .size > self ._maxlen :
632
+ if self . size > self . _clen > - 1 :
633
+ raise ParserError ( "Segment Content-Length exceeded" )
624
634
raise ParserLimitReached ("Maximum segment size exceeded" )
635
+ return chunk
625
636
626
- def _mark_complete (self ):
627
- assert self . name is not None and not self . complete
628
- if self ._clen >= 0 and self .size != self ._clen :
637
+ def _on_data_complete (self ):
638
+ """ Called after the last chunk of segment data. """
639
+ if self ._clen > - 1 and self .size != self ._clen :
629
640
raise ParserError ("Segment size does not match Content-Length header" )
630
641
self .complete = True
631
642
0 commit comments