diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index fc3e3954..5c79bb7e 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -235,27 +235,54 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal Args: filename: The file path where the data will be written. data: The data to be written to the file. - truncate: If True, the file will be truncated before writing ('w' or 'wb' option); - if False (default), data will be appended ('a' or 'ab' option). - binary: If True, the data will be written in binary mode ('wb' or 'ab' option); - if False (default), the data will be written in text mode ('w' or 'a' option). - read_and_write: If True, the file will be opened with read and write permissions ('r+' option); - if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) + truncate: If True, the file will be truncated before writing ('w' option); + if False (default), data will be appended ('a' option). + binary: If True, the data will be written in binary mode ('b' option); + if False (default), the data will be written in text mode. + read_and_write: If True, the file will be opened with read and write permissions ('+' option); + if False (default), only write permission will be used. """ - # If it is a bytes str or list if isinstance(data, bytes) or isinstance(data, list) and all(isinstance(item, bytes) for item in data): binary = True - mode = "wb" if binary else "w" - if not truncate: - mode = "ab" if binary else "a" + + mode = "w" if truncate else "a" + if read_and_write: - mode = "r+b" if binary else "r+" + mode += "+" + + # If it is a bytes str or list + if binary: + mode += "b" + + assert type(mode) == str # noqa: E721 + assert mode != "" with open(filename, mode) as file: if isinstance(data, list): - file.writelines(data) + data2 = [__class__._prepare_line_to_write(s, binary) for s in data] + file.writelines(data2) else: - file.write(data) + data2 = __class__._prepare_data_to_write(data, binary) + file.write(data2) + + def _prepare_line_to_write(data, binary): + data = __class__._prepare_data_to_write(data, binary) + + if binary: + assert type(data) == bytes # noqa: E721 + return data.rstrip(b'\n') + b'\n' + + assert type(data) == str # noqa: E721 + return data.rstrip('\n') + '\n' + + def _prepare_data_to_write(data, binary): + if isinstance(data, bytes): + return data if binary else data.decode() + + if isinstance(data, str): + return data if not binary else data.encode() + + raise InvalidOperationException("Unknown type of data type [{0}].".format(type(data).__name__)) def touch(self, filename): """ diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 3ebc2e60..f690e063 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -278,10 +278,6 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal if not encoding: encoding = get_default_encoding() mode = "wb" if binary else "w" - if not truncate: - mode = "ab" if binary else "a" - if read_and_write: - mode = "r+b" if binary else "r+" with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: # For scp the port is specified by a "-P" option @@ -292,16 +288,12 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal subprocess.run(scp_cmd, check=False) # The file might not exist yet tmp_file.seek(0, os.SEEK_END) - if isinstance(data, bytes) and not binary: - data = data.decode(encoding) - elif isinstance(data, str) and binary: - data = data.encode(encoding) - if isinstance(data, list): - data = [(s if isinstance(s, str) else s.decode(get_default_encoding())).rstrip('\n') + '\n' for s in data] - tmp_file.writelines(data) + data2 = [__class__._prepare_line_to_write(s, binary, encoding) for s in data] + tmp_file.writelines(data2) else: - tmp_file.write(data) + data2 = __class__._prepare_data_to_write(data, binary, encoding) + tmp_file.write(data2) tmp_file.flush() scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.ssh_dest}:{filename}"] @@ -313,6 +305,25 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal os.remove(tmp_file.name) + def _prepare_line_to_write(data, binary, encoding): + data = __class__._prepare_data_to_write(data, binary, encoding) + + if binary: + assert type(data) == bytes # noqa: E721 + return data.rstrip(b'\n') + b'\n' + + assert type(data) == str # noqa: E721 + return data.rstrip('\n') + '\n' + + def _prepare_data_to_write(data, binary, encoding): + if isinstance(data, bytes): + return data if binary else data.decode(encoding) + + if isinstance(data, str): + return data if not binary else data.encode(encoding) + + raise InvalidOperationException("Unknown type of data type [{0}].".format(type(data).__name__)) + def touch(self, filename): """ Create a new file or update the access and modification times of an existing file on the remote server. diff --git a/tests/test_local.py b/tests/test_local.py index 568a4bc5..4051bfb5 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -2,6 +2,7 @@ import pytest import re +import tempfile from testgres import ExecUtilException from testgres import InvalidOperationException @@ -273,3 +274,71 @@ def test_cwd(self): # Comp result assert v == expectedValue + + class tagWriteData001: + def __init__(self, sign, source, cp_rw, cp_truncate, cp_binary, cp_data, result): + self.sign = sign + self.source = source + self.call_param__rw = cp_rw + self.call_param__truncate = cp_truncate + self.call_param__binary = cp_binary + self.call_param__data = cp_data + self.result = result + + sm_write_data001 = [ + tagWriteData001("A001", "1234567890", False, False, False, "ABC", "1234567890ABC"), + tagWriteData001("A002", b"1234567890", False, False, True, b"ABC", b"1234567890ABC"), + + tagWriteData001("B001", "1234567890", False, True, False, "ABC", "ABC"), + tagWriteData001("B002", "1234567890", False, True, False, "ABC1234567890", "ABC1234567890"), + tagWriteData001("B003", b"1234567890", False, True, True, b"ABC", b"ABC"), + tagWriteData001("B004", b"1234567890", False, True, True, b"ABC1234567890", b"ABC1234567890"), + + tagWriteData001("C001", "1234567890", True, False, False, "ABC", "1234567890ABC"), + tagWriteData001("C002", b"1234567890", True, False, True, b"ABC", b"1234567890ABC"), + + tagWriteData001("D001", "1234567890", True, True, False, "ABC", "ABC"), + tagWriteData001("D002", "1234567890", True, True, False, "ABC1234567890", "ABC1234567890"), + tagWriteData001("D003", b"1234567890", True, True, True, b"ABC", b"ABC"), + tagWriteData001("D004", b"1234567890", True, True, True, b"ABC1234567890", b"ABC1234567890"), + + tagWriteData001("E001", "\0001234567890\000", False, False, False, "\000ABC\000", "\0001234567890\000\000ABC\000"), + tagWriteData001("E002", b"\0001234567890\000", False, False, True, b"\000ABC\000", b"\0001234567890\000\000ABC\000"), + + tagWriteData001("F001", "a\nb\n", False, False, False, ["c", "d"], "a\nb\nc\nd\n"), + tagWriteData001("F002", b"a\nb\n", False, False, True, [b"c", b"d"], b"a\nb\nc\nd\n"), + + tagWriteData001("G001", "a\nb\n", False, False, False, ["c\n\n", "d\n"], "a\nb\nc\nd\n"), + tagWriteData001("G002", b"a\nb\n", False, False, True, [b"c\n\n", b"d\n"], b"a\nb\nc\nd\n"), + ] + + @pytest.fixture( + params=sm_write_data001, + ids=[x.sign for x in sm_write_data001], + ) + def write_data001(self, request): + assert isinstance(request, pytest.FixtureRequest) + assert type(request.param) == __class__.tagWriteData001 # noqa: E721 + return request.param + + def test_write(self, write_data001): + assert type(write_data001) == __class__.tagWriteData001 # noqa: E721 + + mode = "w+b" if write_data001.call_param__binary else "w+" + + with tempfile.NamedTemporaryFile(mode=mode, delete=True) as tmp_file: + tmp_file.write(write_data001.source) + tmp_file.flush() + + self.operations.write( + tmp_file.name, + write_data001.call_param__data, + read_and_write=write_data001.call_param__rw, + truncate=write_data001.call_param__truncate, + binary=write_data001.call_param__binary) + + tmp_file.seek(0) + + s = tmp_file.read() + + assert s == write_data001.result diff --git a/tests/test_remote.py b/tests/test_remote.py index 30c5d348..3e6b79dd 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -2,6 +2,7 @@ import pytest import re +import tempfile from testgres import ExecUtilException from testgres import InvalidOperationException @@ -402,3 +403,71 @@ def test_cwd(self): assert v is not None assert type(v) == str # noqa: E721 assert v != "" + + class tagWriteData001: + def __init__(self, sign, source, cp_rw, cp_truncate, cp_binary, cp_data, result): + self.sign = sign + self.source = source + self.call_param__rw = cp_rw + self.call_param__truncate = cp_truncate + self.call_param__binary = cp_binary + self.call_param__data = cp_data + self.result = result + + sm_write_data001 = [ + tagWriteData001("A001", "1234567890", False, False, False, "ABC", "1234567890ABC"), + tagWriteData001("A002", b"1234567890", False, False, True, b"ABC", b"1234567890ABC"), + + tagWriteData001("B001", "1234567890", False, True, False, "ABC", "ABC"), + tagWriteData001("B002", "1234567890", False, True, False, "ABC1234567890", "ABC1234567890"), + tagWriteData001("B003", b"1234567890", False, True, True, b"ABC", b"ABC"), + tagWriteData001("B004", b"1234567890", False, True, True, b"ABC1234567890", b"ABC1234567890"), + + tagWriteData001("C001", "1234567890", True, False, False, "ABC", "1234567890ABC"), + tagWriteData001("C002", b"1234567890", True, False, True, b"ABC", b"1234567890ABC"), + + tagWriteData001("D001", "1234567890", True, True, False, "ABC", "ABC"), + tagWriteData001("D002", "1234567890", True, True, False, "ABC1234567890", "ABC1234567890"), + tagWriteData001("D003", b"1234567890", True, True, True, b"ABC", b"ABC"), + tagWriteData001("D004", b"1234567890", True, True, True, b"ABC1234567890", b"ABC1234567890"), + + tagWriteData001("E001", "\0001234567890\000", False, False, False, "\000ABC\000", "\0001234567890\000\000ABC\000"), + tagWriteData001("E002", b"\0001234567890\000", False, False, True, b"\000ABC\000", b"\0001234567890\000\000ABC\000"), + + tagWriteData001("F001", "a\nb\n", False, False, False, ["c", "d"], "a\nb\nc\nd\n"), + tagWriteData001("F002", b"a\nb\n", False, False, True, [b"c", b"d"], b"a\nb\nc\nd\n"), + + tagWriteData001("G001", "a\nb\n", False, False, False, ["c\n\n", "d\n"], "a\nb\nc\nd\n"), + tagWriteData001("G002", b"a\nb\n", False, False, True, [b"c\n\n", b"d\n"], b"a\nb\nc\nd\n"), + ] + + @pytest.fixture( + params=sm_write_data001, + ids=[x.sign for x in sm_write_data001], + ) + def write_data001(self, request): + assert isinstance(request, pytest.FixtureRequest) + assert type(request.param) == __class__.tagWriteData001 # noqa: E721 + return request.param + + def test_write(self, write_data001): + assert type(write_data001) == __class__.tagWriteData001 # noqa: E721 + + mode = "w+b" if write_data001.call_param__binary else "w+" + + with tempfile.NamedTemporaryFile(mode=mode, delete=True) as tmp_file: + tmp_file.write(write_data001.source) + tmp_file.flush() + + self.operations.write( + tmp_file.name, + write_data001.call_param__data, + read_and_write=write_data001.call_param__rw, + truncate=write_data001.call_param__truncate, + binary=write_data001.call_param__binary) + + tmp_file.seek(0) + + s = tmp_file.read() + + assert s == write_data001.result