diff --git a/pystdf/IO.py b/pystdf/IO.py index 42c4071..502565f 100644 --- a/pystdf/IO.py +++ b/pystdf/IO.py @@ -27,18 +27,32 @@ from pystdf.Pipeline import DataSource -def appendFieldParser(fn, action): - """Append a field parsing function to a record parsing function. - This is used to build record parsing functions based on the record type specification.""" - def newRecordParser(*args): - fields = fn(*args) - try: - fields.append(action(*args)) - except EndOfRecordException: pass - return fields - return newRecordParser +def memorize(func): + """Cache method results in instance's _field_parser_cache dict.""" + def wrapper(self, fieldType, count): + cache = self.__dict__.setdefault('_field_parser_cache', {}) + return cache.setdefault((fieldType, count), func(self, fieldType, count)) + return wrapper + +def groupConsecutiveDuplicates(fieldsList): + """Groups consecutive identical field types and returns them with their counts. + + Examples: + >>> groupConsecutiveDuplicates(['U4', 'U4', 'U1', 'C1', 'C1', 'C1']) + [('U4', 2), ('U1', 1), ('C1', 3)] + >>> groupConsecutiveDuplicates([]) + [] + """ + import itertools + return ( + [(key, len(list(group))) for key, group in itertools.groupby(fieldsList)] + if fieldsList + else [] + ) class Parser(DataSource): + _kFieldPattern = re.compile(r'k(\d+)([A-Z][a-z0-9]+)') + def readAndUnpack(self, header, fmt): size = struct.calcsize(fmt) if (size > header.len): @@ -74,6 +88,30 @@ def readField(self, header, stdfFmt): def readFieldDirect(self, stdfFmt): return self.readAndUnpackDirect(packFormatMap[stdfFmt]) + def batchReadFields(self, header, stdfFmt, count): + fmt = packFormatMap[stdfFmt] + size = struct.calcsize(fmt) + totalSize = size * count + if (totalSize > header.len): + fullCount = header.len // size + if not fullCount: + header.len = 0 + return (None,) * count + tmpResult = list(self.batchReadFields(header, stdfFmt, fullCount)) + header.len = 0 + tmpResult.extend([None] * (count - fullCount)) + return tuple(tmpResult) + buf = self.inp.read(totalSize) + if len(buf) == 0: + self.eof = 1 + raise EofException() + header.len -= totalSize + vals = struct.unpack(self.endian + fmt * count, buf) + if isinstance(vals[0],bytes): + return tuple(val.decode("ascii") for val in vals) + else: + return vals + def readCn(self, header): if header.len == 0: raise EndOfRecordException() @@ -195,18 +233,41 @@ def parse(self, count=0): self.cancel(exception) raise - def getFieldParser(self, fieldType): + @memorize + def getFieldParser(self, fieldType, count): if (fieldType.startswith("k")): - fieldIndex, arrayFmt = re.match('k(\d+)([A-Z][a-z0-9]+)', fieldType).groups() - return lambda self, header, fields: self.readArray(header, fields[int(fieldIndex)], arrayFmt) - else: - parseFn = self.unpackMap[fieldType] - return lambda self, header, fields: parseFn(header, fieldType) + fieldIndex, arrayFmt = self._kFieldPattern.match(fieldType).groups() + def parseDynamicArray(parser, header, fields): + return parser.readArray(header, fields[int(fieldIndex)], arrayFmt) + return parseDynamicArray, count + if fieldType in self._unpackMap: + def parseBatchedFields(parser, header, fields): + result = parser.batchReadFields(header, fieldType, count) + return result + return parseBatchedFields, 1 + parseFn = self.unpackMap[fieldType] + def parseIndividualFields(parser, header, fields): + return parseFn(header, fieldType) + return parseIndividualFields, count def createRecordParser(self, recType): - fn = lambda self, header, fields: fields - for stdfType in recType.fieldStdfTypes: - fn = appendFieldParser(fn, self.getFieldParser(stdfType)) + fieldParsers = [] + groupedFields = groupConsecutiveDuplicates(recType.fieldStdfTypes) + for (stdfType, count) in groupedFields: + func, times = self.getFieldParser(stdfType, count) + for _ in range(times): + fieldParsers.append(func) + + def fn(parser, header, fields): + try: + for parseField in fieldParsers: + result = parseField(parser, header, fields) + if isinstance(result, tuple): + fields.extend(result) + else: + fields.append(result) + except EndOfRecordException: pass + return fields return fn def __init__(self, recTypes=V4.records, inp=sys.stdin, reopen_fn=None, endian=None): @@ -221,23 +282,19 @@ def __init__(self, recTypes=V4.records, inp=sys.stdin, reopen_fn=None, endian=No [ ( (recType.typ, recType.sub), recType ) for recType in recTypes ]) + self._unpackMap = { + ftype: self.readField + for ftype in ("C1", "B1", "U1", "U2", "U4", "U8", + "I1", "I2", "I4", "I8", "R4", "R8") + } self.unpackMap = { - "C1": self.readField, - "B1": self.readField, - "U1": self.readField, - "U2": self.readField, - "U4": self.readField, - "U8": self.readField, - "I1": self.readField, - "I2": self.readField, - "I4": self.readField, - "I8": self.readField, - "R4": self.readField, - "R8": self.readField, - "Cn": lambda header, fmt: self.readCn(header), - "Bn": lambda header, fmt: self.readBn(header), - "Dn": lambda header, fmt: self.readDn(header), - "Vn": lambda header, fmt: self.readVn(header) + **self._unpackMap, + **{ + "Cn": lambda header, _: self.readCn(header), + "Bn": lambda header, _: self.readBn(header), + "Dn": lambda header, _: self.readDn(header), + "Vn": lambda header, _: self.readVn(header) + } } self.recordParsers = dict(