diff --git a/data/complex_array.mat b/data/complex_array.mat new file mode 100644 index 0000000..0a9d34e Binary files /dev/null and b/data/complex_array.mat differ diff --git a/data/test_data.json b/data/test_data.json index 33f390f..077c28e 100644 --- a/data/test_data.json +++ b/data/test_data.json @@ -1,5 +1,6 @@ { "loadmat": { + "complex_array.mat": {"a": [[{"__complex__":[0,0]},{"__complex__":[0,1]}],[{"__complex__":[1,0]},{"__complex__":[0,2]}],[{"__complex__":[2,2]},{"__complex__":[0,0]}]]}, "cell_array.mat": {"c": [["big", "little"], [1, 2, 3]]}, "char_array.mat": {"a": ["123", "456"]}, "struct_array.mat": {"s": {"color": ["red", "red"], "x": [3, 4], "type": ["big", "little"]}}, diff --git a/mat4py/.gitignore b/mat4py/.gitignore new file mode 100644 index 0000000..a348e50 --- /dev/null +++ b/mat4py/.gitignore @@ -0,0 +1 @@ +/__pycache__/ diff --git a/mat4py/loadmat.py b/mat4py/loadmat.py index 259818e..ff9cb4f 100644 --- a/mat4py/loadmat.py +++ b/mat4py/loadmat.py @@ -270,10 +270,16 @@ def read_numeric_array(fd, endian, header, data_etypes): """Read a numeric matrix. Returns an array with rows of the numeric matrix. """ - if header['is_complex']: - raise ParseError('Complex arrays are not supported') # read array data (stored as column-major) - data = read_elements(fd, endian, data_etypes) + if(header["is_complex"]): + realData = read_elements(fd, endian, data_etypes) + imagData = read_elements(fd, endian, data_etypes) + data = list() + for dataIndex in range(0,len(realData)): + data.append(complex(realData[dataIndex],imagData[dataIndex])) + data = tuple(data) + else: + data = read_elements(fd, endian, data_etypes) if not isinstance(data, Sequence): # not an array, just a value return data diff --git a/mat4py/savemat.py b/mat4py/savemat.py index c64bf7c..1511ed8 100644 --- a/mat4py/savemat.py +++ b/mat4py/savemat.py @@ -176,8 +176,13 @@ def write_var_header(fd, header): # write tag bytes, # and array flags + class and nzmax (null bytes) + if(not ('is_complex' in header.keys())): + header['is_complex'] = False fd.write(struct.pack('b3xI', etypes['miUINT32']['n'], 8)) - fd.write(struct.pack('b3x4x', mclasses[header['mclass']])) + if(header['is_complex']): + fd.write(struct.pack('bb2x4x', mclasses[header['mclass']],8)) + else: + fd.write(struct.pack('b3x4x', mclasses[header['mclass']])) # write dimensions array write_elements(fd, 'miINT32', header['dims']) @@ -221,7 +226,19 @@ def write_numeric_array(fd, header, array): array = list(chain.from_iterable(izip(*array))) # write matrix data to memory file - write_elements(bd, header['mtp'], array) + if(not ('is_complex' in header.keys())): + header['is_complex'] = False + + if(header['is_complex']): + arrayReal = list() + arrayImag = list() + for valueIndex in range(0,len(array)): + arrayReal.append(array[valueIndex].real) + arrayImag.append(array[valueIndex].imag) + write_elements(bd, header['mtp'], arrayReal) + write_elements(bd, header['mtp'], arrayImag) + else: + write_elements(bd, header['mtp'], array) # write the variable to disk file data = bd.getvalue() @@ -392,6 +409,13 @@ def guess_header(array, name=''): 'mclass': 'mxDOUBLE_CLASS', 'mtp': 'miDOUBLE', 'dims': (1, len(array))}) + elif isarray(array, lambda i: isinstance(i, complex), 1): + # 1D double array + header.update({ + 'mclass': 'mxDOUBLE_CLASS', 'mtp': 'miDOUBLE', + 'dims': (1, len(array)), + 'is_complex' : True}) + elif (isarray(array, lambda i: isinstance(i, Sequence), 1) and any(diff(len(s) for s in array))): # sequence of unequal length, assume cell array @@ -428,6 +452,14 @@ def guess_header(array, name=''): 'mtp': 'miDOUBLE', 'dims': (len(array), len(array[0]))}) + elif isarray(array, lambda i: isinstance(i, complex)): + # 2D double array + header.update({ + 'mclass': 'mxDOUBLE_CLASS', + 'mtp': 'miDOUBLE', + 'is_complex' : True, + 'dims': (len(array), len(array[0]))}) + elif isarray(array, lambda i: isinstance( i, (int, float, basestring, Sequence, Mapping))): # mixed contents, make it a cell array diff --git a/tests.py b/tests.py index e95ff0f..c844f55 100644 --- a/tests.py +++ b/tests.py @@ -1,17 +1,15 @@ - import sys if sys.version_info[0] == 2: # unittest2 required with python2 for subTest() functionality import unittest2 as unittest else: import unittest -import json +from json_tricks import dumps, load # https://json-tricks.readthedocs.io/en/latest/ import os import mat4py - -test_data = json.load(open('data/test_data.json')) +test_data = load(open('data/test_data.json')) class TestSequenceFunctions(unittest.TestCase): @@ -59,4 +57,6 @@ def test_save_load_mat2(self): if __name__ == '__main__': + # a = (complex(0,1),complex(2,0)) + # print(dumps(a)) unittest.main()