diff --git a/data_url/__init__.py b/data_url/__init__.py index a8617e8..df2f304 100644 --- a/data_url/__init__.py +++ b/data_url/__init__.py @@ -2,9 +2,17 @@ import base64 DATA_URL_RE = re.compile( - r"data:(?P([\w-]+\/[\w+\.-]+(;[\w-]+\=[\w-]+)?)?)(?P;base64)?,(?P[\w\d.~%\=\/\+-]+)" + r""" + data: # literal data: + (?P[a-z][a-z0-9\-]+/[a-z][\w\-\.\+]+)? # optional media type + (?P(?:;[\w\-\.+]+=[\w\-\.+%]+)*) # optional attribute=values, value can be url encoded + (?P;base64)?, # optional base64 flag + (?P[\w\d.~%\=\/\+-]+) # the data + """, + re.MULTILINE | re.VERBOSE ) + def construct_data_url(mime_type, base64_encoded, data): """ Helper method for just creating a data URL from some data. If this @@ -47,8 +55,9 @@ def from_url(cls, url): """ data_url = cls() data_url._url = url - data_url.__parse_url() - return data_url + if data_url.__parse_url(): + return data_url + return None @classmethod def from_data(cls, mime_type, base64_encoded, data): @@ -107,13 +116,16 @@ def __parse_url(self): """Parses a data URL to get each individual element and sets the respecting class attributes.""" match = DATA_URL_RE.fullmatch(self._url) - self._is_base64_encoded = match.group('encoded') is not None - self._mime_type = match.group("MIME") - raw_data = match.group('data') - if self._is_base64_encoded: - self._data = base64.b64decode(raw_data) - else: - self._data = raw_data + if match: + self._is_base64_encoded = match.group('encoded') is not None + self._mime_type = match.group("MIME") or "" + raw_data = match.group('data') + if self._is_base64_encoded: + self._data = base64.b64decode(raw_data) + else: + self._data = raw_data + return True + return False def __construct_url(self): """Constructs an actual data URL string from class attributes.""" diff --git a/test/test_url.py b/test/test_url.py index f9635d5..43110ed 100644 --- a/test/test_url.py +++ b/test/test_url.py @@ -60,6 +60,10 @@ def test_construct_data_url(self): self.assertEqual(raw_data, deconstructed_url.data) self.assertEqual(data, deconstructed_url.encoded_data) + def test_non_compliant_url(self): + url = DataURL.from_url("not a url") + assert url is None + class TestFromData(unittest.TestCase): def test_typing(self): with self.assertRaises(Exception) as context: