From 1e692a4a88c5ce26b610e97c6610fde8744ab4ef Mon Sep 17 00:00:00 2001 From: Amir <0204.amir@gmail.com> Date: Thu, 19 Feb 2026 02:08:53 +0300 Subject: [PATCH] Port to Python 3 and harden unsafe input handling --- requirements.txt | 3 +- setup.py | 28 +++-- simpleapi/__init__.py | 10 +- simpleapi/client/__init__.py | 4 +- simpleapi/client/client.py | 25 ++-- simpleapi/client/dummy.py | 6 +- simpleapi/message/__init__.py | 8 +- simpleapi/message/common.py | 9 +- simpleapi/message/formatter.py | 53 ++++++--- simpleapi/message/py2xml.py | 22 ++-- simpleapi/message/sajson.py | 6 +- simpleapi/message/wrapper.py | 22 ++-- simpleapi/server/__init__.py | 12 +- simpleapi/server/feature.py | 18 +-- simpleapi/server/namespace.py | 15 ++- simpleapi/server/preformat.py | 4 +- simpleapi/server/request.py | 53 +++++---- simpleapi/server/response.py | 20 ++-- simpleapi/server/route.py | 200 +++++++++++++++++++------------- simpleapi/server/routemgr.py | 2 +- simpleapi/server/sapirequest.py | 16 ++- simpleapi/server/serializer.py | 2 +- simpleapi/server/session.py | 2 +- tests/namespace.py | 12 +- tests/route.py | 113 +++++++++--------- tests/test_security.py | 120 +++++++++++++++++++ 26 files changed, 502 insertions(+), 283 deletions(-) create mode 100644 tests/test_security.py diff --git a/requirements.txt b/requirements.txt index e3ad467..a18f532 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ python-dateutil -simplejson \ No newline at end of file +simplejson +defusedxml diff --git a/setup.py b/setup.py index 822f010..1fe0b89 100644 --- a/setup.py +++ b/setup.py @@ -1,23 +1,37 @@ -from setuptools import setup, find_packages +from pathlib import Path + +from setuptools import find_packages, setup + +ROOT = Path(__file__).parent +README = (ROOT / "README.rst").read_text(encoding="utf-8") +REQUIREMENTS = (ROOT / "requirements.txt").read_text(encoding="utf-8").split() setup( name='simpleapi', - version='0.0.9', + version='0.1.0', description='A simple API-framework to provide an easy to use, consistent and portable client/server-architecture (for django, flask and a lot more).', - long_description=open('README.rst').read(), + long_description=README, + long_description_content_type='text/x-rst', author='Florian Schlachter', author_email='flori@n-schlachter.de', - url='http://github.com/flosch/simpleapi/tree/', + url='https://github.com/flosch/simpleapi', packages=find_packages(), + python_requires='>=3.9', classifiers=[ - 'Development Status :: 3 - Alpha', + 'Development Status :: 4 - Beta', 'Environment :: Web Environment', 'Intended Audience :: Developers', 'License :: OSI Approved :: MIT License', 'Operating System :: OS Independent', - 'Programming Language :: Python' + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3 :: Only', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Topic :: Internet :: WWW/HTTP', ], zip_safe=False, test_suite='tests', - install_requires=open("requirements.txt", "r").read().split() + install_requires=REQUIREMENTS, ) diff --git a/simpleapi/__init__.py b/simpleapi/__init__.py index 0fd020a..add32e8 100644 --- a/simpleapi/__init__.py +++ b/simpleapi/__init__.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- -from client import * -from server import * -from message import * +from .client import * +from .server import * +from .message import * __author__ = 'Florian Schlachter' -VERSION = (0, 0, 9) +VERSION = (0, 1, 0) def get_version(): version = '%s.%s' % (VERSION[0], VERSION[1]) @@ -14,4 +14,4 @@ def get_version(): version = '%s.%s' % (version, VERSION[2]) return version -__version__ = get_version() \ No newline at end of file +__version__ = get_version() diff --git a/simpleapi/client/__init__.py b/simpleapi/client/__init__.py index 6fbf229..e59883a 100644 --- a/simpleapi/client/__init__.py +++ b/simpleapi/client/__init__.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- -from client import * -from dummy import * \ No newline at end of file +from .client import * +from .dummy import * \ No newline at end of file diff --git a/simpleapi/client/client.py b/simpleapi/client/client.py index 928dcfd..3f8f2f2 100644 --- a/simpleapi/client/client.py +++ b/simpleapi/client/client.py @@ -3,8 +3,8 @@ __all__ = ('Client', 'ClientException', 'ConnectionException', 'RemoteException', ) import socket -import urllib -import cPickle +import urllib.request, urllib.parse, urllib.error +import pickle from simpleapi.message import formatters, wrappers class ClientException(Exception): pass @@ -51,33 +51,34 @@ def do_call(**kwargs): formatter = formatters[self.transport_type](None, None) - for key, value in kwargs.iteritems(): + for key, value in kwargs.items(): kwargs[key] = formatter.kwargs(value) data.update(kwargs) try: - response = urllib.urlopen(self.ns, - urllib.urlencode(data)) + payload = urllib.parse.urlencode(data).encode('utf-8') + response = urllib.request.urlopen(self.ns, + payload) assert response.getcode() in [200,], \ - u'HTTP-Server returned http code %s (expected: 200) ' % \ + 'HTTP-Server returned http code %s (expected: 200) ' % \ response.getcode() response_buffer = response.read() - except IOError, e: + except IOError as e: raise ConnectionException(e) try: response = formatter.parse(response_buffer) - except (cPickle.UnpicklingError, EOFError), _: + except (pickle.UnpicklingError, EOFError) as _: raise ClientException( - u'Couldn\'t unpickle response ' \ + 'Couldn\'t unpickle response ' \ 'data. Did you added "pickle" to the namespace\'s' \ ' __features__ list?' ) - except ValueError, e: - raise ConnectionException, e + except ValueError as e: + raise ConnectionException(e) if response.get('success'): return response.get('result') @@ -95,4 +96,4 @@ def set_version(self, version): def set_ns(self, ns): """changes the URL for the Route's endpoint""" - self.ns = ns \ No newline at end of file + self.ns = ns diff --git a/simpleapi/client/dummy.py b/simpleapi/client/dummy.py index 767d917..0775d5a 100644 --- a/simpleapi/client/dummy.py +++ b/simpleapi/client/dummy.py @@ -28,7 +28,7 @@ def do_call(**kwargs): formatter = formatters[TRANSPORT_TYPE](None, None) - for key, value in kwargs.iteritems(): + for key, value in kwargs.items(): kwargs[key] = formatter.kwargs(value) data.update(kwargs) @@ -44,8 +44,8 @@ def do_call(**kwargs): try: response = formatter.parse(response_buffer['result']) - except ValueError, e: - raise ConnectionException, e + except ValueError as e: + raise ConnectionException(e) if response.get('success'): return response.get('result') diff --git a/simpleapi/message/__init__.py b/simpleapi/message/__init__.py index 34b6c5e..1717e3e 100644 --- a/simpleapi/message/__init__.py +++ b/simpleapi/message/__init__.py @@ -1,4 +1,4 @@ -from formatter import * -from wrapper import * -from py2xml import * -from extjs import * \ No newline at end of file +from .formatter import * +from .wrapper import * +from .py2xml import * +from .extjs import * \ No newline at end of file diff --git a/simpleapi/message/common.py b/simpleapi/message/common.py index 908cbb5..b2fc896 100644 --- a/simpleapi/message/common.py +++ b/simpleapi/message/common.py @@ -5,14 +5,14 @@ except ImportError: try: from django.utils import simplejson as json - except Exception, e: + except Exception as e: import simplejson as json __all__ = ('json', 'SAException') class SAException(Exception): def __init__(self, msg=None): - super(Exception, self).__init__() + super(SAException, self).__init__(msg) self._message = msg def _get_message(self): @@ -23,5 +23,8 @@ def _set_message(self, message): message = property(_get_message, _set_message) + def __str__(self): + return str(self.message) + def __repr__(self): - return self.message \ No newline at end of file + return str(self.message) diff --git a/simpleapi/message/formatter.py b/simpleapi/message/formatter.py index 7e2a245..ee23e3b 100644 --- a/simpleapi/message/formatter.py +++ b/simpleapi/message/formatter.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- -import cPickle -from common import json +import pickle +import re +from .common import json try: import yaml @@ -9,12 +10,25 @@ except ImportError: has_yaml = False -from py2xml import PythonToXML +from .py2xml import PythonToXML -from sajson import SimpleAPIEncoder, SimpleAPIDecoder +from .sajson import SimpleAPIEncoder, SimpleAPIDecoder __all__ = ('formatters', 'Formatter') + +CALLBACK_RE = re.compile(r'^[A-Za-z_$][0-9A-Za-z_$]*(?:\.[A-Za-z_$][0-9A-Za-z_$]*)*$') + + +def _validate_jsonp_callback(callback): + if callback is None: + return 'simpleapiCallback' + if isinstance(callback, bytes): + callback = callback.decode('utf-8', errors='strict') + if not isinstance(callback, str) or not CALLBACK_RE.match(callback): + raise ValueError('Invalid JSONP callback name.') + return callback + class FormattersSingleton(object): """This singleton takes care of all registered formatters. You can easily register your own formatter for use in both the Namespace and python client. @@ -33,17 +47,16 @@ def register(self, name, formatter, override=False): the given `name`, you can override by setting `override` to ``True``. """ if not isinstance(formatter(None, None), Formatter): - raise TypeError(u"You can only register a Formatter not a %s" % formatter) + raise TypeError("You can only register a Formatter not a %s" % formatter) if name in self._formatters and not override: - raise AttributeError(u"%s is already a valid format type, try a new name" % name) + raise AttributeError("%s is already a valid format type, try a new name" % name) self._formatters[name] = formatter def get_defaults(self): - result = filter(lambda item: getattr(item[1], '__active_by_default__', True), - self._formatters.items()) - return dict(result).keys() + result = [item for item in list(self._formatters.items()) if getattr(item[1], '__active_by_default__', True)] + return list(dict(result).keys()) def copy(self): return dict(**self._formatters) @@ -106,13 +119,15 @@ class JSONPFormatter(Formatter): __mime__ = "application/javascript" def build(self, value): - func = self.callback or 'simpleapiCallback' - result = u'%(func)s(%(data)s)' % {'func': func.decode("utf-8"), 'data': json.dumps(value)} - return result.encode("utf-8") + func = _validate_jsonp_callback(self.callback) + return '%(func)s(%(data)s)' % { + 'func': func, + 'data': json.dumps(value, cls=SimpleAPIEncoder), + } - def kwargs(self, value): + def kwargs(self, value, action='build'): if action == 'build': - return json.dumps(value, cls=SimpleAPIEncoder) + return self.build(value) elif action == 'parse': return self.parse(value) @@ -135,7 +150,7 @@ def kwargs(self, value, action='build'): return self.parse(value) def parse(self, value): - return unicode(value) + return str(value) class PickleFormatter(Formatter): """Formatter for use the cPickle python module which supports python object @@ -150,7 +165,7 @@ class PickleFormatter(Formatter): __active_by_default__ = False def build(self, value): - return cPickle.dumps(value) + return pickle.dumps(value) def kwargs(self, value, action='build'): if action == 'build': @@ -159,9 +174,11 @@ def kwargs(self, value, action='build'): return self.parse(value) def parse(self, value): - if isinstance(value, unicode): + if not getattr(self, 'allow_unsafe_pickle_input', False): + raise ValueError('Pickle input parsing is disabled by default for security reasons.') + if isinstance(value, str): value = value.encode("utf-8") - return cPickle.loads(value) + return pickle.loads(value) class XMLFormatter(Formatter): __mime__ = "text/xml" diff --git a/simpleapi/message/py2xml.py b/simpleapi/message/py2xml.py index ff06f4a..284cfbe 100644 --- a/simpleapi/message/py2xml.py +++ b/simpleapi/message/py2xml.py @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- -from xml.etree import cElementTree as ET +try: + # Prefer hardened parser to reduce XML-based DoS attack surface. + from defusedxml import ElementTree as ET +except ImportError: # pragma: no cover - fallback for minimal environments + from xml.etree import cElementTree as ET from dateutil.parser import parse __all__ = ('PythonToXML',) @@ -56,7 +60,7 @@ def build_int(self, value): def build_long(self, value): element = self.create_item('long') - element.text = str(long(value)) + element.text = str(int(value)) return element def build_float(self, value): @@ -83,7 +87,7 @@ def build_tuple(self, value): def build_dict(self, value): root = self.create_item('dict') - for key, value in value.iteritems(): + for key, value in value.items(): element = self.handle(value) element.set('name', key) root.append(element) @@ -108,25 +112,25 @@ def parse_time(self, element): def parse_dict(self, element): tmp = {} - for item in element.getchildren(): + for item in list(element): tmp[item.get('name')] = self.handle(item, 'parse') return tmp def parse_list(self, element): tmp = [] - for item in element.getchildren(): + for item in list(element): tmp.append(self.handle(item, 'parse')) return tmp def parse_set(self, element): tmp = [] - for item in element.getchildren(): + for item in list(element): tmp.append(self.handle(item, 'parse')) return set(tmp) def parse_tuple(self, element): tmp = [] - for item in element.getchildren(): + for item in list(element): tmp.append(self.handle(item, 'parse')) return tuple(tmp) @@ -140,7 +144,7 @@ def parse_int(self, element): return int(element.text) def parse_long(self, element): - return long(element.text) + return int(element.text) def parse_float(self, element): return float(element.text) @@ -159,4 +163,4 @@ def build(self, value): def parse(self, value): root = ET.fromstring(value) - return self.handle(root, op='parse') \ No newline at end of file + return self.handle(root, op='parse') diff --git a/simpleapi/message/sajson.py b/simpleapi/message/sajson.py index a6f3cbd..0933969 100644 --- a/simpleapi/message/sajson.py +++ b/simpleapi/message/sajson.py @@ -3,7 +3,7 @@ import re import datetime from dateutil.parser import parse -from common import json +from .common import json __all__ = ('SimpleAPIEncoder', 'SimpleAPIDecoder') @@ -27,8 +27,8 @@ def __init__(self, *args, **kwargs): self.object_hook = self.hook def hook(self, obj): - for key, val in obj.iteritems(): - if isinstance(val, basestring) and (date_re.match(val) \ + for key, val in obj.items(): + if isinstance(val, str) and (date_re.match(val) \ or time_re.match(val)): try: obj[key] = parse(val) diff --git a/simpleapi/message/wrapper.py b/simpleapi/message/wrapper.py index 0f44d2a..6578499 100644 --- a/simpleapi/message/wrapper.py +++ b/simpleapi/message/wrapper.py @@ -23,10 +23,10 @@ def register(self, name, wrapper, override=False): Register the given wrapper """ if not isinstance(wrapper(None, ), Wrapper): - raise TypeError(u"You can only register a Wrapper not a %s" % wrapper) + raise TypeError("You can only register a Wrapper not a %s" % wrapper) if name in self._wrappers and not override: - raise AttributeError(u"%s is already a valid wrapper type, try a new name" % name) + raise AttributeError("%s is already a valid wrapper type, try a new name" % name) self._wrappers[name] = wrapper @@ -52,7 +52,7 @@ def __init__(self, sapi_request): self.session = getattr(sapi_request, 'session', None) def _build(self, errors, result): - if isinstance(errors, basestring): + if isinstance(errors, str): errors = [errors,] if errors: @@ -88,9 +88,9 @@ def build(self, errors, result): class ExtJSWrapper(Wrapper): @staticmethod def build_errors(errors): - assert isinstance(errors, (basestring, tuple, list)) + assert isinstance(errors, (str, tuple, list)) - if isinstance(errors, basestring) or \ + if isinstance(errors, str) or \ (isinstance(errors, (tuple, list)) and \ len(errors) == 1): return { @@ -99,7 +99,7 @@ def build_errors(errors): elif isinstance(errors, (tuple, list)) and \ len(errors) > 0: errmsg, errors = errors[0], errors[1] - assert isinstance(errmsg, basestring) + assert isinstance(errmsg, str) assert isinstance(errors, dict) return { @@ -179,7 +179,7 @@ def build(self, errors, result): def parse(self, items): if len(items) == 1: # check for a batch request - key, value = items.items()[0] + key, value = list(items.items())[0] if value == '': data = json.loads(key) if isinstance(data, dict): @@ -188,15 +188,15 @@ def parse(self, items): for item in data: yield self.parse_item(item) else: - raise ValueError(u'Unsupported input format.') + raise ValueError('Unsupported input format.') else: - raise ValueError(u'Unsupported input format.') + raise ValueError('Unsupported input format.') else: s = self.parse_item(items) yield s def parse_item(self, data): - if data.has_key('extMethod'): + if 'extMethod' in data: # formHandler true d = { '_call': data.pop('extMethod', ''), @@ -228,7 +228,7 @@ def parse_item(self, data): if data.get('data') and len(data['data']) > 0 and \ not isinstance(data['data'][0], dict): - raise ValueError(u'data must be a hashable/an array of key/value arguments') + raise ValueError('data must be a hashable/an array of key/value arguments') tid = data.pop('tid', '') action = data.pop('action', '') diff --git a/simpleapi/server/__init__.py b/simpleapi/server/__init__.py index 6ea498d..a228b0b 100644 --- a/simpleapi/server/__init__.py +++ b/simpleapi/server/__init__.py @@ -3,9 +3,9 @@ __all__ = ('Route', 'Namespace', 'Feature', 'FeatureContentResponse', 'serialize', 'UnformattedResponse', 'RouteMgr') -from route import Route -from routemgr import RouteMgr -from namespace import Namespace -from feature import Feature, FeatureContentResponse -from serializer import serialize -from response import UnformattedResponse \ No newline at end of file +from .route import Route +from .routemgr import RouteMgr +from .namespace import Namespace +from .feature import Feature, FeatureContentResponse +from .serializer import serialize +from .response import UnformattedResponse \ No newline at end of file diff --git a/simpleapi/server/feature.py b/simpleapi/server/feature.py index 408a1b2..62ab3d3 100644 --- a/simpleapi/server/feature.py +++ b/simpleapi/server/feature.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -import cPickle +import pickle import hashlib from simpleapi.message.common import SAException @@ -38,11 +38,11 @@ def get_config_scope(self, request): function_location = hasattr(request.session.function['method'], conf_name) if function_location: - return u'local:%s' % request.session.function['name'] + return 'local:%s' % request.session.function['name'] else: return 'global' elif hasattr(self, '__function_config__'): - return u'local:%s' % request.session.function['name'] + return 'local:%s' % request.session.function['name'] elif hasattr(self, '__class_config__'): return 'global' @@ -118,7 +118,7 @@ def handle_request(self, request): caching_config = self.get_config(request) if caching_config: - arg_signature = hashlib.md5(cPickle.dumps( + arg_signature = hashlib.md5(pickle.dumps( request.session.arguments)).hexdigest() timeout = 60 * 60 @@ -138,7 +138,7 @@ def handle_request(self, request): content = cache.get(key) if content: - raise FeatureContentResponse(cPickle.loads(content)) + raise FeatureContentResponse(pickle.loads(content)) else: request.session.cache_timeout = timeout request.session.cache_key = key @@ -149,7 +149,7 @@ def handle_response(self, response): if hasattr(response.session, 'want_cached') and not response.errors: cache.set( response.session.cache_key, - cPickle.dumps(response.result), + pickle.dumps(response.result), response.session.cache_timeout ) @@ -182,7 +182,7 @@ def handle_request(self, request): if rps > 0: no = cache.get(rps_key, 1) if no >= rps: - self.error(u'Throttling active (exceeded %s #/sec.)' % no) + self.error('Throttling active (exceeded %s #/sec.)' % no) else: try: cache.incr(rps_key) # FIXME: using incr() eliminates the timeout! @@ -192,7 +192,7 @@ def handle_request(self, request): if rpm > 0: no = cache.get(rpm_key, 1) if no >= rpm: - self.error(u'Throttling active (exceeded %s #/min.)' % no) + self.error('Throttling active (exceeded %s #/min.)' % no) else: try: cache.incr(rpm_key) @@ -202,7 +202,7 @@ def handle_request(self, request): if rph > 0: no = cache.get(rph_key, 1) if no >= rph: - self.error(u'Throttling active (exceeded %s #/hour)' % no) + self.error('Throttling active (exceeded %s #/hour)' % no) else: try: cache.incr(rph_key) diff --git a/simpleapi/server/namespace.py b/simpleapi/server/namespace.py index 3be4cf6..5c51bef 100644 --- a/simpleapi/server/namespace.py +++ b/simpleapi/server/namespace.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from simpleapi.message.common import json, SAException -from response import UnformattedResponse +from .response import UnformattedResponse __all__ = ('Namespace', 'NamespaceException') @@ -30,7 +30,7 @@ def introspect(self, framework='default', provider='Ext.app', functions = {} for cls in ('forms', 'direct'): functions[cls] = [] - for fn in function_map.iterkeys(): + for fn in function_map.keys(): if len(function_map[fn]['args']['all']) > 0: fnlen = 1 else: @@ -45,25 +45,24 @@ def introspect(self, framework='default', provider='Ext.app', result = { 'actions': { self.request.route.name: functions['direct'], - u'%s_forms' % self.request.route.name: functions['forms'], + '%s_forms' % self.request.route.name: functions['forms'], } } result['type'] = 'remoting' - result['url'] = u'%s?_wrapper=extjsdirect' % \ + result['url'] = '%s?_wrapper=extjsdirect' % \ self.session.request.path_info if namespace: result['namespace'] = namespace return UnformattedResponse( - content=u'%s.%s_REMOTING_API = %s;' %\ + content='%s.%s_REMOTING_API = %s;' %\ (provider, self.request.route.name.upper(), json.dumps(result)), mimetype='text/javascript' ) else: functions = [] - for fn in function_map.iterkeys(): - print function_map[fn]['args'] + for fn in function_map.keys(): optionals = list(set(function_map[fn]['args']['all']) - \ set(function_map[fn]['args']['obligatory'])) functions.append({ @@ -88,4 +87,4 @@ def introspect(self, framework='default', provider='Ext.app', } } - return result \ No newline at end of file + return result diff --git a/simpleapi/server/preformat.py b/simpleapi/server/preformat.py index fdd410e..d715e39 100644 --- a/simpleapi/server/preformat.py +++ b/simpleapi/server/preformat.py @@ -13,7 +13,7 @@ except ImportError: has_mongoengine = False -from serializer import SerializedObject +from .serializer import SerializedObject __all__ = () @@ -48,7 +48,7 @@ def handle_list(self, old_list): def handle_dict(self, old_dict): new_dict = {} - for key, value in old_dict.iteritems(): + for key, value in old_dict.items(): new_dict[key] = self.handle_value(value) return new_dict diff --git a/simpleapi/server/request.py b/simpleapi/server/request.py index 7c95137..f59eecb 100644 --- a/simpleapi/server/request.py +++ b/simpleapi/server/request.py @@ -9,8 +9,8 @@ except ImportError: has_debug = False -from response import Response -from feature import FeatureContentResponse +from .response import Response +from .feature import FeatureContentResponse from simpleapi.message import formatters from simpleapi.message.common import SAException @@ -76,48 +76,48 @@ def process_request(self, request_items): # check the method if not method: - raise RequestException(u'Method must be provided.') + raise RequestException('Method must be provided.') # check whether method exists - if not self.namespace['functions'].has_key(method): - raise RequestException(u'Method %s does not exist.' % method) + if method not in self.namespace['functions']: + raise RequestException('Method %s does not exist.' % method) # check authentication if not self.namespace['authentication'](local_namespace, access_key): - raise RequestException(u'Authentication failed.') + raise RequestException('Authentication failed.') # check ip address if not self.namespace['ip_restriction'](local_namespace, \ self.sapi_request.remote_addr): - raise RequestException(u'You are not allowed to access.') + raise RequestException('You are not allowed to access.') function = self.namespace['functions'][method] self.session.function = function # check allowed HTTP methods if not function['methods']['function'](self.sapi_request.method, function['methods']['allowed_methods']): - raise RequestException(u'Method not allowed: %s' % self.sapi_request.method) + raise RequestException('Method not allowed: %s' % self.sapi_request.method) # if data is set, make sure input formatter is not ValueFormatter if data: if isinstance(self.input_formatter, formatters['value']): - raise RequestException(u'If you\'re using _data please make ' \ + raise RequestException('If you\'re using _data please make ' \ 'sure you set _input and _input is not ' \ '\'value\'.') try: request_items = self.input_formatter.kwargs(data, 'parse') - except ValueError, _: - raise RequestException(u'Data couldn\'t be decoded. ' \ + except ValueError as _: + raise RequestException('Data couldn\'t be decoded. ' \ 'Please check _input and your _data') else: if not isinstance(request_items, dict): - raise RequestException(u'_data must be an array/dictionary') + raise RequestException('_data must be an array/dictionary') # check whether all obligatory arguments are given ungiven_obligatory_args = list(set(function['args']['obligatory']) - \ set(request_items.keys())) if ungiven_obligatory_args: - raise RequestException(u'Obligatory argument(s) missing: %s' % \ + raise RequestException('Obligatory argument(s) missing: %s' % \ ", ".join(ungiven_obligatory_args)) # check whether there are more arguments than needed @@ -127,7 +127,7 @@ def process_request(self, request_items): if unused_arguments: if not self.ignore_unused_args: - raise RequestException(u'Unused arguments: %s' % \ + raise RequestException('Unused arguments: %s' % \ ", ".join(unused_arguments)) else: for key in unused_arguments: @@ -136,29 +136,28 @@ def process_request(self, request_items): # decode incoming variables (only if _data is not set!) if not data: new_request_items = {} - for key, value in request_items.iteritems(): + for key, value in request_items.items(): try: new_request_items[str(key)] = self.input_formatter.kwargs(value, 'parse') - except ValueError, _: - raise - raise RequestException(u'Value for %s couldn\'t be decoded.' % \ + except ValueError as _: + raise RequestException('Value for %s couldn\'t be decoded.' % \ key) request_items = new_request_items else: # make sure all keys are strings, not unicodes (for compatibility # issues: Python < 2.6.5) new_request_items = {} - for key, value in request_items.iteritems(): + for key, value in request_items.items(): new_request_items[str(key)] = value request_items = new_request_items # check constraints - for key, value in request_items.iteritems(): + for key, value in request_items.items(): try: request_items[key] = function['constraints']['function']( local_namespace, key, value) except (ValueError,): - raise RequestException(u'Constraint failed for argument: %s' % key) + raise RequestException('Constraint failed for argument: %s' % key) # we're done working on arguments, pass it to the session self.session.arguments = request_items @@ -167,7 +166,7 @@ def process_request(self, request_items): try: for feature in self.namespace['features']: feature._handle_request(self) - except FeatureContentResponse, e: + except FeatureContentResponse as e: result = e else: # call before_request @@ -178,10 +177,10 @@ def process_request(self, request_items): try: if self.debug: _, fname = tempfile.mkstemp() - self.route.logger.debug(u"Profiling call '%s': %s" % \ + self.route.logger.debug("Profiling call '%s': %s" % \ (method, fname)) - self.route.logger.debug(u"Calling parameters: %s" % \ + self.route.logger.debug("Calling parameters: %s" % \ pprint.pformat(request_items)) profile = cProfile.Profile() @@ -189,13 +188,13 @@ def process_request(self, request_items): **request_items) profile.dump_stats(fname) - self.route.logger.debug(u"Loading stats...") + self.route.logger.debug("Loading stats...") stats = pstats.Stats(fname) stats.strip_dirs().sort_stats('time', 'calls') \ .print_stats(25) else: result = getattr(local_namespace, method)(**request_items) - except Exception, e: + except Exception as e: if has_django and isinstance(e, django_notexist): raise RequestException(e) elif has_mongoengine and isinstance(e, mongoengine_notexist): @@ -217,4 +216,4 @@ def process_request(self, request_items): else: response = result - return response \ No newline at end of file + return response diff --git a/simpleapi/server/response.py b/simpleapi/server/response.py index 4e9e0a8..3c3d659 100644 --- a/simpleapi/server/response.py +++ b/simpleapi/server/response.py @@ -15,7 +15,7 @@ has_flask = False from simpleapi.message import formatters, wrappers -from preformat import Preformatter +from .preformat import Preformatter __all__ = ('Response', 'ResponseMerger', 'ResponseException', 'UnformattedResponse') @@ -48,7 +48,13 @@ def build(self): )) # TODO FIXME XXX: only JSON is supported - result = u'[%s]' % u','.join(map(lambda x: x['result'], results)) + serialized = [] + for built in results: + item = built['result'] + if isinstance(item, bytes): + item = item.decode('utf-8') + serialized.append(item) + result = '[%s]' % ','.join(serialized) return Response._build_response_obj( sapi_request=self.sapi_request, @@ -59,21 +65,21 @@ def build(self): ) -class ResponseException(object): pass +class ResponseException(Exception): pass class Response(object): def __init__(self, sapi_request, namespace=None, output_formatter=None, wrapper=None, errors=None, result=None, mimetype=None, callback=None, function=None): - assert isinstance(errors, (basestring, list)) or errors is None + assert isinstance(errors, (str, list)) or errors is None self.sapi_request = sapi_request self.namespace = namespace self.errors = errors self.result = self._preformat(result) self.mimetype = mimetype - self.callback = None + self.callback = callback self.function = function self.output_formatter = output_formatter or formatters['json'] @@ -91,7 +97,7 @@ def add_error(self, errmsg): else: if isinstance(self.errors, list): self.errors.append(errmsg) - elif isinstance(self.errors, basestring): + elif isinstance(self.errors, str): self.errors = [self.errors, errmsg] def _preformat(self, value): @@ -161,4 +167,4 @@ def _build_response_obj(sapi_request, response): content_type=response['mimetype'] ) else: - return response \ No newline at end of file + return response diff --git a/simpleapi/server/route.py b/simpleapi/server/route.py index 791e7bc..37bba95 100644 --- a/simpleapi/server/route.py +++ b/simpleapi/server/route.py @@ -18,17 +18,17 @@ except ImportError: has_debug = False -import urlparse -import cgi +import urllib.parse from wsgiref.simple_server import make_server -from wsgiref.handlers import SimpleHandler + +from collections import namedtuple SIMPLEAPI_DEBUG = bool(int(os.environ.get('SIMPLEAPI_DEBUG', 0))) SIMPLEAPI_DEBUG_FILENAME = os.environ.get('SIMPLEAPI_DEBUG_FILENAME', 'simpleapi.profile') SIMPLEAPI_DEBUG_LEVEL = os.environ.get('SIMPLEAPI_DEBUG_LEVEL', 'all') assert SIMPLEAPI_DEBUG_LEVEL in ['all', 'call'], \ - u'SIMPLEAPI_DEBUG_LEVEL must be one of these: all, call' + 'SIMPLEAPI_DEBUG_LEVEL must be one of these: all, call' if SIMPLEAPI_DEBUG and not has_debug: SIMPLEAPI_DEBUG = False @@ -37,6 +37,7 @@ TRIGGERED_METHODS = ['get', 'post', 'put', 'delete'] FRAMEWORKS = ['flask', 'django', 'appengine', 'dummy', 'standalone', 'wsgi'] MAX_CONTENT_LENGTH = 1024 * 1024 * 16 # 16 megabytes +JSONP_CALLBACK_RE = re.compile(r'^[A-Za-z_$][0-9A-Za-z_$]*(?:\.[A-Za-z_$][0-9A-Za-z_$]*)*$') restricted_functions = [ 'before_request', @@ -50,16 +51,26 @@ has_appengine = False from simpleapi.message.common import SAException -from sapirequest import SAPIRequest -from request import Request, RequestException -from response import Response, ResponseMerger, ResponseException -from namespace import NamespaceException -from feature import __features__, Feature, FeatureException +from .sapirequest import SAPIRequest +from .request import Request, RequestException +from .response import Response, ResponseMerger, ResponseException +from .namespace import NamespaceException +from .feature import __features__, Feature, FeatureException from simpleapi.message import formatters, wrappers -from utils import glob_list +from .utils import glob_list __all__ = ('Route', ) + +if not hasattr(inspect, 'getargspec'): + ArgSpec = namedtuple('ArgSpec', 'args varargs keywords defaults') + + def _getargspec(func): + full = inspect.getfullargspec(func) + return ArgSpec(full.args, full.varargs, full.varkw, full.defaults) + + inspect.getargspec = _getargspec + class Route(object): def __new__(cls, *args, **kwargs): @@ -106,6 +117,7 @@ def __init__(self, *namespaces, **kwargs): self.nmap = {} self.debug = kwargs.pop('debug', False) self.ignore_unused_args = kwargs.pop('ignore_unused_args', False) + self.allow_custom_mimetype = kwargs.pop('allow_custom_mimetype', False) if self.debug and not has_debug: self.debug = False @@ -115,8 +127,8 @@ def __init__(self, *namespaces, **kwargs): self.framework = kwargs.pop('framework', 'django') self.path = re.compile(kwargs.pop('path', r'^/')) - assert len(kwargs) == 0, u'Unknown Route configuration(s) (%s)' % \ - ", ".join(kwargs.keys()) + assert len(kwargs) == 0, 'Unknown Route configuration(s) (%s)' % \ + ", ".join(list(kwargs.keys())) # make shortcut self._caller = self.__call__ @@ -124,7 +136,7 @@ def __init__(self, *namespaces, **kwargs): assert self.framework in FRAMEWORKS assert (self.debug ^ SIMPLEAPI_DEBUG) or \ not (self.debug and SIMPLEAPI_DEBUG), \ - u'You can either activate Route-debug or simpleapi-debug, not both.' + 'You can either activate Route-debug or simpleapi-debug, not both.' if self.debug or SIMPLEAPI_DEBUG: self.logger.setLevel(logging.DEBUG) @@ -146,9 +158,9 @@ def handle_request(self, environ, start_response): if not self.path.match(environ.get('PATH_INFO')): status = '404 Not found' start_response(status, []) - return ["Entry point not found"] + return [b"Entry point not found"] else: - content_type = environ.get('CONTENT_TYPE') + content_type = (environ.get('CONTENT_TYPE') or '').split(';', 1)[0].strip().lower() try: content_length = int(environ['CONTENT_LENGTH']) except (KeyError, ValueError): @@ -159,7 +171,7 @@ def handle_request(self, environ, start_response): if content_length > MAX_CONTENT_LENGTH: status = '413 Request entity too large' start_response(status, []) - return ["Request entity too large"] + return [b"Request entity too large"] request_method = environ.get('REQUEST_METHOD', '').lower() @@ -167,28 +179,28 @@ def handle_request(self, environ, start_response): if not request_method in TRIGGERED_METHODS: status = '501 Not Implemented' start_response(status, []) - return ["Not Implemented"] + return [b"Not Implemented"] - query_get = urlparse.parse_qs(environ.get('QUERY_STRING')) - for key, value in query_get.iteritems(): + query_get = urllib.parse.parse_qs( + environ.get('QUERY_STRING', ''), + keep_blank_values=True, + ) + for key, value in query_get.items(): query_get[key] = value[0] # respect the first value only query_post = {} if content_type in ['application/x-www-form-urlencoded', 'application/x-url-encoded']: - post_env = environ.copy() - post_env['QUERY_STRING'] = '' - fs = cgi.FieldStorage( - fp=environ['wsgi.input'], - environ=post_env, - keep_blank_values=True + post_body = environ['wsgi.input'].read(content_length) if content_length > 0 else b'' + post_data = urllib.parse.parse_qs( + post_body.decode('utf-8', errors='replace'), + keep_blank_values=True, ) - query_post = {} - for key in fs: - query_post[key] = fs.getvalue(key) + for key, value in post_data.items(): + query_post[key] = value[0] elif content_type == 'multipart/form-data': # XXX TODO - raise NotImplementedError, u'Currently not supported.' + raise NotImplementedError('Currently not supported.') # GET + POST query_data = query_get @@ -196,7 +208,7 @@ def handle_request(self, environ, start_response): # Make request request = StandaloneRequest() - request.method = request_method + request.method = request_method.upper() request.data = query_data request.remote_addr = environ.get('REMOTE_ADDR', '') @@ -204,17 +216,26 @@ def handle_request(self, environ, start_response): result = self._caller(request) status = '200 OK' + body = result['result'] + if isinstance(body, str): + body = body.encode('utf-8') + elif isinstance(body, bytes): + pass + elif body is None: + body = b'' + else: + body = str(body).encode('utf-8') headers = [('Content-type', result['mimetype'])] start_response(status, headers) - return [result['result'],] + return [body] def serve(self, host='', port=5050): httpd = make_server(host, port, self.handle_request) - self.logger.info(u"Started serving on port %d..." % port) + self.logger.info("Started serving on port %d..." % port) try: httpd.serve_forever() except KeyboardInterrupt: - self.logger.info(u"Server stopped.") + self.logger.info("Server stopped.") def profile_start(self): assert has_debug @@ -228,7 +249,7 @@ def profile_stop(self): def profile_stats(self): assert has_debug - self.logger.debug(u"Loading stats...") + self.logger.debug("Loading stats...") stats = pstats.Stats(SIMPLEAPI_DEBUG_FILENAME) stats.strip_dirs().sort_stats('time', 'calls') \ .print_stats() @@ -257,12 +278,12 @@ def _redefine_default_namespace(self): # - recalculate default namespace version - # if map has no default version, determine namespace with the # highest version - if self.nmap.has_key('default'): + if 'default' in self.nmap: del self.nmap['default'] self.nmap['default'] = self.nmap[max(self.nmap.keys())] def remove_namespace(self, version): - if self.nmap.has_key(version): + if version in self.nmap: del self.nmap[version] self._redefine_default_namespace() return True @@ -272,10 +293,10 @@ def remove_namespace(self, version): def add_namespace(self, namespace): version = getattr(namespace, '__version__', 1) assert isinstance(version, int), \ - u'version must be either an integer or not set' + 'version must be either an integer or not set' # make sure no version is assigned twice - assert not self.nmap.has_key(version), u'version is assigned twice' + assert version not in self.nmap, 'version is assigned twice' allowed_functions = [] @@ -284,17 +305,16 @@ def add_namespace(self, namespace): allowed_functions.append('introspect') # determine public and published functions - functions = filter(lambda item: '__' not in item[0] and item[0] not in + functions = [item for item in inspect.getmembers(namespace) if '__' not in item[0] and item[0] not in restricted_functions and ((getattr(item[1], 'published', False) == - True) or item[0] in allowed_functions), - inspect.getmembers(namespace)) + True) or item[0] in allowed_functions)] # determine arguments of each function functions = dict(functions) - for function_name, function_method in functions.iteritems(): + for function_name, function_method in functions.items(): # check for reserved function names assert function_name not in ['error', '__init__', 'get_name'],\ - u'Name %s is reserved.' % function_name + 'Name %s is reserved.' % function_name # ArgSpec(args=['self', 'a', 'b'], varargs=None, keywords=None, defaults=None) raw_args = inspect.getargspec(function_method) @@ -307,10 +327,10 @@ def add_namespace(self, namespace): # build a dict of optional arguments if raw_args[3] is not None: - default_args = zip( + default_args = list(zip( raw_args[0][-len(raw_args[3]):], raw_args[3] - ) + )) default_args = dict(default_args) else: default_args = {} @@ -333,7 +353,7 @@ def check(namespace, key, value): if constraint.match(value): return value else: - raise ValueError(u'%s does not match constraint') + raise ValueError('%s does not match constraint') else: if isinstance(constraint, bool): return bool(int(value)) @@ -386,7 +406,7 @@ def check(namespace, key, value): # configure authentication if hasattr(namespace, '__authentication__'): authentication = namespace.__authentication__ - if isinstance(authentication, basestring): + if isinstance(authentication, str): if hasattr(namespace, authentication): authentication = getattr(namespace, authentication) else: @@ -413,28 +433,42 @@ def check(namespace, key, value): # accept every ip address ip_restriction = lambda namespace, ip: True + allow_unsafe_pickle_input = bool( + getattr(namespace, '__allow_unsafe_pickle_input__', False) or + getattr(namespace, '__allow_unsafe_pickle__', False) + ) + # configure input formatters input_formatters = formatters.copy() - allowed_formatters = getattr(namespace, '__input__', - formatters.get_defaults()) - input_formatters = filter(lambda i: i[0] in allowed_formatters, - input_formatters.items()) + allowed_formatters = list(getattr(namespace, '__input__', + formatters.get_defaults())) + if 'pickle' in allowed_formatters and not allow_unsafe_pickle_input: + warnings.warn( + "Pickle input formatter is disabled by default for security reasons. " + "Set __allow_unsafe_pickle_input__ = True to re-enable it.", + RuntimeWarning, + ) + allowed_formatters = [fmt for fmt in allowed_formatters if fmt != 'pickle'] + if allow_unsafe_pickle_input and not hasattr(namespace, '__authentication__'): + raise AssertionError( + "Enabling unsafe pickle input requires explicit __authentication__." + ) + + input_formatters = [i for i in list(input_formatters.items()) if i[0] in allowed_formatters] input_formatters = dict(input_formatters) # configure output formatters output_formatters = formatters.copy() allowed_formatters = getattr(namespace, '__output__', formatters.get_defaults()) - output_formatters = filter(lambda i: i[0] in allowed_formatters, - output_formatters.items()) + output_formatters = [i for i in list(output_formatters.items()) if i[0] in allowed_formatters] output_formatters = dict(output_formatters) # configure wrappers useable_wrappers = wrappers.copy() if hasattr(namespace, '__wrapper__'): allowed_wrapper = namespace.__wrapper__ - useable_wrappers = filter(lambda i: i[0] in allowed_wrapper, - useable_wrappers.items()) + useable_wrappers = [i for i in list(useable_wrappers.items()) if i[0] in allowed_wrapper] useable_wrappers = dict(useable_wrappers) self.nmap[version] = { @@ -442,6 +476,7 @@ def check(namespace, key, value): 'functions': functions, 'ip_restriction': ip_restriction, 'authentication': authentication, + 'allow_unsafe_pickle_input': allow_unsafe_pickle_input, 'input_formatters': input_formatters, 'output_formatters': output_formatters, 'wrappers': useable_wrappers, @@ -452,12 +487,12 @@ def check(namespace, key, value): if hasattr(namespace, '__features__'): raw_features = namespace.__features__ for feature in raw_features: - assert isinstance(feature, basestring) or \ + assert isinstance(feature, str) or \ issubclass(feature, Feature) - if isinstance(feature, basestring): - assert feature in __features__.keys(), \ - u'%s is not a built-in feature' % feature + if isinstance(feature, str): + assert feature in list(__features__.keys()), \ + '%s is not a built-in feature' % feature features.append(__features__[feature](self.nmap[version])) elif issubclass(feature, Feature): @@ -470,7 +505,7 @@ def check(namespace, key, value): def __call__(self, http_request=None, **urlparameters): sapi_request = SAPIRequest(self, http_request) - request_items = dict(sapi_request.REQUEST.items()) + request_items = dict(list(sapi_request.REQUEST.items())) request_items.update(urlparameters) if SIMPLEAPI_DEBUG and SIMPLEAPI_DEBUG_LEVEL == 'call': @@ -496,11 +531,15 @@ def __call__(self, http_request=None, **urlparameters): wrapper_instance = None try: + if callback is not None and (not isinstance(callback, str) or not JSONP_CALLBACK_RE.match(callback)): + raise RequestException('Invalid JSONP callback name.') + if mimetype and not self.allow_custom_mimetype: + raise RequestException('_mimetype override is disabled.') try: version = int(version) except (ValueError, TypeError): pass - if not self.nmap.has_key(version): + if version not in self.nmap: # continue with wrong version to get the formatters/wrappers # raise the error later! namespace = self.nmap['default'] @@ -509,15 +548,16 @@ def __call__(self, http_request=None, **urlparameters): # check input formatter if input_formatter not in namespace['input_formatters']: - raise RequestException(u'Input formatter not allowed or ' \ + raise RequestException('Input formatter not allowed or ' \ 'unknown: %s' % input_formatter) # get input formatter - input_formatter_instancec = namespace['input_formatters'][input_formatter](sapi_request, callback) + input_formatter_instance = namespace['input_formatters'][input_formatter](sapi_request, callback) + input_formatter_instance.allow_unsafe_pickle_input = namespace['allow_unsafe_pickle_input'] # check output formatter if output_formatter not in namespace['output_formatters']: - raise RequestException(u'Output formatter not allowed or ' \ + raise RequestException('Output formatter not allowed or ' \ 'unknown: %s' % output_formatter) # get output formatter @@ -525,21 +565,21 @@ def __call__(self, http_request=None, **urlparameters): # check wrapper if wrapper not in namespace['wrappers']: - raise RequestException(u'Wrapper unknown or not allowed: %s' % \ + raise RequestException('Wrapper unknown or not allowed: %s' % \ wrapper) # get wrapper wrapper_instance = namespace['wrappers'][wrapper] # check whether version exists or not - if not self.nmap.has_key(version): - raise RouterException(u'Version %s not found (possible: %s)' % \ - (version, ", ".join(map(lambda i: str(i), self.nmap.keys())))) + if version not in self.nmap: + raise RouterException('Version %s not found (possible: %s)' % \ + (version, ", ".join([str(i) for i in list(self.nmap.keys())]))) request = Request( sapi_request=sapi_request, namespace=namespace, - input_formatter=input_formatter_instancec, + input_formatter=input_formatter_instance, output_formatter=output_formatter_instance, wrapper=wrapper_instance, callback=callback, @@ -566,10 +606,10 @@ def __call__(self, http_request=None, **urlparameters): try: responses.append(request.process_request(request_item)) except (NamespaceException, RequestException, \ - ResponseException, RouterException, FeatureException),e: + ResponseException, RouterException, FeatureException) as e: response = Response( sapi_request, - errors=e.message, + errors=str(e), output_formatter=output_formatter_instance, wrapper=wrapper_instance, mimetype=mimetype @@ -581,28 +621,28 @@ def __call__(self, http_request=None, **urlparameters): responses=responses, ) http_response = rm.build() - except Exception, e: + except Exception as e: if isinstance(e, (NamespaceException, RequestException, \ ResponseException, RouterException, \ FeatureException)): err_msg = repr(e) else: - err_msg = u'An internal error occurred during your request.' + err_msg = 'An internal error occurred during your request.' trace = inspect.trace() msgs = [] msgs.append('') - msgs.append(u"******* Exception raised *******") - msgs.append(u'Exception type: %s' % type(e)) - msgs.append(u'Exception msg: %s' % repr(e)) + msgs.append("******* Exception raised *******") + msgs.append('Exception type: %s' % type(e)) + msgs.append('Exception msg: %s' % repr(e)) msgs.append('') - msgs.append(u'------- Traceback follows -------') + msgs.append('------- Traceback follows -------') for idx, item in enumerate(trace): - msgs.append(u"(%s)\t%s:%s (%s)" % + msgs.append("(%s)\t%s:%s (%s)" % (idx+1, item[3], item[2], item[1])) if item[4]: for line in item[4]: - msgs.append(u"\t\t%s" % line.strip()) + msgs.append("\t\t%s" % line.strip()) msgs.append('') # blank line msgs.append(' -- End of traceback -- ') msgs.append('') @@ -625,4 +665,4 @@ def __call__(self, http_request=None, **urlparameters): self.profile_stop() self.profile_stats() - return http_response \ No newline at end of file + return http_response diff --git a/simpleapi/server/routemgr.py b/simpleapi/server/routemgr.py index 95d10a5..c3f8dca 100644 --- a/simpleapi/server/routemgr.py +++ b/simpleapi/server/routemgr.py @@ -14,6 +14,6 @@ def __init__(self, *routes): def __call__(self, *args, **kwargs): route_name = kwargs.pop('name') - if self.routes.has_key(route_name): + if route_name in self.routes: return self.routes[route_name](*args, **kwargs) raise RouteNotFound(route_name) \ No newline at end of file diff --git a/simpleapi/server/sapirequest.py b/simpleapi/server/sapirequest.py index bec6f83..c31993c 100644 --- a/simpleapi/server/sapirequest.py +++ b/simpleapi/server/sapirequest.py @@ -6,7 +6,7 @@ except ImportError: has_flask = False -from session import Session +from .session import Session __all__ = ('SAPIRequest', ) @@ -23,7 +23,7 @@ def __init__(self, route, request=None): elif route.is_appengine(): request = route.request else: - raise ValueError(u'HttpRequest-object is missing') + raise ValueError('HttpRequest-object is missing') self.request = request @@ -56,10 +56,16 @@ def REQUEST(self): if self.route.is_flask(): return self.request.form or self.request.args elif self.route.is_django(): - return self.request.REQUEST + if hasattr(self.request, 'REQUEST'): + return self.request.REQUEST + request_data = {} + if hasattr(self.request, 'GET'): + request_data.update(dict(self.request.GET)) + if hasattr(self.request, 'POST'): + request_data.update(dict(self.request.POST)) + return request_data elif self.route.is_appengine(): - return dict(map(lambda i: (i, self.request.get(i)), \ - self.request.arguments())) + return dict([(i, self.request.get(i)) for i in self.request.arguments()]) elif self.route.is_dummy() or self.route.is_standalone(): return self.request.data raise NotImplementedError diff --git a/simpleapi/server/serializer.py b/simpleapi/server/serializer.py index da5ba7c..85411b9 100644 --- a/simpleapi/server/serializer.py +++ b/simpleapi/server/serializer.py @@ -5,7 +5,7 @@ from django.db.models.query import QuerySet from django.utils.encoding import smart_unicode, is_protected_type has_django = True -except Exception, e: +except Exception as e: has_django = False try: diff --git a/simpleapi/server/session.py b/simpleapi/server/session.py index b4445a2..ed10f62 100644 --- a/simpleapi/server/session.py +++ b/simpleapi/server/session.py @@ -9,6 +9,6 @@ def __init__(self): self._internal = SessionObj() def clear(self): - for key in self.__dict__.keys(): + for key in list(self.__dict__.keys()): if not key.startswith('_'): del self.__dict__[key] \ No newline at end of file diff --git a/tests/namespace.py b/tests/namespace.py index a50a4c5..ef0c68d 100644 --- a/tests/namespace.py +++ b/tests/namespace.py @@ -11,10 +11,18 @@ def setUp(self): class TestNamespace(Namespace): pass - self.namespace = TestNamespace() + class DummyRequest(object): + def __init__(self): + self.session = object() + + self.namespace = TestNamespace(DummyRequest()) + + def test_namespace_init(self): + self.assertIsNotNone(self.namespace.request) + self.assertIsNotNone(self.namespace.session) def tearDown(self): pass if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/route.py b/tests/route.py index 43e1216..6f4d0f8 100644 --- a/tests/route.py +++ b/tests/route.py @@ -7,25 +7,26 @@ import json except ImportError: import simplejson as json -import cPickle +import pickle try: import django except ImportError: - raise Exception, 'Django is required for running the tests' + django = None import os os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.settings' from simpleapi import * +@unittest.skipUnless(django is not None, "Django is required for legacy route tests") class RouteTest(unittest.TestCase): _value_simple = 5592.61 # in JSON key values of dict items must be of type string _value_complex = { - 'test1': u'test äöüß', + 'test1': 'test äöüß', 'test2': 592, 'test3': 1895.29596, 'test4': { @@ -36,7 +37,7 @@ class RouteTest(unittest.TestCase): '6': True, 'test7': False, 'test8': [True, 0, False, 1], - u'täst9': 9 + 'täst9': 9 } def setUp(self): @@ -125,11 +126,11 @@ class Request(object): for transporttype in transporttypes: # encode query parameters local_kwargs = copy.deepcopy(kwargs) - for key, value in local_kwargs.iteritems(): + for key, value in local_kwargs.items(): if transporttype == 'json': local_kwargs[key] = json.dumps(value) elif transporttype == 'pickle': - local_kwargs[key] = cPickle.dumps(value) + local_kwargs[key] = pickle.dumps(value) request.REQUEST.update(local_kwargs) @@ -142,14 +143,14 @@ class Request(object): if transporttype == 'json': response = json.loads(http_response.content) elif transporttype == 'pickle': - response = cPickle.loads(http_response.content) + response = pickle.loads(http_response.content) else: - self.fail(u'unknown transport type: %s' % transporttype) + self.fail('unknown transport type: %s' % transporttype) if not first_response: first_response = response else: - self.failUnlessEqual(response, first_response) + self.assertEqual(response, first_response) return ( response.get('success'), @@ -160,17 +161,17 @@ class Request(object): def test_published(self): # test: published-flag success, errors, result = self.call(self.route1, 'non_published') - self.failIf(success) + self.assertFalse(success) success, errors, result = self.call(self.route1, 'return_value', val=self._value_complex) - self.failUnless(success) - self.failUnlessEqual(result, self._value_complex) + self.assertTrue(success) + self.assertEqual(result, self._value_complex) def test_data(self): # test: _data success, errors, result = self.call(self.route1, 'power', _data={'a': 3, 'b': 10}) - self.failUnlessEqual(result, 59049) + self.assertEqual(result, 59049) def test_authentication(self): # test: __authentication__ @@ -181,8 +182,8 @@ def test_authentication(self): method='power', version='2' ) - self.failIf(success) - self.failUnlessEqual(u'Authentication failed.', errors[0]) + self.assertFalse(success) + self.assertEqual('Authentication failed.', errors[0]) success, errors, result = self.call( route=self.route2, @@ -192,7 +193,7 @@ def test_authentication(self): a=1, b=2 ) - self.failUnless(success) + self.assertTrue(success) # __authentication__ == lambda namespace, access_key: access_key == 'a' * 5 success, errors, result = self.call( @@ -200,8 +201,8 @@ def test_authentication(self): method='power', version='3' ) - self.failIf(success) - self.failUnlessEqual(u'Authentication failed.', errors[0]) + self.assertFalse(success) + self.assertEqual('Authentication failed.', errors[0]) success, errors, result = self.call( route=self.route2, @@ -211,7 +212,7 @@ def test_authentication(self): a=1, b=2 ) - self.failUnless(success) + self.assertTrue(success) def test_kwargs(self): # test: kwargs @@ -223,8 +224,8 @@ def test_kwargs(self): c=99, e=100 ) - self.failUnless(success) - self.failUnlessEqual(result, 211) + self.assertTrue(success) + self.assertEqual(result, 211) success, errors, result = self.call( route=self.route1, @@ -233,8 +234,8 @@ def test_kwargs(self): b=7, c=99 ) - self.failIf(success) - self.failUnlessEqual(u'Unused arguments: c', errors[0]) + self.assertFalse(success) + self.assertEqual('Unused arguments: c', errors[0]) def test_default_args(self): success, errors, result = self.call( @@ -245,8 +246,8 @@ def test_default_args(self): c='99', e='100' ) - self.failUnless(success) - self.failUnlessEqual(result, 199) + self.assertTrue(success) + self.assertEqual(result, 199) def test_constraints(self): # test: constraints[phone_number] = regular expression @@ -255,16 +256,16 @@ def test_constraints(self): method='call_phone', phone_number='0176123456' ) - self.failIf(success) - self.failUnlessEqual(u'Constraint failed for argument: phone_number', errors[0]) + self.assertFalse(success) + self.assertEqual('Constraint failed for argument: phone_number', errors[0]) success, errors, result = self.call( route=self.route1, method='call_phone', phone_number='+49 176 123456' ) - self.failUnless(success) - self.failUnlessEqual(result, True) + self.assertTrue(success) + self.assertEqual(result, True) # test: constraints = lambda namespace, key, value: int(value) success, errors, result = self.call( @@ -275,8 +276,8 @@ def test_constraints(self): c=99, e=100 ) - self.failIf(success) - self.failUnlessEqual(u'Constraint failed for argument: a', errors[0]) + self.assertFalse(success) + self.assertEqual('Constraint failed for argument: a', errors[0]) # test: type conversion success, errors, result = self.call( @@ -287,8 +288,8 @@ def test_constraints(self): c='99', e='1020' ) - self.failUnless(success) - self.failUnlessEqual(result, 1133) + self.assertTrue(success) + self.assertEqual(result, 1133) # test: constraints[a] = lambda value: float(value), b = int success, errors, result = self.call( @@ -297,8 +298,8 @@ def test_constraints(self): a='19.95', b='4' ) - self.failUnless(success) - self.failUnlessEqual(result, 158405.99000624998) + self.assertTrue(success) + self.assertEqual(result, 158405.99000624998) success, errors, result = self.call( route=self.route1, @@ -306,8 +307,8 @@ def test_constraints(self): a='19.95', b='4.5' ) - self.failIf(success) - self.failUnlessEqual(u'Constraint failed for argument: b', errors[0]) + self.assertFalse(success) + self.assertEqual('Constraint failed for argument: b', errors[0]) def test_versions(self): @@ -317,37 +318,37 @@ def test_versions(self): method='power', version='3' ) - self.failIf(success) - self.failUnless(u'Version 3 not found' in errors[0]) + self.assertFalse(success) + self.assertTrue('Version 3 not found' in errors[0]) success, errors, result = self.call( route=self.route2, method='get_version', version='1' ) - self.failUnless(success) - self.failUnlessEqual(result, 1) + self.assertTrue(success) + self.assertEqual(result, 1) success, errors, result = self.call( route=self.route2, method='get_version', version='4' ) - self.failUnless(success) - self.failUnlessEqual(result, 4) + self.assertTrue(success) + self.assertEqual(result, 4) success, errors, result = self.call( route=self.route2, method='get_version', version='default' ) - self.failUnless(success) - self.failUnlessEqual(result, 4) + self.assertTrue(success) + self.assertEqual(result, 4) # add new namespace with same version class TestNamespace(Namespace): __version__ = 4 - self.failUnlessRaises(AssertionError, lambda: self.route2.add_namespace(TestNamespace)) + self.assertRaises(AssertionError, lambda: self.route2.add_namespace(TestNamespace)) # add new namespace with new version class TestNamespace(Namespace): @@ -356,7 +357,7 @@ def get_version(self): return self.__version__ get_version.published = True - self.failUnlessEqual(self.route2.add_namespace(TestNamespace), 999) + self.assertEqual(self.route2.add_namespace(TestNamespace), 999) success, errors, result = self.call( route=self.route2, @@ -364,20 +365,20 @@ def get_version(self): version='default', transporttypes=['json',] ) - self.failUnless(success) - self.failUnlessEqual(result, 999) + self.assertTrue(success) + self.assertEqual(result, 999) # remove added namespace again - self.failUnless(self.route2.remove_namespace(999)) - self.failIf(self.route2.remove_namespace(999)) + self.assertTrue(self.route2.remove_namespace(999)) + self.assertFalse(self.route2.remove_namespace(999)) success, errors, result = self.call( route=self.route2, method='get_version', version='default' ) - self.failUnless(success) - self.failUnlessEqual(result, 4) + self.assertTrue(success) + self.assertEqual(result, 4) def test_pickle(self): # test: pickle @@ -388,8 +389,8 @@ def return_val(self, val): return val return_val.published = True self.route3 = Route(Test) - self.failUnlessRaises( - cPickle.UnpicklingError, + self.assertRaises( + pickle.UnpicklingError, lambda: self.call(route=self.route3, method='return_val') ) del self.route3 @@ -401,4 +402,4 @@ def tearDown(self): pass if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..1e4fd05 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- + +import json +import pickle +import unittest + +from simpleapi import Namespace, Route + + +class DummyRequest(object): + def __init__(self, data, method='POST', remote_addr='127.0.0.1'): + self.data = data + self.method = method + self.remote_addr = remote_addr + + +def decode_response(http_response): + return json.loads(http_response['result']) + + +class SecurityHardeningTest(unittest.TestCase): + def test_pickle_input_is_disabled_by_default(self): + class NS(Namespace): + __input__ = ['json', 'pickle'] + __output__ = ['json'] + + def echo(self, val): + return val + echo.published = True + + route = Route(NS, framework='dummy') + payload = pickle.dumps("owned") + request = DummyRequest({ + '_call': 'echo', + '_input': 'pickle', + '_output': 'json', + 'val': payload, + }) + response = decode_response(route(request)) + self.assertFalse(response['success']) + self.assertIn('Input formatter not allowed or unknown: pickle', response['errors'][0]) + + def test_pickle_input_requires_explicit_opt_in_and_auth(self): + class NS(Namespace): + __input__ = ['json', 'pickle'] + __output__ = ['json'] + __allow_unsafe_pickle_input__ = True + __authentication__ = 'secret' + + def echo(self, val): + return val + echo.published = True + + route = Route(NS, framework='dummy') + payload = pickle.dumps("ok") + request = DummyRequest({ + '_call': 'echo', + '_input': 'pickle', + '_output': 'json', + '_access_key': 'secret', + 'val': payload, + }) + response = decode_response(route(request)) + self.assertTrue(response['success']) + self.assertEqual(response['result'], "ok") + + def test_invalid_jsonp_callback_is_rejected(self): + class NS(Namespace): + def ping(self): + return {'ok': True} + ping.published = True + + route = Route(NS, framework='dummy') + request = DummyRequest({ + '_call': 'ping', + '_output': 'jsonp', + '_callback': 'x);alert(1)//', + }) + response = decode_response(route(request)) + self.assertFalse(response['success']) + self.assertIn('Invalid JSONP callback name.', response['errors'][0]) + + def test_custom_mimetype_override_is_blocked_by_default(self): + class NS(Namespace): + def ping(self): + return {'ok': True} + ping.published = True + + route = Route(NS, framework='dummy') + request = DummyRequest({ + '_call': 'ping', + '_output': 'json', + '_mimetype': 'text/plain', + }) + response = decode_response(route(request)) + self.assertFalse(response['success']) + self.assertIn('_mimetype override is disabled.', response['errors'][0]) + + def test_json_flow_still_works(self): + class NS(Namespace): + def add(self, a, b): + return a + b + add.published = True + add.constraints = {'a': int, 'b': int} + + route = Route(NS, framework='dummy') + request = DummyRequest({ + '_call': 'add', + '_input': 'json', + '_output': 'json', + 'a': json.dumps(2), + 'b': json.dumps(3), + }) + response = decode_response(route(request)) + self.assertTrue(response['success']) + self.assertEqual(response['result'], 5) + + +if __name__ == '__main__': + unittest.main()