From cca7559f80d0e32330e211a3d55ae6d573782f03 Mon Sep 17 00:00:00 2001 From: Sourcery AI Date: Fri, 14 Apr 2023 21:44:47 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- BeautifulSoup.py | 1138 ++++++++++++++---------------- CustomCookie.py | 47 +- appengine_utilities/cache.py | 37 +- appengine_utilities/cron.py | 129 ++-- appengine_utilities/event.py | 11 +- appengine_utilities/paginator.py | 20 +- appengine_utilities/rotmodel.py | 3 +- appengine_utilities/sessions.py | 89 +-- chardet/chardistribution.py | 26 +- chardet/charsetgroupprober.py | 11 +- chardet/escprober.py | 5 +- chardet/eucjpprober.py | 21 +- chardet/hebrewprober.py | 16 +- chardet/jpcntx.py | 28 +- chardet/latin1prober.py | 10 +- chardet/mbcharsetprober.py | 21 +- chardet/sbcharsetprober.py | 24 +- chardet/sjisprober.py | 21 +- chardet/utf8prober.py | 17 +- feedparser.py | 247 +++---- gae_utils.py | 64 +- jinja2/bccache.py | 10 +- jinja2/compiler.py | 156 ++-- jinja2/environment.py | 48 +- jinja2/exceptions.py | 9 +- jinja2/ext.py | 10 +- jinja2/filters.py | 23 +- jinja2/lexer.py | 33 +- jinja2/nodes.py | 40 +- jinja2/optimizer.py | 5 +- jinja2/parser.py | 29 +- jinja2/runtime.py | 12 +- jinja2/sandbox.py | 32 +- jinja2/utils.py | 48 +- jinja2/visitor.py | 2 +- readability/__init__.py | 2 +- readability/hn.py | 92 ++- urlgrabber/byterange.py | 17 +- urlgrabber/grabber.py | 83 +-- urlgrabber/keepalive.py | 39 +- urlgrabber/mirror.py | 13 +- urlgrabber/progress.py | 78 +- urlgrabber/sslfactory.py | 21 +- web/application.py | 51 +- web/browser.py | 10 +- web/contrib/template.py | 16 +- web/db.py | 127 ++-- web/debugerror.py | 17 +- web/form.py | 72 +- web/session.py | 20 +- web/template.py | 197 +++--- web/webapi.py | 46 +- web/webopenid.py | 10 +- web/wsgi.py | 25 +- web/wsgiserver/__init__.py | 144 ++-- 55 files changed, 1583 insertions(+), 1939 deletions(-) diff --git a/BeautifulSoup.py b/BeautifulSoup.py index 0e21463..0551f44 100755 --- a/BeautifulSoup.py +++ b/BeautifulSoup.py @@ -173,64 +173,61 @@ def _lastRecursiveChild(self): return lastChild def insert(self, position, newChild): - if (isinstance(newChild, basestring) - or isinstance(newChild, unicode)) \ - and not isinstance(newChild, NavigableString): - newChild = NavigableString(newChild) - - position = min(position, len(self.contents)) - if hasattr(newChild, 'parent') and newChild.parent != None: - # We're 'inserting' an element that's already one - # of this object's children. - if newChild.parent == self: - index = self.find(newChild) - if index and index < position: - # Furthermore we're moving it further down the - # list of this object's children. That means that - # when we extract this element, our target index - # will jump down one. - position = position - 1 - newChild.extract() - - newChild.parent = self - previousChild = None - if position == 0: - newChild.previousSibling = None - newChild.previous = self - else: - previousChild = self.contents[position-1] - newChild.previousSibling = previousChild - newChild.previousSibling.nextSibling = newChild - newChild.previous = previousChild._lastRecursiveChild() - if newChild.previous: - newChild.previous.next = newChild - - newChildsLastElement = newChild._lastRecursiveChild() - - if position >= len(self.contents): - newChild.nextSibling = None - - parent = self - parentsNextSibling = None - while not parentsNextSibling: - parentsNextSibling = parent.nextSibling - parent = parent.parent - if not parent: # This is the last element in the document. - break - if parentsNextSibling: - newChildsLastElement.next = parentsNextSibling - else: - newChildsLastElement.next = None - else: - nextChild = self.contents[position] - newChild.nextSibling = nextChild - if newChild.nextSibling: - newChild.nextSibling.previousSibling = newChild - newChildsLastElement.next = nextChild - - if newChildsLastElement.next: - newChildsLastElement.next.previous = newChildsLastElement - self.contents.insert(position, newChild) + if (isinstance( + newChild, + (basestring, unicode))) and not isinstance(newChild, NavigableString): + newChild = NavigableString(newChild) + + position = min(position, len(self.contents)) + if hasattr(newChild, 'parent') and newChild.parent != None: + # We're 'inserting' an element that's already one + # of this object's children. + if newChild.parent == self: + index = self.find(newChild) + if index and index < position: + # Furthermore we're moving it further down the + # list of this object's children. That means that + # when we extract this element, our target index + # will jump down one. + position = position - 1 + newChild.extract() + + newChild.parent = self + previousChild = None + if position == 0: + newChild.previousSibling = None + newChild.previous = self + else: + previousChild = self.contents[position-1] + newChild.previousSibling = previousChild + newChild.previousSibling.nextSibling = newChild + newChild.previous = previousChild._lastRecursiveChild() + if newChild.previous: + newChild.previous.next = newChild + + newChildsLastElement = newChild._lastRecursiveChild() + + if position >= len(self.contents): + newChild.nextSibling = None + + parent = self + parentsNextSibling = None + while not parentsNextSibling: + parentsNextSibling = parent.nextSibling + parent = parent.parent + if not parent: # This is the last element in the document. + break + newChildsLastElement.next = parentsNextSibling if parentsNextSibling else None + else: + nextChild = self.contents[position] + newChild.nextSibling = nextChild + if newChild.nextSibling: + newChild.nextSibling.previousSibling = newChild + newChildsLastElement.next = nextChild + + if newChildsLastElement.next: + newChildsLastElement.next.previous = newChildsLastElement + self.contents.insert(position, newChild) def append(self, tag): """Appends the given tag to the contents of this tag.""" @@ -290,15 +287,9 @@ def findPreviousSiblings(self, name=None, attrs={}, text=None, fetchPreviousSiblings = findPreviousSiblings # Compatibility with pre-3.x def findParent(self, name=None, attrs={}, **kwargs): - """Returns the closest parent of this Tag that matches the given + """Returns the closest parent of this Tag that matches the given criteria.""" - # NOTE: We can't use _findOne because findParents takes a different - # set of arguments. - r = None - l = self.findParents(name, attrs, 1) - if l: - r = l[0] - return r + return l[0] if (l := self.findParents(name, attrs, 1)) else None def findParents(self, name=None, attrs={}, limit=None, **kwargs): """Returns the parents of this Tag that match the given @@ -311,34 +302,29 @@ def findParents(self, name=None, attrs={}, limit=None, **kwargs): #These methods do the real heavy lifting. def _findOne(self, method, name, attrs, text, **kwargs): - r = None - l = method(name, attrs, text, 1, **kwargs) - if l: - r = l[0] - return r + return l[0] if (l := method(name, attrs, text, 1, **kwargs)) else None def _findAll(self, name, attrs, text, limit, generator, **kwargs): - "Iterates over a generator looking for things that match." - - if isinstance(name, SoupStrainer): - strainer = name - else: - # Build a SoupStrainer - strainer = SoupStrainer(name, attrs, text, **kwargs) - results = ResultSet(strainer) - g = generator() - while True: - try: - i = g.next() - except StopIteration: + "Iterates over a generator looking for things that match." + + if isinstance(name, SoupStrainer): + strainer = name + else: + # Build a SoupStrainer + strainer = SoupStrainer(name, attrs, text, **kwargs) + results = ResultSet(strainer) + g = generator() + while True: + try: + i = g.next() + except StopIteration: + break + if i: + if found := strainer.search(i): + results.append(found) + if limit and len(results) >= limit: break - if i: - found = strainer.search(i) - if found: - results.append(found) - if limit and len(results) >= limit: - break - return results + return results #These Generators can be used to navigate starting from both #NavigableStrings and Tags. @@ -378,22 +364,18 @@ def substituteEncoding(self, str, encoding=None): return str.replace("%SOUP-ENCODING%", encoding) def toEncoding(self, s, encoding=None): - """Encodes an object to a string in some encoding, or to Unicode. + """Encodes an object to a string in some encoding, or to Unicode. .""" - if isinstance(s, unicode): - if encoding: - s = s.encode(encoding) - elif isinstance(s, str): - if encoding: - s = s.encode(encoding) - else: - s = unicode(s) - else: - if encoding: - s = self.toEncoding(str(s), encoding) - else: - s = unicode(s) - return s + if (isinstance(s, unicode) and encoding + or not isinstance(s, unicode) and isinstance(s, str) and encoding): + s = s.encode(encoding) + elif isinstance(s, unicode): + pass + elif isinstance(s, str) or not encoding: + s = unicode(s) + else: + s = self.toEncoding(str(s), encoding) + return s class NavigableString(unicode, PageElement): @@ -413,53 +395,50 @@ def __getnewargs__(self): return (NavigableString.__str__(self),) def __getattr__(self, attr): - """text.string gives you text. This is for backwards + """text.string gives you text. This is for backwards compatibility for Navigable*String, but for CData* it lets you get the string without the CData wrapper.""" - if attr == 'string': - return self - else: - raise AttributeError, "'%s' object has no attribute '%s'" % (self.__class__.__name__, attr) + if attr == 'string': + return self + else: + raise ( + AttributeError, + f"'{self.__class__.__name__}' object has no attribute '{attr}'", + ) def __unicode__(self): return str(self).decode(DEFAULT_OUTPUT_ENCODING) def __str__(self, encoding=DEFAULT_OUTPUT_ENCODING): - if encoding: - return self.encode(encoding) - else: - return self + return self.encode(encoding) if encoding else self class CData(NavigableString): def __str__(self, encoding=DEFAULT_OUTPUT_ENCODING): - return "" % NavigableString.__str__(self, encoding) + return f"" class ProcessingInstruction(NavigableString): def __str__(self, encoding=DEFAULT_OUTPUT_ENCODING): - output = self - if "%SOUP-ENCODING%" in output: - output = self.substituteEncoding(output, encoding) - return "" % self.toEncoding(output, encoding) + output = self + if "%SOUP-ENCODING%" in output: + output = self.substituteEncoding(output, encoding) + return f"" class Comment(NavigableString): def __str__(self, encoding=DEFAULT_OUTPUT_ENCODING): - return "" % NavigableString.__str__(self, encoding) + return f"" class Declaration(NavigableString): def __str__(self, encoding=DEFAULT_OUTPUT_ENCODING): - return "" % NavigableString.__str__(self, encoding) + return f"" class Tag(PageElement): """Represents a found HTML tag with its attributes and contents.""" - def _invert(h): - "Cheap function to invert a hash." - i = {} - for k,v in h.items(): - i[v] = k - return i + def _invert(self): + "Cheap function to invert a hash." + return {v: k for k, v in self.items()} XML_ENTITIES_TO_SPECIAL_CHARS = { "apos" : "'", "quot" : '"', @@ -470,29 +449,29 @@ def _invert(h): XML_SPECIAL_CHARS_TO_ENTITIES = _invert(XML_ENTITIES_TO_SPECIAL_CHARS) def _convertEntities(self, match): - """Used in a call to re.sub to replace HTML, XML, and numeric + """Used in a call to re.sub to replace HTML, XML, and numeric entities with the appropriate Unicode characters. If HTML entities are being converted, any unrecognized entities are escaped.""" - x = match.group(1) - if self.convertHTMLEntities and x in name2codepoint: - return unichr(name2codepoint[x]) - elif x in self.XML_ENTITIES_TO_SPECIAL_CHARS: - if self.convertXMLEntities: - return self.XML_ENTITIES_TO_SPECIAL_CHARS[x] - else: - return u'&%s;' % x - elif len(x) > 0 and x[0] == '#': - # Handle numeric entities - if len(x) > 1 and x[1] == 'x': - return unichr(int(x[2:], 16)) - else: - return unichr(int(x[1:])) - - elif self.escapeUnrecognizedEntities: - return u'&%s;' % x + x = match.group(1) + if self.convertHTMLEntities and x in name2codepoint: + return unichr(name2codepoint[x]) + elif x in self.XML_ENTITIES_TO_SPECIAL_CHARS: + if self.convertXMLEntities: + return self.XML_ENTITIES_TO_SPECIAL_CHARS[x] else: - return u'&%s;' % x + return f'&{x};' + elif len(x) > 0 and x[0] == '#': + # Handle numeric entities + if len(x) > 1 and x[1] == 'x': + return unichr(int(x[2:], 16)) + else: + return unichr(int(x[1:])) + + elif self.escapeUnrecognizedEntities: + return f'&{x};' + else: + return f'&{x};' def __init__(self, parser, name, attrs=None, parent=None, previous=None): @@ -551,18 +530,18 @@ def __nonzero__(self): return True def __setitem__(self, key, value): - """Setting tag[key] sets the value of the 'key' attribute for the + """Setting tag[key] sets the value of the 'key' attribute for the tag.""" - self._getAttrMap() - self.attrMap[key] = value - found = False - for i in range(0, len(self.attrs)): - if self.attrs[i][0] == key: - self.attrs[i] = (key, value) - found = True - if not found: - self.attrs.append((key, value)) - self._getAttrMap()[key] = value + self._getAttrMap() + self.attrMap[key] = value + found = False + for i in range(len(self.attrs)): + if self.attrs[i][0] == key: + self.attrs[i] = (key, value) + found = True + if not found: + self.attrs.append((key, value)) + self._getAttrMap()[key] = value def __delitem__(self, key): "Deleting tag[key] deletes all 'key' attributes for the tag." @@ -582,25 +561,23 @@ def __call__(self, *args, **kwargs): return apply(self.findAll, args, kwargs) def __getattr__(self, tag): - #print "Getattr %s.%s" % (self.__class__, tag) - if len(tag) > 3 and tag.rfind('Tag') == len(tag)-3: - return self.find(tag[:-3]) - elif tag.find('__') != 0: - return self.find(tag) - raise AttributeError, "'%s' object has no attribute '%s'" % (self.__class__, tag) + #print "Getattr %s.%s" % (self.__class__, tag) + if len(tag) > 3 and tag.rfind('Tag') == len(tag)-3: + return self.find(tag[:-3]) + elif tag.find('__') != 0: + return self.find(tag) + raise (AttributeError, f"'{self.__class__}' object has no attribute '{tag}'") def __eq__(self, other): - """Returns true iff this tag has the same name, the same attributes, + """Returns true iff this tag has the same name, the same attributes, and the same contents (recursively) as the given tag. NOTE: right now this will return false if two tags have the same attributes in a different order. Should this be fixed?""" - if not hasattr(other, 'name') or not hasattr(other, 'attrs') or not hasattr(other, 'contents') or self.name != other.name or self.attrs != other.attrs or len(self) != len(other): - return False - for i in range(0, len(self.contents)): - if self.contents[i] != other.contents[i]: - return False - return True + if not hasattr(other, 'name') or not hasattr(other, 'attrs') or not hasattr(other, 'contents') or self.name != other.name or self.attrs != other.attrs or len(self) != len(other): + return False + return all(self.contents[i] == other.contents[i] + for i in range(len(self.contents))) def __ne__(self, other): """Returns true iff this tag is not identical to the other tag, @@ -619,104 +596,101 @@ def __unicode__(self): + ")") def _sub_entity(self, x): - """Used with a regular expression to substitute the + """Used with a regular expression to substitute the appropriate XML entity for an XML special character.""" - return "&" + self.XML_SPECIAL_CHARS_TO_ENTITIES[x.group(0)[0]] + ";" + return f"&{self.XML_SPECIAL_CHARS_TO_ENTITIES[x.group(0)[0]]};" def __str__(self, encoding=DEFAULT_OUTPUT_ENCODING, prettyPrint=False, indentLevel=0): - """Returns a string or Unicode representation of this tag and + """Returns a string or Unicode representation of this tag and its contents. To get Unicode, pass None for encoding. NOTE: since Python's HTML parser consumes whitespace, this method is not certain to reproduce the whitespace present in the original string.""" - encodedName = self.toEncoding(self.name, encoding) - - attrs = [] - if self.attrs: - for key, val in self.attrs: - fmt = '%s="%s"' - if isString(val): - if self.containsSubstitutions and '%SOUP-ENCODING%' in val: - val = self.substituteEncoding(val, encoding) - - # The attribute value either: - # - # * Contains no embedded double quotes or single quotes. - # No problem: we enclose it in double quotes. - # * Contains embedded single quotes. No problem: - # double quotes work here too. - # * Contains embedded double quotes. No problem: - # we enclose it in single quotes. - # * Embeds both single _and_ double quotes. This - # can't happen naturally, but it can happen if - # you modify an attribute value after parsing - # the document. Now we have a bit of a - # problem. We solve it by enclosing the - # attribute in single quotes, and escaping any - # embedded single quotes to XML entities. - if '"' in val: - fmt = "%s='%s'" - if "'" in val: - # TODO: replace with apos when - # appropriate. - val = val.replace("'", "&squot;") - - # Now we're okay w/r/t quotes. But the attribute - # value might also contain angle brackets, or - # ampersands that aren't part of entities. We need - # to escape those to XML entities too. - val = self.BARE_AMPERSAND_OR_BRACKET.sub(self._sub_entity, val) - - attrs.append(fmt % (self.toEncoding(key, encoding), - self.toEncoding(val, encoding))) - close = '' - closeTag = '' - if self.isSelfClosing: - close = ' /' - else: - closeTag = '' % encodedName - - indentTag, indentContents = 0, 0 - if prettyPrint: - indentTag = indentLevel - space = (' ' * (indentTag-1)) - indentContents = indentTag + 1 - contents = self.renderContents(encoding, prettyPrint, indentContents) - if self.hidden: - s = contents - else: - s = [] - attributeString = '' - if attrs: - attributeString = ' ' + ' '.join(attrs) - if prettyPrint: - s.append(space) - s.append('<%s%s%s>' % (encodedName, attributeString, close)) - if prettyPrint: - s.append("\n") - s.append(contents) - if prettyPrint and contents and contents[-1] != "\n": - s.append("\n") - if prettyPrint and closeTag: - s.append(space) - s.append(closeTag) - if prettyPrint and closeTag and self.nextSibling: - s.append("\n") - s = ''.join(s) - return s + encodedName = self.toEncoding(self.name, encoding) + + attrs = [] + if self.attrs: + for key, val in self.attrs: + fmt = '%s="%s"' + if isString(val): + if self.containsSubstitutions and '%SOUP-ENCODING%' in val: + val = self.substituteEncoding(val, encoding) + + # The attribute value either: + # + # * Contains no embedded double quotes or single quotes. + # No problem: we enclose it in double quotes. + # * Contains embedded single quotes. No problem: + # double quotes work here too. + # * Contains embedded double quotes. No problem: + # we enclose it in single quotes. + # * Embeds both single _and_ double quotes. This + # can't happen naturally, but it can happen if + # you modify an attribute value after parsing + # the document. Now we have a bit of a + # problem. We solve it by enclosing the + # attribute in single quotes, and escaping any + # embedded single quotes to XML entities. + if '"' in val: + fmt = "%s='%s'" + if "'" in val: + # TODO: replace with apos when + # appropriate. + val = val.replace("'", "&squot;") + + # Now we're okay w/r/t quotes. But the attribute + # value might also contain angle brackets, or + # ampersands that aren't part of entities. We need + # to escape those to XML entities too. + val = self.BARE_AMPERSAND_OR_BRACKET.sub(self._sub_entity, val) + + attrs.append(fmt % (self.toEncoding(key, encoding), + self.toEncoding(val, encoding))) + close = '' + closeTag = '' + if self.isSelfClosing: + close = ' /' + else: + closeTag = f'' + + indentTag, indentContents = 0, 0 + if prettyPrint: + indentTag = indentLevel + space = (' ' * (indentTag-1)) + indentContents = indentTag + 1 + contents = self.renderContents(encoding, prettyPrint, indentContents) + if self.hidden: + return contents + s = [] + attributeString = ' ' + ' '.join(attrs) if attrs else '' + if prettyPrint: + s.append(space) + s.append(f'<{encodedName}{attributeString}{close}>') + if prettyPrint: + s.append("\n") + s.append(contents) + if prettyPrint: + if contents and contents[-1] != "\n": + s.append("\n") + if closeTag: + s.append(space) + s.append(closeTag) + if prettyPrint and closeTag and self.nextSibling: + s.append("\n") + return ''.join(s) def decompose(self): - """Recursively destroys the contents of this tree.""" - contents = [i for i in self.contents] - for i in contents: - if isinstance(i, Tag): - i.decompose() - else: - i.extract() - self.extract() + """Recursively destroys the contents of this tree.""" + contents = list(self.contents) + for i in contents: + if isinstance(i, Tag): + i.decompose() + else: + i.extract() + self.extract() def prettify(self, encoding=DEFAULT_OUTPUT_ENCODING): return self.__str__(encoding, True) @@ -746,18 +720,15 @@ def renderContents(self, encoding=DEFAULT_OUTPUT_ENCODING, def find(self, name=None, attrs={}, recursive=True, text=None, **kwargs): - """Return only the first child of this Tag matching the given + """Return only the first child of this Tag matching the given criteria.""" - r = None - l = self.findAll(name, attrs, recursive, text, 1, **kwargs) - if l: - r = l[0] - return r + return (l[0] if (l := self.findAll(name, attrs, recursive, text, 1, ** + kwargs)) else None) findChild = find def findAll(self, name=None, attrs={}, recursive=True, text=None, limit=None, **kwargs): - """Extracts a list of Tag objects that match the given + """Extracts a list of Tag objects that match the given criteria. You can specify the name of the Tag and any attributes you want the Tag to have. @@ -766,10 +737,8 @@ def findAll(self, name=None, attrs={}, recursive=True, text=None, callable that takes a string and returns whether or not the string matches for some custom definition of 'matches'. The same is true of the tag name.""" - generator = self.recursiveChildGenerator - if not recursive: - generator = self.childGenerator - return self._findAll(name, attrs, text, limit, generator, **kwargs) + generator = self.recursiveChildGenerator if recursive else self.childGenerator + return self._findAll(name, attrs, text, limit, generator, **kwargs) findChildren = findAll # Pre-3.x compatibility methods @@ -785,19 +754,17 @@ def firstText(self, text=None, recursive=True): #Private methods def _getAttrMap(self): - """Initializes a map representation of this tag's attributes, + """Initializes a map representation of this tag's attributes, if not already initialized.""" - if not getattr(self, 'attrMap'): - self.attrMap = {} - for (key, value) in self.attrs: - self.attrMap[key] = value - return self.attrMap + if not getattr(self, 'attrMap'): + self.attrMap = dict(self.attrs) + return self.attrMap #Generator methods def childGenerator(self): - for i in range(0, len(self.contents)): - yield self.contents[i] - raise StopIteration + for i in range(len(self.contents)): + yield self.contents[i] + raise StopIteration def recursiveChildGenerator(self): stack = [(self, 0)] @@ -834,73 +801,61 @@ def __init__(self, name=None, attrs={}, text=None, **kwargs): self.text = text def __str__(self): - if self.text: - return self.text - else: - return "%s|%s" % (self.name, self.attrs) + return self.text if self.text else f"{self.name}|{self.attrs}" def searchTag(self, markupName=None, markupAttrs={}): - found = None - markup = None - if isinstance(markupName, Tag): - markup = markupName - markupAttrs = markup - callFunctionWithTagData = callable(self.name) \ - and not isinstance(markupName, Tag) - - if (not self.name) \ - or callFunctionWithTagData \ - or (markup and self._matches(markup, self.name)) \ - or (not markup and self._matches(markupName, self.name)): - if callFunctionWithTagData: - match = self.name(markupName, markupAttrs) - else: - match = True - markupAttrMap = None - for attr, matchAgainst in self.attrs.items(): - if not markupAttrMap: - if hasattr(markupAttrs, 'get'): - markupAttrMap = markupAttrs - else: - markupAttrMap = {} - for k,v in markupAttrs: - markupAttrMap[k] = v - attrValue = markupAttrMap.get(attr) - if not self._matches(attrValue, matchAgainst): - match = False - break - if match: - if markup: - found = markup - else: - found = markupName - return found + found = None + markup = None + if isinstance(markupName, Tag): + markup = markupName + markupAttrs = markup + callFunctionWithTagData = callable(self.name) \ + and not isinstance(markupName, Tag) + + if (not self.name) \ + or callFunctionWithTagData \ + or (markup and self._matches(markup, self.name)) \ + or (not markup and self._matches(markupName, self.name)): + if callFunctionWithTagData: + match = self.name(markupName, markupAttrs) + else: + match = True + markupAttrMap = None + for attr, matchAgainst in self.attrs.items(): + if not markupAttrMap: + if hasattr(markupAttrs, 'get'): + markupAttrMap = markupAttrs + else: + markupAttrMap = dict(markupAttrs) + attrValue = markupAttrMap.get(attr) + if not self._matches(attrValue, matchAgainst): + match = False + break + if match: + found = markup if markup else markupName + return found def search(self, markup): - #print 'looking for %s in %s' % (self, markup) - found = None + #print 'looking for %s in %s' % (self, markup) + found = None # If given a list of items, scan it for a text element that # matches. - if isList(markup) and not isinstance(markup, Tag): - for element in markup: - if isinstance(element, NavigableString) \ + if isList(markup) and not isinstance(markup, Tag): + for element in markup: + if isinstance(element, NavigableString) \ and self.search(element): - found = element - break - # If it's a Tag, make sure its name or attributes match. - # Don't bother with Tags if we're searching for text. - elif isinstance(markup, Tag): - if not self.text: - found = self.searchTag(markup) - # If it's text, make sure the text matches. - elif isinstance(markup, NavigableString) or \ - isString(markup): - if self._matches(markup, self.text): - found = markup - else: - raise Exception, "I don't know how to match against a %s" \ - % markup.__class__ - return found + found = element + break + elif isinstance(markup, Tag): + if not self.text: + found = self.searchTag(markup) + elif isinstance(markup, NavigableString) or \ + isString(markup): + if self._matches(markup, self.text): + found = markup + else: + raise (Exception, f"I don't know how to match against a {markup.__class__}") + return found def _matches(self, markup, matchAgainst): #print "Matching %s against %s" % (markup, matchAgainst) @@ -950,12 +905,12 @@ def isList(l): or (type(l) in (types.ListType, types.TupleType)) def isString(s): - """Convenience method that works with all 2.x versions of Python + """Convenience method that works with all 2.x versions of Python to determine whether or not something is stringlike.""" - try: - return isinstance(s, unicode) or isinstance(s, basestring) - except NameError: - return isinstance(s, str) + try: + return isinstance(s, (unicode, basestring)) + except NameError: + return isinstance(s, str) def buildTagMap(default, *args): """Turns a list of maps, lists, or scalars into a single map. @@ -1102,37 +1057,36 @@ def convert_charref(self, name): return self.convert_codepoint(n) def _feed(self, inDocumentEncoding=None, isHTML=False): - # Convert the document to Unicode. - markup = self.markup - if isinstance(markup, unicode): - if not hasattr(self, 'originalEncoding'): - self.originalEncoding = None - else: - dammit = UnicodeDammit\ - (markup, [self.fromEncoding, inDocumentEncoding], - smartQuotesTo=self.smartQuotesTo, isHTML=isHTML) - markup = dammit.unicode - self.originalEncoding = dammit.originalEncoding - self.declaredHTMLEncoding = dammit.declaredHTMLEncoding - if markup: - if self.markupMassage: - if not isList(self.markupMassage): - self.markupMassage = self.MARKUP_MASSAGE - for fix, m in self.markupMassage: - markup = fix.sub(m, markup) - # TODO: We get rid of markupMassage so that the - # soup object can be deepcopied later on. Some - # Python installations can't copy regexes. If anyone - # was relying on the existence of markupMassage, this - # might cause problems. - del(self.markupMassage) - self.reset() - - SGMLParser.feed(self, markup) - # Close out any unfinished strings and close all the open tags. - self.endData() - while self.currentTag.name != self.ROOT_TAG_NAME: - self.popTag() + # Convert the document to Unicode. + markup = self.markup + if isinstance(markup, unicode): + if not hasattr(self, 'originalEncoding'): + self.originalEncoding = None + else: + dammit = UnicodeDammit\ + (markup, [self.fromEncoding, inDocumentEncoding], + smartQuotesTo=self.smartQuotesTo, isHTML=isHTML) + markup = dammit.unicode + self.originalEncoding = dammit.originalEncoding + self.declaredHTMLEncoding = dammit.declaredHTMLEncoding + if markup and self.markupMassage: + if not isList(self.markupMassage): + self.markupMassage = self.MARKUP_MASSAGE + for fix, m in self.markupMassage: + markup = fix.sub(m, markup) + # TODO: We get rid of markupMassage so that the + # soup object can be deepcopied later on. Some + # Python installations can't copy regexes. If anyone + # was relying on the existence of markupMassage, this + # might cause problems. + del(self.markupMassage) + self.reset() + + SGMLParser.feed(self, markup) + # Close out any unfinished strings and close all the open tags. + self.endData() + while self.currentTag.name != self.ROOT_TAG_NAME: + self.popTag() def __getattr__(self, methodName): """This method routes method call requests to either the SGMLParser @@ -1185,53 +1139,52 @@ def pushTag(self, tag): self.currentTag = self.tagStack[-1] def endData(self, containerClass=NavigableString): - if self.currentData: - currentData = u''.join(self.currentData) - if (currentData.translate(self.STRIP_ASCII_SPACES) == '' and - not set([tag.name for tag in self.tagStack]).intersection( - self.PRESERVE_WHITESPACE_TAGS)): - if '\n' in currentData: - currentData = '\n' - else: - currentData = ' ' - self.currentData = [] - if self.parseOnlyThese and len(self.tagStack) <= 1 and \ - (not self.parseOnlyThese.text or \ - not self.parseOnlyThese.search(currentData)): - return - o = containerClass(currentData) - o.setup(self.currentTag, self.previous) - if self.previous: - self.previous.next = o - self.previous = o - self.currentTag.contents.append(o) + if not self.currentData: + return + currentData = u''.join(self.currentData) + if not currentData.translate(self.STRIP_ASCII_SPACES) and not { + tag.name + for tag in self.tagStack + }.intersection(self.PRESERVE_WHITESPACE_TAGS): + currentData = '\n' if '\n' in currentData else ' ' + self.currentData = [] + if self.parseOnlyThese and len(self.tagStack) <= 1 and \ + (not self.parseOnlyThese.text or \ + not self.parseOnlyThese.search(currentData)): + return + o = containerClass(currentData) + o.setup(self.currentTag, self.previous) + if self.previous: + self.previous.next = o + self.previous = o + self.currentTag.contents.append(o) def _popToTag(self, name, inclusivePop=True): - """Pops the tag stack up to and including the most recent + """Pops the tag stack up to and including the most recent instance of the given tag. If inclusivePop is false, pops the tag stack up to but *not* including the most recent instqance of the given tag.""" - #print "Popping to %s" % name - if name == self.ROOT_TAG_NAME: - return - - numPops = 0 - mostRecentTag = None - for i in range(len(self.tagStack)-1, 0, -1): - if name == self.tagStack[i].name: - numPops = len(self.tagStack)-i - break - if not inclusivePop: - numPops = numPops - 1 - - for i in range(0, numPops): - mostRecentTag = self.popTag() - return mostRecentTag + #print "Popping to %s" % name + if name == self.ROOT_TAG_NAME: + return + + mostRecentTag = None + numPops = next( + (len(self.tagStack) - i for i in range(len(self.tagStack) - 1, 0, -1) + if name == self.tagStack[i].name), + 0, + ) + if not inclusivePop: + numPops = numPops - 1 + + for _ in range(numPops): + mostRecentTag = self.popTag() + return mostRecentTag def _smartPop(self, name): - """We need to pop up to the previous tag of this type, unless + """We need to pop up to the previous tag of this type, unless one of this tag's nesting reset triggers comes between this tag and the previous tag of this type, OR unless this tag is a generic nesting trigger and another generic nesting trigger @@ -1247,33 +1200,32 @@ def _smartPop(self, name): ** should pop to 'tr', not the first 'td' """ - nestingResetTriggers = self.NESTABLE_TAGS.get(name) - isNestable = nestingResetTriggers != None - isResetNesting = self.RESET_NESTING_TAGS.has_key(name) - popTo = None - inclusive = True - for i in range(len(self.tagStack)-1, 0, -1): - p = self.tagStack[i] - if (not p or p.name == name) and not isNestable: - #Non-nestable tags get popped to the top or to their - #last occurance. - popTo = name - break - if (nestingResetTriggers != None - and p.name in nestingResetTriggers) \ - or (nestingResetTriggers == None and isResetNesting - and self.RESET_NESTING_TAGS.has_key(p.name)): - - #If we encounter one of the nesting reset triggers - #peculiar to this tag, or we encounter another tag - #that causes nesting to reset, pop up to but not - #including that tag. - popTo = p.name - inclusive = False - break - p = p.parent - if popTo: - self._popToTag(popTo, inclusive) + nestingResetTriggers = self.NESTABLE_TAGS.get(name) + isNestable = nestingResetTriggers != None + isResetNesting = self.RESET_NESTING_TAGS.has_key(name) + popTo = None + inclusive = True + for i in range(len(self.tagStack)-1, 0, -1): + p = self.tagStack[i] + if (not p or p.name == name) and not isNestable: + #Non-nestable tags get popped to the top or to their + #last occurance. + popTo = name + break + if ((nestingResetTriggers != None and p.name in nestingResetTriggers) + or nestingResetTriggers is None and isResetNesting + and self.RESET_NESTING_TAGS.has_key(p.name)): + + #If we encounter one of the nesting reset triggers + #peculiar to this tag, or we encounter another tag + #that causes nesting to reset, pop up to but not + #including that tag. + popTo = p.name + inclusive = False + break + p = p.parent + if popTo: + self._popToTag(popTo, inclusive) def unknown_starttag(self, name, attrs, selfClosing=0): #print "Start tag %s: %s" % (name, attrs) @@ -1307,16 +1259,16 @@ def unknown_starttag(self, name, attrs, selfClosing=0): def unknown_endtag(self, name): #print "End tag %s" % name - if self.quoteStack and self.quoteStack[-1] != name: + if self.quoteStack and self.quoteStack[-1] != name: #This is not a real end tag. #print " is not real!" % name - self.handle_data('' % name) - return - self.endData() - self._popToTag(name) - if self.quoteStack and self.quoteStack[-1] == name: - self.quoteStack.pop() - self.literal = (len(self.quoteStack) > 0) + self.handle_data(f'') + return + self.endData() + self._popToTag(name) + if self.quoteStack and self.quoteStack[-1] == name: + self.quoteStack.pop() + self.literal = (len(self.quoteStack) > 0) def handle_data(self, data): self.currentData.append(data) @@ -1341,29 +1293,26 @@ def handle_comment(self, text): self._toStringSubclass(text, Comment) def handle_charref(self, ref): - "Handle character references as data." - if self.convertEntities: - data = unichr(int(ref)) - else: - data = '&#%s;' % ref - self.handle_data(data) + "Handle character references as data." + data = unichr(int(ref)) if self.convertEntities else f'&#{ref};' + self.handle_data(data) def handle_entityref(self, ref): - """Handle entity references as data, possibly converting known + """Handle entity references as data, possibly converting known HTML and/or XML entity references to the corresponding Unicode characters.""" - data = None - if self.convertHTMLEntities: - try: - data = unichr(name2codepoint[ref]) - except KeyError: - pass - - if not data and self.convertXMLEntities: - data = self.XML_ENTITIES_TO_SPECIAL_CHARS.get(ref) - - if not data and self.convertHTMLEntities and \ - not self.XML_ENTITIES_TO_SPECIAL_CHARS.get(ref): + data = None + if self.convertHTMLEntities: + try: + data = unichr(name2codepoint[ref]) + except KeyError: + pass + + if not data and self.convertXMLEntities: + data = self.XML_ENTITIES_TO_SPECIAL_CHARS.get(ref) + + if not data and self.convertHTMLEntities and \ + not self.XML_ENTITIES_TO_SPECIAL_CHARS.get(ref): # TODO: We've got a problem here. We're told this is # an entity reference, but it's not an XML entity # reference or an HTML entity reference. Nonetheless, @@ -1380,16 +1329,16 @@ def handle_entityref(self, ref): # # The more common case is a misplaced ampersand, so I # escape the ampersand and omit the trailing semicolon. - data = "&%s" % ref - if not data: + data = f"&{ref}" + if not data: # This case is different from the one above, because we # haven't already gone through a supposedly comprehensive # mapping of entities to Unicode characters. We might not # have gone through any mapping at all. So the chances are # very high that this is a real entity, and not a # misplaced ampersand. - data = "&%s;" % ref - self.handle_data(data) + data = f"&{ref};" + self.handle_data(data) def handle_decl(self, data): "Handle DOCTYPEs and the like as Declaration objects." @@ -1522,51 +1471,51 @@ def __init__(self, *args, **kwargs): CHARSET_RE = re.compile("((^|;)\s*charset=)([^;]*)", re.M) def start_meta(self, attrs): - """Beautiful Soup can detect a charset included in a META tag, + """Beautiful Soup can detect a charset included in a META tag, try to convert the document to that charset, and re-parse the document from the beginning.""" - httpEquiv = None - contentType = None - contentTypeIndex = None - tagNeedsEncodingSubstitution = False - - for i in range(0, len(attrs)): - key, value = attrs[i] - key = key.lower() - if key == 'http-equiv': - httpEquiv = value - elif key == 'content': - contentType = value - contentTypeIndex = i - - if httpEquiv and contentType: # It's an interesting meta tag. - match = self.CHARSET_RE.search(contentType) - if match: - if (self.declaredHTMLEncoding is not None or + httpEquiv = None + contentType = None + contentTypeIndex = None + tagNeedsEncodingSubstitution = False + + for i in range(len(attrs)): + key, value = attrs[i] + key = key.lower() + if key == 'http-equiv': + httpEquiv = value + elif key == 'content': + contentType = value + contentTypeIndex = i + + if httpEquiv and contentType: # It's an interesting meta tag. + match = self.CHARSET_RE.search(contentType) + if match: + if (self.declaredHTMLEncoding is not None or self.originalEncoding == self.fromEncoding): # An HTML encoding was sniffed while converting # the document to Unicode, or an HTML encoding was # sniffed during a previous pass through the # document, or an encoding was specified # explicitly and it worked. Rewrite the meta tag. - def rewrite(match): - return match.group(1) + "%SOUP-ENCODING%" - newAttr = self.CHARSET_RE.sub(rewrite, contentType) - attrs[contentTypeIndex] = (attrs[contentTypeIndex][0], - newAttr) - tagNeedsEncodingSubstitution = True - else: - # This is our first pass through the document. - # Go through it again with the encoding information. - newCharset = match.group(3) - if newCharset and newCharset != self.originalEncoding: - self.declaredHTMLEncoding = newCharset - self._feed(self.declaredHTMLEncoding) - raise StopParsing - pass - tag = self.unknown_starttag("meta", attrs) - if tag and tagNeedsEncodingSubstitution: - tag.containsSubstitutions = True + def rewrite(match): + return f"{match.group(1)}%SOUP-ENCODING%" + + newAttr = self.CHARSET_RE.sub(rewrite, contentType) + attrs[contentTypeIndex] = (attrs[contentTypeIndex][0], + newAttr) + tagNeedsEncodingSubstitution = True + else: + # This is our first pass through the document. + # Go through it again with the encoding information. + newCharset = match.group(3) + if newCharset and newCharset != self.originalEncoding: + self.declaredHTMLEncoding = newCharset + self._feed(self.declaredHTMLEncoding) + raise StopParsing + tag = self.unknown_starttag("meta", attrs) + if tag and tagNeedsEncodingSubstitution: + tag.containsSubstitutions = True class StopParsing(Exception): pass @@ -1749,15 +1698,12 @@ def __init__(self, markup, overrideEncodings=[], if not u: self.originalEncoding = None def _subMSChar(self, orig): - """Changes a MS smart quote character to an XML or HTML + """Changes a MS smart quote character to an XML or HTML entity.""" - sub = self.MS_CHARS.get(orig) - if type(sub) == types.TupleType: - if self.smartQuotesTo == 'xml': - sub = '&#x%s;' % sub[1] - else: - sub = '&%s;' % sub[0] - return sub + sub = self.MS_CHARS.get(orig) + if type(sub) == types.TupleType: + sub = f'&#x{sub[1]};' if self.smartQuotesTo == 'xml' else f'&{sub[0]};' + return sub def _convertFrom(self, proposed): proposed = self.find_codec(proposed) @@ -1788,96 +1734,94 @@ def _convertFrom(self, proposed): return self.markup def _toUnicode(self, data, encoding): - '''Given a string and its encoding, decodes the string into Unicode. + '''Given a string and its encoding, decodes the string into Unicode. %encoding is a string recognized by encodings.aliases''' - # strip Byte Order Mark (if present) - if (len(data) >= 4) and (data[:2] == '\xfe\xff') \ - and (data[2:4] != '\x00\x00'): - encoding = 'utf-16be' - data = data[2:] - elif (len(data) >= 4) and (data[:2] == '\xff\xfe') \ + # strip Byte Order Mark (if present) + if (len(data) >= 4) and (data[:2] == '\xfe\xff') \ and (data[2:4] != '\x00\x00'): - encoding = 'utf-16le' - data = data[2:] - elif data[:3] == '\xef\xbb\xbf': - encoding = 'utf-8' - data = data[3:] - elif data[:4] == '\x00\x00\xfe\xff': - encoding = 'utf-32be' - data = data[4:] - elif data[:4] == '\xff\xfe\x00\x00': - encoding = 'utf-32le' - data = data[4:] - newdata = unicode(data, encoding) - return newdata + encoding = 'utf-16be' + data = data[2:] + elif (len(data) >= 4) and (data[:2] == '\xff\xfe') \ + and (data[2:4] != '\x00\x00'): + encoding = 'utf-16le' + data = data[2:] + elif data[:3] == '\xef\xbb\xbf': + encoding = 'utf-8' + data = data[3:] + elif data[:4] == '\x00\x00\xfe\xff': + encoding = 'utf-32be' + data = data[4:] + elif data[:4] == '\xff\xfe\x00\x00': + encoding = 'utf-32le' + data = data[4:] + return unicode(data, encoding) def _detectEncoding(self, xml_data, isHTML=False): - """Given a document, tries to detect its XML encoding.""" - xml_encoding = sniffed_xml_encoding = None - try: - if xml_data[:4] == '\x4c\x6f\xa7\x94': - # EBCDIC - xml_data = self._ebcdic_to_ascii(xml_data) - elif xml_data[:4] == '\x00\x3c\x00\x3f': - # UTF-16BE - sniffed_xml_encoding = 'utf-16be' - xml_data = unicode(xml_data, 'utf-16be').encode('utf-8') - elif (len(xml_data) >= 4) and (xml_data[:2] == '\xfe\xff') \ + """Given a document, tries to detect its XML encoding.""" + xml_encoding = sniffed_xml_encoding = None + try: + if xml_data[:4] == '\x4c\x6f\xa7\x94': + # EBCDIC + xml_data = self._ebcdic_to_ascii(xml_data) + elif xml_data[:4] == '\x00\x3c\x00\x3f': + # UTF-16BE + sniffed_xml_encoding = 'utf-16be' + xml_data = unicode(xml_data, 'utf-16be').encode('utf-8') + elif (len(xml_data) >= 4) and (xml_data[:2] == '\xfe\xff') \ and (xml_data[2:4] != '\x00\x00'): - # UTF-16BE with BOM - sniffed_xml_encoding = 'utf-16be' - xml_data = unicode(xml_data[2:], 'utf-16be').encode('utf-8') - elif xml_data[:4] == '\x3c\x00\x3f\x00': - # UTF-16LE - sniffed_xml_encoding = 'utf-16le' - xml_data = unicode(xml_data, 'utf-16le').encode('utf-8') - elif (len(xml_data) >= 4) and (xml_data[:2] == '\xff\xfe') and \ + # UTF-16BE with BOM + sniffed_xml_encoding = 'utf-16be' + xml_data = unicode(xml_data[2:], 'utf-16be').encode('utf-8') + elif xml_data[:4] == '\x3c\x00\x3f\x00': + # UTF-16LE + sniffed_xml_encoding = 'utf-16le' + xml_data = unicode(xml_data, 'utf-16le').encode('utf-8') + elif (len(xml_data) >= 4) and (xml_data[:2] == '\xff\xfe') and \ (xml_data[2:4] != '\x00\x00'): - # UTF-16LE with BOM - sniffed_xml_encoding = 'utf-16le' - xml_data = unicode(xml_data[2:], 'utf-16le').encode('utf-8') - elif xml_data[:4] == '\x00\x00\x00\x3c': - # UTF-32BE - sniffed_xml_encoding = 'utf-32be' - xml_data = unicode(xml_data, 'utf-32be').encode('utf-8') - elif xml_data[:4] == '\x3c\x00\x00\x00': - # UTF-32LE - sniffed_xml_encoding = 'utf-32le' - xml_data = unicode(xml_data, 'utf-32le').encode('utf-8') - elif xml_data[:4] == '\x00\x00\xfe\xff': - # UTF-32BE with BOM - sniffed_xml_encoding = 'utf-32be' - xml_data = unicode(xml_data[4:], 'utf-32be').encode('utf-8') - elif xml_data[:4] == '\xff\xfe\x00\x00': - # UTF-32LE with BOM - sniffed_xml_encoding = 'utf-32le' - xml_data = unicode(xml_data[4:], 'utf-32le').encode('utf-8') - elif xml_data[:3] == '\xef\xbb\xbf': - # UTF-8 with BOM - sniffed_xml_encoding = 'utf-8' - xml_data = unicode(xml_data[3:], 'utf-8').encode('utf-8') - else: - sniffed_xml_encoding = 'ascii' - pass - except: - xml_encoding_match = None - xml_encoding_match = re.compile( - '^<\?.*encoding=[\'"](.*?)[\'"].*\?>').match(xml_data) - if not xml_encoding_match and isHTML: - regexp = re.compile('<\s*meta[^>]+charset=([^>]*?)[;\'">]', re.I) - xml_encoding_match = regexp.search(xml_data) - if xml_encoding_match is not None: - xml_encoding = xml_encoding_match.groups()[0].lower() - if isHTML: - self.declaredHTMLEncoding = xml_encoding - if sniffed_xml_encoding and \ - (xml_encoding in ('iso-10646-ucs-2', 'ucs-2', 'csunicode', - 'iso-10646-ucs-4', 'ucs-4', 'csucs4', - 'utf-16', 'utf-32', 'utf_16', 'utf_32', - 'utf16', 'u16')): - xml_encoding = sniffed_xml_encoding - return xml_data, xml_encoding, sniffed_xml_encoding + # UTF-16LE with BOM + sniffed_xml_encoding = 'utf-16le' + xml_data = unicode(xml_data[2:], 'utf-16le').encode('utf-8') + elif xml_data[:4] == '\x00\x00\x00\x3c': + # UTF-32BE + sniffed_xml_encoding = 'utf-32be' + xml_data = unicode(xml_data, 'utf-32be').encode('utf-8') + elif xml_data[:4] == '\x3c\x00\x00\x00': + # UTF-32LE + sniffed_xml_encoding = 'utf-32le' + xml_data = unicode(xml_data, 'utf-32le').encode('utf-8') + elif xml_data[:4] == '\x00\x00\xfe\xff': + # UTF-32BE with BOM + sniffed_xml_encoding = 'utf-32be' + xml_data = unicode(xml_data[4:], 'utf-32be').encode('utf-8') + elif xml_data[:4] == '\xff\xfe\x00\x00': + # UTF-32LE with BOM + sniffed_xml_encoding = 'utf-32le' + xml_data = unicode(xml_data[4:], 'utf-32le').encode('utf-8') + elif xml_data[:3] == '\xef\xbb\xbf': + # UTF-8 with BOM + sniffed_xml_encoding = 'utf-8' + xml_data = unicode(xml_data[3:], 'utf-8').encode('utf-8') + else: + sniffed_xml_encoding = 'ascii' + except: + xml_encoding_match = None + xml_encoding_match = re.compile( + '^<\?.*encoding=[\'"](.*?)[\'"].*\?>').match(xml_data) + if not xml_encoding_match and isHTML: + regexp = re.compile('<\s*meta[^>]+charset=([^>]*?)[;\'">]', re.I) + xml_encoding_match = regexp.search(xml_data) + if xml_encoding_match is not None: + xml_encoding = xml_encoding_match.groups()[0].lower() + if isHTML: + self.declaredHTMLEncoding = xml_encoding + if sniffed_xml_encoding and \ + (xml_encoding in ('iso-10646-ucs-2', 'ucs-2', 'csunicode', + 'iso-10646-ucs-4', 'ucs-4', 'csucs4', + 'utf-16', 'utf-32', 'utf_16', 'utf_32', + 'utf16', 'u16')): + xml_encoding = sniffed_xml_encoding + return xml_data, xml_encoding, sniffed_xml_encoding def find_codec(self, charset): diff --git a/CustomCookie.py b/CustomCookie.py index a71027c..c1b11ea 100755 --- a/CustomCookie.py +++ b/CustomCookie.py @@ -317,10 +317,10 @@ def _quote(str, LegalChars=_LegalChars, # the string in doublequotes and precede quote (with a \) # special characters. # - if "" == translate(str, idmap, LegalChars): + if translate(str, idmap, LegalChars) == "": return str else: - return '"' + _nulljoin( map(_Translator.get, str, str) ) + '"' + return f'"{_nulljoin(map(_Translator.get, str, str))}"' # end _quote @@ -359,12 +359,10 @@ def _unquote(str): if Omatch: j = Omatch.start(0) if Qmatch: k = Qmatch.start(0) if Qmatch and ( not Omatch or k < j ): # QuotePatt matched - res.append(str[i:k]) - res.append(str[k+1]) + res.extend((str[i:k], str[k+1])) i = k+2 - else: # OctalPatt matched - res.append(str[i:j]) - res.append( chr( int(str[j+1:j+4], 8) ) ) + else: # OctalPatt matched + res.extend((str[i:j], chr( int(str[j+1:j+4], 8) ))) i = j+4 return _nulljoin(res) # end _unquote @@ -437,8 +435,8 @@ def __init__(self): def __setitem__(self, K, V): K = K.lower() - if not K in self._reserved: - raise CookieError("Invalid Attribute %s" % K) + if K not in self._reserved: + raise CookieError(f"Invalid Attribute {K}") dict.__setitem__(self, K, V) # end __setitem__ @@ -452,9 +450,9 @@ def set(self, key, val, coded_val, # First we verify that the key isn't a reserved word # Second we make sure it only contains legal characters if key.lower() in self._reserved: - raise CookieError("Attempt to set a reserved key: %s" % key) - if "" != translate(key, idmap, LegalChars): - raise CookieError("Illegal key value: %s" % key) + raise CookieError(f"Attempt to set a reserved key: {key}") + if translate(key, idmap, LegalChars) != "": + raise CookieError(f"Illegal key value: {key}") # It's a good key, so save it. self.key = key @@ -463,13 +461,12 @@ def set(self, key, val, coded_val, # end set def output(self, attrs=None, header = "Set-Cookie:"): - return "%s %s" % ( header, self.OutputString(attrs) ) + return f"{header} {self.OutputString(attrs)}" __str__ = output def __repr__(self): - return '<%s: %s=%s>' % (self.__class__.__name__, - self.key, repr(self.value) ) + return f'<{self.__class__.__name__}: {self.key}={repr(self.value)}>' def js_output(self, attrs=None): # Print javascript @@ -489,7 +486,7 @@ def OutputString(self, attrs=None): RA = result.append # First, the key=value pair - RA("%s=%s" % (self.key, self.coded_value)) + RA(f"{self.key}={self.coded_value}") # Now add any defined attributes if attrs is None: @@ -500,7 +497,7 @@ def OutputString(self, attrs=None): if V == "": continue if K not in attrs: continue if K == "expires" and type(V) == type(1): - RA("%s=%s" % (self._reserved[K], _getdate(V))) + RA(f"{self._reserved[K]}={_getdate(V)}") elif K == "max-age" and type(V) == type(1): RA("%s=%d" % (self._reserved[K], V)) elif K == "secure": @@ -508,7 +505,7 @@ def OutputString(self, attrs=None): elif K == "httponly": RA(str(self._reserved[K])) else: - RA("%s=%s" % (self._reserved[K], V)) + RA(f"{self._reserved[K]}={V}") # Return the result return _semispacejoin(result) @@ -589,31 +586,25 @@ def __setitem__(self, key, value): def output(self, attrs=None, header="Set-Cookie:", sep="\015\012"): """Return a string suitable for HTTP.""" - result = [] items = self.items() items.sort() - for K,V in items: - result.append( V.output(attrs, header) ) + result = [V.output(attrs, header) for K, V in items] return sep.join(result) # end output __str__ = output def __repr__(self): - L = [] items = self.items() items.sort() - for K,V in items: - L.append( '%s=%s' % (K,repr(V.value) ) ) - return '<%s: %s>' % (self.__class__.__name__, _spacejoin(L)) + L = [f'{K}={repr(V.value)}' for K, V in items] + return f'<{self.__class__.__name__}: {_spacejoin(L)}>' def js_output(self, attrs=None): """Return a string suitable for JavaScript.""" - result = [] items = self.items() items.sort() - for K,V in items: - result.append( V.js_output(attrs) ) + result = [V.js_output(attrs) for K, V in items] return _nulljoin(result) # end js_output diff --git a/appengine_utilities/cache.py b/appengine_utilities/cache.py index f8ec04f..6a17d9b 100644 --- a/appengine_utilities/cache.py +++ b/appengine_utilities/cache.py @@ -96,20 +96,20 @@ def _clean_cache(self): # result.delete() def _validate_key(self, key): - if key == None: + if key is None: raise KeyError def _validate_value(self, value): - if value == None: + if value is None: raise ValueError def _validate_timeout(self, timeout): - if timeout == None: + if timeout is None: timeout = datetime.datetime.now() +\ - datetime.timedelta(seconds=DEFAULT_TIMEOUT) + datetime.timedelta(seconds=DEFAULT_TIMEOUT) if type(timeout) == type(1): timeout = datetime.datetime.now() + \ - datetime.timedelta(seconds = timeout) + datetime.timedelta(seconds = timeout) if type(timeout) != datetime.datetime: raise TypeError if timeout < datetime.datetime.now(): @@ -145,7 +145,7 @@ def add(self, key = None, value = None, timeout = None): pass memcache_timeout = timeout - datetime.datetime.now() - memcache.set('cache-'+key, value, int(memcache_timeout.seconds)) + memcache.set(f'cache-{key}', value, int(memcache_timeout.seconds)) if 'AEU_Events' in __main__.__dict__: __main__.AEU_Events.fire_event('cacheAdded') @@ -172,7 +172,7 @@ def set(self, key = None, value = None, timeout = None): pass memcache_timeout = timeout - datetime.datetime.now() - memcache.set('cache-'+key, value, int(memcache_timeout.seconds)) + memcache.set(f'cache-{key}', value, int(memcache_timeout.seconds)) if 'AEU_Events' in __main__.__dict__: __main__.AEU_Events.fire_event('cacheSet') @@ -189,22 +189,14 @@ def _read(self, key = None): query.filter('cachekey', key) query.filter('timeout > ', datetime.datetime.now()) results = query.fetch(1) - if len(results) is 0: - return None - return results[0] - - if 'AEU_Events' in __main__.__dict__: - __main__.AEU_Events.fire_event('cacheReadFromDatastore') - if 'AEU_Events' in __main__.__dict__: - __main__.AEU_Events.fire_event('cacheRead') + return None if len(results) is 0 else results[0] def delete(self, key = None): """ Deletes a cache object determined by the key. """ - memcache.delete('cache-'+key) - result = self._read(key) - if result: + memcache.delete(f'cache-{key}') + if result := self._read(key): if 'AEU_Events' in __main__.__dict__: __main__.AEU_Events.fire_event('cacheDeleted') result.delete() @@ -213,19 +205,16 @@ def get(self, key): """ get is used to return the cache value associated with the key passed. """ - mc = memcache.get('cache-'+key) - if mc: + if mc := memcache.get(f'cache-{key}'): if 'AEU_Events' in __main__.__dict__: __main__.AEU_Events.fire_event('cacheReadFromMemcache') if 'AEU_Events' in __main__.__dict__: __main__.AEU_Events.fire_event('cacheRead') return mc - result = self._read(key) - if result: + if result := self._read(key): timeout = result.timeout - datetime.datetime.now() # print timeout.seconds - memcache.set('cache-'+key, pickle.loads(result.value), - int(timeout.seconds)) + memcache.set(f'cache-{key}', pickle.loads(result.value), int(timeout.seconds)) return pickle.loads(result.value) else: raise KeyError diff --git a/appengine_utilities/cron.py b/appengine_utilities/cron.py index a32835b..b6be71a 100644 --- a/appengine_utilities/cron.py +++ b/appengine_utilities/cron.py @@ -76,8 +76,8 @@ def __init__(self): one_second = datetime.timedelta(seconds = 1) before = datetime.datetime.now() for r in results: - if re.search(':' + APPLICATION_PORT, r.url): - r.url = re.sub(':' + APPLICATION_PORT, ':' + CRON_PORT, r.url) + if re.search(f':{APPLICATION_PORT}', r.url): + r.url = re.sub(f':{APPLICATION_PORT}', f':{CRON_PORT}', r.url) #result = urlfetch.fetch(r.url) diff = datetime.datetime.now() - before if int(diff.seconds) < 1: @@ -143,42 +143,43 @@ def _validate_type(self, v, t): All can * which will then return the range for that entire type. """ - if t == "dow": - if v >= 0 and v <= 7: - return [v] - elif v == "*": - return "*" - else: - raise ValueError, "Invalid day of week." - elif t == "mon": - if v >= 1 and v <= 12: - return [v] - elif v == "*": - return range(1, 12) - else: - raise ValueError, "Invalid month." - elif t == "day": + if t == "day": if v >= 1 and v <= 31: return [v] elif v == "*": return range(1, 31) else: raise ValueError, "Invalid day." + elif t == "dow": + if v >= 0 and v <= 7: + return [v] + elif v == "*": + return "*" + else: + raise ValueError, "Invalid day of week." elif t == "hour": if v >= 0 and v <= 23: return [v] elif v == "*": - return range(0, 23) + return range(23) else: raise ValueError, "Invalid hour." elif t == "min": if v >= 0 and v <= 59: return [v] elif v == "*": - return range(0, 59) + return range(59) else: raise ValueError, "Invalid minute." + elif t == "mon": + if v >= 1 and v <= 12: + return [v] + elif v == "*": + return range(1, 12) + else: + raise ValueError, "Invalid month." + def _validate_list(self, l, t): """ Validates a crontab list. Lists are numerical values seperated @@ -218,7 +219,7 @@ def _validate_range(self, r, t): elements = r.split('-') # a range should be 2 elements if len(elements) is not 2: - raise ValueError, "Invalid range passed: " + str(r) + raise (ValueError, f"Invalid range passed: {str(r)}") # validate the minimum and maximum are valid for the type for e in elements: self._validate_type(int(e), t) @@ -243,16 +244,15 @@ def _validate_step(self, s, t): elements = s.split('/') # a range should be 2 elements if len(elements) is not 2: - raise ValueError, "Invalid step passed: " + str(s) + raise (ValueError, f"Invalid step passed: {str(s)}") try: step = int(elements[1]) except: - raise ValueError, "Invalid step provided " + str(s) + raise (ValueError, f"Invalid step provided {str(s)}") r_list = [] # if the first element is *, use all valid numbers if elements[0] is "*" or elements[0] is "": r_list.extend(self._validate_type('*', t)) - # check and see if there is a list of ranges elif "," in elements[0]: ranges = elements[0].split(",") for r in ranges: @@ -263,7 +263,7 @@ def _validate_step(self, s, t): try: r_list.extend(int(r)) except: - raise ValueError, "Invalid step provided " + str(s) + raise (ValueError, f"Invalid step provided {str(s)}") elif "-" in elements[0]: r_list.extend(self._validate_range(elements[0], t)) return range(r_list[0], r_list[-1] + 1, step) @@ -288,10 +288,6 @@ def _validate_dow(self, dow): if dow in days: dow = days[dow] return [dow] - # if dow is * return it. This is for date parsing where * does not mean - # every day for crontab entries. - elif dow is "*": - return dow elif "/" in dow: return(self._validate_step(dow, "dow")) elif "," in dow: @@ -299,9 +295,9 @@ def _validate_dow(self, dow): elif "-" in dow: return(self._validate_range(dow, "dow")) else: - valid_numbers = range(0, 8) - if not int(dow) in valid_numbers: - raise ValueError, "Invalid day of week " + str(dow) + valid_numbers = range(8) + if int(dow) not in valid_numbers: + raise (ValueError, f"Invalid day of week {str(dow)}") else: return [int(dow)] @@ -333,8 +329,8 @@ def _validate_mon(self, mon): return(self._validate_range(mon, "mon")) else: valid_numbers = range(1, 13) - if not int(mon) in valid_numbers: - raise ValueError, "Invalid month " + str(mon) + if int(mon) not in valid_numbers: + raise (ValueError, f"Invalid month {str(mon)}") else: return [int(mon)] @@ -349,14 +345,14 @@ def _validate_day(self, day): return(self._validate_range(day, "day")) else: valid_numbers = range(1, 31) - if not int(day) in valid_numbers: - raise ValueError, "Invalid day " + str(day) + if int(day) not in valid_numbers: + raise (ValueError, f"Invalid day {str(day)}") else: return [int(day)] def _validate_hour(self, hour): if hour is "*": - return range(0, 24) + return range(24) elif "/" in hour: return(self._validate_step(hour, "hour")) elif "," in hour: @@ -364,15 +360,15 @@ def _validate_hour(self, hour): elif "-" in hour: return(self._validate_range(hour, "hour")) else: - valid_numbers = range(0, 23) - if not int(hour) in valid_numbers: - raise ValueError, "Invalid hour " + str(hour) + valid_numbers = range(23) + if int(hour) not in valid_numbers: + raise (ValueError, f"Invalid hour {str(hour)}") else: return [int(hour)] def _validate_min(self, min): if min is "*": - return range(0, 60) + return range(60) elif "/" in min: return(self._validate_step(min, "min")) elif "," in min: @@ -380,9 +376,9 @@ def _validate_min(self, min): elif "-" in min: return(self._validate_range(min, "min")) else: - valid_numbers = range(0, 59) - if not int(min) in valid_numbers: - raise ValueError, "Invalid min " + str(min) + valid_numbers = range(59) + if int(min) not in valid_numbers: + raise (ValueError, f"Invalid min {str(min)}") else: return [int(min)] @@ -390,7 +386,7 @@ def _validate_url(self, url): # kludge for issue 842, right now we use request headers # to set the host. if url[0] is not "/": - url = "/" + url + url = f"/{url}" url = 'http://' + str(os.environ['HTTP_HOST']) + url return url # content below is for when that issue gets fixed @@ -404,43 +400,38 @@ def _calc_month(self, next_run, cron): while True: if cron["mon"][-1] < next_run.month: next_run = next_run.replace(year=next_run.year+1, \ - month=cron["mon"][0], \ - day=1,hour=0,minute=0) + month=cron["mon"][0], \ + day=1,hour=0,minute=0) else: if next_run.month in cron["mon"]: return next_run - else: - one_month = datetime.timedelta(months=1) - next_run = next_run + one_month + one_month = datetime.timedelta(months=1) + next_run = next_run + one_month def _calc_day(self, next_run, cron): + # convert any integers to lists in order to easily compare values + m = next_run.month # start with dow as per cron if dow and day are set # then dow is used if it comes before day. If dow # is *, then ignore it. - if str(cron["dow"]) != str("*"): - # convert any integers to lists in order to easily compare values - m = next_run.month + if str(cron["dow"]) != "*": while True: if next_run.month is not m: next_run = next_run.replace(hour=0, minute=0) next_run = self._calc_month(next_run, cron) if next_run.weekday() in cron["dow"] or next_run.day in cron["day"]: return next_run - else: - one_day = datetime.timedelta(days=1) - next_run = next_run + one_day + one_day = datetime.timedelta(days=1) + next_run = next_run + one_day else: - m = next_run.month while True: if next_run.month is not m: next_run = next_run.replace(hour=0, minute=0) next_run = self._calc_month(next_run, cron) - # if cron["dow"] is next_run.weekday() or cron["day"] is next_run.day: if next_run.day in cron["day"]: return next_run - else: - one_day = datetime.timedelta(days=1) - next_run = next_run + one_day + one_day = datetime.timedelta(days=1) + next_run = next_run + one_day def _calc_hour(self, next_run, cron): m = next_run.month @@ -454,11 +445,10 @@ def _calc_hour(self, next_run, cron): next_run = self._calc_day(next_run, cron) if next_run.hour in cron["hour"]: return next_run - else: - m = next_run.month - d = next_run.day - one_hour = datetime.timedelta(hours=1) - next_run = next_run + one_hour + m = next_run.month + d = next_run.day + one_hour = datetime.timedelta(hours=1) + next_run = next_run + one_hour def _calc_minute(self, next_run, cron): one_minute = datetime.timedelta(minutes=1) @@ -478,11 +468,10 @@ def _calc_minute(self, next_run, cron): next_run = self._calc_day(next_run, cron) if next_run.minute in cron["min"]: return next_run - else: - m = next_run.month - d = next_run.day - h = next_run.hour - next_run = next_run + one_minute + m = next_run.month + d = next_run.day + h = next_run.hour + next_run = next_run + one_minute def _get_next_run(self, cron): one_minute = datetime.timedelta(minutes=1) diff --git a/appengine_utilities/event.py b/appengine_utilities/event.py index 51d1424..7529832 100644 --- a/appengine_utilities/event.py +++ b/appengine_utilities/event.py @@ -41,10 +41,13 @@ def subscribe(self, event, callback, args = None): """ This method will subscribe a callback function to an event name. """ - if not {"event": event, "callback": callback, "args": args, } \ - in self.events: + if { + "event": event, + "callback": callback, + "args": args, + } not in self.events: self.events.append({"event": event, "callback": callback, \ - "args": args, }) + "args": args, }) def unsubscribe(self, event, callback, args = None): """ @@ -66,7 +69,7 @@ def fire_event(self, event = None): e["callback"](*e["args"]) elif type(e["args"]) == type({}): e["callback"](**e["args"]) - elif e["args"] == None: + elif e["args"] is None: e["callback"]() else: e["callback"](e["args"]) diff --git a/appengine_utilities/paginator.py b/appengine_utilities/paginator.py index 090ef3b..33d9810 100644 --- a/appengine_utilities/paginator.py +++ b/appengine_utilities/paginator.py @@ -35,7 +35,7 @@ class Paginator(object): @classmethod def get(cls, count=10, q_filters={}, search=None, start=None, model=None, \ - order='ASC', order_by='__key__'): + order='ASC', order_by='__key__'): """ get queries the database on model, starting with key, ordered by order. It receives count + 1 items, returning count and setting a @@ -62,18 +62,18 @@ def get(cls, count=10, q_filters={}, search=None, start=None, model=None, \ """ # argument validation - if model == None: + if model is None: raise ValueError('You must pass a model to query') # a valid model object will have a gql method. - if callable(model.gql) == False: + if not callable(model.gql): raise TypeError('model must be a valid model object.') # cache check cache_string = "gae_paginator_" for q_filter in q_filters: cache_string = cache_string + q_filter + "_" + q_filters[q_filter] + "_" - cache_string = cache_string + "index" + cache_string = f"{cache_string}index" c = Cache() if c.has_key(cache_string): return c[cache_string] @@ -82,16 +82,16 @@ def get(cls, count=10, q_filters={}, search=None, start=None, model=None, \ query = model.all() if len(q_filters) > 0: for q_filter in q_filters: - query.filter(q_filter + " = ", q_filters[q_filter]) + query.filter(f"{q_filter} = ", q_filters[q_filter]) if start: if order.lower() == "DESC".lower(): - query.filter(order_by + " <", start) + query.filter(f"{order_by} <", start) else: - query.filter(order_by + " >", start) + query.filter(f"{order_by} >", start) if search: query.search(search) if order.lower() == "DESC".lower(): - query.order("-" + order_by) + query.order(f"-{order_by}") else: query.order(order_by) results = query.fetch(count + 1) @@ -101,13 +101,13 @@ def get(cls, count=10, q_filters={}, search=None, start=None, model=None, \ if start is not None: rquery = model.all() for q_filter in q_filters: - rquery.filter(q_filter + " = ", q_filters[q_filter]) + rquery.filter(f"{q_filter} = ", q_filters[q_filter]) if search: query.search(search) if order.lower() == "DESC".lower(): rquery.order(order_by) else: - rquery.order("-" + order_by) + rquery.order(f"-{order_by}") rresults = rquery.fetch(count) previous = getattr(results[0], order_by) else: diff --git a/appengine_utilities/rotmodel.py b/appengine_utilities/rotmodel.py index 23ec091..2d903e0 100644 --- a/appengine_utilities/rotmodel.py +++ b/appengine_utilities/rotmodel.py @@ -42,5 +42,4 @@ def put(self): return db.Model.put(self) except db.Timeout: count += 1 - else: - raise db.Timeout() + raise db.Timeout() diff --git a/appengine_utilities/sessions.py b/appengine_utilities/sessions.py index 821925c..a6fba6d 100644 --- a/appengine_utilities/sessions.py +++ b/appengine_utilities/sessions.py @@ -311,9 +311,10 @@ def new_sid(self): """ Create a new session id. """ - sid = str(self.session.key()) + md5.new(repr(time.time()) + \ - str(random.random())).hexdigest() - return sid + return ( + str(self.session.key()) + + md5.new(repr(time.time()) + str(random.random())).hexdigest() + ) def _get_session(self): """ @@ -328,12 +329,11 @@ def _get_session(self): results = query.fetch(1) if len(results) is 0: return None - else: - sessionAge = datetime.datetime.now() - results[0].last_activity - if sessionAge.seconds > self.session_expire_time: - results[0].delete() - return None - return results[0] + sessionAge = datetime.datetime.now() - results[0].last_activity + if sessionAge.seconds > self.session_expire_time: + results[0].delete() + return None + return results[0] def _get(self, keyname=None): """ @@ -355,9 +355,7 @@ def _get(self, keyname=None): if len(results) is 0: return None - if keyname != None: - return results[0] - return results + return results[0] if keyname != None else results def _validate_key(self, keyname): """ @@ -365,13 +363,11 @@ def _validate_key(self, keyname): """ if keyname is None: raise ValueError('You must pass a keyname for the session' + \ - ' data content.') + ' data content.') elif keyname in ('sid', 'flash'): - raise ValueError(keyname + ' is a reserved keyname.') + raise ValueError(f'{keyname} is a reserved keyname.') - if type(keyname) != type([str, unicode]): - return str(keyname) - return keyname + return str(keyname) if type(keyname) != type([str, unicode]) else keyname def _put(self, keyname, value): """ @@ -381,11 +377,7 @@ def _put(self, keyname, value): keyname: The keyname of the mapping. value: The value of the mapping. """ - if self.writer == "datastore": - writer = _DatastoreWriter() - else: - writer = _CookieWriter() - + writer = _DatastoreWriter() if self.writer == "datastore" else _CookieWriter() writer.put(keyname, value, self) def _delete_session(self): @@ -441,7 +433,7 @@ def delete_all_sessions(cls): all_sessions_deleted = True else: for result in results: - memcache.delete('sid-' + str(result.key())) + memcache.delete(f'sid-{str(result.key())}') result.delete() while not all_data_deleted: @@ -471,7 +463,7 @@ def _clean_old_sessions(self): data_results = data_query.fetch(1000) for data_result in data_results: data_result.delete() - memcache.delete('sid-'+str(result.key())) + memcache.delete(f'sid-{str(result.key())}') result.delete() # Implement Python container methods @@ -492,12 +484,10 @@ def __getitem__(self, keyname): if keyname in self.cookie_vals: return self.cookie_vals[keyname] if hasattr(self, "session"): - mc = memcache.get('sid-'+str(self.session.key())) - if mc is not None: - if keyname in mc: - return mc[keyname] - data = self._get(keyname) - if data: + mc = memcache.get(f'sid-{str(self.session.key())}') + if mc is not None and keyname in mc: + return mc[keyname] + if data := self._get(keyname): #UNPICKLING CACHE self.cache[keyname] = data.content self.cache[keyname] = pickle.loads(data.content) self._set_memcache() @@ -574,14 +564,11 @@ def __len__(self): """ # check memcache first if hasattr(self, "session"): - mc = memcache.get('sid-'+str(self.session.key())) + mc = memcache.get(f'sid-{str(self.session.key())}') if mc is not None: return len(mc) + len(self.cookie_vals) results = self._get() - if results is not None: - return len(results) + len(self.cookie_vals) - else: - return 0 + return len(results) + len(self.cookie_vals) if results is not None else 0 return len(self.cookie_vals) def __contains__(self, keyname): @@ -603,15 +590,13 @@ def __iter__(self): """ # try memcache first if hasattr(self, "session"): - mc = memcache.get('sid-'+str(self.session.key())) + mc = memcache.get(f'sid-{str(self.session.key())}') if mc is not None: - for k in mc: - yield k + yield from mc else: for k in self._get(): yield k.keyname - for k in self.cookie_vals: - yield k + yield from self.cookie_vals def __str__(self): """ @@ -619,7 +604,7 @@ def __str__(self): """ #if self._get(): - return '{' + ', '.join(['"%s" = "%s"' % (k, self[k]) for k in self]) + '}' + return '{' + ', '.join([f'"{k}" = "{self[k]}"' for k in self]) + '}' #else: # return [] @@ -637,8 +622,7 @@ def _set_memcache(self): for sd in sessiondata: data[sd.keyname] = pickle.loads(sd.content) - memcache.set('sid-'+str(self.session.key()), data, \ - self.session_expire_time) + memcache.set(f'sid-{str(self.session.key())}', data, self.session_expire_time) def cycle_key(self): """ @@ -694,19 +678,13 @@ def items(self): """ A copy of list of (key, value) pairs """ - op = {} - for k in self: - op[k] = self[k] - return op + return {k: self[k] for k in self} def keys(self): """ List of keys. """ - l = [] - for k in self: - l.append(k) - return l + return list(self) def update(*dicts): """ @@ -721,10 +699,7 @@ def values(self): """ A copy list of values. """ - v = [] - for k in self: - v.append(self[k]) - return v + return [self[k] for k in self] def get(self, keyname, default = None): """ @@ -733,9 +708,7 @@ def get(self, keyname, default = None): try: return self.__getitem__(keyname) except KeyError: - if default is not None: - return default - return None + return default if default is not None else None def setdefault(self, keyname, default = None): """ diff --git a/chardet/chardistribution.py b/chardet/chardistribution.py index b893341..eb473d3 100644 --- a/chardet/chardistribution.py +++ b/chardet/chardistribution.py @@ -51,17 +51,12 @@ def reset(self): def feed(self, aStr, aCharLen): """feed a character with known length""" - if aCharLen == 2: - # we only care about 2-bytes character in our distribution analysis - order = self.get_order(aStr) - else: - order = -1 + order = self.get_order(aStr) if aCharLen == 2 else -1 if order >= 0: self._mTotalChars += 1 # order is valid - if order < self._mTableSize: - if 512 > self._mCharToFreqOrder[order]: - self._mFreqChars += 1 + if order < self._mTableSize and self._mCharToFreqOrder[order] < 512: + self._mFreqChars += 1 def get_confidence(self): """return confidence based on existing data""" @@ -147,17 +142,12 @@ def __init__(self): self._mTypicalDistributionRatio = BIG5_TYPICAL_DISTRIBUTION_RATIO def get_order(self, aStr): - # for big5 encoding, we are interested - # first byte range: 0xa4 -- 0xfe - # second byte range: 0x40 -- 0x7e , 0xa1 -- 0xfe - # no validation needed here. State machine has done that - if aStr[0] >= '\xA4': - if aStr[1] >= '\xA1': - return 157 * (ord(aStr[0]) - 0xA4) + ord(aStr[1]) - 0xA1 + 63 - else: - return 157 * (ord(aStr[0]) - 0xA4) + ord(aStr[1]) - 0x40 - else: + if aStr[0] < '\xA4': return -1 + if aStr[1] >= '\xA1': + return 157 * (ord(aStr[0]) - 0xA4) + ord(aStr[1]) - 0xA1 + 63 + else: + return 157 * (ord(aStr[0]) - 0xA4) + ord(aStr[1]) - 0x40 class SJISDistributionAnalysis(CharDistributionAnalysis): def __init__(self): diff --git a/chardet/charsetgroupprober.py b/chardet/charsetgroupprober.py index 5188069..2ad7206 100644 --- a/chardet/charsetgroupprober.py +++ b/chardet/charsetgroupprober.py @@ -48,9 +48,11 @@ def reset(self): def get_charset_name(self): if not self._mBestGuessProber: self.get_confidence() - if not self._mBestGuessProber: return None -# self._mBestGuessProber = self._mProbers[0] - return self._mBestGuessProber.get_charset_name() + return ( + self._mBestGuessProber.get_charset_name() + if self._mBestGuessProber + else None + ) def feed(self, aBuf): for prober in self._mProbers: @@ -89,8 +91,7 @@ def get_confidence(self): if bestConf < cf: bestConf = cf self._mBestGuessProber = prober - if not self._mBestGuessProber: return 0.0 - return bestConf + return bestConf if self._mBestGuessProber else 0.0 # else: # self._mBestGuessProber = self._mProbers[0] # return self._mBestGuessProber.get_confidence() diff --git a/chardet/escprober.py b/chardet/escprober.py index 572ed7b..62acd12 100644 --- a/chardet/escprober.py +++ b/chardet/escprober.py @@ -54,10 +54,7 @@ def get_charset_name(self): return self._mDetectedCharset def get_confidence(self): - if self._mDetectedCharset: - return 0.99 - else: - return 0.00 + return 0.99 if self._mDetectedCharset else 0.00 def feed(self, aBuf): for c in aBuf: diff --git a/chardet/eucjpprober.py b/chardet/eucjpprober.py index 46a8b38..31ac23f 100644 --- a/chardet/eucjpprober.py +++ b/chardet/eucjpprober.py @@ -50,11 +50,14 @@ def get_charset_name(self): def feed(self, aBuf): aLen = len(aBuf) - for i in range(0, aLen): + for i in range(aLen): codingState = self._mCodingSM.next_state(aBuf[i]) if codingState == eError: if constants._debug: - sys.stderr.write(self.get_charset_name() + ' prober hit error at byte ' + str(i) + '\n') + sys.stderr.write( + f'{self.get_charset_name()} prober hit error at byte {str(i)}' + + '\n' + ) self._mState = constants.eNotMe break elif codingState == eItsMe: @@ -69,13 +72,15 @@ def feed(self, aBuf): else: self._mContextAnalyzer.feed(aBuf[i-1:i+1], charLen) self._mDistributionAnalyzer.feed(aBuf[i-1:i+1], charLen) - + self._mLastChar[0] = aBuf[aLen - 1] - - if self.get_state() == constants.eDetecting: - if self._mContextAnalyzer.got_enough_data() and \ - (self.get_confidence() > constants.SHORTCUT_THRESHOLD): - self._mState = constants.eFoundIt + + if ( + self.get_state() == constants.eDetecting + and self._mContextAnalyzer.got_enough_data() + and (self.get_confidence() > constants.SHORTCUT_THRESHOLD) + ): + self._mState = constants.eFoundIt return self.get_state() diff --git a/chardet/hebrewprober.py b/chardet/hebrewprober.py index a2b1eaa..5d96448 100644 --- a/chardet/hebrewprober.py +++ b/chardet/hebrewprober.py @@ -215,7 +215,7 @@ def feed(self, aBuf): return constants.eNotMe aBuf = self.filter_high_bit_only(aBuf) - + for cur in aBuf: if cur == ' ': # We stand on a space - a word just ended @@ -227,11 +227,9 @@ def feed(self, aBuf): elif self.is_non_final(self._mPrev): # case (2) [-2:not space][-1:Non-Final letter][cur:space] self._mFinalCharVisualScore += 1 - else: - # Not standing on a space - if (self._mBeforePrev == ' ') and (self.is_final(self._mPrev)) and (cur != ' '): - # case (3) [-2:space][-1:final letter][cur:not space] - self._mFinalCharVisualScore += 1 + elif (self._mBeforePrev == ' ') and (self.is_final(self._mPrev)) and (cur != ' '): + # case (3) [-2:space][-1:final letter][cur:not space] + self._mFinalCharVisualScore += 1 self._mBeforePrev = self._mPrev self._mPrev = cur @@ -255,11 +253,7 @@ def get_charset_name(self): return VISUAL_HEBREW_NAME # Still no good, back to final letter distance, maybe it'll save the day. - if finalsub < 0.0: - return VISUAL_HEBREW_NAME - - # (finalsub > 0 - Logical) or (don't know what to do) default to Logical. - return LOGICAL_HEBREW_NAME + return VISUAL_HEBREW_NAME if finalsub < 0.0 else LOGICAL_HEBREW_NAME def get_state(self): # Remain active as long as any of the model probers are active. diff --git a/chardet/jpcntx.py b/chardet/jpcntx.py index 93db4a9..3c39af6 100644 --- a/chardet/jpcntx.py +++ b/chardet/jpcntx.py @@ -174,17 +174,19 @@ def get_order(self, aStr): if not aStr: return -1, 1 # find out current char's byte length if ((aStr[0] >= '\x81') and (aStr[0] <= '\x9F')) or \ - ((aStr[0] >= '\xE0') and (aStr[0] <= '\xFC')): + ((aStr[0] >= '\xE0') and (aStr[0] <= '\xFC')): charLen = 2 else: charLen = 1 # return its order if it is hiragana - if len(aStr) > 1: - if (aStr[0] == '\202') and \ - (aStr[1] >= '\x9F') and \ - (aStr[1] <= '\xF1'): - return ord(aStr[1]) - 0x9F, charLen + if ( + len(aStr) > 1 + and (aStr[0] == '\202') + and (aStr[1] >= '\x9F') + and (aStr[1] <= '\xF1') + ): + return ord(aStr[1]) - 0x9F, charLen return -1, charLen @@ -193,7 +195,7 @@ def get_order(self, aStr): if not aStr: return -1, 1 # find out current char's byte length if (aStr[0] == '\x8E') or \ - ((aStr[0] >= '\xA1') and (aStr[0] <= '\xFE')): + ((aStr[0] >= '\xA1') and (aStr[0] <= '\xFE')): charLen = 2 elif aStr[0] == '\x8F': charLen = 3 @@ -201,10 +203,12 @@ def get_order(self, aStr): charLen = 1 # return its order if it is hiragana - if len(aStr) > 1: - if (aStr[0] == '\xA4') and \ - (aStr[1] >= '\xA1') and \ - (aStr[1] <= '\xF3'): - return ord(aStr[1]) - 0xA1, charLen + if ( + len(aStr) > 1 + and (aStr[0] == '\xA4') + and (aStr[1] >= '\xA1') + and (aStr[1] <= '\xF3') + ): + return ord(aStr[1]) - 0xA1, charLen return -1, charLen diff --git a/chardet/latin1prober.py b/chardet/latin1prober.py index b46129b..2886d8d 100644 --- a/chardet/latin1prober.py +++ b/chardet/latin1prober.py @@ -122,15 +122,11 @@ def feed(self, aBuf): def get_confidence(self): if self.get_state() == constants.eNotMe: return 0.01 - + total = reduce(operator.add, self._mFreqCounter) if total < 0.01: confidence = 0.0 else: confidence = (self._mFreqCounter[3] / total) - (self._mFreqCounter[1] * 20.0 / total) - if confidence < 0.0: - confidence = 0.0 - # lower the confidence of latin1 so that other more accurate detector - # can take priority. - confidence = confidence * 0.5 - return confidence + confidence = max(confidence, 0.0) + return confidence * 0.5 diff --git a/chardet/mbcharsetprober.py b/chardet/mbcharsetprober.py index a813144..f461d8a 100644 --- a/chardet/mbcharsetprober.py +++ b/chardet/mbcharsetprober.py @@ -51,11 +51,14 @@ def get_charset_name(self): def feed(self, aBuf): aLen = len(aBuf) - for i in range(0, aLen): + for i in range(aLen): codingState = self._mCodingSM.next_state(aBuf[i]) if codingState == eError: if constants._debug: - sys.stderr.write(self.get_charset_name() + ' prober hit error at byte ' + str(i) + '\n') + sys.stderr.write( + f'{self.get_charset_name()} prober hit error at byte {str(i)}' + + '\n' + ) self._mState = constants.eNotMe break elif codingState == eItsMe: @@ -68,13 +71,15 @@ def feed(self, aBuf): self._mDistributionAnalyzer.feed(self._mLastChar, charLen) else: self._mDistributionAnalyzer.feed(aBuf[i-1:i+1], charLen) - + self._mLastChar[0] = aBuf[aLen - 1] - - if self.get_state() == constants.eDetecting: - if self._mDistributionAnalyzer.got_enough_data() and \ - (self.get_confidence() > constants.SHORTCUT_THRESHOLD): - self._mState = constants.eFoundIt + + if ( + self.get_state() == constants.eDetecting + and self._mDistributionAnalyzer.got_enough_data() + and (self.get_confidence() > constants.SHORTCUT_THRESHOLD) + ): + self._mState = constants.eFoundIt return self.get_state() diff --git a/chardet/sbcharsetprober.py b/chardet/sbcharsetprober.py index da07116..37213fb 100644 --- a/chardet/sbcharsetprober.py +++ b/chardet/sbcharsetprober.py @@ -80,17 +80,19 @@ def feed(self, aBuf): self._mSeqCounters[self._mModel['precedenceMatrix'][(order * SAMPLE_SIZE) + self._mLastOrder]] += 1 self._mLastOrder = order - if self.get_state() == constants.eDetecting: - if self._mTotalSeqs > SB_ENOUGH_REL_THRESHOLD: - cf = self.get_confidence() - if cf > POSITIVE_SHORTCUT_THRESHOLD: - if constants._debug: - sys.stderr.write('%s confidence = %s, we have a winner\n' % (self._mModel['charsetName'], cf)) - self._mState = constants.eFoundIt - elif cf < NEGATIVE_SHORTCUT_THRESHOLD: - if constants._debug: - sys.stderr.write('%s confidence = %s, below negative shortcut threshhold %s\n' % (self._mModel['charsetName'], cf, NEGATIVE_SHORTCUT_THRESHOLD)) - self._mState = constants.eNotMe + if ( + self.get_state() == constants.eDetecting + and self._mTotalSeqs > SB_ENOUGH_REL_THRESHOLD + ): + cf = self.get_confidence() + if cf > POSITIVE_SHORTCUT_THRESHOLD: + if constants._debug: + sys.stderr.write('%s confidence = %s, we have a winner\n' % (self._mModel['charsetName'], cf)) + self._mState = constants.eFoundIt + elif cf < NEGATIVE_SHORTCUT_THRESHOLD: + if constants._debug: + sys.stderr.write('%s confidence = %s, below negative shortcut threshhold %s\n' % (self._mModel['charsetName'], cf, NEGATIVE_SHORTCUT_THRESHOLD)) + self._mState = constants.eNotMe return self.get_state() diff --git a/chardet/sjisprober.py b/chardet/sjisprober.py index fea2690..935b2b1 100644 --- a/chardet/sjisprober.py +++ b/chardet/sjisprober.py @@ -50,11 +50,14 @@ def get_charset_name(self): def feed(self, aBuf): aLen = len(aBuf) - for i in range(0, aLen): + for i in range(aLen): codingState = self._mCodingSM.next_state(aBuf[i]) if codingState == eError: if constants._debug: - sys.stderr.write(self.get_charset_name() + ' prober hit error at byte ' + str(i) + '\n') + sys.stderr.write( + f'{self.get_charset_name()} prober hit error at byte {str(i)}' + + '\n' + ) self._mState = constants.eNotMe break elif codingState == eItsMe: @@ -69,13 +72,15 @@ def feed(self, aBuf): else: self._mContextAnalyzer.feed(aBuf[i + 1 - charLen : i + 3 - charLen], charLen) self._mDistributionAnalyzer.feed(aBuf[i - 1 : i + 1], charLen) - + self._mLastChar[0] = aBuf[aLen - 1] - - if self.get_state() == constants.eDetecting: - if self._mContextAnalyzer.got_enough_data() and \ - (self.get_confidence() > constants.SHORTCUT_THRESHOLD): - self._mState = constants.eFoundIt + + if ( + self.get_state() == constants.eDetecting + and self._mContextAnalyzer.got_enough_data() + and (self.get_confidence() > constants.SHORTCUT_THRESHOLD) + ): + self._mState = constants.eFoundIt return self.get_state() diff --git a/chardet/utf8prober.py b/chardet/utf8prober.py index c1792bb..d4da03c 100644 --- a/chardet/utf8prober.py +++ b/chardet/utf8prober.py @@ -60,17 +60,18 @@ def feed(self, aBuf): if self._mCodingSM.get_current_charlen() >= 2: self._mNumOfMBChar += 1 - if self.get_state() == constants.eDetecting: - if self.get_confidence() > constants.SHORTCUT_THRESHOLD: - self._mState = constants.eFoundIt + if ( + self.get_state() == constants.eDetecting + and self.get_confidence() > constants.SHORTCUT_THRESHOLD + ): + self._mState = constants.eFoundIt return self.get_state() def get_confidence(self): unlike = 0.99 - if self._mNumOfMBChar < 6: - for i in range(0, self._mNumOfMBChar): - unlike = unlike * ONE_CHAR_PROB - return 1.0 - unlike - else: + if self._mNumOfMBChar >= 6: return unlike + for _ in range(self._mNumOfMBChar): + unlike = unlike * ONE_CHAR_PROB + return 1.0 - unlike diff --git a/feedparser.py b/feedparser.py index bb802df..39ab4fc 100755 --- a/feedparser.py +++ b/feedparser.py @@ -11,6 +11,7 @@ Recommended: CJKCodecs and iconv_codec """ + __version__ = "4.1"# + "$Revision: 1.92 $"[11:15] + "-cvs" __license__ = """Copyright (c) 2002-2006, Mark Pilgrim, All rights reserved. @@ -45,7 +46,7 @@ # HTTP "User-Agent" header to send to servers when downloading feeds. # If you are embedding feedparser in a larger application, you should # change this to your application name and URL. -USER_AGENT = "UniversalFeedParser/%s +http://feedparser.org/" % __version__ +USER_AGENT = f"UniversalFeedParser/{__version__} +http://feedparser.org/" # HTTP "Accept" header to send to servers when downloading feeds. If you don't # want to send an Accept header, set this to None. @@ -164,10 +165,7 @@ class UndeclaredNamespace(Exception): pass # Python 2.1 does not have dict from UserDict import UserDict def dict(aList): - rc = {} - for k, v in aList: - rc[k] = v - return rc + return dict(aList) class FeedParserDict(UserDict): keymap = {'channel': 'feed', @@ -208,10 +206,7 @@ def __setitem__(self, key, value): return UserDict.__setitem__(self, key, value) def get(self, key, default=None): - if self.has_key(key): - return self[key] - else: - return default + return self[key] if self.has_key(key) else default def setdefault(self, key, value): if not self.has_key(key): @@ -233,7 +228,7 @@ def __getattr__(self, key): assert not key.startswith('_') return self.__getitem__(key) except: - raise AttributeError, "object has no attribute '%s'" % key + raise (AttributeError, f"object has no attribute '{key}'") def __setattr__(self, key, value): if key.startswith('_') or key == 'data': @@ -500,12 +495,9 @@ def handle_charref(self, ref): if not self.elementstack: return ref = ref.lower() if ref in ('34', '38', '39', '60', '62', 'x22', 'x26', 'x27', 'x3c', 'x3e'): - text = '&#%s;' % ref + text = f'&#{ref};' else: - if ref[0] == 'x': - c = int(ref[1:], 16) - else: - c = int(ref) + c = int(ref[1:], 16) if ref[0] == 'x' else int(ref) text = unichr(c).encode('utf-8') self.elementstack[-1][2].append(text) @@ -514,7 +506,7 @@ def handle_entityref(self, ref): if not self.elementstack: return if _debug: sys.stderr.write('entering handle_entityref with %s\n' % ref) if ref in ('lt', 'gt', 'quot', 'amp', 'apos'): - text = '&%s;' % ref + text = f'&{ref};' else: # entity resolution graciously donated by Aaron Swartz def name2cp(k): @@ -522,11 +514,11 @@ def name2cp(k): if hasattr(htmlentitydefs, 'name2codepoint'): # requires Python 2.3 return htmlentitydefs.name2codepoint[k] k = htmlentitydefs.entitydefs[k] - if k.startswith('&#') and k.endswith(';'): - return int(k[2:-1]) # not in latin-1 - return ord(k) + return int(k[2:-1]) if k.startswith('&#') and k.endswith(';') else ord(k) + try: name2cp(ref) - except KeyError: text = '&%s;' % ref + except KeyError: + text = f'&{ref};' else: text = unichr(name2cp(ref)).encode('utf-8') self.elementstack[-1][2].append(text) @@ -601,7 +593,7 @@ def push(self, element, expectingText): def pop(self, element, stripWhitespace=1): if not self.elementstack: return if self.elementstack[-1][0] != element: return - + element, expectingText, pieces = self.elementstack.pop() output = ''.join(pieces) if stripWhitespace: @@ -616,11 +608,11 @@ def pop(self, element, stripWhitespace=1): pass except binascii.Incomplete: pass - + # resolve relative URIs if (element in self.can_be_relative_uri) and output: output = self.resolveURI(output) - + # decode entities within embedded markup if not self.contentparams.get('base64', 0): output = self.decodeEntities(element, output) @@ -636,14 +628,20 @@ def pop(self, element, stripWhitespace=1): pass # resolve relative URIs within embedded markup - if self.mapContentType(self.contentparams.get('type', 'text/html')) in self.html_types: - if element in self.can_contain_relative_uris: - output = _resolveRelativeURIs(output, self.baseuri, self.encoding) - + if ( + self.mapContentType(self.contentparams.get('type', 'text/html')) + in self.html_types + and element in self.can_contain_relative_uris + ): + output = _resolveRelativeURIs(output, self.baseuri, self.encoding) + # sanitize embedded markup - if self.mapContentType(self.contentparams.get('type', 'text/html')) in self.html_types: - if element in self.can_contain_dangerous_markup: - output = _sanitizeHTML(output, self.encoding) + if ( + self.mapContentType(self.contentparams.get('type', 'text/html')) + in self.html_types + and element in self.can_contain_dangerous_markup + ): + output = _sanitizeHTML(output, self.encoding) if self.encoding and type(output) != type(u''): try: @@ -654,7 +652,7 @@ def pop(self, element, stripWhitespace=1): # categories/tags/keywords/whatever are handled in _end_category if element == 'category': return output - + # store output in appropriate place(s) if self.inentry and not self.insource: if element == 'content': @@ -673,7 +671,7 @@ def pop(self, element, stripWhitespace=1): if self.incontent: contentparams = copy.deepcopy(self.contentparams) contentparams['value'] = output - self.entries[-1][element + '_detail'] = contentparams + self.entries[-1][f'{element}_detail'] = contentparams elif (self.infeed or self.insource) and (not self.intextinput) and (not self.inimage): context = self._getContext() if element == 'description': @@ -684,7 +682,7 @@ def pop(self, element, stripWhitespace=1): elif self.incontent: contentparams = copy.deepcopy(self.contentparams) contentparams['value'] = output - context[element + '_detail'] = contentparams + context[f'{element}_detail'] = contentparams return output def pushContent(self, tag, attrsD, defaultContentType, expectingText): @@ -721,13 +719,10 @@ def _isBase64(self, attrsD, contentparams): return 0 if self.contentparams['type'].endswith('+xml'): return 0 - if self.contentparams['type'].endswith('/xml'): - return 0 - return 1 + return 0 if self.contentparams['type'].endswith('/xml') else 1 def _itsAnHrefDamnIt(self, attrsD): - href = attrsD.get('url', attrsD.get('uri', attrsD.get('href', None))) - if href: + if href := attrsD.get('url', attrsD.get('uri', attrsD.get('href', None))): try: del attrsD['url'] except KeyError: @@ -744,14 +739,13 @@ def _save(self, key, value): context.setdefault(key, value) def _start_rss(self, attrsD): - versionmap = {'0.91': 'rss091u', - '0.92': 'rss092', - '0.93': 'rss093', - '0.94': 'rss094'} if not self.version: attr_version = attrsD.get('version', '') - version = versionmap.get(attr_version) - if version: + versionmap = {'0.91': 'rss091u', + '0.92': 'rss092', + '0.93': 'rss093', + '0.94': 'rss094'} + if version := versionmap.get(attr_version): self.version = version elif attr_version.startswith('2.'): self.version = 'rss20' @@ -778,16 +772,12 @@ def _cdf_common(self, attrsD): def _start_feed(self, attrsD): self.infeed = 1 - versionmap = {'0.1': 'atom01', - '0.2': 'atom02', - '0.3': 'atom03'} if not self.version: attr_version = attrsD.get('version') - version = versionmap.get(attr_version) - if version: - self.version = version - else: - self.version = 'atom' + versionmap = {'0.1': 'atom01', + '0.2': 'atom02', + '0.3': 'atom03'} + self.version = version if (version := versionmap.get(attr_version)) else 'atom' def _end_channel(self): self.infeed = 0 @@ -942,17 +932,16 @@ def _end_email(self): def _getContext(self): if self.insource: - context = self.sourcedata + return self.sourcedata elif self.inentry: - context = self.entries[-1] + return self.entries[-1] else: - context = self.feeddata - return context + return self.feeddata def _save_author(self, key, value, prefix='author'): context = self._getContext() - context.setdefault(prefix + '_detail', FeedParserDict()) - context[prefix + '_detail'][key] = value + context.setdefault(f'{prefix}_detail', FeedParserDict()) + context[f'{prefix}_detail'][key] = value self._sync_author_detail() def _save_contributor(self, key, value): @@ -962,12 +951,11 @@ def _save_contributor(self, key, value): def _sync_author_detail(self, key='author'): context = self._getContext() - detail = context.get('%s_detail' % key) - if detail: + if detail := context.get(f'{key}_detail'): name = detail.get('name') email = detail.get('email') if name and email: - context[key] = '%s (%s)' % (name, email) + context[key] = f'{name} ({email})' elif name: context[key] = name elif email: @@ -977,7 +965,7 @@ def _sync_author_detail(self, key='author'): if not author: return emailmatch = re.search(r'''(([a-zA-Z0-9\_\-\.\+]+)@((\[[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.)|(([a-zA-Z0-9\-]+\.)+))([a-zA-Z]{2,4}|[0-9]{1,3})(\]?))''', author) if not emailmatch: return - email = emailmatch.group(0) + email = emailmatch[0] # probably a better way to do the following, but it passes all the tests author = author.replace(email, '') author = author.replace('()', '') @@ -987,9 +975,9 @@ def _sync_author_detail(self, key='author'): if author and (author[-1] == ')'): author = author[:-1] author = author.strip() - context.setdefault('%s_detail' % key, FeedParserDict()) - context['%s_detail' % key]['name'] = author - context['%s_detail' % key]['email'] = email + context.setdefault(f'{key}_detail', FeedParserDict()) + context[f'{key}_detail']['name'] = author + context[f'{key}_detail']['email'] = email def _start_subtitle(self, attrsD): self.pushContent('subtitle', attrsD, 'text/plain', 1) @@ -1016,8 +1004,7 @@ def _start_item(self, attrsD): self.push('item', 0) self.inentry = 1 self.guidislink = 0 - id = self._getAttribute(attrsD, 'rdf:about') - if id: + if id := self._getAttribute(attrsD, 'rdf:about'): context = self._getContext() context['id'] = id self._cdf_common(attrsD) @@ -1090,8 +1077,7 @@ def _end_expirationdate(self): def _start_cc_license(self, attrsD): self.push('license', 1) - value = self._getAttribute(attrsD, 'rdf:resource') - if value: + if value := self._getAttribute(attrsD, 'rdf:resource'): self.elementstack[-1][2].append(value) self.pop('license') @@ -1255,8 +1241,7 @@ def _start_admin_generatoragent(self, attrsD): def _start_admin_errorreportsto(self, attrsD): self.push('errorreportsto', 1) - value = self._getAttribute(attrsD, 'rdf:resource') - if value: + if value := self._getAttribute(attrsD, 'rdf:resource'): self.elementstack[-1][2].append(value) self.pop('errorreportsto') @@ -1281,8 +1266,7 @@ def _end_summary(self): def _start_enclosure(self, attrsD): attrsD = self._itsAnHrefDamnIt(attrsD) self._getContext().setdefault('enclosures', []).append(FeedParserDict(attrsD)) - href = attrsD.get('href') - if href: + if href := attrsD.get('href'): context = self._getContext() if not context.get('id'): context['id'] = href @@ -1297,8 +1281,7 @@ def _end_source(self): def _start_content(self, attrsD): self.pushContent('content', attrsD, 'text/plain', 1) - src = attrsD.get('src') - if src: + if src := attrsD.get('src'): self.contentparams['src'] = src self.push('content', 1) @@ -1315,8 +1298,8 @@ def _start_content_encoded(self, attrsD): def _end_content(self): copyToDescription = self.mapContentType(self.contentparams.get('type')) in (['text/plain'] + self.html_types) - value = self.popContent('content') if copyToDescription: + value = self.popContent('content') self._save('description', value) _end_body = _end_content _end_xhtml_body = _end_content @@ -1331,11 +1314,11 @@ def _start_itunes_image(self, attrsD): def _end_itunes_block(self): value = self.pop('itunes_block', 0) - self._getContext()['itunes_block'] = (value == 'yes') and 1 or 0 + self._getContext()['itunes_block'] = 1 if value == 'yes' else 0 def _end_itunes_explicit(self): value = self.pop('itunes_explicit', 0) - self._getContext()['itunes_explicit'] = (value == 'yes') and 1 or 0 + self._getContext()['itunes_explicit'] = 1 if value == 'yes' else 0 if _XML_AVAILABLE: class _StrictFeedParser(_FeedParserMixin, xml.sax.handler.ContentHandler): @@ -1392,13 +1375,9 @@ def characters(self, text): def endElementNS(self, name, qname): namespace, localname = name lowernamespace = str(namespace or '').lower() - if qname and qname.find(':') > 0: - givenprefix = qname.split(':')[0] - else: - givenprefix = '' - prefix = self._matchnamespaces.get(lowernamespace, givenprefix) - if prefix: - localname = prefix + ':' + localname + givenprefix = qname.split(':')[0] if qname and qname.find(':') > 0 else '' + if prefix := self._matchnamespaces.get(lowernamespace, givenprefix): + localname = f'{prefix}:{localname}' localname = str(localname).lower() self.unknown_endtag(localname) @@ -1425,10 +1404,7 @@ def reset(self): def _shorttag_replace(self, match): tag = match.group(1) - if tag in self.elements_no_end_tag: - return '<' + tag + ' />' - else: - return '<' + tag + '>' + return f'<{tag} />' if tag in self.elements_no_end_tag else f'<{tag}>' def feed(self, data): data = re.compile(r'' % locals()) else: @@ -1509,13 +1487,10 @@ def _scan_name(self, i, declstartpos): n = len(rawdata) if i == n: return None, -1 - m = self._new_declname_match(rawdata, i) - if m: + if m := self._new_declname_match(rawdata, i): s = m.group() name = s.strip() - if (i + len(s)) == n: - return None, -1 # end of buffer - return name.lower(), m.end() + return (None, -1) if (i + len(s)) == n else (name.lower(), m.end()) else: self.handle_data(rawdata) # self.updatepos(declstartpos, i) @@ -1622,7 +1597,7 @@ def reset(self): self.unacceptablestack = 0 def unknown_starttag(self, tag, attrs): - if not tag in self.acceptable_elements: + if tag not in self.acceptable_elements: if tag in self.unacceptable_elements_with_end_tag: self.unacceptablestack += 1 return @@ -1631,7 +1606,7 @@ def unknown_starttag(self, tag, attrs): _BaseHTMLProcessor.unknown_starttag(self, tag, attrs) def unknown_endtag(self, tag): - if not tag in self.acceptable_elements: + if tag not in self.acceptable_elements: if tag in self.unacceptable_elements_with_end_tag: self.unacceptablestack -= 1 return @@ -1688,7 +1663,7 @@ def _tidy(data, **kwargs): class _FeedURLHandler(urllib2.HTTPDigestAuthHandler, urllib2.HTTPRedirectHandler, urllib2.HTTPDefaultErrorHandler): def http_error_default(self, req, fp, code, msg, headers): - if ((code / 100) == 3) and (code != 304): + if code == 300: return self.http_error_302(req, fp, code, msg, headers) infourl = urllib.addinfourl(fp, headers, req.get_full_url()) infourl.status = code @@ -1784,7 +1759,7 @@ def _open_resource(url_file_stream_or_string, etag, modified, agent, referrer, h if realhost: user_passwd, realhost = urllib.splituser(realhost) if user_passwd: - url_file_stream_or_string = '%s://%s%s' % (urltype, realhost, rest) + url_file_stream_or_string = f'{urltype}://{realhost}{rest}' auth = base64.encodestring(user_passwd).strip() # try to open with urllib2 (to use optional headers) request = urllib2.Request(url_file_stream_or_string) @@ -1810,7 +1785,7 @@ def _open_resource(url_file_stream_or_string, etag, modified, agent, referrer, h else: request.add_header('Accept-encoding', '') if auth: - request.add_header('Authorization', 'Basic %s' % auth) + request.add_header('Authorization', f'Basic {auth}') if ACCEPT_HEADER: request.add_header('Accept', ACCEPT_HEADER) request.add_header('A-IM', 'feed') # RFC 3229 support @@ -1820,7 +1795,7 @@ def _open_resource(url_file_stream_or_string, etag, modified, agent, referrer, h return opener.open(request) finally: opener.close() # JohnD - + # try to open with native open function (if url_file_stream_or_string is a filename) try: return open(url_file_stream_or_string) @@ -1875,10 +1850,7 @@ def _parse_date_iso8601(dateString): if m.span() == (0, 0): return params = m.groupdict() ordinal = params.get('ordinal', 0) - if ordinal: - ordinal = int(ordinal) - else: - ordinal = 0 + ordinal = int(ordinal) if ordinal else 0 year = params.get('year', '--') if not year or year == '--': year = time.gmtime()[0] @@ -1889,25 +1861,18 @@ def _parse_date_iso8601(dateString): year = int(year) month = params.get('month', '-') if not month or month == '-': - # ordinals are NOT normalized by mktime, we simulate them - # by setting month=1, day=ordinal - if ordinal: - month = 1 - else: - month = time.gmtime()[1] + month = 1 if ordinal else time.gmtime()[1] month = int(month) day = params.get('day', 0) - if not day: - # see above - if ordinal: - day = ordinal - elif params.get('century', 0) or \ - params.get('year', 0) or params.get('month', 0): - day = 1 - else: - day = time.gmtime()[2] - else: + if day: day = int(day) + elif ordinal: + day = ordinal + elif params.get('century', 0) or \ + params.get('year', 0) or params.get('month', 0): + day = 1 + else: + day = time.gmtime()[2] # special case of the century - is the first year of the 21st century # 2000 or 2001 ? The debate goes on... if 'century' in params.keys(): @@ -1980,7 +1945,7 @@ def _parse_date_nate(dateString): hour += 12 hour = str(hour) if len(hour) == 1: - hour = '0' + hour + hour = f'0{hour}' w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s:%(second)s%(zonediff)s' % \ {'year': m.group(1), 'month': m.group(2), 'day': m.group(3),\ 'hour': hour, 'minute': m.group(6), 'second': m.group(7),\ @@ -2086,10 +2051,10 @@ def _parse_date_hungarian(dateString): month = _hungarian_months[m.group(2)] day = m.group(3) if len(day) == 1: - day = '0' + day + day = f'0{day}' hour = m.group(4) if len(hour) == 1: - hour = '0' + hour + hour = f'0{hour}' except: return w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s%(zonediff)s' % \ @@ -2108,7 +2073,7 @@ def _parse_date_w3dtf(dateString): def __extract_date(m): year = int(m.group('year')) if year < 100: - year = 100 * int(time.gmtime()[0] / 100) + int(year) + year = 100 * int(time.gmtime()[0] / 100) + year if year < 1000: return 0, 0, 0 julian = m.group('julian') @@ -2140,10 +2105,7 @@ def __extract_date(m): else: month = int(month) day = m.group('day') - if day: - day = int(day) - else: - day = 1 + day = int(day) if day else 1 return year, month, day def __extract_time(m): @@ -2155,10 +2117,7 @@ def __extract_time(m): hours = int(hours) minutes = int(m.group('minutes')) seconds = m.group('seconds') - if seconds: - seconds = int(seconds) - else: - seconds = 0 + seconds = int(seconds) if seconds else 0 return hours, minutes, seconds def __extract_tzd(m): @@ -2172,14 +2131,9 @@ def __extract_tzd(m): return 0 hours = int(m.group('tzdhours')) minutes = m.group('tzdminutes') - if minutes: - minutes = int(minutes) - else: - minutes = 0 + minutes = int(minutes) if minutes else 0 offset = (hours*60 + minutes) * 60 - if tzd[0] == '+': - return -offset - return offset + return -offset if tzd[0] == '+' else offset __date_re = ('(?P\d\d\d\d)' '(?:(?P-|)' @@ -2190,7 +2144,7 @@ def __extract_tzd(m): __time_re = ('(?P\d\d)(?P:|)(?P\d\d)' '(?:(?P=tsep)(?P\d\d(?:[.,]\d+)?))?' + __tzd_re) - __datetime_re = '%s(?:T%s)?' % (__date_re, __time_re) + __datetime_re = f'{__date_re}(?:T{__time_re})?' __datetime_rx = re.compile(__datetime_re) m = __datetime_rx.match(dateString) if (m is None) or (m.group() != dateString): return @@ -2214,8 +2168,7 @@ def _parse_date_rfc822(dateString): dateString = " ".join(data) if len(data) < 5: dateString += ' 00:00:00 GMT' - tm = rfc822.parsedate_tz(dateString) - if tm: + if tm := rfc822.parsedate_tz(dateString): return time.gmtime(rfc822.mktime_tz(tm)) # rfc822.py defines several time zones, but we define some extra ones. # 'ET' is equivalent to 'EST', etc. @@ -2346,9 +2299,6 @@ def _parseHTTPContentType(content_type): # UTF-8 with BOM sniffed_xml_encoding = 'utf-8' xml_data = unicode(xml_data[3:], 'utf-8').encode('utf-8') - else: - # ASCII-compatible - pass xml_encoding_match = re.compile('^<\?.*encoding=[\'"](.*?)[\'"].*\?>').match(xml_data) except: xml_encoding_match = None @@ -2439,10 +2389,7 @@ def _stripDoctype(data): doctype_pattern = re.compile(r']*?)>', re.MULTILINE) doctype_results = doctype_pattern.findall(data) doctype = doctype_results and doctype_results[0] or '' - if doctype.lower().count('netscape'): - version = 'rss091n' - else: - version = None + version = 'rss091n' if doctype.lower().count('netscape') else None data = doctype_pattern.sub('', data) return version, data diff --git a/gae_utils.py b/gae_utils.py index c897fe5..c1699ef 100644 --- a/gae_utils.py +++ b/gae_utils.py @@ -35,43 +35,36 @@ from appengine_utilities.sessions import Session -DEV = False -if 'DEV' in os.listdir('.'): - DEV = True - +DEV = 'DEV' in os.listdir('.') ENV = Environment(loader=FileSystemLoader('templates/'), autoescape=True, auto_reload=DEV, line_statement_prefix='#') def cond(c, str, other=''): - if c: - return str - return other + return str if c else other ENV.globals['cond'] = cond ENV.globals['urlquote'] = urlquote def number_format(num, places=0): - """Format a number with grouped thousands and given decimal places""" - - # in utils, commify(n) adds commas to an int - - places = max(0,places) - tmp = "%.*f" % (places, num) - point = tmp.find(".") - integer = (point == -1) and tmp or tmp[:point] - decimal = (point != -1) and tmp[point:] or "" - - count = 0 - formatted = [] - for i in range(len(integer), 0, -1): - count += 1 - formatted.append(integer[i - 1]) - if count % 3 == 0 and i - 1: - formatted.append(",") - - integer = "".join(formatted[::-1]) - return integer+decimal + """Format a number with grouped thousands and given decimal places""" + + # in utils, commify(n) adds commas to an int + + places = max(0,places) + tmp = "%.*f" % (places, num) + point = tmp.find(".") + integer = (point == -1) and tmp or tmp[:point] + decimal = (point != -1) and tmp[point:] or "" + + formatted = [] + for count, i in enumerate(range(len(integer), 0, -1), start=1): + formatted.append(integer[i - 1]) + if count % 3 == 0 and i - 1: + formatted.append(",") + + integer = "".join(formatted[::-1]) + return integer+decimal ENV.filters['number_format'] = number_format @@ -91,11 +84,13 @@ def simple_name(name, sep='-'): def valid_email(email): if len(email) <= 5: # a@b.ca return False - - if not re.match("^.+\\@(\\[?)[a-zA-Z0-9\\-\\.]+\\.([a-zA-Z]{2,3}|[0-9]{1,3})(\\]?)$", email): - return False - - return True + + return bool( + re.match( + "^.+\\@(\\[?)[a-zA-Z0-9\\-\\.]+\\.([a-zA-Z]{2,3}|[0-9]{1,3})(\\]?)$", + email, + ) + ) class RenderHandler(webapp.RequestHandler): @@ -112,5 +107,6 @@ def get_session(): secure, httponly = False, True if DEV: secure = False - session = Session(cookie_name='andrewtrusty.appspot', secure=secure, httponly=httponly) - return session + return Session( + cookie_name='andrewtrusty.appspot', secure=secure, httponly=httponly + ) diff --git a/jinja2/bccache.py b/jinja2/bccache.py index 2c57616..97c876e 100644 --- a/jinja2/bccache.py +++ b/jinja2/bccache.py @@ -14,6 +14,7 @@ :copyright: Copyright 2008 by Armin Ronacher. :license: BSD. """ + from os import path, listdir import marshal import tempfile @@ -28,7 +29,7 @@ bc_version = 1 -bc_magic = 'j2' + pickle.dumps(bc_version, 2) +bc_magic = f'j2{pickle.dumps(bc_version, 2)}' class Bucket(object): @@ -64,10 +65,7 @@ def load_bytecode(self, f): return # now load the code. Because marshal is not able to load # from arbitrary streams we have to work around that - if isinstance(f, file): - self.code = marshal.load(f) - else: - self.code = marshal.loads(f.read()) + self.code = marshal.load(f) if isinstance(f, file) else marshal.loads(f.read()) def write_bytecode(self, f): """Dump the bytecode into the file or file like object passed.""" @@ -146,7 +144,7 @@ def get_cache_key(self, name, filename=None): if filename is not None: if isinstance(filename, unicode): filename = filename.encode('utf-8') - hash.update('|' + filename) + hash.update(f'|{filename}') return hash.hexdigest() def get_source_checksum(self, source): diff --git a/jinja2/compiler.py b/jinja2/compiler.py index 54a80ba..e2b1c7b 100644 --- a/jinja2/compiler.py +++ b/jinja2/compiler.py @@ -53,10 +53,7 @@ def has_safe_repr(value): xrange, Markup)): return True if isinstance(value, (tuple, list, set, frozenset)): - for item in value: - if not has_safe_repr(item): - return False - return True + return all(has_safe_repr(item) for item in value) elif isinstance(value, dict): for key, value in value.iteritems(): if not has_safe_repr(key): @@ -111,17 +108,15 @@ def is_declared(self, name, local_only=False): """Check if a name is declared in this or an outer scope.""" if name in self.declared_locally or name in self.declared_parameter: return True - if local_only: - return False - return name in self.declared + return False if local_only else name in self.declared def find_shadowed(self, extra=()): """Find all the shadowed names. extra is an iterable of variables that may be defined with `add_special` which may occour scoped. """ - return (self.declared | self.outer_undeclared) & \ - (self.declared_locally | self.declared_parameter) | \ - set(x for x in extra if self.is_declared(x)) + return (self.declared | self.outer_undeclared) & ( + self.declared_locally | self.declared_parameter + ) | {x for x in extra if self.is_declared(x)} class Frame(object): @@ -372,14 +367,14 @@ def temporary_identifier(self): def buffer(self, frame): """Enable buffering for the frame from that point onwards.""" frame.buffer = self.temporary_identifier() - self.writeline('%s = []' % frame.buffer) + self.writeline(f'{frame.buffer} = []') def return_buffer_contents(self, frame): """Return the buffer contents of the frame.""" if self.environment.autoescape: - self.writeline('return Markup(concat(%s))' % frame.buffer) + self.writeline(f'return Markup(concat({frame.buffer}))') else: - self.writeline('return concat(%s)' % frame.buffer) + self.writeline(f'return concat({frame.buffer})') def indent(self): """Indent by one.""" @@ -394,7 +389,7 @@ def start_write(self, frame, node=None): if frame.buffer is None: self.writeline('yield ', node) else: - self.writeline('%s.append(' % frame.buffer, node) + self.writeline(f'{frame.buffer}.append(', node) def end_write(self, frame): """End the writing process started by `start_write`.""" @@ -456,14 +451,10 @@ def signature(self, node, frame, extra_kwargs=None): error could occour. The extra keyword arguments should be given as python dict. """ - # if any of the given keyword arguments is a python keyword - # we have to make sure that no invalid call is created. - kwarg_workaround = False - for kwarg in chain((x.key for x in node.kwargs), extra_kwargs or ()): - if is_python_keyword(kwarg): - kwarg_workaround = True - break - + kwarg_workaround = any( + is_python_keyword(kwarg) + for kwarg in chain((x.key for x in node.kwargs), extra_kwargs or ()) + ) for arg in node.args: self.write(', ') self.visit(arg, frame) @@ -474,7 +465,7 @@ def signature(self, node, frame, extra_kwargs=None): self.visit(kwarg, frame) if extra_kwargs is not None: for key, value in extra_kwargs.iteritems(): - self.write(', %s=%s' % (key, value)) + self.write(f', {key}={value}') if node.dyn_args: self.write(', *') self.visit(node.dyn_args, frame) @@ -533,24 +524,24 @@ def push_scope(self, frame, extra_vars=()): aliases = {} for name in frame.identifiers.find_shadowed(extra_vars): aliases[name] = ident = self.temporary_identifier() - self.writeline('%s = l_%s' % (ident, name)) - to_declare = set() - for name in frame.identifiers.declared_locally: - if name not in aliases: - to_declare.add('l_' + name) - if to_declare: + self.writeline(f'{ident} = l_{name}') + if to_declare := { + f'l_{name}' + for name in frame.identifiers.declared_locally + if name not in aliases + }: self.writeline(' = '.join(to_declare) + ' = missing') return aliases def pop_scope(self, aliases, frame): """Restore all aliases and delete unused variables.""" for name, alias in aliases.iteritems(): - self.writeline('l_%s = %s' % (name, alias)) - to_delete = set() - for name in frame.identifiers.declared_locally: - if name not in aliases: - to_delete.add('l_' + name) - if to_delete: + self.writeline(f'l_{name} = {alias}') + if to_delete := { + f'l_{name}' + for name in frame.identifiers.declared_locally + if name not in aliases + }: self.writeline('del ' + ', '.join(to_delete)) def function_scoping(self, node, frame, children=None, @@ -573,17 +564,14 @@ def function_scoping(self, node, frame, children=None, func_frame = frame.inner() func_frame.inspect(children, hard_scope=True) - # variables that are undeclared (accessed before declaration) and - # declared locally *and* part of an outside scope raise a template - # assertion error. Reason: we can't generate reasonable code from - # it without aliasing all the variables. XXX: alias them ^^ - overriden_closure_vars = ( - func_frame.identifiers.undeclared & - func_frame.identifiers.declared & - (func_frame.identifiers.declared_locally | - func_frame.identifiers.declared_parameter) - ) - if overriden_closure_vars: + if overriden_closure_vars := ( + func_frame.identifiers.undeclared + & func_frame.identifiers.declared + & ( + func_frame.identifiers.declared_locally + | func_frame.identifiers.declared_parameter + ) + ): self.fail('It\'s not possible to set and access variables ' 'derived from an outer scope! (affects: %s' % ', '.join(sorted(overriden_closure_vars)), node.lineno) @@ -602,7 +590,7 @@ def function_scoping(self, node, frame, children=None, func_frame.accesses_kwargs = False func_frame.accesses_varargs = False func_frame.accesses_caller = False - func_frame.arguments = args = ['l_' + x.name for x in node.args] + func_frame.arguments = args = [f'l_{x.name}' for x in node.args] undeclared = find_undeclared(children, ('caller', 'kwargs', 'varargs')) @@ -626,7 +614,7 @@ def macro_body(self, node, frame, children=None): # macros are delayed, they never require output checks frame.require_output_check = False args = frame.arguments - self.writeline('def macro(%s):' % ', '.join(args), node) + self.writeline(f"def macro({', '.join(args)}):", node) self.indent() self.buffer(frame) self.pull_locals(frame) @@ -656,7 +644,7 @@ def position(self, node): """Return a human readable position for the node.""" rv = 'line %d' % node.lineno if self.name is not None: - rv += ' in ' + repr(self.name) + rv += f' in {repr(self.name)}' return rv # -- Statement Visitors @@ -684,10 +672,9 @@ def visit_Template(self, node, frame=None): self.import_aliases[imp] = alias = self.temporary_identifier() if '.' in imp: module, obj = imp.rsplit('.', 1) - self.writeline('from %s import %s as %s' % - (module, obj, alias)) + self.writeline(f'from {module} import {obj} as {alias}') else: - self.writeline('import %s as %s' % (imp, alias)) + self.writeline(f'import {imp} as {alias}') # add the load name self.writeline('name = %r' % self.name) @@ -728,8 +715,9 @@ def visit_Template(self, node, frame=None): block_frame = Frame() block_frame.inspect(block.body) block_frame.block = name - self.writeline('def block_%s(context, environment=environment):' - % name, block, 1) + self.writeline( + f'def block_{name}(context, environment=environment):', block, 1 + ) self.indent() undeclared = find_undeclared(block.body, ('self', 'super')) if 'self' in undeclared: @@ -836,7 +824,7 @@ def visit_Include(self, node, frame): def visit_Import(self, node, frame): """Visit regular imports.""" - self.writeline('l_%s = ' % node.target, node) + self.writeline(f'l_{node.target} = ', node) if frame.toplevel: self.write('context.vars[%r] = ' % node.target) self.write('environment.get_template(') @@ -860,8 +848,8 @@ def visit_FromImport(self, node, frame): else: self.write('module') - var_names = [] discarded_names = [] + var_names = [] for name in node.names: if isinstance(name, tuple): name, alias = name @@ -869,7 +857,7 @@ def visit_FromImport(self, node, frame): alias = name self.writeline('l_%s = getattr(included_template, ' '%r, missing)' % (alias, name)) - self.writeline('if l_%s is missing:' % alias) + self.writeline(f'if l_{alias} is missing:') self.indent() self.writeline('l_%s = environment.undefined(%r %% ' 'included_template.__name__, ' @@ -898,8 +886,9 @@ def visit_FromImport(self, node, frame): self.writeline('context.exported_vars.discard(%r)' % discarded_names[0]) else: - self.writeline('context.exported_vars.difference_' - 'update((%s))' % ', '.join(map(repr, discarded_names))) + self.writeline( + f"context.exported_vars.difference_update(({', '.join(map(repr, discarded_names))}))" + ) def visit_For(self, node, frame): # when calculating the nodes for the inner frame we have to exclude @@ -916,7 +905,7 @@ def visit_For(self, node, frame): # is necessary if the loop is in recursive mode if the special loop # variable is accessed in the body. extended_loop = node.recursive or 'loop' in \ - find_undeclared(node.iter_child_nodes( + find_undeclared(node.iter_child_nodes( only=('body',)), ('loop',)) # if we don't have an recursive loop we have to find the shadowed @@ -944,7 +933,7 @@ def visit_For(self, node, frame): self.pull_locals(loop_frame) if node.else_: iteration_indicator = self.temporary_identifier() - self.writeline('%s = 1' % iteration_indicator) + self.writeline(f'{iteration_indicator} = 1') # Create a fake parent loop if the else or test section of a # loop is accessing the special loop variable and no parent loop @@ -1002,11 +991,11 @@ def visit_For(self, node, frame): self.indent() self.blockvisit(node.body, loop_frame) if node.else_: - self.writeline('%s = 0' % iteration_indicator) + self.writeline(f'{iteration_indicator} = 0') self.outdent() if node.else_: - self.writeline('if %s:' % iteration_indicator) + self.writeline(f'if {iteration_indicator}:') self.indent() self.blockvisit(node.else_, loop_frame) self.outdent() @@ -1047,7 +1036,7 @@ def visit_Macro(self, node, frame): if not node.name.startswith('_'): self.write('context.exported_vars.add(%r)' % node.name) self.writeline('context.vars[%r] = ' % node.name) - self.write('l_%s = ' % node.name) + self.write(f'l_{node.name} = ') self.macro_def(node, macro_frame) def visit_CallBlock(self, node, frame): @@ -1106,10 +1095,7 @@ def visit_Output(self, node, frame): continue try: if self.environment.autoescape: - if hasattr(const, '__html__'): - const = const.__html__() - else: - const = escape(const) + const = const.__html__() if hasattr(const, '__html__') else escape(const) const = finalize(const) except: # if something goes wrong here we evaluate the node @@ -1134,9 +1120,9 @@ def visit_Output(self, node, frame): if isinstance(item, list): val = repr(concat(item)) if frame.buffer is None: - self.writeline('yield ' + val) + self.writeline(f'yield {val}') else: - self.writeline(val + ', ') + self.writeline(f'{val}, ') else: if frame.buffer is None: self.writeline('yield ', item) @@ -1159,7 +1145,6 @@ def visit_Output(self, node, frame): self.outdent() self.writeline(len(body) == 1 and ')' or '))') - # otherwise we create a format string as this is faster in that case else: format = [] arguments = [] @@ -1224,15 +1209,16 @@ def visit_Assign(self, node, frame): self.writeline('context.exported_vars.add(%r)' % public_names[0]) else: - self.writeline('context.exported_vars.update((%s))' % - ', '.join(map(repr, public_names))) + self.writeline( + f"context.exported_vars.update(({', '.join(map(repr, public_names))}))" + ) # -- Expression Visitors def visit_Name(self, node, frame): if node.ctx == 'store' and frame.toplevel: frame.assigned_names.add(node.name) - self.write('l_' + node.name) + self.write(f'l_{node.name}') def visit_Const(self, node, frame): val = node.value @@ -1275,16 +1261,18 @@ def binop(operator): def visitor(self, node, frame): self.write('(') self.visit(node.left, frame) - self.write(' %s ' % operator) + self.write(f' {operator} ') self.visit(node.right, frame) self.write(')') + return visitor def uaop(operator): def visitor(self, node, frame): - self.write('(' + operator) + self.write(f'({operator}') self.visit(node.node, frame) self.write(')') + return visitor visit_Add = binop('+') @@ -1315,7 +1303,7 @@ def visit_Compare(self, node, frame): self.visit(op, frame) def visit_Operand(self, node, frame): - self.write(' %s ' % operators[node.op]) + self.write(f' {operators[node.op]} ') self.visit(node.expr, frame) def visit_Getattr(self, node, frame): @@ -1348,7 +1336,7 @@ def visit_Slice(self, node, frame): self.visit(node.step, frame) def visit_Filter(self, node, frame): - self.write(self.filters[node.name] + '(') + self.write(f'{self.filters[node.name]}(') func = self.environment.filters.get(node.name) if func is None: self.fail('no filter named %r' % node.name, node.lineno) @@ -1362,14 +1350,14 @@ def visit_Filter(self, node, frame): if node.node is not None: self.visit(node.node, frame) elif self.environment.autoescape: - self.write('Markup(concat(%s))' % frame.buffer) + self.write(f'Markup(concat({frame.buffer}))') else: - self.write('concat(%s)' % frame.buffer) + self.write(f'concat({frame.buffer})') self.signature(node, frame) self.write(')') def visit_Test(self, node, frame): - self.write(self.tests[node.name] + '(') + self.write(f'{self.tests[node.name]}(') if node.name not in self.environment.tests: self.fail('no test named %r' % node.name, node.lineno) self.visit(node.node, frame) @@ -1407,12 +1395,12 @@ def visit_Call(self, node, frame, forward_caller=False): else: self.write('context.call(') self.visit(node.node, frame) - extra_kwargs = forward_caller and {'caller': 'caller'} or None + extra_kwargs = {'caller': 'caller'} if forward_caller else None self.signature(node, frame, extra_kwargs) self.write(')') def visit_Keyword(self, node, frame): - self.write(node.key + '=') + self.write(f'{node.key}=') self.visit(node.value, frame) # -- Unused nodes for extensions @@ -1423,7 +1411,7 @@ def visit_MarkSafe(self, node, frame): self.write(')') def visit_EnvironmentAttribute(self, node, frame): - self.write('environment.' + node.name) + self.write(f'environment.{node.name}') def visit_ExtensionAttribute(self, node, frame): self.write('environment.extensions[%r].%s' % (node.identifier, node.name)) diff --git a/jinja2/environment.py b/jinja2/environment.py index 9d43339..3473969 100644 --- a/jinja2/environment.py +++ b/jinja2/environment.py @@ -45,9 +45,7 @@ def create_cache(size): """Return the cache class for the given size.""" if size == 0: return None - if size < 0: - return {} - return LRUCache(size) + return {} if size < 0 else LRUCache(size) def copy_cache(cache): @@ -303,7 +301,7 @@ def overlay(self, block_start_string=missing, block_end_string=missing, for key, value in self.extensions.iteritems(): rv.extensions[key] = value.bind(rv) if extensions is not missing: - rv.extensions.update(load_extensions(extensions)) + rv.extensions |= load_extensions(extensions) return _environment_sanity_check(rv) @@ -511,9 +509,7 @@ def from_string(self, source, globals=None, template_class=None): def make_globals(self, d): """Return a dict for the globals.""" - if not d: - return self.globals - return dict(self.globals, **d) + return dict(self.globals, **d) if d else self.globals class Template(object): @@ -631,8 +627,7 @@ def generate(self, *args, **kwargs): """ vars = dict(*args, **kwargs) try: - for event in self.root_render_func(self.new_context(vars)): - yield event + yield from self.root_render_func(self.new_context(vars)) except: from jinja2.debug import translate_exception exc_type, exc_value, tb = translate_exception(sys.exc_info()) @@ -648,10 +643,7 @@ def new_context(self, vars=None, shared=False, locals=None): """ if vars is None: vars = {} - if shared: - parent = vars - else: - parent = dict(self.globals, **vars) + parent = vars if shared else dict(self.globals, **vars) if locals: # if the parent is shared a copy should be created because # we don't want to modify the dict passed @@ -692,17 +684,19 @@ def get_corresponding_lineno(self, lineno): """Return the source line number of a line number in the generated bytecode as they are not in sync. """ - for template_line, code_line in reversed(self.debug_info): - if code_line <= lineno: - return template_line - return 1 + return next( + ( + template_line + for template_line, code_line in reversed(self.debug_info) + if code_line <= lineno + ), + 1, + ) @property def is_up_to_date(self): """If this variable is `False` there is a newer version available.""" - if self._uptodate is None: - return True - return self._uptodate() + return True if self._uptodate is None else self._uptodate() @property def debug_info(self): @@ -711,11 +705,8 @@ def debug_info(self): self._debug_info.split('&')] def __repr__(self): - if self.name is None: - name = 'memory:%x' % id(self) - else: - name = repr(self.name) - return '<%s %s>' % (self.__class__.__name__, name) + name = 'memory:%x' % id(self) if self.name is None else repr(self.name) + return f'<{self.__class__.__name__} {name}>' class TemplateModule(object): @@ -736,11 +727,8 @@ def __str__(self): return unicode(self).encode('utf-8') def __repr__(self): - if self.__name__ is None: - name = 'memory:%x' % id(self) - else: - name = repr(self.__name__) - return '<%s %s>' % (self.__class__.__name__, name) + name = 'memory:%x' % id(self) if self.__name__ is None else repr(self.__name__) + return f'<{self.__class__.__name__} {name}>' class TemplateExpression(object): diff --git a/jinja2/exceptions.py b/jinja2/exceptions.py index 5bfca66..1e9356d 100644 --- a/jinja2/exceptions.py +++ b/jinja2/exceptions.py @@ -37,10 +37,9 @@ def __init__(self, message, lineno, name=None, filename=None): def __unicode__(self): location = 'line %d' % self.lineno - name = self.filename or self.name - if name: - location = 'File "%s", %s' % (name, location) - lines = [self.message, ' ' + location] + if name := self.filename or self.name: + location = f'File "{name}", {location}' + lines = [self.message, f' {location}'] # if the source is set, add the line to the output if self.source is not None: @@ -49,7 +48,7 @@ def __unicode__(self): except IndexError: line = None if line: - lines.append(' ' + line.strip()) + lines.append(f' {line.strip()}') return u'\n'.join(lines) diff --git a/jinja2/ext.py b/jinja2/ext.py index 353f265..5307732 100644 --- a/jinja2/ext.py +++ b/jinja2/ext.py @@ -30,7 +30,7 @@ class ExtensionRegistry(type): def __new__(cls, name, bases, d): rv = type.__new__(cls, name, bases, d) - rv.identifier = rv.__module__ + '.' + rv.__name__ + rv.identifier = f'{rv.__module__}.{rv.__name__}' return rv @@ -381,8 +381,7 @@ def extract_from_ast(node, gettext_functions=GETTEXT_FUNCTIONS, else: strings.append(None) - for arg in node.kwargs: - strings.append(None) + strings.extend(None for _ in node.kwargs) if node.dyn_args is not None: strings.append(None) if node.dyn_kwargs is not None: @@ -393,10 +392,7 @@ def extract_from_ast(node, gettext_functions=GETTEXT_FUNCTIONS, if not strings: continue else: - if len(strings) == 1: - strings = strings[0] - else: - strings = tuple(strings) + strings = strings[0] if len(strings) == 1 else tuple(strings) yield node.lineno, node.node.name, strings diff --git a/jinja2/filters.py b/jinja2/filters.py index afa7667..ff3c606 100644 --- a/jinja2/filters.py +++ b/jinja2/filters.py @@ -111,12 +111,12 @@ def do_xmlattr(_environment, d, autospace=True): if the filter returned something unless the second parameter is false. """ rv = u' '.join( - u'%s="%s"' % (escape(key), escape(value)) + f'{escape(key)}="{escape(value)}"' for key, value in d.iteritems() if value is not None and not isinstance(value, Undefined) ) if autospace and rv: - rv = u' ' + rv + rv = f' {rv}' if _environment.autoescape: rv = Markup(rv) return rv @@ -240,10 +240,7 @@ def do_join(environment, value, d=u''): do_escape = True else: value[idx] = unicode(item) - if do_escape: - d = escape(d) - else: - d = unicode(d) + d = escape(d) if do_escape else unicode(d) return d.join(value) # no html involved, to normal joining @@ -289,15 +286,15 @@ def do_filesizeformat(value, binary=False): prefixes are (mebi, gibi). """ bytes = float(value) - base = binary and 1024 or 1000 - middle = binary and 'i' or '' + base = 1024 if binary else 1000 + middle = 'i' if binary else '' if bytes < base: return "%d Byte%s" % (bytes, bytes != 1 and 's' or '') - elif bytes < base * base: + elif bytes < base**2: return "%.1f K%sB" % (bytes / base, middle) - elif bytes < base * base * base: - return "%.1f M%sB" % (bytes / (base * base), middle) - return "%.1f G%sB" % (bytes / (base * base * base), middle) + elif bytes < base**2 * base: + return "%.1f M%sB" % (bytes / base**2, middle) + return "%.1f G%sB" % (bytes / (base**2 * base), middle) def do_pprint(value, verbose=False): @@ -535,7 +532,7 @@ def do_round(value, precision=0, method='common'): {{ 42.55|round(1, 'floor') }} -> 42.5 """ - if not method in ('common', 'ceil', 'floor'): + if method not in ('common', 'ceil', 'floor'): raise FilterArgumentError('method must be common, ceil or floor') if precision < 0: raise FilterArgumentError('precision must be a postive integer ' diff --git a/jinja2/lexer.py b/jinja2/lexer.py index 6b26983..613873c 100644 --- a/jinja2/lexer.py +++ b/jinja2/lexer.py @@ -120,10 +120,7 @@ def test(self, expr): def test_any(self, *iterable): """Test against multiple token expressions.""" - for expr in iterable: - if self.test(expr): - return True - return False + return any(self.test(expr) for expr in iterable) def __repr__(self): return 'Token(%r, %r, %r)' % ( @@ -191,7 +188,7 @@ def look(self): def skip(self, n=1): """Got n tokens ahead.""" - for x in xrange(n): + for _ in xrange(n): self.next() def next_if(self, expr): @@ -310,7 +307,7 @@ def __init__(self, environment): root_tag_rules.insert(0, ('linestatement', '^\s*' + prefix)) # block suffix if trimming is enabled - block_suffix_re = environment.trim_blocks and '\\n?' or '' + block_suffix_re = '\\n?' if environment.trim_blocks else '' self.newline_sequence = environment.newline_sequence @@ -431,12 +428,11 @@ def tokeniter(self, source, name, filename=None, state=None): generator. Use this method if you just want to tokenize a template. """ source = '\n'.join(unicode(source).splitlines()) - pos = 0 lineno = 1 stack = ['root'] if state is not None and state != 'root': assert state in ('variable', 'block'), 'invalid state' - stack.append(state + '_begin') + stack.append(f'{state}_begin') else: state = 'root' statetokens = self.rules[stack[-1]] @@ -444,6 +440,7 @@ def tokeniter(self, source, name, filename=None, state=None): balancing_stack = [] + pos = 0 while 1: # tokenizer loop for regex, tokens, new_state in statetokens: @@ -457,7 +454,7 @@ def tokeniter(self, source, name, filename=None, state=None): # is the operator rule. do this only if the end tags look # like operators if balancing_stack and \ - tokens in ('variable_end', 'block_end', + tokens in ('variable_end', 'block_end', 'linestatement_end'): continue @@ -488,7 +485,6 @@ def tokeniter(self, source, name, filename=None, state=None): yield lineno, token, data lineno += data.count('\n') - # strings as token just are yielded as it. else: data = m.group() # update brace/parentheses balance @@ -501,16 +497,15 @@ def tokeniter(self, source, name, filename=None, state=None): balancing_stack.append(']') elif data in ('}', ')', ']'): if not balancing_stack: - raise TemplateSyntaxError('unexpected "%s"' % - data, lineno, name, - filename) + raise TemplateSyntaxError(f'unexpected "{data}"', lineno, name, filename) expected_op = balancing_stack.pop() if expected_op != data: - raise TemplateSyntaxError('unexpected "%s", ' - 'expected "%s"' % - (data, expected_op), - lineno, name, - filename) + raise TemplateSyntaxError( + f'unexpected "{data}", expected "{expected_op}"', + lineno, + name, + filename, + ) # yield items yield lineno, tokens, data lineno += data.count('\n') @@ -549,8 +544,6 @@ def tokeniter(self, source, name, filename=None, state=None): # publish new function and start again pos = pos2 break - # if loop terminated without break we havn't found a single match - # either we are at the end of the file or we have a problem else: # end of text if pos >= source_length: diff --git a/jinja2/nodes.py b/jinja2/nodes.py index 405622a..c6dc1af 100644 --- a/jinja2/nodes.py +++ b/jinja2/nodes.py @@ -57,8 +57,7 @@ class NodeType(type): def __new__(cls, name, bases, d): for attr in 'fields', 'attributes': - storage = [] - storage.extend(getattr(bases[0], attr, ())) + storage = list(getattr(bases[0], attr, ())) storage.extend(d.get(attr, ())) assert len(bases) == 1, 'multiple inheritance not allowed' assert len(storage) == len(set(storage)), 'layout conflict' @@ -150,8 +149,7 @@ def find_all(self, node_type): for child in self.iter_child_nodes(): if isinstance(child, node_type): yield child - for result in child.find_all(node_type): - yield result + yield from child.find_all(node_type) def set_ctx(self, ctx): """Reset the context of a node and all child nodes. Per default the @@ -172,9 +170,8 @@ def set_lineno(self, lineno, override=False): todo = deque([self]) while todo: node = todo.popleft() - if 'lineno' in node.attributes: - if node.lineno is None or override: - node.lineno = lineno + if 'lineno' in node.attributes and (node.lineno is None or override): + node.lineno = lineno todo.extend(node.iter_child_nodes()) return self @@ -407,9 +404,7 @@ class TemplateData(Literal): fields = ('data',) def as_const(self): - if self.environment.autoescape: - return Markup(self.data) - return self.data + return Markup(self.data) if self.environment.autoescape else self.data class Tuple(Literal): @@ -423,10 +418,7 @@ def as_const(self): return tuple(x.as_const() for x in self.items) def can_assign(self): - for item in self.items: - if not item.can_assign(): - return False - return True + return all(item.can_assign() for item in self.items) class List(Literal): @@ -508,7 +500,7 @@ def as_const(self, obj=None): raise Impossible() if self.dyn_kwargs is not None: try: - kwargs.update(self.dyn_kwargs.as_const()) + kwargs |= self.dyn_kwargs.as_const() except: raise Impossible() try: @@ -551,7 +543,7 @@ def as_const(self): raise Impossible() if self.dyn_kwargs is not None: try: - kwargs.update(self.dyn_kwargs.as_const()) + kwargs |= self.dyn_kwargs.as_const() except: raise Impossible() try: @@ -603,9 +595,8 @@ class Slice(Expr): def as_const(self): def const(obj): - if obj is None: - return obj - return obj.as_const() + return obj if obj is None else obj.as_const() + return slice(const(self.start), const(self.stop), const(self.step)) @@ -642,9 +633,14 @@ class Operand(Helper): fields = ('op', 'expr') if __debug__: - Operand.__doc__ += '\nThe following operators are available: ' + \ - ', '.join(sorted('``%s``' % x for x in set(_binop_to_func) | - set(_uaop_to_func) | set(_cmpop_to_func))) + Operand.__doc__ += '\nThe following operators are available: ' + ', '.join( + sorted( + f'``{x}``' + for x in set(_binop_to_func) + | set(_uaop_to_func) + | set(_cmpop_to_func) + ) + ) class Mul(BinExpr): diff --git a/jinja2/optimizer.py b/jinja2/optimizer.py index 43065df..109767b 100644 --- a/jinja2/optimizer.py +++ b/jinja2/optimizer.py @@ -42,10 +42,7 @@ def visit_If(self, node): val = self.visit(node.test).as_const() except nodes.Impossible: return self.generic_visit(node) - if val: - body = node.body - else: - body = node.else_ + body = node.body if val else node.else_ result = [] for node in body: result.extend(self.visit_list(node)) diff --git a/jinja2/parser.py b/jinja2/parser.py index d6f1b36..4a280f9 100644 --- a/jinja2/parser.py +++ b/jinja2/parser.py @@ -66,7 +66,7 @@ def parse_statement(self): if token.type is not 'name': self.fail('tag name expected', token.lineno) if token.value in _statement_keywords: - return getattr(self, 'parse_' + self.stream.current.value)() + return getattr(self, f'parse_{self.stream.current.value}')() if token.value == 'call': return self.parse_call_block() if token.value == 'filter': @@ -113,9 +113,7 @@ def parse_for(self): self.stream.expect('name:in') iter = self.parse_tuple(with_condexpr=False, extra_end_rules=('name:recursive',)) - test = None - if self.stream.skip_if('name:if'): - test = self.parse_expression() + test = self.parse_expression() if self.stream.skip_if('name:if') else None recursive = self.stream.skip_if('name:recursive') body = self.parse_statements(('name:endfor', 'name:else')) if self.stream.next().value == 'endfor': @@ -150,7 +148,7 @@ def parse_block(self): node = nodes.Block(lineno=self.stream.next().lineno) node.name = self.stream.expect('name').value node.body = self.parse_statements(('name:endblock',), drop_needle=True) - self.stream.skip_if('name:' + node.name) + self.stream.skip_if(f'name:{node.name}') return node def parse_extends(self): @@ -299,19 +297,14 @@ def parse_expression(self, with_condexpr=True): the optional `with_condexpr` parameter is set to `False` conditional expressions are not parsed. """ - if with_condexpr: - return self.parse_condexpr() - return self.parse_or() + return self.parse_condexpr() if with_condexpr else self.parse_or() def parse_condexpr(self): lineno = self.stream.current.lineno expr1 = self.parse_or() while self.stream.skip_if('name:if'): expr2 = self.parse_or() - if self.stream.skip_if('name:else'): - expr3 = self.parse_condexpr() - else: - expr3 = None + expr3 = self.parse_condexpr() if self.stream.skip_if('name:else') else None expr1 = nodes.CondExpr(expr2, expr1, expr3, lineno=lineno) lineno = self.stream.current.lineno return expr1 @@ -346,15 +339,13 @@ def parse_compare(self): elif self.stream.skip_if('name:in'): ops.append(nodes.Operand('in', self.parse_add())) elif self.stream.current.test('name:not') and \ - self.stream.look().test('name:in'): + self.stream.look().test('name:in'): self.stream.skip(2) ops.append(nodes.Operand('notin', self.parse_add())) else: break lineno = self.stream.current.lineno - if not ops: - return expr - return nodes.Compare(expr, ops, lineno=lineno) + return nodes.Compare(expr, ops, lineno=lineno) if ops else expr def parse_add(self): lineno = self.stream.current.lineno @@ -382,9 +373,7 @@ def parse_concat(self): while self.stream.current.type is 'tilde': self.stream.next() args.append(self.parse_mul()) - if len(args) == 1: - return args[0] - return nodes.Concat(args, lineno=lineno) + return args[0] if len(args) == 1 else nodes.Concat(args, lineno=lineno) def parse_mul(self): lineno = self.stream.current.lineno @@ -484,7 +473,7 @@ def parse_primary(self, with_postfix=True): elif token.type is 'lbrace': node = self.parse_dict() else: - self.fail("unexpected token '%s'" % (token,), token.lineno) + self.fail(f"unexpected token '{token}'", token.lineno) if with_postfix: node = self.parse_postfix(node) return node diff --git a/jinja2/runtime.py b/jinja2/runtime.py index 2ed3ac6..b90c7e3 100644 --- a/jinja2/runtime.py +++ b/jinja2/runtime.py @@ -73,7 +73,7 @@ def __init__(self, environment, parent, name, blocks): # create the initial mapping of blocks. Whenever template inheritance # takes place the runtime will update this mapping with the new blocks # from the template. - self.blocks = dict((k, [v]) for k, v in blocks.iteritems()) + self.blocks = {k: [v] for k, v in blocks.iteritems()} def super(self, name, current): """Render a parent block.""" @@ -108,7 +108,7 @@ def resolve(self, key): def get_exported(self): """Get a new dict with the exported variables.""" - return dict((k, self.vars[k]) for k in self.exported_vars) + return {k: self.vars[k] for k in self.exported_vars} def get_all(self): """Return a copy of the complete context as dict including the @@ -352,10 +352,7 @@ def __call__(self, *args, **kwargs): return self._func(*arguments) def __repr__(self): - return '<%s %s>' % ( - self.__class__.__name__, - self.name is None and 'anonymous' or repr(self.name) - ) + return f"<{self.__class__.__name__} {self.name is None and 'anonymous' or repr(self.name)}>" class Undefined(object): @@ -419,8 +416,7 @@ def __len__(self): return 0 def __iter__(self): - if 0: - yield None + pass def __nonzero__(self): return False diff --git a/jinja2/sandbox.py b/jinja2/sandbox.py index 7b28273..b3e2a17 100644 --- a/jinja2/sandbox.py +++ b/jinja2/sandbox.py @@ -12,6 +12,7 @@ :copyright: Copyright 2008 by Armin Ronacher. :license: BSD. """ + import operator from jinja2.runtime import Undefined from jinja2.environment import Environment @@ -24,11 +25,16 @@ MAX_RANGE = 100000 #: attributes of function objects that are considered unsafe. -UNSAFE_FUNCTION_ATTRIBUTES = set(['func_closure', 'func_code', 'func_dict', - 'func_defaults', 'func_globals']) +UNSAFE_FUNCTION_ATTRIBUTES = { + 'func_closure', + 'func_code', + 'func_dict', + 'func_defaults', + 'func_globals', +} #: unsafe method attributes. function attributes are unsafe for methods too -UNSAFE_METHOD_ATTRIBUTES = set(['im_class', 'im_func', 'im_self']) +UNSAFE_METHOD_ATTRIBUTES = {'im_class', 'im_func', 'im_self'} from collections import deque @@ -143,10 +149,14 @@ def modifies_known_mutable(obj, attr): >>> modifies_known_mutable("foo", "upper") False """ - for typespec, unsafe in _mutable_spec: - if isinstance(obj, typespec): - return attr in unsafe - return False + return next( + ( + attr in unsafe + for typespec, unsafe in _mutable_spec + if isinstance(obj, typespec) + ), + False, + ) class SandboxedEnvironment(Environment): @@ -245,6 +255,8 @@ class ImmutableSandboxedEnvironment(SandboxedEnvironment): """ def is_safe_attribute(self, obj, attr, value): - if not SandboxedEnvironment.is_safe_attribute(self, obj, attr, value): - return False - return not modifies_known_mutable(obj, attr) + return ( + not modifies_known_mutable(obj, attr) + if SandboxedEnvironment.is_safe_attribute(self, obj, attr, value) + else False + ) diff --git a/jinja2/utils.py b/jinja2/utils.py index 45be2c2..634fb69 100644 --- a/jinja2/utils.py +++ b/jinja2/utils.py @@ -43,7 +43,6 @@ try: def _test_gen_bug(): raise TypeError(_test_gen_bug) - yield None _concat(_test_gen_bug()) except TypeError, _error: if not _error.args or _error.args[0] is not _test_gen_bug: @@ -138,8 +137,6 @@ def default(var, default=''): def consume(iterable): """Consumes an iterable without doing anything with it.""" - for event in iterable: - pass def clear_caches(): @@ -219,10 +216,9 @@ def urlize(text, trim_url_limit=None, nofollow=False): and (x[:limit] + (len(x) >=limit and '...' or '')) or x words = _word_split_re.split(unicode(escape(text))) - nofollow_attr = nofollow and ' rel="nofollow"' or '' + nofollow_attr = ' rel="nofollow"' if nofollow else '' for i, word in enumerate(words): - match = _punctuation_re.match(word) - if match: + if match := _punctuation_re.match(word): lead, middle, trail = match.groups() if middle.startswith('www.') or ( '@' not in middle and @@ -233,15 +229,17 @@ def urlize(text, trim_url_limit=None, nofollow=False): middle.endswith('.net') or middle.endswith('.com') )): - middle = '%s' % (middle, - nofollow_attr, trim_url(middle)) + middle = f'{trim_url(middle)}' if middle.startswith('http://') or \ middle.startswith('https://'): - middle = '%s' % (middle, - nofollow_attr, trim_url(middle)) - if '@' in middle and not middle.startswith('www.') and \ - not ':' in middle and _simple_email_re.match(middle): - middle = '%s' % (middle, middle) + middle = f'{trim_url(middle)}' + if ( + '@' in middle + and not middle.startswith('www.') + and ':' not in middle + and _simple_email_re.match(middle) + ): + middle = f'{middle}' if lead + middle + trail != word: words[i] = lead + middle + trail return u''.join(words) @@ -286,14 +284,16 @@ def generate_lorem_ipsum(n=5, html=True, min=20, max=100): # ensure that the paragraph ends with a dot. p = u' '.join(p) if p.endswith(','): - p = p[:-1] + '.' + p = f'{p[:-1]}.' elif not p.endswith('.'): p += '.' result.append(p) - if not html: - return u'\n\n'.join(result) - return Markup(u'\n'.join(u'

