Skip to content

Commit ef99d40

Browse files
committed
feat: Allow subclassing MultipartSegment
This patch includes refactoring and changes of internal APIs to allow stabilizing those APIs in the future.
1 parent 45fb97c commit ef99d40

File tree

2 files changed

+60
-46
lines changed

2 files changed

+60
-46
lines changed

CHANGELOG.rst

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ Release 1.3
1616
**Not released yet**
1717

1818
* feat: Nicer error messages when reading from a closed ``MultipartPart``.
19+
* feat: Support custom `MultipartSegment` subclasses to be used and emitted by
20+
`PushMultipartParser`. However, the API between parser and segment is not
21+
stable yet. Overriding any of the ``_on_*`` methods may break during releases.
1922

2023
Release 1.2
2124
===========

multipart.py

+57-46
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import re
2222
from io import BytesIO
23-
from typing import Iterator, Union, Optional, Tuple, List
23+
from typing import Generic, Iterator, Type, TypeVar, Union, Optional, Tuple, List
2424
from urllib.parse import parse_qs
2525
from wsgiref.headers import Headers
2626
from collections.abc import MutableMapping as DictMixin
@@ -280,8 +280,10 @@ def parse_options_header(header, options=None, unquote=header_unquote):
280280
_BODY = "BODY"
281281
_COMPLETE = "END"
282282

283+
t_segment = TypeVar('SegmentType', bound="MultipartSegment")
284+
285+
class PushMultipartParser(Generic[t_segment]):
283286

