diff --git a/tests/test_xmljson.py b/tests/test_xmljson.py
index a4ebb18..7eb5c55 100644
--- a/tests/test_xmljson.py
+++ b/tests/test_xmljson.py
@@ -13,13 +13,14 @@
import unittest
from collections import OrderedDict as Dict
-from lxml.etree import tostring as tostring, fromstring
+from lxml.etree import tostring as tostring, fromstring, ElementTree
from lxml.doctestcompare import LXMLOutputChecker
import lxml.html
import lxml.etree
import xml.etree.cElementTree
import xmljson
+
# For Python 3, decode byte strings as UTF-8
if sys.version_info[0] == 3:
def decode(s):
@@ -54,7 +55,24 @@ def compare(jsonstring, xmlstring):
first = json.loads(jsonstring, object_pairs_hook=Dict)
second = conv.data(fromstring(xmlstring))
self.assertEqual(first, second)
+ return compare
+ def check_nsmap(self, conv):
+ def compare(xmlstring):
+ result = conv.data(fromstring(xmlstring))
+ root = conv.etree(result)
+ t1 = fromstring(xmlstring)
+ t2 = root[0]
+ try:
+ t1.nsmap
+ except:
+ ns = {'charlie': "http://some-other-namespace"}
+
+ r1 = t1.find('charlie:joe', ns)
+ r2 = t2.find('charlie:joe', ns)
+ self.assertEqual(r1.tag, r2.tag)
+ return
+ self.assertEqual(t1.nsmap, t2.nsmap)
return compare
@@ -163,9 +181,9 @@ def test_data(self):
'bob')
def test_xml_namespace(self):
- 'XML namespaces are not yet implemented'
- with self.assertRaises(ValueError):
- xmljson.badgerfish.etree({'alice': {'@xmlns': {'$': 'http:\/\/some-namespace'}}})
+ 'Checks nsmap attribute of root tag'
+ eq = self.check_nsmap(xmljson.badgerfish)
+ eq('bob')
def test_custom_dict(self):
'Conversion to dict uses OrderedDict'
diff --git a/xmljson/__init__.py b/xmljson/__init__.py
index d63c1d3..b103f4c 100644
--- a/xmljson/__init__.py
+++ b/xmljson/__init__.py
@@ -2,10 +2,12 @@
import sys
from collections import Counter, OrderedDict
+from io import BytesIO
+
try:
- from lxml.etree import Element
+ from lxml.etree import Element, iterparse, ElementTree, tostring
except ImportError:
- from xml.etree.cElementTree import Element
+ from xml.etree.cElementTree import Element, iterparse, ElementTree
__author__ = 'S Anand'
__email__ = 'root.node@gmail.com'
@@ -19,7 +21,7 @@
class XMLData(object):
def __init__(self, xml_fromstring=True, xml_tostring=True, element=None, dict_type=None,
- list_type=None, attr_prefix=None, text_content=None, simple_text=False):
+ list_type=None, attr_prefix=None, text_content=None, ns_name=None, simple_text=False):
# xml_fromstring == False(y) => '1' -> '1'
# xml_fromstring == True => '1' -> 1
# xml_fromstring == fn => '1' -> fn(1)
@@ -44,6 +46,15 @@ def __init__(self, xml_fromstring=True, xml_tostring=True, element=None, dict_ty
# simple_text == True => 'a' = {'x': 'a'}
self.simple_text = simple_text
+ self.ns_name = ns_name
+ try:
+ elem = Element("html", nsmap={None: 'test'})
+ elem.nsmap
+ self.lxml_lib = True
+ self.root_count = 0
+ except:
+ self.lxml_lib = False
+
@staticmethod
def _tostring(value):
'Convert value to XML compatible string'
@@ -51,6 +62,8 @@ def _tostring(value):
value = 'true'
elif value is False:
value = 'false'
+ else:
+ value = str(value)
return unicode(value) # noqa: convert to whatever native unicode repr
@staticmethod
@@ -64,7 +77,10 @@ def _fromstring(value):
elif std_value == 'false':
return False
try:
- return int(std_value)
+ if std_value.startswith('0'):
+ return std_value
+ else:
+ return int(std_value)
except ValueError:
pass
try:
@@ -76,6 +92,7 @@ def _fromstring(value):
def etree(self, data, root=None):
'Convert data structure into a list of etree.Element'
result = self.list() if root is None else root
+
if isinstance(data, (self.dict, dict)):
for key, value in data.items():
value_is_list = isinstance(value, (self.list, list))
@@ -88,7 +105,21 @@ def etree(self, data, root=None):
key = key.lstrip(self.attr_prefix)
# @xmlns: {$: xxx, svg: yyy} becomes xmlns="xxx" xmlns:svg="yyy"
if value_is_dict:
- raise ValueError('XML namespaces not yet supported')
+ if self.lxml_lib:
+ if key == self.ns_name.lstrip(self.attr_prefix):
+ # Actually nothing to do here
+ pass
+ else:
+ for k in value.keys():
+ if len(k) > 0:
+ if k == self.text_content:
+ k_default = 'ns0'
+ self.ns_counter += 1
+ result.set('xmlns:' + k_default, self._tostring(value[k]))
+ else:
+ result.set('xmlns:' + k, self._tostring(value[k]))
+ else:
+ result.set('xmlns' + k, self._tostring(value[k]))
else:
result.set(key, self._tostring(value))
continue
@@ -105,8 +136,34 @@ def etree(self, data, root=None):
# Add other keys as one or more children
values = value if value_is_list else [value]
for value in values:
- elem = self.element(key)
- result.append(elem)
+ if value_is_dict:
+ # Add namespaces to nodes if @xmlns present
+ if self.ns_name in value.keys() and self.lxml_lib:
+ NS_MAP = self.dict()
+ for k in value[self.ns_name]:
+ prefix = k
+ if prefix == self.text_content:
+ prefix = 'ns0'
+ uri = value[self.ns_name][k]
+
+ if ':' in key:
+ prefix, tag = key.split(':')
+ key = tag
+
+ NS_MAP[prefix] = uri
+ continue
+
+ if len(value[self.ns_name]) > 1:
+ uri = ''
+ elem = self.element('{0}{1}'.format('{' + uri + '}', key), nsmap=NS_MAP)
+ result.append(elem)
+ else:
+ elem = self.element(key)
+ result.append(elem)
+ else:
+ elem = self.element(key)
+ result.append(elem)
+
# Treat scalars as text content, not children (Parker)
if not isinstance(value, (self.dict, dict, self.list, list)):
if self.text_content:
@@ -122,10 +179,47 @@ def etree(self, data, root=None):
def data(self, root):
'Convert etree.Element into a dictionary'
value = self.dict()
+ root = XMLData._process_ns(self, element=root)
+
children = [node for node in root if isinstance(node.tag, basestring)]
- for attr, attrval in root.attrib.items():
- attr = attr if self.attr_prefix is None else self.attr_prefix + attr
- value[attr] = self._fromstring(attrval)
+
+ # form lxml.Element with namespaces if present
+ if self.lxml_lib:
+ if root.tag.startswith('{'):
+ uri, root.tag = root.tag.split('}')
+ uri = uri.lstrip('{')
+ nsmap = root.nsmap
+ value[self.ns_name] = {}
+
+ # pushing namespaces to dic; Filtering namespaces by prefix except root node
+ for key in nsmap.keys():
+ if self.root_count == 0:
+ value[self.ns_name].update({key: nsmap[key]})
+ else:
+ if nsmap[key] == uri:
+ value[self.ns_name].update({key: nsmap[key]})
+ self.root_count += 1
+ else:
+ for attr, attrval in root.attrib.items():
+ attr = attr if self.attr_prefix is None else self.attr_prefix + attr
+ value[attr] = self._fromstring(attrval)
+ else:
+ for attr, attrval in root.attrib.items():
+ attr = attr if self.attr_prefix is None else self.attr_prefix + attr
+
+ if self.attr_prefix:
+ if self.ns_name in attr:
+ if not attr.endswith(':'):
+ prefix = attr.split(':')[1]
+ value[attr.replace(prefix, '')] = {prefix: self._fromstring(attrval)}
+ else:
+ prefix = attr.split(':')[1]
+ value['@xmlns'] = {prefix: self._fromstring(attrval)}
+ else:
+ value[attr] = self._fromstring(attrval)
+ else:
+ value[attr] = self._fromstring(attrval)
+
if root.text and self.text_content is not None:
text = root.text.strip()
if text:
@@ -133,8 +227,10 @@ def data(self, root):
value = self._fromstring(text)
else:
value[self.text_content] = self._fromstring(text)
+
count = Counter(child.tag for child in children)
for child in children:
+ child = XMLData._process_ns(self, child)
if count[child.tag] == 1:
value.update(self.data(child))
else:
@@ -142,11 +238,65 @@ def data(self, root):
result += self.data(child).values()
return self.dict([(root.tag, value)])
+ @staticmethod
+ def _process_ns(cls, element):
+ if element.tag.startswith('{'):
+ if any([True if k.split(':')[0] == 'xmlns' else False for k in element.attrib.keys()]):
+ revers_attr = {v:k for k,v in element.attrib.items()}
+
+ end_prefix = element.tag.find('}')
+ uri = element.tag[:end_prefix+1]
+ key_prefix = revers_attr[uri.strip('{}')]
+ prefix = key_prefix.split(':')[1]
+
+ if len(prefix) > 1:
+ element.tag = element.tag.replace(uri, prefix + ':')
+ else:
+ element.tag = element.tag.replace(uri, '')
+
+ # trick to determine if given element is root element
+ try:
+ _ = element.getroot()
+ element.attrib.pop(key_prefix, None)
+ except:
+ pass
+ else:
+ ns_keys = [k if k.split(':')[0] == 'xmlns' else None for k in element.attrib.keys()]
+ for key in ns_keys:
+ if key:
+ element.attrib.pop(key, None)
+ return element
+
+ @classmethod
+ def parse_nsmap(cls, file):
+ # Parse given file-like xml object for namespaces
+ if isinstance(file, (str)):
+ file = BytesIO(file.encode('utf-8'))
+
+ events = "start", "start-ns", "end-ns"
+ root = None
+ ns_map = []
+
+ for event, elem in iterparse(file, events):
+ if event == "start-ns":
+ ns_map.append(elem)
+ elif event == "end-ns":
+ ns_map.pop()
+ elif event == "start":
+ if root is None:
+ root = elem
+ if ns_map:
+ for ns in ns_map:
+ ns_prefix = ns[0]
+ ns_uri = ns[1]
+ elem.set('xmlns:{}'.format(ns_prefix), ns_uri)
+ return ElementTree(root).getroot()
+
class BadgerFish(XMLData):
'Converts between XML and data using the BadgerFish convention'
def __init__(self, **kwargs):
- super(BadgerFish, self).__init__(attr_prefix='@', text_content='$', **kwargs)
+ super(BadgerFish, self).__init__(attr_prefix='@', text_content='$', ns_name='@xmlns', **kwargs)
class GData(XMLData):