%s

' % escape(x) for x in result)) + return ( + Markup(u'\n'.join(f'

{escape(x)}

' for x in result)) + if html + else u'\n\n'.join(result) + ) class Markup(unicode): @@ -376,10 +376,7 @@ def __mod__(self, arg): return self.__class__(unicode.__mod__(self, arg)) def __repr__(self): - return '%s(%s)' % ( - self.__class__.__name__, - unicode.__repr__(self) - ) + return f'{self.__class__.__name__}({unicode.__repr__(self)})' def join(self, seq): return self.__class__(unicode.join(self, imap(escape, seq))) @@ -437,9 +434,7 @@ def escape(cls, s): correct subclass. """ rv = escape(s) - if rv.__class__ is not cls: - return cls(rv) - return rv + return cls(rv) if rv.__class__ is not cls else rv def make_wrapper(name): orig = getattr(unicode, name) @@ -686,9 +681,8 @@ def current(self): def next(self): """Goes one item ahead and returns it.""" - rv = self.current self.pos = (self.pos + 1) % len(self.items) - return rv + return self.current class Joiner(object): @@ -744,5 +738,5 @@ def __init__(self, _func, *args, **kwargs): self._args = args self._kwargs = kwargs def __call__(self, *args, **kwargs): - kwargs.update(self._kwargs) + kwargs |= self._kwargs return self._func(*(self._args + args), **kwargs) diff --git a/jinja2/visitor.py b/jinja2/visitor.py index ad11108..9ec3279 100644 --- a/jinja2/visitor.py +++ b/jinja2/visitor.py @@ -28,7 +28,7 @@ def get_visitor(self, node): exists for this node. In that case the generic visit function is used instead. """ - method = 'visit_' + node.__class__.__name__ + method = f'visit_{node.__class__.__name__}' return getattr(self, method, None) def visit(self, node, *args, **kwargs): diff --git a/readability/__init__.py b/readability/__init__.py index 258e708..296ce65 100644 --- a/readability/__init__.py +++ b/readability/__init__.py @@ -32,7 +32,7 @@ class Feed(db.Model): title = db.StringProperty(default='') def feed_url(self): - return '/readability/feed?url=%s' % urlquote(self.url) + return f'/readability/feed?url={urlquote(self.url)}' class ReadabilityHandler(RenderHandler): diff --git a/readability/hn.py b/readability/hn.py index a9634cc..85e47a3 100644 --- a/readability/hn.py +++ b/readability/hn.py @@ -49,56 +49,56 @@ def grabContent(link, html): replaceBrs = re.compile("
[ \r\n]*
") html = re.sub(replaceBrs, "

