diff --git a/cutplace/__init__.py b/cutplace/__init__.py index 3e527267..19bff138 100644 --- a/cutplace/__init__.py +++ b/cutplace/__init__.py @@ -6,11 +6,10 @@ accessible through a Python API. """ from cutplace.data import DataFormat, FORMAT_DELIMITED, FORMAT_EXCEL, FORMAT_FIXED, FORMAT_ODS -from cutplace.errors import \ - CheckError, CutplaceError, DataError, DataFormatError, FieldValueError, InterfaceError, Location, RangeValueError +from cutplace.errors import Location from cutplace.interface import Cid from cutplace.ranges import Range -from cutplace.validio import Reader +from cutplace.validio import Reader, Writer from cutplace._version import get_versions #: Package version information. @@ -20,20 +19,13 @@ #: Public classes and functions. __all__ = [ 'Cid', - 'CheckError', - 'CutplaceError', - 'DataError', - 'DataFormat', - 'DataFormatError', 'FORMAT_DELIMITED', 'FORMAT_EXCEL', 'FORMAT_FIXED', 'FORMAT_ODS', - 'FieldValueError', - 'InterfaceError', 'Location', 'Range', - 'RangeValueError', 'Reader', + 'Writer', '__version__' ] diff --git a/cutplace/applications.py b/cutplace/applications.py index 7dbeb1da..b34c0f93 100644 --- a/cutplace/applications.py +++ b/cutplace/applications.py @@ -120,9 +120,10 @@ def validate(self, data_path): assert self.cid is not None _log.info('validate "%s"', data_path) - reader = validio.Reader(self.cid, data_path) + try: - reader.validate() + with validio.Reader(self.cid, data_path) as reader: + reader.validate_rows() _log.info(' accepted %d rows', reader.accepted_rows_count) except errors.CutplaceError as error: _log.error(' %s', error) @@ -137,8 +138,9 @@ def process(argv=None): Before calling this, module :py:mod:`logging` has to be set up properly. For example, by calling :py:func:`logging.basicConfig`. - :return: 0 unless ``argv`` requested to validate one or more files and at - least one of them contained rejected data. In this case, the result is 1. + :return: 0 unless ``argv`` requested to validate one or more files and \ + at least one of them contained rejected data. In this case, the \ + result is 1. """ if argv is None: # pragma: no cover argv = sys.argv diff --git a/cutplace/interface.py b/cutplace/interface.py index f5c1f704..d1b972d4 100644 --- a/cutplace/interface.py +++ b/cutplace/interface.py @@ -101,7 +101,9 @@ def set_location_to_caller(self): def data_format(self): """ The data format used by the this CID; refer to the - py:mod:`cutplace.data` module for possible formats. + :py:mod:`cutplace.data` module for possible formats. + + :rtype: cutplace.data.DataFormat """ return self._data_format @@ -265,6 +267,7 @@ def read(self, cid_path, rows): self._location.advance_line() if self.data_format is None: raise errors.InterfaceError('data format must be specified', self._location) + self.data_format.validate() if len(self.field_names) == 0: raise errors.InterfaceError('fields must be specified', self._location) @@ -535,6 +538,24 @@ def create_cid_from_string(cid_text): return result +def field_lengths(fixed_cid): + """ + List of :py:class:`int`s for all field lengths in ``fixed_cid`` which + must be of data format py:attr:`~cutplace.data.FORMAT_FIXED`. + """ + assert fixed_cid is not None + assert fixed_cid.data_format.format == data.FORMAT_FIXED, 'format=' + fixed_cid.data_format.format + result = [] + for field_format in fixed_cid.field_formats: + field_length_range = field_format.length.items[0] + lower, upper = field_length_range + assert lower is not None + assert lower == upper + field_length = lower + result.append(field_length) + return result + + def field_names_and_lengths(fixed_cid): """ List of tuples ``(field_name, field_length)`` for all field formats in diff --git a/cutplace/rowio.py b/cutplace/rowio.py index 62df9090..5d515ec0 100644 --- a/cutplace/rowio.py +++ b/cutplace/rowio.py @@ -121,8 +121,8 @@ def _excel_cell_value(cell, datemode): def excel_rows(source_path, sheet=1): """ - Rows read from an Excel document (both *.xls and *.xlsx thanks to - :py:mod:`xlrd`). + Rows read from an Excel document (both :file:`*.xls` and :file:`*.xlsx` + thanks to :py:mod:`xlrd`). :param str source_path: path to the Excel file to be read :param int sheet: the sheet in the file to be read @@ -180,6 +180,7 @@ def _as_delimited_keywords(delimited_data_format): } return result + def delimited_rows(delimited_source, data_format): """ Rows in ``delimited_source`` with using ``data_format``. In case @@ -455,34 +456,34 @@ def auto_rows(source): class AbstractRowWriter(object): """ - Base class for writers that can write rows using a certain - :py:class:`cutplace.data.DataFormat`. + Base class for writers that can write rows to ``target`` using a certain + :py:class:`~cutplace.data.DataFormat`. + + :param target: :py:class:`str` or filelike object to write to; a \ + :py:class:`str` is assumed to be a path to a file which is \ + automatically opened during in the constructor and closed with \ + :py:meth:`~.cutplace.rowio.AbstractRowWriter.close` or by using the \ + ``with`` statement + :param cutplace.data.DataFormat: data format to use for writing """ - def __init__(self, target_path, data_format): - assert target_path is not None - assert data_format is not None - self._data_format = data_format - self._target_path = target_path - - -class DelimitedRowWriter(AbstractRowWriter): - def __init__(self, delimited_target, data_format): - assert delimited_target is not None + def __init__(self, target, data_format): + assert target is not None assert data_format is not None - assert data_format.format == data.FORMAT_DELIMITED assert data_format.is_valid + self._data_format = data_format self._has_opened_target_stream = False - if isinstance(delimited_target, six.string_types): - self._target_path = delimited_target + if isinstance(target, six.string_types): + self._target_path = target self._target_stream = io.open(self._target_path, 'w', encoding=data_format.encoding, newline='') self._has_opened_target_stream = True else: - self._target_path = delimited_target - self._target_stream = delimited_target + try: + self._target_path = target.name + except AttributeError: + self._target_path = '' + self._target_stream = target self._location = errors.Location(self.target_path, has_cell=True) - keywords = _as_delimited_keywords(data_format) - self._delimited_writer = _compat.csv_writer(self._target_stream, **keywords) def __enter__(self): return self @@ -500,7 +501,8 @@ def data_format(self): def location(self): """ The :py:class:`cutplace.errors.Location` to write the next row to. - This is automatically advanced by :py:meth:`~.write_row`. + This is automatically advanced by + :py:meth:`~.cutplace.rowio.AbstractRowWriter.write_row`. """ return self._location @@ -508,12 +510,12 @@ def location(self): def target_path(self): return self._target_path - def write_row(self, row_to_write): - try: - self._delimited_writer.writerow(row_to_write) - except UnicodeEncodeError as error: - raise errors.DataFormatError('cannot write data row: %s; row=%s' % (error, row_to_write), self.location) - self._location.advance_line() + @property + def target_stream(self): + return self._target_stream + + def write_row(self, rows_to_write): + raise NotImplementedError def write_rows(self, rows_to_write): assert self.target_stream is not None @@ -528,3 +530,73 @@ def close(self): self._has_opened_target_stream = False self._target_stream = None self._target_path = None + + +class DelimitedRowWriter(AbstractRowWriter): + def __init__(self, target, data_format): + assert target is not None + assert data_format is not None + assert data_format.format == data.FORMAT_DELIMITED + assert data_format.is_valid + + super(DelimitedRowWriter, self).__init__(target, data_format) + keywords = _as_delimited_keywords(data_format) + self._delimited_writer = _compat.csv_writer(self._target_stream, **keywords) + + def write_row(self, row_to_write): + try: + self._delimited_writer.writerow(row_to_write) + except UnicodeEncodeError as error: + raise errors.DataFormatError('cannot write data row: %s; row=%s' % (error, row_to_write), self.location) + self._location.advance_line() + + +class FixedRowWriter(AbstractRowWriter): + def __init__(self, target, data_format, field_lengths): + assert target is not None + assert data_format is not None + assert data_format.format == data.FORMAT_FIXED + assert data_format.is_valid + assert field_lengths is not None + for field_length in field_lengths: + assert field_length is not None + assert field_length >= 1, 'field_length=%r' % field_length + + super(FixedRowWriter, self).__init__(target, data_format) + self._field_lengths = field_lengths + self._expected_row_item_count = len(self._field_lengths) + if self.data_format.line_delimiter == 'any': + self._line_separator = os.linesep + else: + self._line_separator = self.data_format.line_delimiter + + def write_row(self, row_to_write): + """ + Write a row of fixed length strings. + + :param list row_to_write: a list of str where each item must have \ + exactly the same length as the corresponding entry in \ + :py:attr:`~.field_lengths` + :raises AssertionError: if ``row_to_write`` is not a list of \ + strings with each matching the corresponding ``field_lengths`` \ + as specified to :py:meth:`~.__init__`. + """ + assert row_to_write is not None + row_to_write_item_count = len(row_to_write) + assert row_to_write_item_count == self._expected_row_item_count, \ + 'row %d have %d items instead of %d: %s' \ + % (self.location.line, self._expected_row_item_count, row_to_write_item_count, row_to_write) + + for item_index, item in enumerate(row_to_write): + assert isinstance(item, six.text_type), \ + 'row %d, item %d must be a (unicode) str but is: %r' % (self.location.line, item_index, item) + assert len(item) == self._field_lengths[item_index], \ + 'row %d, item %d must have exactly %d characters instead of %d: %r' \ + % (self.location.line, item_index, len(item), self._field_lengths[item_index], item) + try: + self._target_stream.write(''.join(row_to_write)) + except UnicodeEncodeError as error: + raise errors.DataFormatError('cannot write data row: %s; row=%s' % (error, row_to_write), self.location) + if self._line_separator is not None: + self._target_stream.write(self._line_separator) + self.location.advance_line() diff --git a/cutplace/validio.py b/cutplace/validio.py index 15fd978c..0866e31b 100644 --- a/cutplace/validio.py +++ b/cutplace/validio.py @@ -34,34 +34,143 @@ def _create_field_map(field_names, field_values): return dict(zip(field_names, field_values)) -class Reader(object): - def __init__(self, cid, source_path): +class BaseValidator(object): + """ + A general validator to validate a single row (by validating its fields + and perform row checks), perform final checks when done with all rows + and finally release all resources required to do that. + + The :py:attr:`~.location` has to be set by descendants. While + :py:meth:`~.validate_row` takes care of advancing the cell, descendants + are responsible for advancing the row (by calling + :py:meth:`cutplace.errors.Location.advance_line`). + + It also provides a context manager and can consequently be used with the + ``with`` statement. + """ + def __init__(self, cid): + assert cid is not None + self._cid = cid - self._source_path = source_path + self._expected_item_count = len(self._cid.field_formats) self._location = None + self._is_closed = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Simply call :py:meth:`~.close()`. + """ + self.close() + + @property + def cid(self): + """ + The CID to validate the data. + + :rtype: cutplace.interface.Cid + """ + return self._cid + + @property + def location(self): + """ + The current location in the data to validate. + + :rtype: cutplace.errors.Location + """ + return self._location + + def validate_row(self, field_values): + """ + Validate ``row`` by: + + 1. Check if the number of items in ``row`` matches the number of + fields in the CID + 2. Check that all fields conform to their field format (as defined + by :py:class:`cutplace.fields.AbstractFieldFormat` and its + descendants) + 3. Check that the row conforms to all row checks (as defined by + :py:meth:`cutplace.checks.AbstractCheck.check_row`) + + The caller is responsible for :py:attr:`~.location` pointing to the + correct row in the data while ``validate_row`` take care of calling + :py:meth:`~.Location.advance_cell` appropriately. + """ + assert field_values is not None + assert self.location is not None + + # Validate that number of fields. + actual_item_count = len(field_values) + if actual_item_count < self._expected_item_count: + raise errors.DataError( + 'row must contain %d fields but only has %d: %s' + % (self._expected_item_count, actual_item_count, field_values), + self._location) + if actual_item_count > self._expected_item_count: + raise errors.DataError( + 'row must contain %d fields but has %d, additional values are: %s' + % (self._expected_item_count, actual_item_count, field_values[self._expected_item_count:]), + self.location) + + # Validate each field according to its format. + for field_index, field_value in enumerate(field_values): + field_to_validate = self.cid.field_formats[field_index] + try: + field_to_validate.validated(field_value) + except errors.FieldValueError as error: + error.prepend_message('cannot accept field %s' % field_to_validate.field_name, self.location) + raise + self.location.advance_cell() + + # Validate the whole row according to row checks. + field_map = _create_field_map(self.cid.field_names, field_values) + for check_name in self.cid.check_names: + self.cid.check_map[check_name].check_row(field_map, self.location) + + def close(self): + """ + Validate final checks and release all resources. When called a second + time, do nothing. + + :raises cutplace.errors.CheckError: if any \ + :py:meth:`cutplace.checks.AbstractCheck.check_at_end` fails. + """ + if not self._is_closed: + try: + for check_name in self.cid.check_names: + self.cid.check_map[check_name].check_at_end(self.location) + finally: + for check in self.cid.check_map.values(): + check.cleanup() + self._is_closed = True + + +class Reader(BaseValidator): + def __init__(self, cid, source_path): + assert cid is not None + assert source_path is not None + + super(Reader, self).__init__(cid) + self._source_path = source_path self.accepted_rows_count = None self.rejected_rows_count = None def _raw_rows(self): - data_format = self._cid.data_format + data_format = self.cid.data_format if data_format.format == data.FORMAT_EXCEL: return rowio.excel_rows(self._source_path, data_format.sheet) elif data_format.format == data.FORMAT_DELIMITED: return rowio.delimited_rows(self._source_path, data_format) elif data_format.format == data.FORMAT_FIXED: return rowio.fixed_rows( - self._source_path, data_format.encoding, interface.field_names_and_lengths(self._cid), + self._source_path, data_format.encoding, interface.field_names_and_lengths(self.cid), data_format.line_delimiter) elif data_format.format == data.FORMAT_ODS: return rowio.ods_rows(self._source_path, data_format.sheet) - @property - def cid(self): - """ - The :py:class:`cutplace.interface.Cid` used to validate data. - """ - return self._cid - @property def source_path(self): """ @@ -90,78 +199,83 @@ def rows(self, on_error='raise'): assert on_error in ('continue', 'raise', 'yield') self._location = errors.Location(self._source_path, has_cell=True) - expected_item_count = len(self._cid.field_formats) - - def validate_field_formats(field_values): - actual_item_count = len(field_values) - if actual_item_count < expected_item_count: - raise errors.DataError( - 'row must contain %d fields but only has %d: %s' - % (expected_item_count, actual_item_count, field_values), - self._location) - if actual_item_count > expected_item_count: - raise errors.DataError( - 'row must contain %d fields but has %d, additional values are: %s' % ( - expected_item_count, actual_item_count, field_values[expected_item_count:]), - self._location) - for i in range(actual_item_count): - field_to_validate = self._cid.field_formats[i] - try: - field_to_validate.validated(field_values[i]) - except errors.FieldValueError as error: - field_name = field_to_validate.field_name - error.prepend_message('cannot accept field %s' % field_name, self._location) - raise - self._location.advance_cell() - - def validate_row_checks(field_values): - field_map = _create_field_map(self._cid.field_names, field_values) - for check_name in self._cid.check_names: - self._cid.check_map[check_name].check_row(field_map, self._location) - - def validate_checks_at_end(): - for check_name in self._cid.check_names: - self._cid.check_map[check_name].check_at_end(self._location) - self.accepted_rows_count = 0 self.rejected_rows_count = 0 - for check in self._cid.check_map.values(): + for check in self.cid.check_map.values(): check.reset() - try: - for row in self._raw_rows(): - try: - validate_field_formats(row) - validate_row_checks(row) - self.accepted_rows_count += 1 - yield row - self._location.advance_line() - except errors.DataError as error: - if on_error == 'raise': - raise - self.rejected_rows_count += 1 - if on_error == 'yield': - yield error - else: - assert on_error == 'continue' - validate_checks_at_end() - finally: - for check in self._cid.check_map.values(): - check.cleanup() + for row in self._raw_rows(): + try: + self.validate_row(row) + self.accepted_rows_count += 1 + yield row + except errors.DataError as error: + if on_error == 'raise': + raise + self.rejected_rows_count += 1 + if on_error == 'yield': + yield error + else: + assert on_error == 'continue' + self._location.advance_line() - def validate(self): + def validate_rows(self): """ Validate that the data read from :py:meth:`~cutplace.validio.Reader.rows()` conform to :py:attr:`~cutplace.validio.Reader.cid`. + In order to check everything, :py:meth`~.close()` has to be + called to also validate the checks at the end of the data. + :raises cutplace.errors.DataError: on broken data """ for _ in self.rows(): pass - def close(self): + +class Writer(BaseValidator): + def __init__(self, cid, target): + assert cid is not None + assert target is not None + data_format = cid.data_format + assert data_format.is_valid + + super(Writer, self).__init__(cid) + self._delegated_writer = None + if data_format.format == data.FORMAT_DELIMITED: + self._delegated_writer = rowio.DelimitedRowWriter(target, data_format) + elif data_format.format == data.FORMAT_FIXED: + field_lengths = interface.field_lengths(cid) + self._delegated_writer = rowio.FixedRowWriter(target, data_format, field_lengths) + else: + raise NotImplementedError('data_format=%r' % data_format.format) + + @property + def location(self): """ - Release all resources allocated for reading. + The location in the :py:class:`cutplace.rowio.AbstractRowWriter` used + to actually write the data. """ - # TODO: Ponder: do we actually need this? - pass + return self._delegated_writer.location if self._delegated_writer is not None else None + + def write_row(self, row_to_write): + assert row_to_write is not None + assert self._delegated_writer is not None + + self.validate_row(row_to_write) + self._delegated_writer.write_row(row_to_write) + + def write_rows(self, rows_to_write): + assert rows_to_write is not None + assert self._delegated_writer is not None + + for row_to_write in rows_to_write: + self._delegated_writer.write_row(row_to_write) + + def close(self): + try: + super(Writer, self).close() + finally: + if self._delegated_writer is not None: + self._delegated_writer.close() + self._delegated_writer = None diff --git a/docs/api.rst b/docs/api.rst index ce82a62d..5f15c9c0 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -111,20 +111,20 @@ Here is an example that prints any data related errors detected during validation:: >>> broken_data_path = os.path.join(os.pardir, 'tests', 'data', 'broken_customers.csv') - >>> reader = cutplace.Reader(cid, broken_data_path) - >>> for row_or_error in reader.rows(on_error='yield'): - ... if isinstance(row_or_error, Exception): - ... if isinstance(row_or_error, cutplace.errors.CutplaceError): - ... # Print data related error details and move on. - ... print(row_or_error) + >>> with cutplace.Reader(cid, broken_data_path) as reader: + ... for row_or_error in reader.rows(on_error='yield'): + ... if isinstance(row_or_error, Exception): + ... if isinstance(row_or_error, cutplace.errors.CutplaceError): + ... # Print data related error details and move on. + ... print(row_or_error) + ... else: + ... # Let other, more severe errors terminate the validation. + ... raise row_or_error ... else: - ... # Let other, more severe errors terminate the validation. - ... raise row_or_error - ... else: - ... pass # We could also do something useful with the data in ``row`` here. + ... pass # We could also do something useful with the data in ``row`` here. broken_customers.csv (R4C1): cannot accept field branch_id: value '12345' must match regular expression: '38\\d\\d\\d' - broken_customers.csv (R4C2): cannot accept field customer_id: value must be an integer number: 'XX' - broken_customers.csv (R4C7): cannot accept field date_of_birth: date must match format DD.MM.YYYY (%d.%m.%Y) but is: '30.02.1994' (day is out of range for month) + broken_customers.csv (R5C2): cannot accept field customer_id: value must be an integer number: 'XX' + broken_customers.csv (R6C6): cannot accept field date_of_birth: date must match format DD.MM.YYYY (%d.%m.%Y) but is: '30.02.1994' (day is out of range for month) Note that it is possible for the reader to throw other exceptions, for example :py:exc:`IOError` in case the file cannot be read at all or :py:exc:`UnicodeError` diff --git a/docs/changes.rst b/docs/changes.rst index 4ba3a40c..2eb5961a 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -5,6 +5,18 @@ Revision history This chapter describes improvements compared to earlier versions of cutplace. +Version 0.8.3, 2015-01-xx +========================= + +* Improved API: + + * Removed shortcuts for exceptions from :py:mod:`cutplace`. Use the + originals in :py:mod:`cutplace.errors` instead. + * Added :py:class:`cutplace.Writer` for validated writing of delimited and + fixed data + (`issue #84 `_). + * Improved API documentation. + Version 0.8.2, 2015-01-19 ========================= @@ -15,7 +27,7 @@ Version 0.8.2, 2015-01-19 * Improved error reporting when parsing CIDs. In particular all errors related to the data format include a specific location, and some errors - provide more information about the context the occurred in. + provide more information about the context they occurred in. * Cleaned up :option:`--help`: diff --git a/tests/dev_test.py b/tests/dev_test.py index 4db7a4a4..91f84ec0 100644 --- a/tests/dev_test.py +++ b/tests/dev_test.py @@ -248,3 +248,11 @@ def assert_fnmatches(test_case, actual_value, expected_pattern): test_case.assertNotEqual(None, actual_value) if not fnmatch.fnmatch(actual_value, expected_pattern): test_case.fail('%r must match pattern %r' % (actual_value, expected_pattern)) + + +def unified_newlines(text): + """ + Same as ``text`` but with newline sequences unified to ``'\n'``. + """ + assert text is not None + return text.replace('\r\n', '\n').replace('\r', '\n') diff --git a/tests/test_performance.py b/tests/test_performance.py index 4ddf7475..0080c528 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -63,8 +63,8 @@ def _build_and_validate_many_customers(): # Validate the data using the API, so in case of errors we get specific information. customers_cid = interface.Cid(icd_ods_path) - reader = validio.Reader(customers_cid, many_customers_csv_path) - reader.validate() + with validio.Reader(customers_cid, many_customers_csv_path) as reader: + reader.validate_rows() # Validate the data using the command line application in order to use # the whole tool chain from an end user's point of view. diff --git a/tests/test_rowio.py b/tests/test_rowio.py index 73cf5e59..8f13762a 100644 --- a/tests/test_rowio.py +++ b/tests/test_rowio.py @@ -1,5 +1,5 @@ """ -Test for `iotools` module. +Tests for the :py:mod:`cutplace.rowio` module. """ # Copyright (C) 2009-2015 Thomas Aglassinger # @@ -31,6 +31,7 @@ _EURO_SIGN = '\u20ac' + class RowsTest(unittest.TestCase): def _assert_rows_contain_data(self, rows): self.assertTrue(rows is not None) @@ -245,12 +246,7 @@ def test_can_auto_read_ods_rows(self): class DelimitedRowWriterTest(unittest.TestCase): - @staticmethod - def _unified_newlines(text): - assert text is not None - return text.replace('\r\n', '\n').replace('\r', '\n') - - def test_can_write_delimited_to_string_io(self): + def test_can_write_delimited_data_to_string_io(self): delimited_data_format = data.DataFormat(data.FORMAT_DELIMITED) delimited_data_format.validate() with io.StringIO() as target: @@ -258,10 +254,10 @@ def test_can_write_delimited_to_string_io(self): delimited_writer.write_row(['a', 'b', _EURO_SIGN]) delimited_writer.write_row([]) delimited_writer.write_row([1, 2, 'end']) - data_written = DelimitedRowWriterTest._unified_newlines(target.getvalue()) + data_written = dev_test.unified_newlines(target.getvalue()) self.assertEqual('%r' % data_written, '%r' % 'a,b,\u20ac\n\n1,2,end\n') - def test_can_write_delimited_to_path(self): + def test_can_write_delimited_data_to_path(self): delimited_data_format = data.DataFormat(data.FORMAT_DELIMITED) delimited_data_format.set_property(data.KEY_ENCODING, 'utf-8') delimited_data_format.validate() @@ -278,7 +274,7 @@ def test_can_write_delimited_to_path(self): def test_fails_on_unicode_error_during_delimited_write(self): delimited_data_format = data.DataFormat(data.FORMAT_DELIMITED) - delimited_data_format.set_property(data.KEY_ENCODING, 'latin-1') + delimited_data_format.set_property(data.KEY_ENCODING, 'ascii') delimited_data_format.validate() delimited_path = dev_test.path_to_test_result('test_fails_on_unicode_error_during_delimited_write.csv') with io.open(delimited_path, 'w', encoding=delimited_data_format.encoding) as delimited_target_stream: @@ -293,5 +289,44 @@ def test_fails_on_unicode_error_during_delimited_write(self): self, anticipated_error_message, "*.csv (R2C1): cannot write data row: *; row=?'b', '\u20ac'?") +class FixedRowWriter(unittest.TestCase): + def test_can_write_fixed_data_to_string(self): + fixed_data_format = data.DataFormat(data.FORMAT_FIXED) + fixed_data_format.set_property(data.KEY_ENCODING, 'utf-8') + fixed_data_format.validate() + field_lengths = [1, 3] + with io.StringIO() as target: + with rowio.FixedRowWriter(target, fixed_data_format, field_lengths) as fixed_writer: + fixed_writer.write_row(['a', 'bcd']) + fixed_writer.write_row([_EURO_SIGN, ' ']) + data_written = dev_test.unified_newlines(target.getvalue()) + self.assertEqual('%r' % data_written, '%r' % 'abcd\n\u20ac \n') + + def test_can_write_fixed_data_without_line_delimiter(self): + fixed_data_format = data.DataFormat(data.FORMAT_FIXED) + fixed_data_format.set_property(data.KEY_LINE_DELIMITER, 'none') + fixed_data_format.validate() + with io.StringIO() as target: + with rowio.FixedRowWriter(target, fixed_data_format, [1]) as fixed_writer: + fixed_writer.write_rows([['1'], ['2'], ['3']]) + data_written = target.getvalue() + self.assertEqual(data_written, '123') + + def test_fails_on_unicode_error_during_fixed_write(self): + fixed_data_format = data.DataFormat(data.FORMAT_FIXED) + fixed_data_format.set_property(data.KEY_ENCODING, 'ascii') + fixed_data_format.validate() + fixed_path = dev_test.path_to_test_result('test_fails_on_unicode_error_during_fixed_write.txt') + with rowio.FixedRowWriter(fixed_path, fixed_data_format, [1]) as fixed_writer: + fixed_writer.write_row(['a']) + try: + fixed_writer.write_row([_EURO_SIGN]) + self.fail() + except errors.DataError as anticipated_error: + anticipated_error_message = str(anticipated_error) + dev_test.assert_fnmatches( + self, anticipated_error_message, "*.txt (R2C1): cannot write data row: *; row=?'\u20ac'?") + + if __name__ == "__main__": # pragma: no cover unittest.main() diff --git a/tests/test_validio.py b/tests/test_validio.py index a7aee23b..ec866c1b 100644 --- a/tests/test_validio.py +++ b/tests/test_validio.py @@ -30,85 +30,58 @@ from tests import dev_test -class ValidatorTest(unittest.TestCase): +class ReaderTest(unittest.TestCase): """ Tests for data formats. """ - # TODO: Cleanup: rename all the ``cid_reader`` variables to ``cid``. _TEST_ENCODING = "cp1252" def test_can_open_and_validate_csv_source_file(self): - cid_reader = interface.Cid() - source_path = dev_test.path_to_test_cid("icd_customers.xls") - cid_reader.read(source_path, rowio.excel_rows(source_path)) - - reader = validio.Reader(cid_reader, dev_test.path_to_test_data("valid_customers.csv")) - reader.validate() + cid = interface.Cid(dev_test.path_to_test_cid("icd_customers.xls")) + with validio.Reader(cid, dev_test.path_to_test_data("valid_customers.csv")) as reader: + reader.validate_rows() def test_can_open_and_validate_excel_source_file(self): - cid_reader = interface.Cid() - source_path = dev_test.path_to_test_cid("icd_customers_excel.xls") - cid_reader.read(source_path, rowio.excel_rows(source_path)) - - reader = validio.Reader(cid_reader, dev_test.path_to_test_data("valid_customers.xls")) - reader.validate() + cid = interface.Cid(dev_test.path_to_test_cid("icd_customers_excel.xls")) + with validio.Reader(cid, dev_test.path_to_test_data("valid_customers.xls")) as reader: + reader.validate_rows() def test_can_open_and_validate_ods_source_file(self): - cid_reader = interface.Cid() - source_path = dev_test.path_to_test_cid("icd_customers_ods.xls") - cid_reader.read(source_path, rowio.excel_rows(source_path)) - - reader = validio.Reader(cid_reader, dev_test.path_to_test_data("valid_customers.ods")) - reader.validate() + cid = interface.Cid(dev_test.path_to_test_cid("icd_customers_ods.xls")) + with validio.Reader(cid, dev_test.path_to_test_data("valid_customers.ods")) as reader: + reader.validate_rows() def test_can_open_and_validate_fixed_source_file(self): - cid_reader = interface.Cid() - source_path = dev_test.path_to_test_cid("customers_fixed.xls") - cid_reader.read(source_path, rowio.excel_rows(source_path)) - - reader = validio.Reader(cid_reader, dev_test.path_to_test_data("valid_customers_fixed.txt")) - reader.validate() + cid = interface.Cid(dev_test.path_to_test_cid("customers_fixed.xls")) + with validio.Reader(cid, dev_test.path_to_test_data("valid_customers_fixed.txt")) as reader: + reader.validate_rows() def test_fails_on_invalid_csv_source_file(self): - cid_reader = interface.Cid() - source_path = dev_test.path_to_test_cid("icd_customers.xls") - cid_reader.read(source_path, rowio.excel_rows(source_path)) - - reader = validio.Reader(cid_reader, dev_test.path_to_test_data("broken_customers.csv")) - self.assertRaises(errors.FieldValueError, reader.validate) + cid = interface.Cid(dev_test.path_to_test_cid("icd_customers.xls")) + with validio.Reader(cid, dev_test.path_to_test_data("broken_customers.csv")) as reader: + self.assertRaises(errors.FieldValueError, reader.validate_rows) def test_fails_on_csv_source_file_with_fewer_elements_than_expected(self): - cid_reader = interface.Cid() - source_path = dev_test.path_to_test_cid("icd_customers.xls") - cid_reader.read(source_path, rowio.excel_rows(source_path)) - - reader = validio.Reader(cid_reader, dev_test.path_to_test_data("broken_customers_fewer_elements.csv")) - self.assertRaises(errors.DataError, reader.validate) + cid = interface.Cid(dev_test.path_to_test_cid("icd_customers.xls")) + with validio.Reader(cid, dev_test.path_to_test_data("broken_customers_fewer_elements.csv")) as reader: + self.assertRaises(errors.DataError, reader.validate_rows) def test_fails_on_csv_source_file_with_more_elements_than_expected(self): - cid_reader = interface.Cid() - source_path = dev_test.path_to_test_cid("icd_customers.xls") - cid_reader.read(source_path, rowio.excel_rows(source_path)) - - reader = validio.Reader(cid_reader, dev_test.path_to_test_data("broken_customers_more_elements.csv")) - self.assertRaises(errors.DataError, reader.validate) + cid_reader = interface.Cid(dev_test.path_to_test_cid("icd_customers.xls")) + with validio.Reader(cid_reader, dev_test.path_to_test_data("broken_customers_more_elements.csv")) as reader: + self.assertRaises(errors.DataError, reader.validate_rows) def test_fails_on_invalid_csv_source_file_with_duplicates(self): - cid_reader = interface.Cid() - source_path = dev_test.path_to_test_cid("icd_customers.xls") - cid_reader.read(source_path, rowio.excel_rows(source_path)) - - reader = validio.Reader(cid_reader, dev_test.path_to_test_data("broken_customers_with_duplicates.csv")) - self.assertRaises(errors.CheckError, reader.validate) + cid = interface.Cid(dev_test.path_to_test_cid("icd_customers.xls")) + with validio.Reader(cid, dev_test.path_to_test_data("broken_customers_with_duplicates.csv")) as reader: + self.assertRaises(errors.CheckError, reader.validate_rows) def test_fails_on_invalid_csv_source_file_with_not_observed_count_expression(self): - cid_reader = interface.Cid() - source_path = dev_test.path_to_test_cid("icd_customers.xls") - # FIXME: either test `validator` or move to `test_tools`. - cid_reader.read(source_path, rowio.excel_rows(source_path)) - - reader = validio.Reader(cid_reader, dev_test.path_to_test_data("broken_customers_with_too_many_branches.csv")) - self.assertRaises(errors.CheckError, reader.validate) + cid = interface.Cid(dev_test.path_to_test_cid("icd_customers.xls")) + data_path = dev_test.path_to_test_data("broken_customers_with_too_many_branches.csv") + reader = validio.Reader(cid, data_path) + reader.validate_rows() + self.assertRaises(errors.CheckError, reader.close) def test_can_process_escape_character(self): """ @@ -124,11 +97,11 @@ def test_can_process_escape_character(self): ]) cid = interface.create_cid_from_string(cid_text) with io.StringIO('"\\"x"\n') as data_starting_with_escape_character: - reader = validio.Reader(cid, data_starting_with_escape_character) - reader.validate() + with validio.Reader(cid, data_starting_with_escape_character) as reader: + reader.validate_rows() with io.StringIO('"x\\""\n') as data_ending_with_escape_character: - reader = validio.Reader(cid, data_ending_with_escape_character) - reader.validate() + with validio.Reader(cid, data_ending_with_escape_character) as reader: + reader.validate_rows() def test_can_yield_errors(self): cid_text = '\n'.join([ @@ -138,8 +111,8 @@ def test_can_yield_errors(self): ]) cid = interface.create_cid_from_string(cid_text) with io.StringIO('1\nabc\n3') as partially_broken_data: - reader = validio.Reader(cid, partially_broken_data) - rows = list(reader.rows('yield')) + with validio.Reader(cid, partially_broken_data) as reader: + rows = list(reader.rows('yield')) self.assertEqual(3, len(rows), 'expected 3 rows but got: %s' % rows) self.assertEqual(['1'], rows[0]) self.assertEqual(errors.FieldValueError, type(rows[1]), 'rows=%s' % rows) @@ -153,8 +126,39 @@ def test_can_continue_after_errors(self): ]) cid = interface.create_cid_from_string(cid_text) with io.StringIO('1\nabc\n3') as partially_broken_data: - reader = validio.Reader(cid, partially_broken_data) - rows = list(reader.rows('continue')) - self.assertEqual(2, len(rows), 'expected 3 rows but got: %s' % rows) - self.assertEqual(['1'], rows[0]) - self.assertEqual(['3'], rows[1]) + with validio.Reader(cid, partially_broken_data) as reader: + rows = list(reader.rows('continue')) + expected_row_count = 2 + self.assertEqual(expected_row_count, len(rows), 'expected %d rows but got: %s' % (expected_row_count, rows)) + self.assertEqual([['1'], ['3']], rows) + + +class WriterTest(unittest.TestCase): + def setUp(self): + standard_cid_text = '\n'.join([ + 'd,format,delimited', + ' ,name ,,empty,length,type,rule', + 'f,surname', + 'f,height ,, , ,Integer', + 'f,born_on,, , ,DateTime,YYYY-MM-DD' + ]) + self._standard_cid = interface.create_cid_from_string(standard_cid_text) + + def test_can_write_delimited(self): + with io.StringIO() as delimited_stream: + with validio.Writer(self._standard_cid, delimited_stream) as delimited_writer: + delimited_writer.write_row(['Miller', '173', '1967-05-23']) + delimited_writer.write_row(['Webster', '167', '1983-11-02']) + data_written = dev_test.unified_newlines(delimited_stream.getvalue()) + self.assertEqual('%r' % 'Miller,173,1967-05-23\nWebster,167,1983-11-02\n', '%r' % data_written) + + def test_fails_on_writing_broken_field(self): + with io.StringIO() as delimited_stream: + with validio.Writer(self._standard_cid, delimited_stream) as delimited_writer: + delimited_writer.write_row(['Miller', '173', '1967-05-23']) + try: + delimited_writer.write_row(['Webster', 'not_a_number', '1983-11-02']) + except errors.FieldValueError as anticipated_error: + dev_test.assert_fnmatches( + self, str(anticipated_error), + "* (R2C2): cannot accept field height: value must be an integer number: 'not_a_number'")