diff --git a/bddrest/headerset.py b/bddrest/headerset.py index 2c9960a..df42267 100644 --- a/bddrest/headerset.py +++ b/bddrest/headerset.py @@ -1,4 +1,4 @@ -from typing.re import Pattern +from typing import Pattern class HeaderSet(list): diff --git a/bddrest/helpers.py b/bddrest/helpers.py index bf3b504..b477c16 100644 --- a/bddrest/helpers.py +++ b/bddrest/helpers.py @@ -23,16 +23,7 @@ def encode_multipart_data(fields): for key, value in fields.items(): values = [value] if not isinstance(value, list) else value for value in values: - if not hasattr(value, 'read'): - lines.append('--' + boundary) - lines.append('Content-Disposition: form-data; name="%s"' % key) - lines.append('') - if not isinstance(value, str): - value = str(value) - - lines.append(value) - - else: + if hasattr(value, 'read'): filename = value.name if hasattr(value, 'name') else key lines.append('--' + boundary) lines.append( @@ -46,6 +37,15 @@ def encode_multipart_data(fields): lines.append('') lines.append(value.read()) + else: + lines.append('--' + boundary) + lines.append('Content-Disposition: form-data; name="%s"' % key) + lines.append('') + if not isinstance(value, str): + value = str(value) + + lines.append(value) + lines.append('--' + boundary + '--') lines.append('') diff --git a/requirements-ci.txt b/requirements-ci.txt index 170de4f..74b3e40 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -5,4 +5,3 @@ coveralls pytest-cov bddcli >= 2.4, < 3 flake8 - diff --git a/tests/multipart.py b/tests/multipart.py new file mode 100644 index 0000000..9f897cc --- /dev/null +++ b/tests/multipart.py @@ -0,0 +1,562 @@ +# -*- coding: utf-8 -*- +""" +Parser for multipart/form-data +============================== + +This module provides a parser for the multipart/form-data format. It can read +from a file, a socket or a WSGI environment. The parser can be used to replace +cgi.FieldStorage to work around its limitations. + +This module is modified version of https://github.com/defnull/multipart to +accept multipart and url-encoded forms when verb is other than PUT or POST. +HEAD, OPTION, GET and such verbs are remains prohibited. + + +This module completely excluded from code coverage because the 95% of code +coverage. +""" + + +__all__ = [ + "MultipartError", + "MultipartParser", + "MultipartPart", + "parse_form_data" +] + + +import re +import sys +from io import BytesIO +from tempfile import TemporaryFile +from urllib.parse import parse_qs +from wsgiref.headers import Headers +from collections.abc import MutableMapping as DictMixin + + +VERBS_PROHIBITED = [ + 'HEAD', + 'GET', + 'OPTIONS', + 'CONNECT' + 'DELETE', + 'LOCK', + 'UNLOCK', +] + + +############################################################################## +################################ Helper & Misc ############################### +############################################################################## +# Some of these were copied from bottle: https://bottlepy.org + + +# --------- +# MultiDict +# --------- + + +class MultiDict(DictMixin): + """ A dict that remembers old values for each key. + HTTP headers may repeat with differing values, + such as Set-Cookie. We need to remember all + values. + """ + + def __init__(self, *args, **kwargs): + self.dict = dict() + for k, v in dict(*args, **kwargs).items(): + self[k] = v + + def __len__(self): + return len(self.dict) + + def __iter__(self): + return iter(self.dict) + + def __contains__(self, key): + return key in self.dict + + def __delitem__(self, key): + del self.dict[key] + + def keys(self): + return self.dict.keys() + + def __getitem__(self, key): + return self.get(key, KeyError, -1) + + def __setitem__(self, key, value): + self.append(key, value) + + def append(self, key, value): + self.dict.setdefault(key, []).append(value) + + def replace(self, key, value): + self.dict[key] = [value] + + def getall(self, key): + return self.dict.get(key) or [] + + def get(self, key, default=None, index=-1): + if key not in self.dict and default != KeyError: + return [default][index] + + return self.dict[key][index] + + def iterallitems(self): + for key, values in self.dict.items(): + for value in values: + yield key, value + + +def to_bytes(data, enc="utf8"): + if isinstance(data, str): + data = data.encode(enc) + + return data + + +def copy_file(stream, target, maxread=-1, buffer_size=2 ** 16): + """ Read from :stream and write to :target until :maxread or EOF. """ + size, read = 0, stream.read + + while True: + to_read = buffer_size if maxread < 0 else min(buffer_size, maxread - size) + part = read(to_read) + + if not part: + return size + + target.write(part) + size += len(part) + + +# ------------- +# Header Parser +# ------------- + + +_special = re.escape('()<>@,;:"\\/[]?={} \t') +_re_special = re.compile(r'[%s]' % _special) +_quoted_string = r'"(?:\\.|[^"])*"' # Quoted string +_value = r'(?:[^%s]+|%s)' % (_special, _quoted_string) # Save or quoted string +_option = r'(?:;|^)\s*([^%s]+)\s*=\s*(%s)' % (_special, _value) +_re_option = re.compile(_option) # key=value part of an Content-Type like header + + +def header_quote(val): + if not _re_special.search(val): + return val + + return '"' + val.replace("\\", "\\\\").replace('"', '\\"') + '"' + + +def header_unquote(val, filename=False): + if val[0] == val[-1] == '"': + val = val[1:-1] + + if val[1:3] == ":\\" or val[:2] == "\\\\": + val = val.split("\\")[-1] # fix ie6 bug: full path --> filename + + return val.replace("\\\\", "\\").replace('\\"', '"') + + return val + + +def parse_options_header(header, options=None): + if ";" not in header: + return header.lower().strip(), {} + + content_type, tail = header.split(";", 1) + options = options or {} + + for match in _re_option.finditer(tail): + key = match.group(1).lower() + value = header_unquote(match.group(2), key == "filename") + options[key] = value + + return content_type, options + + +############################################################################## +################################## Multipart ################################# +############################################################################## + + +class MultipartError(ValueError): + pass + + +class MultipartParser(object): + def __init__( + self, + stream, + boundary, + content_length=-1, + disk_limit=2 ** 30, + mem_limit=2 ** 20, + memfile_limit=2 ** 18, + buffer_size=2 ** 16, + charset="latin1", + ): + """ Parse a multipart/form-data byte stream. This object is an iterator + over the parts of the message. + + :param stream: A file-like stream. Must implement ``.read(size)``. + :param boundary: The multipart boundary as a byte string. + :param content_length: The maximum number of bytes to read. + """ + self.stream = stream + self.boundary = boundary + self.content_length = content_length + self.disk_limit = disk_limit + self.memfile_limit = memfile_limit + self.mem_limit = min(mem_limit, self.disk_limit) + self.buffer_size = min(buffer_size, self.mem_limit) + self.charset = charset + + if self.buffer_size - 6 < len(boundary): # "--boundary--\r\n" + raise MultipartError("Boundary does not fit into buffer_size.") + + self._done = [] + self._part_iter = None + + def __iter__(self): + """ Iterate over the parts of the multipart message. """ + if not self._part_iter: + self._part_iter = self._iterparse() + + for part in self._done: + yield part + + for part in self._part_iter: + self._done.append(part) + yield part + + def parts(self): + """ Returns a list with all parts of the multipart message. """ + return list(self) + + def get(self, name, default=None): + """ Return the first part with that name or a default value (None). """ + for part in self: + if name == part.name: + return part + + return default + + def get_all(self, name): + """ Return a list of parts with that name. """ + return [p for p in self if p.name == name] + + def _lineiter(self): + """ Iterate over a binary file-like object line by line. Each line is + returned as a (line, line_ending) tuple. If the line does not fit + into self.buffer_size, line_ending is empty and the rest of the line + is returned with the next iteration. + """ + read = self.stream.read + maxread, maxbuf = self.content_length, self.buffer_size + buffer = b"" # buffer for the last (partial) line + + while True: + data = read(maxbuf if maxread < 0 else min(maxbuf, maxread)) + maxread -= len(data) + lines = (buffer + data).splitlines(True) + len_first_line = len(lines[0]) + + # be sure that the first line does not become too big + if len_first_line > self.buffer_size: + # at the same time don't split a '\r\n' accidentally + if len_first_line == self.buffer_size + 1 and lines[0].endswith(b"\r\n"): + splitpos = self.buffer_size - 1 + else: + splitpos = self.buffer_size + lines[:1] = [lines[0][:splitpos], lines[0][splitpos:]] + + if data: + buffer = lines[-1] + lines = lines[:-1] + + for line in lines: + if line.endswith(b"\r\n"): + yield line[:-2], b"\r\n" + elif line.endswith(b"\n"): + yield line[:-1], b"\n" + elif line.endswith(b"\r"): + yield line[:-1], b"\r" + else: + yield line, b"" + + if not data: + break + + def _iterparse(self): + lines, line = self._lineiter(), "" + separator = b"--" + to_bytes(self.boundary) + terminator = b"--" + to_bytes(self.boundary) + b"--" + + # Consume first boundary. Ignore any preamble, as required by RFC + # 2046, section 5.1.1. + for line, nl in lines: + if line in (separator, terminator): + break + else: + raise MultipartError("Stream does not contain boundary") + + # Check for empty data + if line == terminator: + for _ in lines: + raise MultipartError("Data after end of stream") + return + + # For each part in stream... + mem_used, disk_used = 0, 0 # Track used resources to prevent DoS + is_tail = False # True if the last line was incomplete (cutted) + + opts = { + "buffer_size": self.buffer_size, + "memfile_limit": self.memfile_limit, + "charset": self.charset, + } + + part = MultipartPart(**opts) + + for line, nl in lines: + if line == terminator and not is_tail: + part.file.seek(0) + yield part + break + + elif line == separator and not is_tail: + if part.is_buffered(): + mem_used += part.size + else: + disk_used += part.size + part.file.seek(0) + + yield part + + part = MultipartPart(**opts) + + else: + is_tail = not nl # The next line continues this one + try: + part.feed(line, nl) + + if part.is_buffered(): + if part.size + mem_used > self.mem_limit: + raise MultipartError("Memory limit reached.") + elif part.size + disk_used > self.disk_limit: + raise MultipartError("Disk limit reached.") + except MultipartError: + part.close() + raise + else: + # If we run off the end of the loop, the current MultipartPart + # will not have been yielded, so it's our responsibility to + # close it. + part.close() + + if line != terminator: + raise MultipartError("Unexpected end of multipart stream.") + + +class MultipartPart(object): + def __init__(self, buffer_size=2 ** 16, memfile_limit=2 ** 18, charset="latin1"): + self.headerlist = [] + self.headers = None + self.file = False + self.size = 0 + self._buf = b"" + self.disposition = None + self.name = None + self.filename = None + self.content_type = None + self.charset = charset + self.memfile_limit = memfile_limit + self.buffer_size = buffer_size + + def feed(self, line, nl=""): + if self.file: + return self.write_body(line, nl) + + return self.write_header(line, nl) + + def write_header(self, line, nl): + line = line.decode(self.charset) + + if not nl: + raise MultipartError("Unexpected end of line in header.") + + if not line.strip(): # blank line -> end of header segment + self.finish_header() + elif line[0] in " \t" and self.headerlist: + name, value = self.headerlist.pop() + self.headerlist.append((name, value + line.strip())) + else: + if ":" not in line: + raise MultipartError("Syntax error in header: No colon.") + + name, value = line.split(":", 1) + self.headerlist.append((name.strip(), value.strip())) + + def write_body(self, line, nl): + if not line and not nl: + return # This does not even flush the buffer + + self.size += len(line) + len(self._buf) + self.file.write(self._buf + line) + self._buf = nl + + if self.content_length > 0 and self.size > self.content_length: + raise MultipartError("Size of body exceeds Content-Length header.") + + if self.size > self.memfile_limit and isinstance(self.file, BytesIO): + # TODO: What about non-file uploads that exceed the memfile_limit? + self.file, old = TemporaryFile(mode="w+b"), self.file + old.seek(0) + copy_file(old, self.file, self.size, self.buffer_size) + + def finish_header(self): + self.file = BytesIO() + self.headers = Headers(self.headerlist) + content_disposition = self.headers.get("Content-Disposition", "") + content_type = self.headers.get("Content-Type", "") + + if not content_disposition: + raise MultipartError("Content-Disposition header is missing.") + + self.disposition, self.options = parse_options_header(content_disposition) + self.name = self.options.get("name") + self.filename = self.options.get("filename") + self.content_type, options = parse_options_header(content_type) + self.charset = options.get("charset") or self.charset + self.content_length = int(self.headers.get("Content-Length", "-1")) + + def is_buffered(self): + """ Return true if the data is fully buffered in memory.""" + return isinstance(self.file, BytesIO) + + @property + def value(self): + """ Data decoded with the specified charset """ + + return self.raw.decode(self.charset) + + @property + def raw(self): + """ Data without decoding """ + pos = self.file.tell() + self.file.seek(0) + + try: + val = self.file.read() + except IOError: + raise + finally: + self.file.seek(pos) + + return val + + def save_as(self, path): + with open(path, "wb") as fp: + pos = self.file.tell() + + try: + self.file.seek(0) + size = copy_file(self.file, fp) + finally: + self.file.seek(pos) + + return size + + def close(self): + if self.file: + self.file.close() + self.file = False + + +############################################################################## +#################################### WSGI #################################### +############################################################################## + + +def parse_form_data(environ, charset="utf8", strict=False, **kwargs): + """ Parse form data from an environ dict and return a (forms, files) tuple. + Both tuple values are dictionaries with the form-field name as a key + (unicode) and lists as values (multiple values per key are possible). + The forms-dictionary contains form-field values as unicode strings. + The files-dictionary contains :class:`MultipartPart` instances, either + because the form-field was a file-upload or the value is too big to fit + into memory limits. + + :param environ: An WSGI environment dict. + :param charset: The charset to use if unsure. (default: utf8) + :param strict: If True, raise :exc:`MultipartError` on any parsing + errors. These are silently ignored by default. + """ + + forms, files = MultiDict(), MultiDict() + + try: + if environ.get("REQUEST_METHOD", "GET").upper() in VERBS_PROHIBITED: + raise MultipartError( + f"{', '.join(VERBS_PROHIBITED[:-1])} and VERBS_PROHIBITED[-1]" + "verbs are not allowed to submit the HTTP body.") + content_length = int(environ.get("CONTENT_LENGTH", "-1")) + content_type = environ.get("CONTENT_TYPE", "") + + if not content_type: + raise MultipartError("Missing Content-Type header.") + + content_type, options = parse_options_header(content_type) + stream = environ.get("wsgi.input") or BytesIO() + kwargs["charset"] = charset = options.get("charset", charset) + + if content_type == "multipart/form-data": + boundary = options.get("boundary", "") + + if not boundary: + raise MultipartError("No boundary for multipart/form-data.") + + for part in MultipartParser(stream, boundary, content_length, **kwargs): + if part.filename or not part.is_buffered(): + if part.name in files: + if not isinstance(files[part.name], list): + files[part.name] = [files[part.name]] + files[part.name].append(part) + else: + files[part.name] = part + else: # TODO: Big form-fields are in the files dict. really? + forms[part.name] = part.value + + elif content_type in ( + "application/x-www-form-urlencoded", + "application/x-url-encoded", + ): + mem_limit = kwargs.get("mem_limit", 2 ** 20) + if content_length > mem_limit: + raise MultipartError("Request too big. Increase MAXMEM.") + + data = stream.read(mem_limit).decode(charset) + + if stream.read(1): # These is more that does not fit mem_limit + raise MultipartError("Request too big. Increase MAXMEM.") + + data = parse_qs(data, keep_blank_values=True, encoding=charset) + + for key, values in data.items(): + for value in values: + forms[key] = value + else: + raise MultipartError("Unsupported content type.") + + except MultipartError: + if strict: + for part in files.values(): + part.close() + raise + + return forms, files diff --git a/tests/test_assertion.py b/tests/test_assertion.py index b8b0cc9..dd76c5b 100644 --- a/tests/test_assertion.py +++ b/tests/test_assertion.py @@ -1,19 +1,19 @@ -import cgi import json +from urllib.parse import parse_qs from bddrest import Given, when, response def wsgi_application(environ, start_response): - form = cgi.FieldStorage( - fp=environ['wsgi.input'], - environ=environ, - strict_parsing=False, - keep_blank_values=True + fp = environ['wsgi.input'] + form = parse_qs( + fp.read(int(environ.get('CONTENT_LENGTH'))).decode(), + keep_blank_values=True, + strict_parsing=True ) try: - code = int(form['activationCode'].value) ^ 1234 + code = int(form['activationCode'][0]) ^ 1234 except ValueError: start_response( '400 Bad Request', diff --git a/tests/test_call.py b/tests/test_call.py index 13c370a..5bcff6a 100644 --- a/tests/test_call.py +++ b/tests/test_call.py @@ -1,5 +1,5 @@ -import cgi import json +from urllib.parse import parse_qs import pytest @@ -7,11 +7,11 @@ def wsgi_application(environ, start_response): - form = cgi.FieldStorage( - fp=environ['wsgi.input'], - environ=environ, - strict_parsing=False, - keep_blank_values=True + fp = environ['wsgi.input'] + form = parse_qs( + fp.read(int(environ.get('CONTENT_LENGTH', 0))).decode(), + keep_blank_values=True, + strict_parsing=True ) start_response( diff --git a/tests/test_form_manipulation.py b/tests/test_form_manipulation.py index 87f911a..8157e66 100644 --- a/tests/test_form_manipulation.py +++ b/tests/test_form_manipulation.py @@ -1,5 +1,5 @@ -import cgi import json +from urllib.parse import parse_qs import pytest @@ -7,17 +7,17 @@ def wsgi_application(environ, start_response): - form = cgi.FieldStorage( - fp=environ['wsgi.input'], - environ=environ, - strict_parsing=False, - keep_blank_values=True + fp = environ['wsgi.input'] + form = parse_qs( + fp.read(int(environ.get('CONTENT_LENGTH', 0))).decode(), + keep_blank_values=True, + strict_parsing=True ) start_response('200 OK', [ ('Content-Type', 'application/json;charset=utf-8'), ]) - yield json.dumps({k: form[k].value for k in form.keys()}).encode() + yield json.dumps({k: form[k][0] for k in form.keys()}).encode() def test_append_form_field(): diff --git a/tests/test_multipart_form.py b/tests/test_multipart_form.py index cfb0316..0a38fbf 100644 --- a/tests/test_multipart_form.py +++ b/tests/test_multipart_form.py @@ -1,23 +1,23 @@ import base64 -import cgi import hashlib import io +from . import multipart + from bddrest import Given, response, status, story def wsgi_application(environ, start_response): - form = cgi.FieldStorage( - fp=environ['wsgi.input'], - environ=environ, - strict_parsing=False, - keep_blank_values=True + _, files = multipart.parse_form_data( + environ, + charset="utf8", + strict=True ) start_response('200 OK', [ ('Content-Type', 'text/plain;charset=utf-8'), ]) - binary_file = form['a'].file + binary_file = files['a'].file yield base64.encodebytes(hashlib.md5(binary_file.read()).digest()).decode() diff --git a/tests/test_multipart_list.py b/tests/test_multipart_list.py index 9f23245..b3866f7 100644 --- a/tests/test_multipart_list.py +++ b/tests/test_multipart_list.py @@ -1,25 +1,25 @@ import base64 -import cgi import hashlib import io import json from bddrest import Given, response, status +from . import multipart + def wsgi_application(environ, start_response): - form = cgi.FieldStorage( - fp=environ['wsgi.input'], - environ=environ, - strict_parsing=False, - keep_blank_values=True + _, files = multipart.parse_form_data( + environ, + charset="utf8", + strict=True ) start_response('200 OK', [ ('Content-Type', 'text/plain;charset=utf-8'), ]) - file1 = form['a'][0].file - file2 = form['a'][1].file + file1 = files['a'][0].file + file2 = files['a'][1].file response = dict( file1=base64.encodebytes(hashlib.md5(file1.read()).digest()).decode(), file2=base64.encodebytes(hashlib.md5(file2.read()).digest()).decode() diff --git a/tests/test_none_parameters.py b/tests/test_none_parameters.py index 8a8dd95..beeb95b 100644 --- a/tests/test_none_parameters.py +++ b/tests/test_none_parameters.py @@ -1,6 +1,7 @@ -import cgi import json +from . import multipart + from bddrest import Given, response, status @@ -8,16 +9,14 @@ def wsgi_application(environ, start_response): if environ.get('CONTENT_TYPE', '').startswith('application/json'): result = json.loads(environ['wsgi.input'].read()) else: - form = cgi.FieldStorage( - fp=environ['wsgi.input'], - environ=environ, - strict_parsing=False, - keep_blank_values=True + form, _ = multipart.parse_form_data( + environ, + charset="utf8", + strict=True ) result = {} for k in form.keys(): - v = form[k] - value = v.value + value = form[k] result[k] = value.decode() if isinstance(value, bytes) else value diff --git a/tests/test_querystring.py b/tests/test_querystring.py index 39496b6..7e05972 100644 --- a/tests/test_querystring.py +++ b/tests/test_querystring.py @@ -1,4 +1,3 @@ -import cgi import json import tempfile diff --git a/tests/test_story.py b/tests/test_story.py index d3b1378..283c98f 100644 --- a/tests/test_story.py +++ b/tests/test_story.py @@ -1,4 +1,3 @@ -import cgi import json import tempfile @@ -7,18 +6,19 @@ from bddrest import Given, when, story, response, InvalidUrlParametersError, \ CallVerifyError, Call, AlteredCall, Story, given +from . import multipart + def wsgi_application(environ, start_response): - form = cgi.FieldStorage( - fp=environ['wsgi.input'], - environ=environ, - strict_parsing=False, - keep_blank_values=True + form, _ = multipart.parse_form_data( + environ, + charset="utf8", + strict=True ) try: # FIXME: Why x ^ 1234 - code = int(form['activationCode'].value) ^ 1234 + code = int(form['activationCode']) ^ 1234 except ValueError: start_response( diff --git a/tests/test_wsgi_connector.py b/tests/test_wsgi_connector.py index c65fff4..97b9c04 100644 --- a/tests/test_wsgi_connector.py +++ b/tests/test_wsgi_connector.py @@ -1,10 +1,11 @@ -import cgi import io import json from bddrest.connectors import WSGIConnector from bddrest.response import Response +from . import multipart + def test_wagi_streaming(): def app(environ, start_response): @@ -33,16 +34,20 @@ def app(environ, start_response): if environ.get('CONTENT_TYPE', '').startswith('application/json'): result = json.loads(environ['wsgi.input'].read()) else: - form = cgi.FieldStorage( - fp=environ['wsgi.input'], - environ=environ, - strict_parsing=False, - keep_blank_values=True + form, files = multipart.parse_form_data( + environ, + charset="utf8", + strict=True ) result = {} for k in form.keys(): - v = form[k] - value = v.value + value = form[k] + + result[k] = \ + value.decode() if isinstance(value, bytes) else value + + for k in files.keys(): + value = files[k].file.read() result[k] = \ value.decode() if isinstance(value, bytes) else value