", html) - + try: soup = BeautifulSoup(html) except HTMLParser.HTMLParseError: return u"" - + # REMOVE SCRIPTS for s in soup.findAll("script"): s.extract() - + allParagraphs = soup.findAll("p") topParent = None - + parents = [] for paragraph in allParagraphs: parent = paragraph.parent - + if (parent not in parents): parents.append(parent) parent.score = 0 - + if (parent.has_key("class")): if (NEGATIVE.match(parent["class"])): parent.score -= 50 if (POSITIVE.match(parent["class"])): parent.score += 25 - + if (parent.has_key("id")): if (NEGATIVE.match(parent["id"])): parent.score -= 50 if (POSITIVE.match(parent["id"])): parent.score += 25 - if (parent.score == None): + if parent.score is None: parent.score = 0 - + innerText = paragraph.renderContents() #"".join(paragraph.findAll(text=True)) if (len(innerText) > 10): parent.score += 1 - + parent.score += innerText.count(",") - + for parent in parents: if ((not topParent) or (parent.score > topParent.score)): topParent = parent if (not topParent): return u"" - + # REMOVE LINK'D STYLES styleLinks = soup.findAll("link", attrs={"type" : "text/css"}) for s in styleLinks: @@ -112,14 +112,14 @@ def grabContent(link, html): for ele in topParent.findAll(True): del(ele['style']) del(ele['class']) - + killDivs(topParent) clean(topParent, "form") clean(topParent, "object") clean(topParent, "iframe") - + fixLinks(topParent, link) - + return topParent.renderContents().decode('utf-8') @@ -152,60 +152,52 @@ def killDivs(parent): embed = len(d.findAll("embed")) pre = len(d.findAll("pre")) code = len(d.findAll("code")) - - if (d.renderContents().count(",") < 10): - if ((pre == 0) and (code == 0)): - if ((img > p ) or (li > p) or (a > p) or (p == 0) or (embed > 0)): - d.extract() + + if ( + (d.renderContents().count(",") < 10) + and ((pre == 0) and (code == 0)) + and ((img > p) or (li > p) or (a > p) or (p == 0) or (embed > 0)) + ): + d.extract() # gives me the content i want def upgradeLink(link, user_agent, graball=False): link = link.encode('utf-8') - - # TODO: handle other exceptions - - # XXX: also, better way to check file types would be content-type headers - # and don't mess with anything that isn't a webpage.. - if (not (link.startswith("http://news.ycombinator.com") or link.endswith(".pdf"))): - linkFile = "upgraded/" + re.sub(PUNCTUATION, "_", link) - if linkFile in CACHE: - return CACHE[linkFile] - else: - content = u"" - try: - html = urlgrabber.urlread(link, keepalive=0, user_agent=user_agent) - content = grabContent(link, html) - CACHE[linkFile] = content - except IOError: - pass - return content - else: + + if link.startswith("http://news.ycombinator.com") or link.endswith(".pdf"): return u"" + linkFile = "upgraded/" + re.sub(PUNCTUATION, "_", link) + if linkFile in CACHE: + return CACHE[linkFile] + content = u"" + try: + html = urlgrabber.urlread(link, keepalive=0, user_agent=user_agent) + content = grabContent(link, html) + CACHE[linkFile] = content + except IOError: + pass + return content def get_headers(feedUrl): - if 'headers-'+feedUrl not in CACHE: + if f'headers-{feedUrl}' not in CACHE: return None, None, None - headers = loads(CACHE['headers-'+feedUrl]) - + headers = loads(CACHE[f'headers-{feedUrl}']) + # headers are lowercased by feedparser last_modified = headers.get('last-modified', '') etag = headers.get('etag', '') expires = headers.get('expires', '') - - fp_last_modified = None - if last_modified: - fp_last_modified = rfc822.parsedate(last_modified) - fp_expires = None - if expires: - fp_expires = rfc822.parsedate(expires) + + fp_last_modified = rfc822.parsedate(last_modified) if last_modified else None + fp_expires = rfc822.parsedate(expires) if expires else None # fp if for 9 tuple feed parser required format return etag, fp_last_modified, fp_expires def save_headers(parsedFeed, feedUrl): - CACHE['headers-'+feedUrl] = dumps(parsedFeed.headers) + CACHE[f'headers-{feedUrl}'] = dumps(parsedFeed.headers) class NotFeedException(Exception): diff --git a/urlgrabber/byterange.py b/urlgrabber/byterange.py index 001b4e3..def791d 100644 --- a/urlgrabber/byterange.py +++ b/urlgrabber/byterange.py @@ -245,7 +245,7 @@ def open_local_file(self, req): headers = mimetools.Message(StringIO( 'Content-Type: %s\nContent-Length: %d\nLast-modified: %s\n' % (mtype or 'text/plain', size, modified))) - return urllib.addinfourl(fo, headers, 'file:'+file) + return urllib.addinfourl(fo, headers, f'file:{file}') # FTP Range Support @@ -342,8 +342,7 @@ def ftp_open(self, req): raise IOError, ('ftp error', msg), sys.exc_info()[2] def connect_ftp(self, user, passwd, host, port, dirs): - fw = ftpwrapper(user, passwd, host, port, dirs) - return fw + return ftpwrapper(user, passwd, host, port, dirs) class ftpwrapper(urllib.ftpwrapper): # range support note: @@ -417,8 +416,7 @@ def range_header_to_tuple(range_header): if _rangere is None: import re _rangere = re.compile(r'^bytes=(\d{1,})-(\d*)') - match = _rangere.match(range_header) - if match: + if match := _rangere.match(range_header): tup = range_tuple_normalize(match.group(1,2)) if tup and tup[1]: tup = (tup[0],tup[1]+1) @@ -431,8 +429,7 @@ def range_tuple_to_header(range_tup): if no range is needed. """ if range_tup is None: return None - range_tup = range_tuple_normalize(range_tup) - if range_tup: + if range_tup := range_tuple_normalize(range_tup): if range_tup[1]: range_tup = (range_tup[0],range_tup[1] - 1) return 'bytes=%s-%s' % range_tup @@ -447,8 +444,7 @@ def range_tuple_normalize(range_tup): if range_tup is None: return None # handle first byte fb = range_tup[0] - if fb in (None,''): fb = 0 - else: fb = int(fb) + fb = 0 if fb in (None,'') else int(fb) # handle last byte try: lb = range_tup[1] except IndexError: lb = '' @@ -458,6 +454,7 @@ def range_tuple_normalize(range_tup): # check if range is over the entire file if (fb,lb) == (0,''): return None # check that the range is valid - if lb < fb: raise RangeError('Invalid byte range: %s-%s' % (fb,lb)) + if lb < fb: + raise RangeError(f'Invalid byte range: {fb}-{lb}') return (fb,lb) diff --git a/urlgrabber/grabber.py b/urlgrabber/grabber.py index 4d72a9d..d935916 100644 --- a/urlgrabber/grabber.py +++ b/urlgrabber/grabber.py @@ -499,8 +499,7 @@ def _init_default_logger(): if level < 1: raise ValueError() formatter = logging.Formatter('%(asctime)s %(message)s') - if len(dbinfo) > 1: filename = dbinfo[1] - else: filename = '' + filename = dbinfo[1] if len(dbinfo) > 1 else '' if filename == '': handler = logging.StreamHandler(sys.stderr) elif filename == '-': handler = logging.StreamHandler(sys.stdout) else: handler = logging.FileHandler(filename) @@ -650,36 +649,33 @@ def parse(self, url, opts): opts.quote = None --> guess """ quote = opts.quote - + if opts.prefix: url = self.add_prefix(url, opts.prefix) - + parts = urlparse.urlparse(url) (scheme, host, path, parm, query, frag) = parts if not scheme or (len(scheme) == 1 and scheme in string.letters): # if a scheme isn't specified, we guess that it's "file:" if url[0] not in '/\\': url = os.path.abspath(url) - url = 'file:' + urllib.pathname2url(url) + url = f'file:{urllib.pathname2url(url)}' parts = urlparse.urlparse(url) quote = 0 # pathname2url quotes, so we won't do it again - + if scheme in ['http', 'https']: parts = self.process_http(parts) - + if quote is None: quote = self.guess_should_quote(parts) if quote: parts = self.quote(parts) - + url = urlparse.urlunparse(parts) return url, parts def add_prefix(self, url, prefix): - if prefix[-1] == '/' or url[0] == '/': - url = prefix + url - else: - url = prefix + '/' + url + url = prefix + url if prefix[-1] == '/' or url[0] == '/' else f'{prefix}/{url}' return url def process_http(self, parts): @@ -778,9 +774,9 @@ def _set_attributes(self, **kwargs): if have_range and kwargs.has_key('range'): # normalize the supplied range value self.range = range_tuple_normalize(self.range) - if not self.reget in [None, 'simple', 'check_timestamp']: + if self.reget not in [None, 'simple', 'check_timestamp']: raise URLGrabError(11, _('Illegal reget mode: %s') \ - % (self.reget, )) + % (self.reget, )) def _set_defaults(self): """Set all options to their default values. @@ -796,7 +792,7 @@ def _set_defaults(self): self.copy_local = 0 self.close_connection = 0 self.range = None - self.user_agent = 'urlgrabber/%s' % __version__ + self.user_agent = f'urlgrabber/{__version__}' self.keepalive = 1 self.proxies = None self.reget = None @@ -894,7 +890,7 @@ def urlgrab(self, url, filename=None, **kwargs): different from the passed-in filename if copy_local == 0. """ opts = self.opts.derive(**kwargs) - (url,parts) = opts.urlparser.parse(url, opts) + (url,parts) = opts.urlparser.parse(url, opts) (scheme, host, path, parm, query, frag) = parts if filename is None: filename = os.path.basename( urllib.unquote(path) ) @@ -903,7 +899,7 @@ def urlgrab(self, url, filename=None, **kwargs): # copy currently path = urllib.url2pathname(path) if host: - path = os.path.normpath('//' + host + path) + path = os.path.normpath(f'//{host}{path}') if not os.path.exists(path): raise URLGrabError(2, _('Local file does not exist: %s') % (path, )) @@ -912,14 +908,14 @@ def urlgrab(self, url, filename=None, **kwargs): _('Not a normal file: %s') % (path, )) elif not opts.range: return path - + def retryfunc(opts, url, filename): fo = URLGrabberFileObject(url, filename, opts) try: fo._do_grab() - if not opts.checkfunc is None: + if opts.checkfunc is not None: cb_func, cb_args, cb_kwargs = \ - self._make_callback(opts.checkfunc) + self._make_callback(opts.checkfunc) obj = CallbackObject() obj.filename = filename obj.url = url @@ -927,7 +923,7 @@ def retryfunc(opts, url, filename): finally: fo.close() return filename - + return self._retry(opts, retryfunc, url, filename) def urlread(self, url, limit=None, **kwargs): @@ -938,10 +934,10 @@ def urlread(self, url, limit=None, **kwargs): into memory, but don't use too much' """ opts = self.opts.derive(**kwargs) - (url,parts) = opts.urlparser.parse(url, opts) + (url,parts) = opts.urlparser.parse(url, opts) if limit is not None: limit = limit + 1 - + def retryfunc(opts, url, limit): fo = URLGrabberFileObject(url, filename=None, opts=opts) s = '' @@ -950,12 +946,10 @@ def retryfunc(opts, url, limit): # have a default "limit" of None, while the built-in (real) # file objects have -1. They each break the other, so for # now, we just force the default if necessary. - if limit is None: s = fo.read() - else: s = fo.read(limit) - - if not opts.checkfunc is None: + s = fo.read() if limit is None else fo.read(limit) + if opts.checkfunc is not None: cb_func, cb_args, cb_kwargs = \ - self._make_callback(opts.checkfunc) + self._make_callback(opts.checkfunc) obj = CallbackObject() obj.data = s obj.url = url @@ -963,7 +957,7 @@ def retryfunc(opts, url, limit): finally: fo.close() return s - + s = self._retry(opts, retryfunc, url, limit) if limit and len(s) > limit: raise URLGrabError(8, @@ -971,10 +965,7 @@ def retryfunc(opts, url, limit): return s def _make_callback(self, callback_obj): - if callable(callback_obj): - return callback_obj, (), {} - else: - return callback_obj + return (callback_obj, (), {}) if callable(callback_obj) else callback_obj # create the default URLGrabber used by urlXXX functions. # NOTE: actual defaults are set in URLGrabberOptions @@ -1020,7 +1011,7 @@ def _get_opener(self): handlers = [] need_keepalive_handler = (have_keepalive and self.opts.keepalive) need_range_handler = (range_handlers and \ - (self.opts.range or self.opts.reget)) + (self.opts.range or self.opts.reget)) # if you specify a ProxyHandler when creating the opener # it _must_ come before all other handlers in the list or urllib2 # chokes. @@ -1052,8 +1043,7 @@ def _get_opener(self): self.opts.ssl_context) if need_keepalive_handler: - handlers.append(HTTPHandler()) - handlers.append(HTTPSHandler(ssl_factory)) + handlers.extend((HTTPHandler(), HTTPSHandler(ssl_factory))) if need_range_handler: handlers.extend( range_handlers ) handlers.append( auth_handler ) @@ -1082,9 +1072,9 @@ def _do_open(self): modified_tuple = hdr.getdate_tz('last-modified') modified_stamp = rfc822.mktime_tz(modified_tuple) if modified_stamp > self.reget_time: fetch_again = 1 - except (TypeError,): + except TypeError: fetch_again = 1 - + if fetch_again: # the server version is newer than the (incomplete) local # version, so we should abandon the version we're getting @@ -1098,7 +1088,7 @@ def _do_open(self): (scheme, host, path, parm, query, frag) = urlparse.urlparse(self.url) path = urllib.unquote(path) if not (self.opts.progress_obj or self.opts.raw_throttle() \ - or self.opts.timeout): + or self.opts.timeout): # if we're not using the progress_obj, throttling, or timeout # we can get a performance boost by going directly to # the underlying fileobject for reads. @@ -1108,7 +1098,7 @@ def _do_open(self): elif self.opts.progress_obj: try: length = int(hdr['Content-Length']) - length = length + self._amount_read # Account for regets + length += self._amount_read except (KeyError, ValueError, TypeError): length = None @@ -1151,17 +1141,17 @@ def _build_range(self, req): rt = reget_length, '' self.append = 1 - + if self.opts.range: if not have_range: raise URLGrabError(10, _('Byte range requested but range '\ - 'support unavailable')) + 'support unavailable')) rt = self.opts.range if rt[0]: rt = (rt[0] + reget_length, rt[1]) if rt: - header = range_tuple_to_header(rt) - if header: req.add_header('Range', header) + if header := range_tuple_to_header(rt): + req.add_header('Range', header) def _make_request(self, req, opener): try: @@ -1286,11 +1276,10 @@ def readline(self, limit=-1): while i < 0 and not (0 < limit <= len(self._rbuf)): L = len(self._rbuf) self._fill_buffer(L + self._rbufsize) - if not len(self._rbuf) > L: break + if len(self._rbuf) <= L: break i = string.find(self._rbuf, '\n', L) - if i < 0: i = len(self._rbuf) - else: i = i+1 + i = len(self._rbuf) if i < 0 else i+1 if 0 <= limit < len(self._rbuf): i = limit s, self._rbuf = self._rbuf[:i], self._rbuf[i:] diff --git a/urlgrabber/keepalive.py b/urlgrabber/keepalive.py index 71393e2..8c089de 100644 --- a/urlgrabber/keepalive.py +++ b/urlgrabber/keepalive.py @@ -99,6 +99,7 @@ """ + # $Id: keepalive.py,v 1.16 2006/09/22 00:58:05 mstenner Exp $ import urllib2 @@ -111,8 +112,7 @@ import sslfactory import sys -if sys.version_info < (2, 4): HANDLE_ERRORS = 1 -else: HANDLE_ERRORS = 0 +HANDLE_ERRORS = 1 if sys.version_info < (2, 4) else 0 class ConnectionManager: """ @@ -138,15 +138,14 @@ def add(self, host, connection, ready): def remove(self, connection): self._lock.acquire() try: - try: - host = self._connmap[connection] - except KeyError: - pass - else: - del self._connmap[connection] - del self._readymap[connection] - self._hostmap[host].remove(connection) - if not self._hostmap[host]: del self._hostmap[host] + host = self._connmap[connection] + except KeyError: + pass + else: + del self._connmap[connection] + del self._readymap[connection] + self._hostmap[host].remove(connection) + if not self._hostmap[host]: del self._hostmap[host] finally: self._lock.release() @@ -169,10 +168,7 @@ def get_ready_conn(self, host): return conn def get_all(self, host=None): - if host: - return list(self._hostmap.get(host, [])) - else: - return dict(self._hostmap) + return list(self._hostmap.get(host, [])) if host else dict(self._hostmap) class KeepAliveHandler: def __init__(self): @@ -282,11 +278,11 @@ def _reuse_connection(self, h, req, host): # that it's now possible this call will raise # a DIFFERENT exception if DEBUG: DEBUG.error("unexpected exception - closing " + \ - "connection to %s (%d)", host, id(h)) + "connection to %s (%d)", host, id(h)) self._cm.remove(h) h.close() raise - + if r is None or r.version == 9: # httplib falls back to assuming HTTP 0.9 if it gets a # bad header back. This is most likely to happen if @@ -295,9 +291,7 @@ def _reuse_connection(self, h, req, host): if DEBUG: DEBUG.info("failed to re-use connection to %s (%d)", host, id(h)) r = None - else: - if DEBUG: DEBUG.info("re-using connection to %s (%d)", host, id(h)) - + elif DEBUG: DEBUG.info("re-using connection to %s (%d)", host, id(h)) return r def _start_transaction(self, h, req): @@ -406,7 +400,7 @@ def geturl(self): def read(self, amt=None): # the _rbuf test is only in this first if for speed. It's not # logically necessary - if self._rbuf and not amt is None: + if self._rbuf and amt is not None: L = len(self._rbuf) if amt > L: amt -= L @@ -428,8 +422,7 @@ def readline(self, limit=-1): i = new.find('\n') if i >= 0: i = i + len(self._rbuf) self._rbuf = self._rbuf + new - if i < 0: i = len(self._rbuf) - else: i = i+1 + i = len(self._rbuf) if i < 0 else i+1 if 0 <= limit < len(self._rbuf): i = limit data, self._rbuf = self._rbuf[:i], self._rbuf[i:] return data diff --git a/urlgrabber/mirror.py b/urlgrabber/mirror.py index 9664c6b..6f3b7b0 100644 --- a/urlgrabber/mirror.py +++ b/urlgrabber/mirror.py @@ -287,12 +287,7 @@ def _get_mirror(self, gr): return gr.mirrors[gr._next] def _failure(self, gr, cb_obj): - # OVERRIDE IDEAS: - # inspect the error - remove=1 for 404, remove=2 for connection - # refused, etc. (this can also be done via - # the callback) - cb = gr.kw.get('failure_callback') or self.failure_callback - if cb: + if cb := gr.kw.get('failure_callback') or self.failure_callback: if type(cb) == type( () ): cb, args, kwargs = cb else: @@ -306,7 +301,7 @@ def _failure(self, gr, cb_obj): #action = action or gr.kw.get('default_action') or self.default_action # the other is to fall through for each element in the action dict a = dict(self.default_action or {}) - a.update(gr.kw.get('default_action', {})) + a |= gr.kw.get('default_action', {}) a.update(action) action = a self.increment_mirror(gr, action) @@ -372,7 +367,7 @@ def _join_url(self, base_url, rel_url): if base_url.endswith('/') or rel_url.startswith('/'): return base_url + rel_url else: - return base_url + '/' + rel_url + return f'{base_url}/{rel_url}' def _mirror_try(self, func, url, kw): gr = GrabRequest() @@ -454,5 +449,3 @@ def __init__(self, grabber, mirrors, **kwargs): MirrorGroup.__init__(self, grabber, mirrors, **kwargs) random.shuffle(self.mirrors) -if __name__ == '__main__': - pass diff --git a/urlgrabber/progress.py b/urlgrabber/progress.py index 02db524..160bab7 100644 --- a/urlgrabber/progress.py +++ b/urlgrabber/progress.py @@ -47,7 +47,8 @@ def start(self, filename=None, url=None, basename=None, #size = None ######### TESTING self.size = size - if not size is None: self.fsize = format_number(size) + 'B' + if size is not None: + self.fsize = f'{format_number(size)}B' if now is None: now = time.time() self.start_time = now @@ -93,13 +94,10 @@ def _do_update(self, amount_read, now=None): fetime = format_time(etime) fread = format_number(amount_read) #self.size = None - if self.text is not None: - text = self.text - else: - text = self.basename + text = self.text if self.text is not None else self.basename if self.size is None: out = '\r%-60.60s %5sB %s ' % \ - (text, fread, fetime) + (text, fread, fetime) else: rtime = self.re.remaining_time() frtime = format_time(rtime) @@ -107,7 +105,7 @@ def _do_update(self, amount_read, now=None): bar = '='*int(25 * frac) out = '\r%-25.25s %3i%% |%-25.25s| %5sB %8s ETA ' % \ - (text, frac*100, bar, fread, frtime) + (text, frac*100, bar, fread, frtime) self.fo.write(out) self.fo.flush() @@ -115,17 +113,14 @@ def _do_update(self, amount_read, now=None): def _do_end(self, amount_read, now=None): total_time = format_time(self.re.elapsed_time()) total_size = format_number(amount_read) - if self.text is not None: - text = self.text - else: - text = self.basename + text = self.text if self.text is not None else self.basename if self.size is None: out = '\r%-60.60s %5sB %s ' % \ - (text, total_size, total_time) + (text, total_size, total_time) else: bar = '='*25 out = '\r%-25.25s %3i%% |%-25.25s| %5sB %8s ' % \ - (text, 100, bar, total_size, total_time) + (text, 100, bar, total_size, total_time) self.fo.write(out + '\n') self.fo.flush() @@ -213,11 +208,11 @@ def removeMeter(self, meter): ########################################################### # child functions - these should only be called by helpers def start_meter(self, meter, now): - if not meter in self.meters: + if meter not in self.meters: raise ValueError('attempt to use orphaned meter') self._lock.acquire() try: - if not meter in self.in_progress_meters: + if meter not in self.in_progress_meters: self.in_progress_meters.append(meter) self.open_files += 1 finally: @@ -228,10 +223,10 @@ def _do_start_meter(self, meter, now): pass def update_meter(self, meter, now): - if not meter in self.meters: + if meter not in self.meters: raise ValueError('attempt to use orphaned meter') if (now >= self.last_update_time + self.update_period) or \ - not self.last_update_time: + not self.last_update_time: self.re.update(self._amount_read(), now) self.last_update_time = now self._do_update_meter(meter, now) @@ -240,7 +235,7 @@ def _do_update_meter(self, meter, now): pass def end_meter(self, meter, now): - if not meter in self.meters: + if meter not in self.meters: raise ValueError('attempt to use orphaned meter') self._lock.acquire() try: @@ -257,7 +252,7 @@ def _do_end_meter(self, meter, now): pass def failure_meter(self, meter, message, now): - if not meter in self.meters: + if meter not in self.meters: raise ValueError('attempt to use orphaned meter') self._lock.acquire() try: @@ -298,7 +293,7 @@ def _do_update_meter(self, meter, now): self._lock.acquire() try: format = "files: %3i/%-3i %3i%% data: %6.6s/%-6.6s %3i%% " \ - "time: %8.8s/%8.8s" + "time: %8.8s/%8.8s" df = self.finished_files tf = self.numfiles or 1 pf = 100 * float(df)/tf + 0.49 @@ -307,14 +302,12 @@ def _do_update_meter(self, meter, now): pd = 100 * (self.re.fraction_read() or 0) + 0.49 dt = self.re.elapsed_time() rt = self.re.remaining_time() - if rt is None: tt = None - else: tt = dt + rt - - fdd = format_number(dd) + 'B' - ftd = format_number(td) + 'B' + tt = None if rt is None else dt + rt + fdd = f'{format_number(dd)}B' + ftd = f'{format_number(td)}B' fdt = format_time(dt, 1) ftt = format_time(tt, 1) - + out = '%-79.79s' % (format % (df, tf, pf, fdd, ftd, pd, fdt, ftt)) self.fo.write('\r' + out) self.fo.flush() @@ -327,11 +320,11 @@ def _do_end_meter(self, meter, now): format = "%-30.30s %6.6s %8.8s %9.9s" fn = meter.basename size = meter.last_amount_read - fsize = format_number(size) + 'B' + fsize = f'{format_number(size)}B' et = meter.re.elapsed_time() fet = format_time(et, 1) - frate = format_number(size / et) + 'B/s' - + frate = f'{format_number(size / et)}B/s' + out = '%-79.79s' % (format % (fn, fsize, fet, frate)) self.fo.write('\r' + out + '\n') finally: @@ -348,7 +341,8 @@ def _do_failure_meter(self, meter, message, now): if not message: message = [''] out = '%-79s' % (format % (fn, 'FAILED', message[0] or '')) self.fo.write('\r' + out + '\n') - for m in message[1:]: self.fo.write(' ' + m + '\n') + for m in message[1:]: + self.fo.write(f' {m}' + '\n') self._lock.release() finally: self._do_update_meter(meter, now) @@ -471,24 +465,20 @@ def _round_remaining_time(self, rt, start_time=15.0): if rt < 0: return 0.0 shift = int(math.log(rt/start_time)/math.log(2)) rt = int(rt) - if shift <= 0: return rt - return float(int(rt) >> shift << shift) + return rt if shift <= 0 else float(rt >> shift << shift) def format_time(seconds, use_hours=0): if seconds is None or seconds < 0: - if use_hours: return '--:--:--' - else: return '--:--' - else: - seconds = int(seconds) - minutes = seconds / 60 - seconds = seconds % 60 - if use_hours: - hours = minutes / 60 - minutes = minutes % 60 - return '%02i:%02i:%02i' % (hours, minutes, seconds) - else: - return '%02i:%02i' % (minutes, seconds) + return '--:--:--' if use_hours else '--:--' + seconds = int(seconds) + minutes = seconds / 60 + seconds %= 60 + if not use_hours: + return '%02i:%02i' % (minutes, seconds) + hours = minutes / 60 + minutes = minutes % 60 + return '%02i:%02i:%02i' % (hours, minutes, seconds) def format_number(number, SI=0, space=' '): """Turn numbers into human-readable metric-like numbers""" diff --git a/urlgrabber/sslfactory.py b/urlgrabber/sslfactory.py index f7e6d3d..4851e4c 100644 --- a/urlgrabber/sslfactory.py +++ b/urlgrabber/sslfactory.py @@ -45,13 +45,12 @@ def _get_ssl_context(self, ssl_ca_cert, ssl_context): then the supplied ssl context is used. If no ssl context was supplied, None is returned. """ - if ssl_ca_cert: - context = SSL.Context() - context.load_verify_locations(ssl_ca_cert) - context.set_verify(SSL.verify_peer, -1) - return context - else: + if not ssl_ca_cert: return ssl_context + context = SSL.Context() + context.load_verify_locations(ssl_ca_cert) + context.set_verify(SSL.verify_peer, -1) + return context def create_https_connection(self, host, response_class = None): connection = httplib.HTTPSConnection(host, self.ssl_context) @@ -80,10 +79,8 @@ def get_factory(ssl_ca_cert = None, ssl_context = None): """ Return an SSLFactory, based on if M2Crypto is available. """ if have_m2crypto: return M2SSLFactory(ssl_ca_cert, ssl_context) - else: # Log here if someone provides the args but we don't use them. - if ssl_ca_cert or ssl_context: - if DEBUG: - DEBUG.warning("SSL arguments supplied, but M2Crypto is not available. " - "Using Python SSL.") - return SSLFactory() + if (ssl_ca_cert or ssl_context) and DEBUG: + DEBUG.warning("SSL arguments supplied, but M2Crypto is not available. " + "Using Python SSL.") + return SSLFactory() diff --git a/web/application.py b/web/application.py index d9dd0fb..e5b9c78 100755 --- a/web/application.py +++ b/web/application.py @@ -158,11 +158,8 @@ def request(self, localpart='/', method='GET', data=None, """ path, maybe_query = urllib.splitquery(localpart) query = maybe_query or "" - - if 'env' in kw: - env = kw['env'] - else: - env = {} + + env = kw['env'] if 'env' in kw else {} env = dict(env, HTTP_HOST=host, REQUEST_METHOD=method, PATH_INFO=path, QUERY_STRING=query, HTTPS=str(https)) headers = headers or {} @@ -177,10 +174,7 @@ def request(self, localpart='/', method='GET', data=None, if data: import StringIO - if isinstance(data, dict): - q = urllib.urlencode(data) - else: - q = data + q = urllib.urlencode(data) if isinstance(data, dict) else data env['wsgi.input'] = StringIO.StringIO(q) if not env.get('CONTENT_TYPE', '').lower().startswith('multipart/') and 'CONTENT_LENGTH' not in env: env['CONTENT_LENGTH'] = len(q) @@ -189,6 +183,7 @@ def start_response(status, headers): response.status = status response.headers = dict(headers) response.header_items = headers + response.data = "".join(self.wsgifunc(cleanup_threadlocal=False)(env, start_response)) return response @@ -320,7 +315,7 @@ def load(self, env): ctx.protocol = 'https' else: ctx.protocol = 'http' - ctx.homedomain = ctx.protocol + '://' + env.get('HTTP_HOST', '[unknown]') + ctx.homedomain = f'{ctx.protocol}://' + env.get('HTTP_HOST', '[unknown]') ctx.homepath = os.environ.get('REAL_SCRIPT_NAME', env.get('SCRIPT_NAME', '')) ctx.home = ctx.homedomain + ctx.homepath #@@ home is changed when the request is handled to a sub-application. @@ -339,14 +334,14 @@ def load(self, env): ctx.query = '' ctx.fullpath = ctx.path + ctx.query - + for k, v in ctx.iteritems(): if isinstance(v, str): ctx[k] = safeunicode(v) # status must always be str ctx.status = '200 OK' - + ctx.app_stack = [] def _delegate(self, f, fvars, args=[]): @@ -358,9 +353,9 @@ def handle_class(cls): raise web.nomethod(cls) tocall = getattr(cls(), meth) return tocall(*args) - + def is_class(o): return isinstance(o, (types.ClassType, type)) - + if f is None: raise web.notfound() elif isinstance(f, application): @@ -373,7 +368,7 @@ def is_class(o): return isinstance(o, (types.ClassType, type)) if web.ctx.method == "GET": x = web.ctx.env.get('QUERY_STRING', '') if x: - url += '?' + x + url += f'?{x}' raise web.redirect(url) elif '.' in f: x = f.split('.') @@ -397,10 +392,10 @@ def _match(self, mapping, value): else: continue elif isinstance(what, basestring): - what, result = utils.re_subm('^' + pat + '$', what, value) + what, result = utils.re_subm(f'^{pat}$', what, value) else: - result = utils.re_compile('^' + pat + '$').match(value) - + result = utils.re_compile(f'^{pat}$').match(value) + if result: # it's a match return what, [x and urllib.unquote(x) for x in result.groups()] return None, None @@ -433,16 +428,14 @@ def get_parent_app(self): def notfound(self): """Returns HTTPError with '404 not found' message""" - parent = self.get_parent_app() - if parent: + if parent := self.get_parent_app(): return parent.notfound() else: return web._NotFound() def internalerror(self): """Returns HTTPError with '500 internal error' message""" - parent = self.get_parent_app() - if parent: + if parent := self.get_parent_app(): return parent.internalerror() elif web.config.get('debug'): import debugerror @@ -469,16 +462,19 @@ class auto_application(application): def __init__(self): application.__init__(self) + + class metapage(type): def __init__(klass, name, bases, attrs): type.__init__(klass, name, bases, attrs) - path = attrs.get('path', '/' + name) + path = attrs.get('path', f'/{name}') # path can be specified as None to ignore that class # typically required to create a abstract base class. if path is not None: self.add_mapping(path, klass) + class page: path = None __metaclass__ = metapage @@ -515,9 +511,9 @@ def handle(self): def _match(self, mapping, value): for pat, what in utils.group(mapping, 2): if isinstance(what, basestring): - what, result = utils.re_subm('^' + pat + '$', what, value) + what, result = utils.re_subm(f'^{pat}$', what, value) else: - result = utils.re_compile('^' + pat + '$').match(value) + result = utils.re_compile(f'^{pat}$').match(value) if result: # it's a match return what, [x and urllib.unquote(x) for x in result.groups()] @@ -577,11 +573,11 @@ def internal(self, arg): if '/' in arg: first, rest = arg.split('/', 1) func = prefix + first - args = ['/' + rest] + args = [f'/{rest}'] else: func = prefix + arg args = [] - + if hasattr(self, func): try: return getattr(self, func)(*args) @@ -589,6 +585,7 @@ def internal(self, arg): return web.notfound() else: return web.notfound() + return internal class Reloader: diff --git a/web/browser.py b/web/browser.py index f1d260f..1b06053 100644 --- a/web/browser.py +++ b/web/browser.py @@ -70,10 +70,8 @@ def open(self, url, data=None, headers={}): def show(self): """Opens the current page in real web browser.""" - f = open('page.html', 'w') - f.write(self.data) - f.close() - + with open('page.html', 'w') as f: + f.write(self.data) import webbrowser, os url = 'file://' + os.path.abspath('page.html') webbrowser.open(url) @@ -94,7 +92,7 @@ def get_text(self, e=None): def _get_links(self): soup = self.get_soup() - return [a for a in soup.findAll(name='a')] + return list(soup.findAll(name='a')) def get_links(self, text=None, text_regex=None, url=None, url_regex=None, predicate=None): """Returns all links in the document.""" @@ -224,7 +222,7 @@ def https_open(self, req): https_request = urllib2.HTTPHandler.do_request_ def _make_response(self, result, url): - data = "\r\n".join(["%s: %s" % (k, v) for k, v in result.header_items]) + data = "\r\n".join([f"{k}: {v}" for k, v in result.header_items]) headers = httplib.HTTPMessage(StringIO(data)) response = urllib.addinfourl(StringIO(result.data), headers, url) code, msg = result.status.split(None, 1) diff --git a/web/contrib/template.py b/web/contrib/template.py index 2642d65..0827517 100644 --- a/web/contrib/template.py +++ b/web/contrib/template.py @@ -23,8 +23,8 @@ def __init__(self, path): def __getattr__(self, name): from Cheetah.Template import Template - path = os.path.join(self.path, name + ".html") - + path = os.path.join(self.path, f"{name}.html") + def template(**kw): t = Template(file=path, searchList=[kw]) return t.respond() @@ -54,7 +54,7 @@ def __init__(self, *a, **kwargs): def __getattr__(self, name): # Assuming all templates are html - path = name + ".html" + path = f"{name}.html" if self._type == "text": from genshi.template import TextTemplate @@ -67,10 +67,8 @@ def __getattr__(self, name): t = self._loader.load(path, cls=cls) def template(**kw): stream = t.generate(**kw) - if type: - return stream.render(type) - else: - return stream.render() + return stream.render(type) if type else stream.render() + return template class render_jinja: @@ -87,7 +85,7 @@ def __init__(self, *a, **kwargs): def __getattr__(self, name): # Assuming all templates end with .html - path = name + '.html' + path = f'{name}.html' t = self._lookup.get_template(path) return t.render @@ -105,7 +103,7 @@ def __init__(self, *a, **kwargs): def __getattr__(self, name): # Assuming all templates are html - path = name + ".html" + path = f"{name}.html" t = self._lookup.get_template(path) return t.render diff --git a/web/db.py b/web/db.py index d74b650..93e5132 100644 --- a/web/db.py +++ b/web/db.py @@ -87,7 +87,7 @@ def __str__(self): return str(self.value) def __repr__(self): - return '' % repr(self.value) + return f'' sqlparam = SQLParam @@ -183,18 +183,18 @@ def values(self): """ return [i.value for i in self.items if isinstance(i, SQLParam)] - def join(items, sep=' '): + def join(self, sep=' '): """ Joins multiple queries. >>> SQLQuery.join(['a', 'b'], ', ') """ - if len(items) == 0: + if len(self) == 0: return SQLQuery("") - q = SQLQuery(items[0]) - for item in items[1:]: + q = SQLQuery(self[0]) + for item in self[1:]: q += sep q += item return q @@ -203,12 +203,12 @@ def join(items, sep=' '): def __str__(self): try: - return self.query() % tuple([sqlify(x) for x in self.values()]) + return self.query() % tuple(sqlify(x) for x in self.values()) except (ValueError, TypeError): return self.query() def __repr__(self): - return '' % repr(str(self)) + return f'' class SQLLiteral: """ @@ -271,7 +271,7 @@ def sqlify(obj): else: return repr(obj) -def sqllist(lst): +def sqllist(lst): """ Converts the arguments for use in something like a WHERE clause. @@ -282,10 +282,7 @@ def sqllist(lst): >>> sqllist(u'abc') u'abc' """ - if isinstance(lst, basestring): - return lst - else: - return ', '.join(lst) + return lst if isinstance(lst, basestring) else ', '.join(lst) def sqlors(left, lst): """ @@ -312,14 +309,16 @@ def sqlors(left, lst): lst = lst[0] if isinstance(lst, iters): - return SQLQuery(['('] + - sum([[left, sqlparam(x), ' OR '] for x in lst], []) + - ['1=2)'] + return SQLQuery( + ( + (['('] + sum(([left, sqlparam(x), ' OR '] for x in lst), [])) + + ['1=2)'] + ) ) else: return left + sqlparam(lst) -def sqlwhere(dictionary, grouping=' AND '): +def sqlwhere(dictionary, grouping=' AND '): """ Converts a `dictionary` to an SQL WHERE clause `SQLQuery`. @@ -330,7 +329,9 @@ def sqlwhere(dictionary, grouping=' AND '): >>> sqlwhere({'a': 'a', 'b': 'b'}).query() 'a = %s AND b = %s' """ - return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping) + return SQLQuery.join( + [f'{k} = {sqlparam(v)}' for k, v in dictionary.items()], grouping + ) def sqlquote(a): """ @@ -503,7 +504,7 @@ def _param_marker(self): return '?' elif style == 'numeric': return ':1' - elif style in ['format', 'pyformat']: + elif style in {'format', 'pyformat'}: return '%s' raise UnknownParamstyle, style @@ -544,14 +545,11 @@ def _db_execute(self, cur, sql_query): def _where(self, where, vars): if isinstance(where, (int, long)): - where = "id = " + sqlparam(where) - #@@@ for backward-compatibility + where = f"id = {sqlparam(where)}" elif isinstance(where, (list, tuple)) and len(where) == 2: where = SQLQuery(where[0], where[1]) - elif isinstance(where, SQLQuery): - pass - else: - where = reparam(where, vars) + elif not isinstance(where, SQLQuery): + where = reparam(where, vars) return where def query(self, sql_query, vars=None, processed=False, _test=False): @@ -597,7 +595,7 @@ def iterwrapper(): return out def select(self, tables, vars=None, what='*', where=None, order=None, group=None, - limit=None, offset=None, _test=False): + limit=None, offset=None, _test=False): """ Selects `what` from `tables` with clauses `where`, `order`, `group`, `limit`, and `offset`. Uses vars to interpolate. @@ -613,8 +611,7 @@ def select(self, tables, vars=None, what='*', where=None, order=None, group=None sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset) clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None] qout = SQLQuery.join(clauses) - if _test: return qout - return self.query(qout, processed=True) + return qout if _test else self.query(qout, processed=True) def where(self, table, what='*', order=None, group=None, limit=None, offset=None, _test=False, **kwargs): @@ -627,9 +624,7 @@ def where(self, table, what='*', order=None, group=None, limit=None, >>> db.where('foo', source=2, crust='dewey', _test=True) """ - where = [] - for k, v in kwargs.iteritems(): - where.append(k + ' = ' + sqlquote(v)) + where = [f'{k} = {sqlquote(v)}' for k, v in kwargs.iteritems()] return self.select(table, what=what, order=order, group=group, limit=limit, offset=offset, _test=_test, where=SQLQuery.join(where, ' AND ')) @@ -646,11 +641,7 @@ def sql_clauses(self, what, tables, where, group, order, limit, offset): def gen_clause(self, sql, val, vars): if isinstance(val, (int, long)): - if sql == 'WHERE': - nout = 'id = ' + sqlquote(val) - else: - nout = SQLQuery(val) - #@@@ + nout = f'id = {sqlquote(val)}' if sql == 'WHERE' else SQLQuery(val) elif isinstance(val, (list, tuple)) and len(val) == 2: nout = SQLQuery(val[0], val[1]) # backwards-compatibility elif isinstance(val, SQLQuery): @@ -659,12 +650,11 @@ def gen_clause(self, sql, val, vars): nout = reparam(val, vars) def xjoin(a, b): - if a and b: return a + ' ' + b - else: return a or b + return f'{a} {b}' if a and b else a or b return xjoin(sql, nout) - def insert(self, tablename, seqname=None, _test=False, **values): + def insert(self, tablename, seqname=None, _test=False, **values): """ Inserts `values` into `tablename`. Returns current sequence ID. Set `seqname` to the ID if it's not the default, or to `False` @@ -679,17 +669,18 @@ def insert(self, tablename, seqname=None, _test=False, **values): >>> q.values() [2, 'bob'] """ - def q(x): return "(" + x + ")" - + def q(x): + return f"({x})" + if values: _keys = SQLQuery.join(values.keys(), ', ') _values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ') - sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values) + sql_query = f"INSERT INTO {tablename} {q(_keys)} VALUES {q(_values)}" else: - sql_query = SQLQuery("INSERT INTO %s DEFAULT VALUES" % tablename) + sql_query = SQLQuery(f"INSERT INTO {tablename} DEFAULT VALUES") if _test: return sql_query - + db_cursor = self._db_cursor() if seqname is not False: sql_query = self._process_insert_query(sql_query, tablename, seqname) @@ -707,7 +698,7 @@ def q(x): return "(" + x + ")" out = db_cursor.fetchone()[0] except Exception: out = None - + if not self.ctx.transactions: self.ctx.commit() return out @@ -728,14 +719,10 @@ def multiple_insert(self, tablename, values, seqname=None, _test=False): """ if not values: return [] - + if not self.supports_multiple_insert: out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values] - if seqname is False: - return None - else: - return out - + return None if seqname is False else out keys = values[0].keys() #@@ make sure all keys are valid @@ -744,12 +731,12 @@ def multiple_insert(self, tablename, values, seqname=None, _test=False): if v.keys() != keys: raise ValueError, 'Bad data' - sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys))) + sql_query = SQLQuery(f"INSERT INTO {tablename} ({', '.join(keys)}) VALUES ") data = [] for row in values: d = SQLQuery.join([SQLParam(row[k]) for k in keys], ', ') - data.append('(' + d + ')') + data.append(f'({d})') sql_query += SQLQuery.join(data, ', ') if _test: return sql_query @@ -778,7 +765,7 @@ def multiple_insert(self, tablename, values, seqname=None, _test=False): return out - def update(self, tables, where, vars=None, _test=False, **values): + def update(self, tables, where, vars=None, _test=False, **values): """ Update `tables` with clause `where` (interpolated using `vars`) and setting `values`. @@ -798,19 +785,18 @@ def update(self, tables, where, vars=None, _test=False, **values): where = self._where(where, vars) query = ( - "UPDATE " + sqllist(tables) + - " SET " + sqlwhere(values, ', ') + - " WHERE " + where) + (f"UPDATE {sqllist(tables)} SET " + sqlwhere(values, ', ')) + " WHERE " + ) + where if _test: return query - + db_cursor = self._db_cursor() self._db_execute(db_cursor, query) if not self.ctx.transactions: self.ctx.commit() return db_cursor.rowcount - def delete(self, table, where, using=None, vars=None, _test=False): + def delete(self, table, where, using=None, vars=None, _test=False): """ Deletes from `table` with clauses `where` and `using`. @@ -822,9 +808,11 @@ def delete(self, table, where, using=None, vars=None, _test=False): if vars is None: vars = {} where = self._where(where, vars) - q = 'DELETE FROM ' + table - if where: q += ' WHERE ' + where - if using: q += ' USING ' + sqllist(using) + q = f'DELETE FROM {table}' + if where: + q += f' WHERE {where}' + if using: + q += f' USING {sqllist(using)}' if _test: return q @@ -869,8 +857,8 @@ def get_db_module(self): def _process_insert_query(self, query, tablename, seqname): if seqname is None: - seqname = tablename + "_id_seq" - return query + "; SELECT currval('%s')" % seqname + seqname = f"{tablename}_id_seq" + return f"{query}; SELECT currval('{seqname}')" def _connect(self, keywords): conn = DB._connect(self, keywords) @@ -959,7 +947,6 @@ def __init__(self, **keywords): import kinterbasdb as db except Exception: db = None - pass if 'pw' in keywords: keywords['passwd'] = keywords['pw'] del keywords['pw'] @@ -1014,7 +1001,7 @@ def _process_insert_query(self, query, tablename, seqname): # It is not possible to get seq name from table name in Oracle return query else: - return query + "; SELECT %s.currval FROM dual" % seqname + return f"{query}; SELECT {seqname}.currval FROM dual" _databases = {} def database(dburl=None, **params): @@ -1049,7 +1036,7 @@ def register_database(name, clazz): register_database('mssql', MSSQLDB) register_database('oracle', OracleDB) -def _interpolate(format): +def _interpolate(format): """ Takes a format string and returns a list of 2-tuples of the form (boolean, string) where boolean says whether string should be evaled @@ -1084,9 +1071,9 @@ def matchorfail(text, pos): tstart, tend = match.regs[3] token = format[tstart:tend] if token == "{": - level = level + 1 + level += 1 elif token == "}": - level = level - 1 + level -= 1 chunks.append((1, format[dollar + 2:pos - 1])) elif nextchar in namechars: @@ -1103,9 +1090,9 @@ def matchorfail(text, pos): tstart, tend = match.regs[3] token = format[tstart:tend] if token[0] in "([": - level = level + 1 + level += 1 elif token[0] in ")]": - level = level - 1 + level -= 1 else: break chunks.append((1, format[dollar + 1:pos])) diff --git a/web/debugerror.py b/web/debugerror.py index 5ff62ff..9e8492a 100644 --- a/web/debugerror.py +++ b/web/debugerror.py @@ -314,7 +314,7 @@ def emailerrors_internal(): error_value = tb[1] tb_txt = ''.join(traceback.format_exception(*tb)) path = web.ctx.path - request = web.ctx.method+' '+web.ctx.home+web.ctx.fullpath + request = f'{web.ctx.method} {web.ctx.home}{web.ctx.fullpath}' eaddr = email_address text = ("""\ ------here---- @@ -331,13 +331,16 @@ def emailerrors_internal(): """ % locals()) + str(djangoerror()) sendmail( - "your buggy site <%s>" % eaddr, - "the bugfixer <%s>" % eaddr, - "bug: %(error_name)s: %(error_value)s (%(path)s)" % locals(), - text, - headers={'Content-Type': 'multipart/mixed; boundary="----here----"'}) + f"your buggy site <{eaddr}>", + f"the bugfixer <{eaddr}>", + "bug: %(error_name)s: %(error_value)s (%(path)s)" % locals(), + text, + headers={ + 'Content-Type': 'multipart/mixed; boundary="----here----"' + }, + ) return error - + return emailerrors_internal if __name__ == "__main__": diff --git a/web/form.py b/web/form.py index 76f5b0b..6431984 100644 --- a/web/form.py +++ b/web/form.py @@ -9,8 +9,7 @@ def attrget(obj, attr, value=None): if hasattr(obj, 'has_key') and obj.has_key(attr): return obj[attr] - if hasattr(obj, attr): return getattr(obj, attr) - return value + return getattr(obj, attr) if hasattr(obj, attr) else value class Form: r""" @@ -36,25 +35,25 @@ def render(self): out += self.rendernote(self.note) out += '\n' for i in self.inputs: - out += ' ' % (i.id, net.websafe(i.description)) - out += "\n" + out += f' ' + out += f"\n" out += "
"+i.pre+i.render()+i.post+"
{i.pre}{i.render()}{i.post}" + "
" return out def render_css(self): - out = [] - out.append(self.rendernote(self.note)) + out = [self.rendernote(self.note)] for i in self.inputs: - out.append('' % (i.id, net.websafe(i.description))) - out.append(i.pre) - out.append(i.render()) - out.append(i.post) - out.append('\n') + out.extend( + ( + f'', + i.pre, + ) + ) + out.extend((i.render(), i.post, '\n')) return ''.join(out) def rendernote(self, note): - if note: return '%s' % net.websafe(note) - else: return "" + return f'{net.websafe(note)}' if note else "" def validates(self, source=None, _validate=True, **kw): source = source or kw or web.input() @@ -126,21 +125,18 @@ def validate(self, value): def render(self): raise NotImplementedError def rendernote(self, note): - if note: return '%s' % net.websafe(note) - else: return "" + return f'{net.websafe(note)}' if note else "" def addatts(self): - str = "" - for (n, v) in self.attrs.items(): - str += ' %s="%s"' % (n, net.websafe(v)) - return str + return "".join(f' {n}="{net.websafe(v)}"' for n, v in self.attrs.items()) #@@ quoting class Textbox(Input): def render(self, shownote=True): - x = ' {net.websafe(arg)} ' x += '' - x += self.rendernote(self.note) + x += self.rendernote(self.note) return x class Checkbox(Input): def render(self): - x = '{safename}' x += self.rendernote(self.note) return x @@ -228,15 +219,16 @@ def __init__(self, name, *validators, **attrs): self.description = "" def render(self): - x = '>> splitline('') ('', '') """ - index = text.find('\n') + 1 - if index: + if index := text.find('\n') + 1: return text[:index], text[index:] else: return text, '' @@ -76,12 +75,11 @@ def parse(self): return DefwithNode(defwith, suite) def read_defwith(self, text): - if text.startswith('$def with'): - defwith, text = splitline(text) - defwith = defwith[1:].strip() # strip $ and spaces - return defwith, text - else: + if not text.startswith('$def with'): return '', text + defwith, text = splitline(text) + defwith = defwith[1:].strip() # strip $ and spaces + return defwith, text def read_section(self, text): r"""Reads one section from the given text. @@ -126,11 +124,11 @@ def read_var(self, text): tokens = self.python_tokens(line) if len(tokens) < 4: raise SyntaxError('Invalid var statement') - + name = tokens[1] sep = tokens[2] value = line.split(sep, 1)[1].strip() - + if sep == '=': pass # no need to process value elif sep == ':': @@ -144,9 +142,9 @@ def read_var(self, text): nodes.append(TextNode('\n')) else: # single-line var statement linenode, _ = self.readline(value) - nodes = linenode.nodes + nodes = linenode.nodes parts = [node.emit('') for node in nodes] - value = "join_(%s)" % ", ".join(parts) + value = f'join_({", ".join(parts)})' else: raise SyntaxError('Invalid var statement') return VarNode(name, value), text @@ -395,14 +393,11 @@ def read_indented_block(self, text, indent): """ if indent == '': return '', text - + block = "" - while True: - if text.startswith(indent): - line, text = splitline(text) - block += line[len(indent):] - else: - break + while text.startswith(indent): + line, text = splitline(text) + block += line[len(indent):] return block, text def read_statement(self, text): @@ -452,7 +447,7 @@ def create_block_node(self, keyword, stmt, block, begin_indent): if keyword in STATEMENT_NODES: return STATEMENT_NODES[keyword](stmt, block, begin_indent) else: - raise ParseError, 'Unknown statement: %s' % repr(keyword) + raise (ParseError, f'Unknown statement: {repr(keyword)}') class PythonTokenizer: """Utility wrapper over python tokenizer.""" @@ -509,7 +504,7 @@ def emit(self, indent): return self.defwith + self.suite.emit(indent + INDENT) def __repr__(self): - return "" % (self.defwith, self.nodes) + return f"" class TextNode: def __init__(self, value): @@ -519,27 +514,24 @@ def emit(self, indent): return repr(self.value) def __repr__(self): - return 't' + repr(self.value) + return f't{repr(self.value)}' class ExpressionNode: def __init__(self, value, escape=True): self.value = value.strip() - + # convert ${...} to $(...) if value.startswith('{') and value.endswith('}'): - self.value = '(' + self.value[1:-1] + ')' - + self.value = f'({self.value[1:-1]})' + self.escape = escape def emit(self, indent): - return 'escape_(%s, %s)' % (self.value, bool(self.escape)) + return f'escape_({self.value}, {bool(self.escape)})' def __repr__(self): - if self.escape: - escape = '' - else: - escape = ':' - return "$%s%s" % (escape, self.value) + escape = '' if self.escape else ':' + return f"${escape}{self.value}" class AssignmentNode: def __init__(self, code): @@ -549,7 +541,7 @@ def emit(self, indent, begin_indent=''): return indent + self.code + "\n" def __repr__(self): - return "" % repr(self.code) + return f"" class LineNode: def __init__(self, nodes): @@ -562,7 +554,7 @@ def emit(self, indent, text_indent='', name=''): return indent + 'yield %s, join_(%s)\n' % (repr(name), ', '.join(text)) def __repr__(self): - return "" % repr(self.nodes) + return f"" INDENT = ' ' # 4 spaces @@ -574,14 +566,13 @@ def __init__(self, stmt, block, begin_indent=''): def emit(self, indent, text_indent=''): text_indent = self.begin_indent + text_indent - out = indent + self.stmt + self.suite.emit(indent + INDENT, text_indent) - return out + return indent + self.stmt + self.suite.emit(indent + INDENT, text_indent) def text(self): return '${' + self.stmt + '}' + "".join([node.text(indent) for node in self.nodes]) def __repr__(self): - return "" % (repr(self.stmt), repr(self.nodelist)) + return f"" class ForNode(BlockNode): def __init__(self, stmt, block, begin_indent=''): @@ -590,11 +581,11 @@ def __init__(self, stmt, block, begin_indent=''): tok.consume_till('in') a = stmt[:tok.index] # for i in b = stmt[tok.index:-1] # rest of for stmt excluding : - stmt = a + ' loop.setup(' + b.strip() + '):' + stmt = f'{a} loop.setup({b.strip()}):' BlockNode.__init__(self, stmt, block, begin_indent) def __repr__(self): - return "" % (repr(self.original_stmt), repr(self.suite)) + return f"" class CodeNode: def __init__(self, stmt, block, begin_indent=''): @@ -606,7 +597,7 @@ def emit(self, indent, text_indent=''): return rx.sub(indent, self.code).rstrip(' ') def __repr__(self): - return "" % repr(self.code) + return f"" class IfNode(BlockNode): pass @@ -629,7 +620,7 @@ def emit(self, indent, text_indent): return indent + 'yield %s, %s\n' % (repr(self.name), self.value) def __repr__(self): - return "" % (self.name, self.value) + return f"" class SuiteNode: """Suite is a list of sections.""" @@ -713,14 +704,10 @@ def __init__(self, forloop, parent): self.parent = parent def setup(self, seq): - if hasattr(seq, '__len__'): - n = len(seq) - else: - n = 0 - + n = len(seq) if hasattr(seq, '__len__') else 0 self.index = 0 seq = iter(seq) - + # Pre python-2.5 does not support yield in try-except. # This is a work-around to overcome that limitation. def next(seq): @@ -729,7 +716,7 @@ def next(seq): except: self._forloop._pop() raise - + while True: self._next(self.index + 1, n) yield next(seq) @@ -753,10 +740,7 @@ def __init__(self, code, filename, filter, globals, builtins): self.filter = filter self._globals = globals self._builtins = builtins - if code: - self.t = self._compile(code) - else: - self.t = lambda: '' + self.t = self._compile(code) if code else (lambda: '') def _compile(self, code): env = self.make_env(self._globals or {}, self._builtins) @@ -849,10 +833,10 @@ def __call__(self, *a, **kw): return BaseTemplate.__call__(self, *a, **kw) - def generate_code(text, filename): + def generate_code(self, filename): # parse the text - rootnode = Parser(text, filename).parse() - + rootnode = Parser(self, filename).parse() + # generate python code from the parse tree code = rootnode.emit(indent="").strip() return safestr(code) @@ -916,12 +900,8 @@ def __init__(self, loc='templates', cache=None, base=None, **keywords): if cache is None: cache = not config.get('debug', False) - - if cache: - self._cache = {} - else: - self._cache = None - + + self._cache = {} if cache else None if base and not hasattr(base, '__call__'): # make base a function, so that it can be passed to sub-renders self._base = lambda page: self._template(base)(page) @@ -933,33 +913,28 @@ def _lookup(self, name): if os.path.isdir(path): return 'dir', path else: - path = self._findfile(path) - if path: - return 'file', path - else: - return 'none', None + return ('file', path) if (path := self._findfile(path)) else ('none', None) def _load_template(self, name): kind, path = self._lookup(name) - + if kind == 'dir': return Render(path, cache=self._cache is not None, base=self._base, **self._keywords) elif kind == 'file': return Template(open(path).read(), filename=path, **self._keywords) else: - raise AttributeError, "No template named " + name + raise (AttributeError, f"No template named {name}") def _findfile(self, path_prefix): - p = [f for f in glob.glob(path_prefix + '.*') if not f.endswith('~')] # skip backup files + p = [f for f in glob.glob(f'{path_prefix}.*') if not f.endswith('~')] return p and p[0] def _template(self, name): - if self._cache is not None: - if name not in self._cache: - self._cache[name] = self._load_template(name) - return self._cache[name] - else: + if self._cache is None: return self._load_template(name) + if name not in self._cache: + self._cache[name] = self._load_template(name) + return self._cache[name] def __getattr__(self, name): t = self._template(name) @@ -1011,44 +986,43 @@ def frender(path, **keywords): def compile_templates(root): """Compiles templates to python code.""" re_start = re_compile('^', re.M) - + for dirpath, dirnames, filenames in os.walk(root): filenames = [f for f in filenames if not f.startswith('.') and not f.endswith('~') and not f.startswith('__init__.py')] - - out = open(os.path.join(dirpath, '__init__.py'), 'w') - out.write('from web.template import CompiledTemplate, ForLoop\n\n') - if dirnames: - out.write("import " + ", ".join(dirnames)) - for f in filenames: - path = os.path.join(dirpath, f) + with open(os.path.join(dirpath, '__init__.py'), 'w') as out: + out.write('from web.template import CompiledTemplate, ForLoop\n\n') + if dirnames: + out.write("import " + ", ".join(dirnames)) - # create template to make sure it compiles - t = Template(open(path).read(), path) - - if '.' in f: - name, _ = f.split('.', 1) - else: - name = f - - code = Template.generate_code(open(path).read(), path) - code = re_start.sub(' ', code) - - _gen = '' + \ - '\ndef %s():' + \ - '\n loop = ForLoop()' + \ - '\n _dummy = CompiledTemplate(lambda: None, "dummy")' + \ - '\n join_ = _dummy._join' + \ - '\n escape_ = _dummy._escape' + \ - '\n' + \ - '\n%s' + \ - '\n return __template__' - - gen_code = _gen % (name, code) - out.write(gen_code) - out.write('\n\n') - out.write('%s = CompiledTemplate(%s(), %s)\n\n' % (name, name, repr(path))) - out.close() + for f in filenames: + path = os.path.join(dirpath, f) + + # create template to make sure it compiles + t = Template(open(path).read(), path) + + if '.' in f: + name, _ = f.split('.', 1) + else: + name = f + + code = Template.generate_code(open(path).read(), path) + code = re_start.sub(' ', code) + + _gen = '' + \ + '\ndef %s():' + \ + '\n loop = ForLoop()' + \ + '\n _dummy = CompiledTemplate(lambda: None, "dummy")' + \ + '\n join_ = _dummy._join' + \ + '\n escape_ = _dummy._escape' + \ + '\n' + \ + '\n%s' + \ + '\n return __template__' + + gen_code = _gen % (name, code) + out.write(gen_code) + out.write('\n\n') + out.write('%s = CompiledTemplate(%s(), %s)\n\n' % (name, name, repr(path))) class ParseError(Exception): pass @@ -1116,22 +1090,21 @@ def visit(self, node, *args): "Recursively validate node and all of its children." def classname(obj): return obj.__class__.__name__ + nodename = classname(node) - fn = getattr(self, 'visit' + nodename, None) - + fn = getattr(self, f'visit{nodename}', None) + if fn: fn(node, *args) else: if nodename not in ALLOWED_AST_NODES: self.fail(node, *args) - + for child in node.getChildNodes(): self.visit(child, *args) def visitName(self, node, *args): "Disallow any attempts to access a restricted attr." - #self.assert_attr(node.getChildren()[0], node) - pass def visitGetattr(self, node, *args): "Disallow any attempts to access a restricted attribute." @@ -1179,7 +1152,7 @@ def __str__(self): return safestr(self.get('__body__', '')) def __repr__(self): - return "" % dict.__repr__(self) + return f"" def test(): r"""Doctest for testing template module. diff --git a/web/webapi.py b/web/webapi.py index 62f4824..0c1f321 100644 --- a/web/webapi.py +++ b/web/webapi.py @@ -87,10 +87,7 @@ def __init__(self, url, status='301 Moved Permanently', absolute=False): newloc = urlparse.urljoin(ctx.path, url) if newloc.startswith('/'): - if absolute: - home = ctx.realhome - else: - home = ctx.home + home = ctx.realhome if absolute else ctx.home newloc = home + newloc headers = { @@ -126,14 +123,11 @@ class NoMethod(HTTPError): """A `405 Method Not Allowed` error.""" def __init__(self, cls=None): status = '405 Method Not Allowed' - headers = {} - headers['Content-Type'] = 'text/html' - methods = ['GET', 'HEAD', 'POST', 'PUT', 'DELETE'] if cls: methods = [method for method in methods if hasattr(cls, method)] - headers['Allow'] = ', '.join(methods) + headers = {'Content-Type': 'text/html', 'Allow': ', '.join(methods)} data = None HTTPError.__init__(self, status, headers, data) @@ -190,27 +184,29 @@ def dictify(fs): fs.list = [] return dict([(k, fs[k]) for k in fs.keys()]) - + _method = defaults.pop('_method', 'both') - + e = ctx.env.copy() a = b = {} - - if _method.lower() in ['both', 'post', 'put']: - if e['REQUEST_METHOD'] in ['POST', 'PUT']: - if e.get('CONTENT_TYPE', '').lower().startswith('multipart/'): - # since wsgi.input is directly passed to cgi.FieldStorage, - # it can not be called multiple times. Saving the FieldStorage - # object in ctx to allow calling web.input multiple times. - a = ctx.get('_fieldstorage') - if not a: - fp = e['wsgi.input'] - a = cgi.FieldStorage(fp=fp, environ=e, keep_blank_values=1) - ctx._fieldstorage = a - else: - fp = StringIO(data()) + + if _method.lower() in ['both', 'post', 'put'] and e['REQUEST_METHOD'] in [ + 'POST', + 'PUT', + ]: + if e.get('CONTENT_TYPE', '').lower().startswith('multipart/'): + # since wsgi.input is directly passed to cgi.FieldStorage, + # it can not be called multiple times. Saving the FieldStorage + # object in ctx to allow calling web.input multiple times. + a = ctx.get('_fieldstorage') + if not a: + fp = e['wsgi.input'] a = cgi.FieldStorage(fp=fp, environ=e, keep_blank_values=1) - a = dictify(a) + ctx._fieldstorage = a + else: + fp = StringIO(data()) + a = cgi.FieldStorage(fp=fp, environ=e, keep_blank_values=1) + a = dictify(a) if _method.lower() in ['both', 'get']: e['REQUEST_METHOD'] = 'GET' diff --git a/web/webopenid.py b/web/webopenid.py index b482216..582bcb3 100644 --- a/web/webopenid.py +++ b/web/webopenid.py @@ -45,8 +45,7 @@ def _random_session(): n = random.random() while n in sessions: n = random.random() - n = str(n) - return n + return str(n) def status(): oid_hash = web.cookies().get('openid_identity_hash', '').split(',', 1) @@ -57,8 +56,7 @@ def status(): return None def form(openid_loc): - oid = status() - if oid: + if oid := status(): return '''