284-
class PushMultipartParser:
285287
def __init__(
286288
self,
287289
boundary: Union[str, bytes],
@@ -292,6 +294,7 @@ def __init__(
292294
max_segment_count=inf, # unlimited
293295
header_charset="utf8",
294296
strict=False,
297+
segment_class: Optional[Type[t_segment]] = None,
295298
):
296299
"""A push-based (incremental, non-blocking) parser for multipart/form-data.
297300
@@ -311,6 +314,8 @@ def __init__(
311314
:param max_segment_count: Maximum number of segments.
312315
:param header_charset: Charset for header names and values.
313316
:param strict: Enables additional format and sanity checks.
317+
318+
:param segment_class: Class for emitted segments, defaults to `MultipartSegment`.
314319
"""
315320
self.boundary = to_bytes(boundary)
316321
self.content_length = content_length
@@ -321,13 +326,17 @@ def __init__(
321326
self.max_segment_count = max_segment_count
322327
self.strict = strict
323328

324-
self._delimiter = b"\r\n--" + self.boundary
329+
if segment_class and issubclass(self.segment_class, MultipartSegment):
330+
self.segment_class = segment_class
331+
else:
332+
self.segment_class = MultipartSegment
325333

326334
# Internal parser state
335+
self._delimiter = b"\r\n--" + self.boundary
327336
self._parsed = 0
328-
self._fieldcount = 0
329337
self._buffer = bytearray()
330-
self._current = None
338+
self._segment_count = 0
339+
self._segment = None
331340
self._state = _PREAMBLE
332341

333342
#: True if the parser reached the end of the multipart stream, stopped
@@ -344,7 +353,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
344353

345354
def parse(
346355
self, chunk: Union[bytes, bytearray]
347-
) -> Iterator[Union["MultipartSegment", bytearray, None]]:
356+
) -> Iterator[Union[t_segment, bytearray, None]]:
348357
"""Parse a chunk of data and yield as many result objects as possible
349358
with the data given.
350359
@@ -406,7 +415,7 @@ def parse(
406415
tail = buffer[next_start-2 : next_start]
407416

408417
if tail == b"\r\n": # Normal delimiter found
409-
self._current = MultipartSegment(self)
418+
self._segment = self._new_segment()
410419
self._state = _HEADER
411420
offset = next_start
412421
continue
@@ -433,12 +442,12 @@ def parse(
433442
nl = buffer.find(b"\r\n", offset)
434443

435444
if nl > offset: # Non-empty header line
436-
self._current._add_headerline(buffer[offset:nl])
445+
self._segment._on_headerline(buffer[offset:nl])
437446
offset = nl + 2
438447
continue
439448
elif nl == offset: # Empty header line -> End of header section
440-
self._current._close_headers()
441-
yield self._current
449+
self._segment._on_header_complete()
450+
yield self._segment
442451
self._state = _BODY
443452
offset += 2
444453
continue
@@ -463,27 +472,25 @@ def parse(
463472

464473
if tail == b"\r\n" or tail == b"--":
465474
if index > offset:
466-
self._current._update_size(index - offset)
467-
yield buffer[offset:index]
475+
yield self._segment._on_data(buffer[offset:index])
468476

469477
offset = next_start
470-
self._current._mark_complete()
478+
self._segment._on_data_complete()
471479
yield None # End of segment
472480

473481
if tail == b"--": # Last delimiter
474482
self._state = _COMPLETE
475483
break
476484
else: # Normal delimiter
477-
self._current = MultipartSegment(self)
485+
self._segment = self._new_segment()
478486
self._state = _HEADER
479487
continue
480488

481489
# Keep enough in buffer to accout for a partial delimiter at
482490
# the end, but emiot the rest.
483491
chunk_end = bufferlen - (d_len + 1)
484492
assert chunk_end > offset # Always true
485-
self._current._update_size(chunk_end - offset)
486-
yield buffer[offset:chunk_end]
493+
yield self._segment._on_data(buffer[offset:chunk_end])
487494
offset = chunk_end
488495
break # wait for more data
489496

@@ -501,6 +508,12 @@ def parse(
501508
self.close(check_complete=False)
502509
raise
503510

511+
def _new_segment(self) -> t_segment:
512+
self._segment_count += 1
513+
if self._segment_count > self.max_segment_count:
514+
raise ParserLimitReached("Maximum segment count exceeded")
515+
return self.segment_class(self)
516+
504517
def close(self, check_complete=True):
505518
"""
506519
Close this parser if not already closed.
@@ -510,7 +523,7 @@ def close(self, check_complete=True):
510523
"""
511524

512525
self.closed = True
513-
self._current = None
526+
self._segment = None
514527
del self._buffer[:]
515528

516529
if check_complete and self._state is not _COMPLETE:
@@ -551,39 +564,34 @@ class MultipartSegment:
551564
def __init__(self, parser: PushMultipartParser):
552565
""" Private constructor, used by :class:`PushMultipartParser` """
553566
self._parser = parser
554-
555-
if parser._fieldcount+1 > parser.max_segment_count:
556-
raise ParserLimitReached("Maximum segment count exceeded")
557-
parser._fieldcount += 1
558-
559567
self.headerlist = []
560568
self.size = 0
561-
self.complete = 0
569+
self.complete = False
562570

563-
self.name = None
571+
self.name = ""
564572
self.filename = None
565573
self.content_type = None
566574
self.charset = None
575+
self._maxlen = parser.max_segment_size
567576
self._clen = -1
568-
self._size_limit = parser.max_segment_size
569577

570-
def _add_headerline(self, line: bytearray):
571-
assert line and self.name is None
572-
parser = self._parser
578+
def _on_headerline(self, line: bytearray):
579+
""" Called for each raw header line in a segment. """
573580

574-
if line[0] in b" \t": # Multi-line header value
575-
if not self.headerlist or parser.strict:
581+
if line[0] in b" \t": # Continuation of last header line
582+
if not self.headerlist or self._parser.strict:
576583
raise StrictParserError("Unexpected segment header continuation")
577584
prev = ": ".join(self.headerlist.pop())
578-
line = prev.encode(parser.header_charset) + b" " + line.strip()
585+
line = prev.encode(self._parser.header_charset) + b" " + line.strip()
579586

580-
if len(line) > parser.max_header_size:
587+
if len(line) > self._parser.max_header_size:
581588
raise ParserLimitReached("Maximum segment header length exceeded")
582-
if len(self.headerlist) >= parser.max_header_count:
589+
590+
if len(self.headerlist) >= self._parser.max_header_count:
583591
raise ParserLimitReached("Maximum segment header count exceeded")
584592

585593
try:
586-
name, col, value = line.decode(parser.header_charset).partition(":")
594+
name, col, value = line.decode(self._parser.header_charset).partition(":")
587595
name = name.strip()
588596
if not col or not name:
589597
raise ParserError("Malformed segment header")
@@ -594,9 +602,10 @@ def _add_headerline(self, line: bytearray):
594602

595603
self.headerlist.append((name.title(), value.strip()))
596604

597-
def _close_headers(self):
598-
assert self.name is None
605+
def _on_header_complete(self):
606+
""" Called after the last segment header. """
599607

608+
dtype = False
600609
for h,v in self.headerlist:
601610
if h == "Content-Disposition":
602611
dtype, args = parse_options_header(v, unquote=content_disposition_unquote)
@@ -611,21 +620,23 @@ def _close_headers(self):
611620
self.charset = args.get("charset")
612621
elif h == "Content-Length" and v.isdecimal():
613622
self._clen = int(v)
623+
self._maxlen = min(self._clen, self._maxlen)
614624

615-
if self.name is None:
625+
if not dtype:
616626
raise ParserError("Missing Content-Disposition segment header")
617627

618-
def _update_size(self, bytecount: int):
619-
assert self.name is not None and not self.complete
620-
self.size += bytecount
621-
if self._clen >= 0 and self.size > self._clen:
622-
raise ParserError("Segment Content-Length exceeded")
623-
if self.size > self._size_limit:
628+
def _on_data(self, chunk: bytearray) -> bytearray:
629+
""" Called for each chunk of segment data. Must return the chunk. """
630+
self.size += len(chunk)
631+
if self.size > self._maxlen:
632+
if self.size > self._clen > -1:
633+
raise ParserError("Segment Content-Length exceeded")
624634
raise ParserLimitReached("Maximum segment size exceeded")
635+
return chunk
625636

626-
def _mark_complete(self):
627-
assert self.name is not None and not self.complete
628-
if self._clen >= 0 and self.size != self._clen:
637+
def _on_data_complete(self):
638+
""" Called after the last chunk of segment data. """
639+
if self._clen > -1 and self.size != self._clen:
629640
raise ParserError("Segment size does not match Content-Length header")
630641
self.complete = True
631642

0 commit comments

Comments
 (0)