OpenID @@ -109,7 +107,9 @@ def GET(self): a = c.complete(web.input(), web.ctx.home + web.ctx.fullpath) if a.status.lower() == 'success': - web.setcookie('openid_identity_hash', _hmac(a.identity_url) + ',' + a.identity_url) + web.setcookie( + 'openid_identity_hash', f'{_hmac(a.identity_url)},{a.identity_url}' + ) del sessions[n] return web.redirect(return_to) diff --git a/web/wsgi.py b/web/wsgi.py index 2cde078..ef98bd9 100644 --- a/web/wsgi.py +++ b/web/wsgi.py @@ -33,33 +33,26 @@ def runwsgi(func): if (os.environ.has_key('PHP_FCGI_CHILDREN') #lighttpd fastcgi or os.environ.has_key('SERVER_SOFTWARE')): return runfcgi(func, None) - + if 'fcgi' in sys.argv or 'fastcgi' in sys.argv: args = sys.argv[1:] if 'fastcgi' in args: args.remove('fastcgi') elif 'fcgi' in args: args.remove('fcgi') - if args: - return runfcgi(func, validaddr(args[0])) - else: - return runfcgi(func, None) - + return runfcgi(func, validaddr(args[0])) if args else runfcgi(func, None) if 'scgi' in sys.argv: args = sys.argv[1:] args.remove('scgi') - if args: - return runscgi(func, validaddr(args[0])) - else: - return runscgi(func) - + return runscgi(func, validaddr(args[0])) if args else runscgi(func) return httpserver.runsimple(func, validip(listget(sys.argv, 1, ''))) def _is_dev_mode(): # quick hack to check if the program is running in dev mode. - if os.environ.has_key('SERVER_SOFTWARE') \ - or os.environ.has_key('PHP_FCGI_CHILDREN') \ - or 'fcgi' in sys.argv or 'fastcgi' in sys.argv: - return False - return True + return ( + not os.environ.has_key('SERVER_SOFTWARE') + and not os.environ.has_key('PHP_FCGI_CHILDREN') + and 'fcgi' not in sys.argv + and 'fastcgi' not in sys.argv + ) # When running the builtin-server, enable debug mode if not already set. web.config.setdefault('debug', _is_dev_mode()) \ No newline at end of file diff --git a/web/wsgiserver/__init__.py b/web/wsgiserver/__init__.py index f9396c8..3db00e2 100644 --- a/web/wsgiserver/__init__.py +++ b/web/wsgiserver/__init__.py @@ -169,12 +169,12 @@ def __call__(self, environ, start_response): path = environ["PATH_INFO"] or "/" for p, app in self.apps: # The apps list should be sorted by length, descending. - if path.startswith(p + "/") or path == p: + if path.startswith(f"{p}/") or path == p: environ = environ.copy() environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p environ["PATH_INFO"] = path[len(p):] return app(environ, start_response) - + start_response('404 Not Found', [('Content-Type', 'text/plain'), ('Content-Length', '0')]) return [''] @@ -224,13 +224,11 @@ def readlines(self, sizehint=0): # Shamelessly stolen from StringIO total = 0 lines = [] - line = self.readline() - while line: + while line := self.readline(): lines.append(line) total += len(line) if 0 < sizehint <= total: break - line = self.readline() return lines def close(self): @@ -465,17 +463,17 @@ def _parse_request(self): def read_headers(self): """Read header lines from the incoming stream.""" environ = self.environ - + while True: line = self.rfile.readline() if not line: # No more data--illegal end of headers raise ValueError("Illegal end of headers.") - + if line == '\r\n': # Normal end of headers break - + if line[0] in ' \t': # It's a continuation line. v = line.strip() @@ -483,13 +481,12 @@ def read_headers(self): k, v = line.split(":", 1) k, v = k.strip().upper(), v.strip() envname = "HTTP_" + k.replace("-", "_") - + if k in comma_separated_headers: - existing = environ.get(envname) - if existing: + if existing := environ.get(envname): v = ", ".join((existing, v)) environ[envname] = v - + ct = environ.pop("HTTP_CONTENT_TYPE", None) if ct is not None: environ["CONTENT_TYPE"] = ct @@ -548,11 +545,10 @@ def respond(self): return def _respond(self): - if self.chunked_read: - if not self.decode_chunked(): - self.close_connection = True - return - + if self.chunked_read and not self.decode_chunked(): + self.close_connection = True + return + response = self.wsgi_app(self.environ, self.start_response) try: for chunk in response: @@ -567,7 +563,7 @@ def _respond(self): finally: if hasattr(response, "close"): response.close() - + if (self.ready and not self.sent_headers): self.sent_headers = True self.send_headers() @@ -641,17 +637,12 @@ def send_headers(self): """Assert, process, and send the HTTP response message-headers.""" hkeys = [key.lower() for key, value in self.outheaders] status = int(self.status[:3]) - + if status == 413: # Request Entity Too Large. Close conn to avoid garbage. self.close_connection = True elif "content-length" not in hkeys: - # "All 1xx (informational), 204 (no content), - # and 304 (not modified) responses MUST NOT - # include a message-body." So no point chunking. - if status < 200 or status in (204, 205, 304): - pass - else: + if status >= 200 and status not in (204, 205, 304): if (self.response_protocol == 'HTTP/1.1' and self.environ["REQUEST_METHOD"] != 'HEAD'): # Use the chunked transfer-coding @@ -660,17 +651,15 @@ def send_headers(self): else: # Closing the conn is the only way to determine len. self.close_connection = True - + if "connection" not in hkeys: if self.response_protocol == 'HTTP/1.1': # Both server and client are HTTP/1.1 or better if self.close_connection: self.outheaders.append(("Connection", "close")) - else: - # Server and/or client are HTTP/1.0 - if not self.close_connection: - self.outheaders.append(("Connection", "Keep-Alive")) - + elif not self.close_connection: + self.outheaders.append(("Connection", "Keep-Alive")) + if (not self.close_connection) and (not self.chunked_read): # Read any remaining request body data on the socket. # "If an origin server receives a request that does not include an @@ -687,16 +676,16 @@ def send_headers(self): size = self.rfile.maxlen - self.rfile.bytes_read if size > 0: self.rfile.read(size) - + if "date" not in hkeys: self.outheaders.append(("Date", rfc822.formatdate())) - + if "server" not in hkeys: self.outheaders.append(("Server", self.environ['SERVER_SOFTWARE'])) - + buf = [self.environ['ACTUAL_SERVER_PROTOCOL'], " ", self.status, "\r\n"] try: - buf += [k + ": " + v + "\r\n" for k, v in self.outheaders] + buf += [f"{k}: {v}" + "\r\n" for k, v in self.outheaders] except TypeError: if not isinstance(k, str): raise TypeError("WSGI response header key %r is not a string.") @@ -764,11 +753,10 @@ def read(self, size=-1): # Read until EOF self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. while True: - data = self.recv(rbufsize) - if not data: + if data := self.recv(rbufsize): + buf.write(data) + else: break - buf.write(data) - return buf.getvalue() else: # Read until size bytes or EOF seen, whichever comes first buf_len = buf.tell() @@ -808,7 +796,8 @@ def read(self, size=-1): buf_len += n del data # explicit free #assert buf_len == buf.tell() - return buf.getvalue() + + return buf.getvalue() def readline(self, size=-1): buf = self._rbuf @@ -852,7 +841,6 @@ def readline(self, size=-1): del data break buf.write(data) - return buf.getvalue() else: # Read until size bytes or \n or EOF seen, whichever comes first buf.seek(0, 2) # seek end @@ -875,13 +863,12 @@ def readline(self, size=-1): nl += 1 # save the excess data to _rbuf self._rbuf.write(data[nl:]) - if buf_len: - buf.write(data[:nl]) - break - else: + if not buf_len: # Shortcut. Avoid data copy through buf when returning # a substring of our first recv(). return data[:nl] + buf.write(data[:nl]) + break n = len(data) if n == size and not buf_len: # Shortcut. Avoid data copy through buf when @@ -893,8 +880,9 @@ def readline(self, size=-1): break buf.write(data) buf_len += n - #assert buf_len == buf.tell() - return buf.getvalue() + #assert buf_len == buf.tell() + + return buf.getvalue() else: class CP_fileobject(socket._fileobject): @@ -933,17 +921,12 @@ def read(self, size=-1): # Read until EOF buffers = [self._rbuf] self._rbuf = "" - if self._rbufsize <= 1: - recv_size = self.default_bufsize - else: - recv_size = self._rbufsize - + recv_size = self.default_bufsize if self._rbufsize <= 1 else self._rbufsize while True: - data = self.recv(recv_size) - if not data: + if data := self.recv(recv_size): + buffers.append(data) + else: break - buffers.append(data) - return "".join(buffers) else: # Read until size bytes or EOF seen, whichever comes first data = self._rbuf @@ -968,7 +951,8 @@ def read(self, size=-1): buffers[-1] = data[:left] break buf_len += n - return "".join(buffers) + + return "".join(buffers) def readline(self, size=-1): data = self._rbuf @@ -1004,7 +988,6 @@ def readline(self, size=-1): self._rbuf = data[nl:] buffers[-1] = data[:nl] break - return "".join(buffers) else: # Read until size bytes or \n or EOF seen, whichever comes first nl = data.find('\n', 0, size) @@ -1038,7 +1021,8 @@ def readline(self, size=-1): buffers[-1] = data[:left] break buf_len += n - return "".join(buffers) + + return "".join(buffers) class SSL_fileobject(CP_fileobject): @@ -1282,10 +1266,10 @@ def __init__(self, server, min=10, max=-1): def start(self): """Start the pool of threads.""" - for i in xrange(self.min): + for _ in xrange(self.min): self._threads.append(WorkerThread(self.server)) for worker in self._threads: - worker.setName("CP WSGIServer " + worker.getName()) + worker.setName(f"CP WSGIServer {worker.getName()}") worker.start() for worker in self._threads: while not worker.ready: @@ -1303,11 +1287,11 @@ def put(self, obj): def grow(self, amount): """Spawn new worker threads (not above self.max).""" - for i in xrange(amount): + for _ in xrange(amount): if self.max > 0 and len(self._threads) >= self.max: break worker = WorkerThread(self.server) - worker.setName("CP WSGIServer " + worker.getName()) + worker.setName(f"CP WSGIServer {worker.getName()}") self._threads.append(worker) worker.start() @@ -1319,9 +1303,9 @@ def shrink(self, amount): if not t.isAlive(): self._threads.remove(t) amount -= 1 - + if amount > 0: - for i in xrange(min(amount, len(self._threads) - self.min)): + for _ in xrange(min(amount, len(self._threads) - self.min)): # Put a number of shutdown requests on the queue equal # to 'amount'. Once each of those is processed by a worker, # that worker will terminate and be culled from our list @@ -1744,30 +1728,26 @@ def populate_ssl_environ(self): "wsgi.url_scheme": "https", "HTTPS": "on", # pyOpenSSL doesn't provide access to any of these AFAICT -## 'SSL_PROTOCOL': 'SSLv2', -## SSL_CIPHER string The cipher specification name -## SSL_VERSION_INTERFACE string The mod_ssl program version -## SSL_VERSION_LIBRARY string The OpenSSL program version - } - - # Server certificate attributes - ssl_environ.update({ + # ## 'SSL_PROTOCOL': 'SSLv2', + # ## SSL_CIPHER string The cipher specification name + # ## SSL_VERSION_INTERFACE string The mod_ssl program version + # ## SSL_VERSION_LIBRARY string The OpenSSL program version + } | { 'SSL_SERVER_M_VERSION': cert.get_version(), 'SSL_SERVER_M_SERIAL': cert.get_serial_number(), -## 'SSL_SERVER_V_START': Validity of server's certificate (start time), -## 'SSL_SERVER_V_END': Validity of server's certificate (end time), - }) - + # ## 'SSL_SERVER_V_START': Validity of server's certificate (start time), + # ## 'SSL_SERVER_V_END': Validity of server's certificate (end time), + } for prefix, dn in [("I", cert.get_issuer()), ("S", cert.get_subject())]: # X509Name objects don't seem to have a way to get the # complete DN string. Use str() and slice it instead, # because str(dn) == "" dnstr = str(dn)[18:-2] - - wsgikey = 'SSL_SERVER_%s_DN' % prefix + + wsgikey = f'SSL_SERVER_{prefix}_DN' ssl_environ[wsgikey] = dnstr - + # The DN should be of the form: /k1=v1/k2=v2, but we must allow # for any value to contain slashes itself (in a URL). while dnstr: @@ -1776,7 +1756,7 @@ def populate_ssl_environ(self): pos = dnstr.rfind("/") dnstr, key = dnstr[:pos], dnstr[pos + 1:] if key and value: - wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key) + wsgikey = f'SSL_SERVER_{prefix}_DN_{key}' ssl_environ[wsgikey] = value - + self.environ.update(ssl_environ)