diff --git a/setup.py b/setup.py index 16d4c300..e0287659 100755 --- a/setup.py +++ b/setup.py @@ -27,16 +27,16 @@ readme = f.read() setup( - version='1.9.2', + version='1.9.3', name='testgres', - packages=['testgres', 'testgres.operations'], + packages=['testgres', 'testgres.operations', 'testgres.helpers'], description='Testing utility for PostgreSQL and its extensions', url='https://github.com/postgrespro/testgres', long_description=readme, long_description_content_type='text/markdown', license='PostgreSQL', - author='Ildar Musin', - author_email='zildermann@gmail.com', + author='Postgres Professional', + author_email='testgres@postgrespro.ru', keywords=['test', 'testing', 'postgresql'], install_requires=install_requires, classifiers=[], diff --git a/testgres/__init__.py b/testgres/__init__.py index 383daf2d..8d0e38c6 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -52,6 +52,8 @@ from .operations.local_ops import LocalOperations from .operations.remote_ops import RemoteOperations +from .helpers.port_manager import PortManager + __all__ = [ "get_new_node", "get_remote_node", @@ -62,6 +64,6 @@ "XLogMethod", "IsolationLevel", "NodeStatus", "ProcessType", "DumpFormat", "PostgresNode", "NodeApp", "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version", - "First", "Any", + "First", "Any", "PortManager", "OsOperations", "LocalOperations", "RemoteOperations", "ConnectionParams" ] diff --git a/testgres/helpers/__init__.py b/testgres/helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testgres/helpers/port_manager.py b/testgres/helpers/port_manager.py new file mode 100644 index 00000000..6afdf8a9 --- /dev/null +++ b/testgres/helpers/port_manager.py @@ -0,0 +1,40 @@ +import socket +import random +from typing import Set, Iterable, Optional + + +class PortForException(Exception): + pass + + +class PortManager: + def __init__(self, ports_range=(1024, 65535)): + self.ports_range = ports_range + + @staticmethod + def is_port_free(port: int) -> bool: + """Check if a port is free to use.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", port)) + return True + except OSError: + return False + + def find_free_port(self, ports: Optional[Set[int]] = None, exclude_ports: Optional[Iterable[int]] = None) -> int: + """Return a random unused port number.""" + if ports is None: + ports = set(range(1024, 65535)) + + if exclude_ports is None: + exclude_ports = set() + + ports.difference_update(set(exclude_ports)) + + sampled_ports = random.sample(tuple(ports), min(len(ports), 100)) + + for port in sampled_ports: + if self.is_port_free(port): + return port + + raise PortForException("Can't select a port") diff --git a/testgres/node.py b/testgres/node.py index 52e6d2ee..20cf4264 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -623,8 +623,8 @@ def status(self): "-D", self.data_dir, "status" ] # yapf: disable - status_code, out, err = execute_utility(_params, self.utils_log_file, verbose=True) - if 'does not exist' in err: + status_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True) + if error and 'does not exist' in error: return NodeStatus.Uninitialized elif 'no server running' in out: return NodeStatus.Stopped @@ -717,7 +717,7 @@ def start(self, params=[], wait=True): try: exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True) - if 'does not exist' in error: + if error and 'does not exist' in error: raise Exception except Exception as e: msg = 'Cannot start node' @@ -791,7 +791,7 @@ def restart(self, params=[]): try: error_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True) - if 'could not start server' in error: + if error and 'could not start server' in error: raise ExecUtilException except ExecUtilException as e: msg = 'Cannot restart node' diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index 36b14058..93ebf012 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -8,8 +8,7 @@ import psutil from ..exceptions import ExecUtilException -from .os_ops import ConnectionParams, OsOperations -from .os_ops import pglib +from .os_ops import ConnectionParams, OsOperations, pglib, get_default_encoding try: from shutil import which as find_executable @@ -22,6 +21,14 @@ error_markers = [b'error', b'Permission denied', b'fatal'] +def has_errors(output): + if output: + if isinstance(output, str): + output = output.encode(get_default_encoding()) + return any(marker in output for marker in error_markers) + return False + + class LocalOperations(OsOperations): def __init__(self, conn_params=None): if conn_params is None: @@ -33,72 +40,80 @@ def __init__(self, conn_params=None): self.remote = False self.username = conn_params.username or self.get_user() - # Command execution - def exec_command(self, cmd, wait_exit=False, verbose=False, - expect_error=False, encoding=None, shell=False, text=False, - input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - get_process=None, timeout=None): - """ - Execute a command in a subprocess. - - Args: - - cmd: The command to execute. - - wait_exit: Whether to wait for the subprocess to exit before returning. - - verbose: Whether to return verbose output. - - expect_error: Whether to raise an error if the subprocess exits with an error status. - - encoding: The encoding to use for decoding the subprocess output. - - shell: Whether to use shell when executing the subprocess. - - text: Whether to return str instead of bytes for the subprocess output. - - input: The input to pass to the subprocess. - - stdout: The stdout to use for the subprocess. - - stderr: The stderr to use for the subprocess. - - proc: The process to use for subprocess creation. - :return: The output of the subprocess. - """ - if os.name == 'nt': - with tempfile.NamedTemporaryFile() as buf: - process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT) - process.communicate() - buf.seek(0) - result = buf.read().decode(encoding) - return result - else: + @staticmethod + def _raise_exec_exception(message, command, exit_code, output): + """Raise an ExecUtilException.""" + raise ExecUtilException(message=message.format(output), + command=command, + exit_code=exit_code, + out=output) + + @staticmethod + def _process_output(encoding, temp_file_path): + """Process the output of a command from a temporary file.""" + with open(temp_file_path, 'rb') as temp_file: + output = temp_file.read() + if encoding: + output = output.decode(encoding) + return output, None # In Windows stderr writing in stdout + + def _run_command(self, cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding): + """Execute a command and return the process and its output.""" + if os.name == 'nt' and stdout is None: # Windows + with tempfile.NamedTemporaryFile(mode='w+b', delete=False) as temp_file: + stdout = temp_file + stderr = subprocess.STDOUT + process = subprocess.Popen( + cmd, + shell=shell, + stdin=stdin or subprocess.PIPE if input is not None else None, + stdout=stdout, + stderr=stderr, + ) + if get_process: + return process, None, None + temp_file_path = temp_file.name + + # Wait process finished + process.wait() + + output, error = self._process_output(encoding, temp_file_path) + return process, output, error + else: # Other OS process = subprocess.Popen( cmd, shell=shell, - stdout=stdout, - stderr=stderr, + stdin=stdin or subprocess.PIPE if input is not None else None, + stdout=stdout or subprocess.PIPE, + stderr=stderr or subprocess.PIPE, ) if get_process: - return process - + return process, None, None try: - result, error = process.communicate(input, timeout=timeout) + output, error = process.communicate(input=input.encode(encoding) if input else None, timeout=timeout) + if encoding: + output = output.decode(encoding) + error = error.decode(encoding) + return process, output, error except subprocess.TimeoutExpired: process.kill() raise ExecUtilException("Command timed out after {} seconds.".format(timeout)) - exit_status = process.returncode - - error_found = exit_status != 0 or any(marker in error for marker in error_markers) - if encoding: - result = result.decode(encoding) - error = error.decode(encoding) - - if expect_error: - raise Exception(result, error) - - if exit_status != 0 or error_found: - if exit_status == 0: - exit_status = 1 - raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error), - command=cmd, - exit_code=exit_status, - out=result) - if verbose: - return exit_status, result, error - else: - return result + def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=False, + text=False, input=None, stdin=None, stdout=None, stderr=None, get_process=False, timeout=None): + """ + Execute a command in a subprocess and handle the output based on the provided parameters. + """ + process, output, error = self._run_command(cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding) + if get_process: + return process + if process.returncode != 0 or (has_errors(error) and not expect_error): + self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, process.returncode, error) + + if verbose: + return process.returncode, output, error + else: + return output # Environment setup def environ(self, var_name): @@ -210,7 +225,7 @@ def read(self, filename, encoding=None, binary=False): if binary: return content if isinstance(content, bytes): - return content.decode(encoding or 'utf-8') + return content.decode(encoding or get_default_encoding()) return content def readlines(self, filename, num_lines=0, binary=False, encoding=None): diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index 9261cacf..dd6613cf 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -1,3 +1,5 @@ +import locale + try: import psycopg2 as pglib # noqa: F401 except ImportError: @@ -14,6 +16,10 @@ def __init__(self, host='127.0.0.1', ssh_key=None, username=None): self.username = username +def get_default_encoding(): + return locale.getdefaultlocale()[1] or 'UTF-8' + + class OsOperations: def __init__(self, username=None): self.ssh_key = None @@ -75,7 +81,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal def touch(self, filename): raise NotImplementedError() - def read(self, filename): + def read(self, filename, encoding, binary): raise NotImplementedError() def readlines(self, filename): diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 0a545834..01251e1c 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,4 +1,3 @@ -import locale import logging import os import subprocess @@ -15,12 +14,7 @@ raise ImportError("You must have psycopg2 or pg8000 modules installed") from ..exceptions import ExecUtilException - -from .os_ops import OsOperations, ConnectionParams - -ConsoleEncoding = locale.getdefaultlocale()[1] -if not ConsoleEncoding: - ConsoleEncoding = 'UTF-8' +from .os_ops import OsOperations, ConnectionParams, get_default_encoding error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory'] @@ -36,7 +30,7 @@ def kill(self): def cmdline(self): command = "ps -p {} -o cmd --no-headers".format(self.pid) - stdin, stdout, stderr = self.ssh.exec_command(command, verbose=True, encoding=ConsoleEncoding) + stdin, stdout, stderr = self.ssh.exec_command(command, verbose=True, encoding=get_default_encoding()) cmdline = stdout.strip() return cmdline.split() @@ -51,6 +45,10 @@ def __init__(self, conn_params: ConnectionParams): self.conn_params = conn_params self.host = conn_params.host self.ssh_key = conn_params.ssh_key + if self.ssh_key: + self.ssh_cmd = ["-i", self.ssh_key] + else: + self.ssh_cmd = [] self.remote = True self.username = conn_params.username or self.get_user() self.add_known_host(self.host) @@ -97,9 +95,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, """ ssh_cmd = [] if isinstance(cmd, str): - ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key, cmd] + ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + [cmd] elif isinstance(cmd, list): - ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key] + cmd + ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + cmd process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if get_process: return process @@ -145,7 +143,7 @@ def environ(self, var_name: str) -> str: - var_name (str): The name of the environment variable. """ cmd = "echo ${}".format(var_name) - return self.exec_command(cmd, encoding=ConsoleEncoding).strip() + return self.exec_command(cmd, encoding=get_default_encoding()).strip() def find_executable(self, executable): search_paths = self.environ("PATH") @@ -176,11 +174,11 @@ def set_env(self, var_name: str, var_val: str): # Get environment variables def get_user(self): - return self.exec_command("echo $USER", encoding=ConsoleEncoding).strip() + return self.exec_command("echo $USER", encoding=get_default_encoding()).strip() def get_name(self): cmd = 'python3 -c "import os; print(os.name)"' - return self.exec_command(cmd, encoding=ConsoleEncoding).strip() + return self.exec_command(cmd, encoding=get_default_encoding()).strip() # Work with dirs def makedirs(self, path, remove_existing=False): @@ -227,7 +225,7 @@ def listdir(self, path): return result.splitlines() def path_exists(self, path): - result = self.exec_command("test -e {}; echo $?".format(path), encoding=ConsoleEncoding) + result = self.exec_command("test -e {}; echo $?".format(path), encoding=get_default_encoding()) return int(result.strip()) == 0 @property @@ -248,9 +246,9 @@ def mkdtemp(self, prefix=None): - prefix (str): The prefix of the temporary directory name. """ if prefix: - command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"] + command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"] else: - command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", "mktemp -d"] + command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", "mktemp -d"] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) @@ -264,9 +262,9 @@ def mkdtemp(self, prefix=None): def mkstemp(self, prefix=None): if prefix: - temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding=ConsoleEncoding) + temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding=get_default_encoding()) else: - temp_dir = self.exec_command("mktemp", encoding=ConsoleEncoding) + temp_dir = self.exec_command("mktemp", encoding=get_default_encoding()) if temp_dir: if not os.path.isabs(temp_dir): @@ -283,7 +281,9 @@ def copytree(self, src, dst): return self.exec_command("cp -r {} {}".format(src, dst)) # Work with files - def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=ConsoleEncoding): + def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=None): + if not encoding: + encoding = get_default_encoding() mode = "wb" if binary else "w" if not truncate: mode = "ab" if binary else "a" @@ -292,7 +292,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: if not truncate: - scp_cmd = ['scp', '-i', self.ssh_key, f"{self.username}@{self.host}:{filename}", tmp_file.name] + scp_cmd = ['scp'] + self.ssh_cmd + [f"{self.username}@{self.host}:{filename}", tmp_file.name] subprocess.run(scp_cmd, check=False) # The file might not exist yet tmp_file.seek(0, os.SEEK_END) @@ -302,18 +302,17 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal data = data.encode(encoding) if isinstance(data, list): - data = [(s if isinstance(s, str) else s.decode(ConsoleEncoding)).rstrip('\n') + '\n' for s in data] + data = [(s if isinstance(s, str) else s.decode(get_default_encoding())).rstrip('\n') + '\n' for s in data] tmp_file.writelines(data) else: tmp_file.write(data) tmp_file.flush() - - scp_cmd = ['scp', '-i', self.ssh_key, tmp_file.name, f"{self.username}@{self.host}:{filename}"] + scp_cmd = ['scp'] + self.ssh_cmd + [tmp_file.name, f"{self.username}@{self.host}:{filename}"] subprocess.run(scp_cmd, check=True) remote_directory = os.path.dirname(filename) - mkdir_cmd = ['ssh', '-i', self.ssh_key, f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"] + mkdir_cmd = ['ssh'] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"] subprocess.run(mkdir_cmd, check=True) os.remove(tmp_file.name) @@ -334,7 +333,7 @@ def read(self, filename, binary=False, encoding=None): result = self.exec_command(cmd, encoding=encoding) if not binary and result: - result = result.decode(encoding or ConsoleEncoding) + result = result.decode(encoding or get_default_encoding()) return result @@ -347,7 +346,7 @@ def readlines(self, filename, num_lines=0, binary=False, encoding=None): result = self.exec_command(cmd, encoding=encoding) if not binary and result: - lines = result.decode(encoding or ConsoleEncoding).splitlines() + lines = result.decode(encoding or get_default_encoding()).splitlines() else: lines = result.splitlines() @@ -375,10 +374,10 @@ def kill(self, pid, signal): def get_pid(self): # Get current process id - return int(self.exec_command("echo $$", encoding=ConsoleEncoding)) + return int(self.exec_command("echo $$", encoding=get_default_encoding())) def get_process_children(self, pid): - command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"pgrep -P {pid}"] + command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"pgrep -P {pid}"] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) diff --git a/testgres/utils.py b/testgres/utils.py index db75fadc..b21fc2c8 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -4,7 +4,7 @@ from __future__ import print_function import os -import port_for + import sys from contextlib import contextmanager @@ -13,6 +13,7 @@ from six import iteritems +from .helpers.port_manager import PortManager from .exceptions import ExecUtilException from .config import testgres_config as tconf @@ -37,8 +38,8 @@ def reserve_port(): """ Generate a new port and add it to 'bound_ports'. """ - - port = port_for.select_random(exclude_ports=bound_ports) + port_mng = PortManager() + port = port_mng.find_free_port(exclude_ports=bound_ports) bound_ports.add(port) return port @@ -80,7 +81,8 @@ def execute_utility(args, logfile=None, verbose=False): lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] tconf.os_ops.write(filename=logfile, data=lines) except IOError: - raise ExecUtilException("Problem with writing to logfile `{}` during run command `{}`".format(logfile, args)) + raise ExecUtilException( + "Problem with writing to logfile `{}` during run command `{}`".format(logfile, args)) if verbose: return exit_status, out, error else: diff --git a/tests/test_remote.py b/tests/test_remote.py index 2e0f0676..e0e4a555 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -11,10 +11,9 @@ class TestRemoteOperations: @pytest.fixture(scope="function", autouse=True) def setup(self): - conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3', - username='dev', - ssh_key=os.getenv( - 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') + conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1', + username=os.getenv('USER'), + ssh_key=os.getenv('RDBMS_TESTPOOL_SSHKEY')) self.operations = RemoteOperations(conn_params) def test_exec_command_success(self): @@ -41,7 +40,7 @@ def test_is_executable_true(self): """ Test is_executable for an existing executable. """ - cmd = "postgres" + cmd = os.getenv('PG_CONFIG') response = self.operations.is_executable(cmd) assert response is True diff --git a/tests/test_simple.py b/tests/test_simple.py old mode 100755 new mode 100644 index 45c28a21..9d31d4d9 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -74,6 +74,24 @@ def good_properties(f): return True +def rm_carriage_returns(out): + """ + In Windows we have additional '\r' symbols in output. + Let's get rid of them. + """ + if os.name == 'nt': + if isinstance(out, (int, float, complex)): + return out + elif isinstance(out, tuple): + return tuple(rm_carriage_returns(item) for item in out) + elif isinstance(out, bytes): + return out.replace(b'\r', b'') + else: + return out.replace('\r', '') + else: + return out + + @contextmanager def removing(f): try: @@ -123,7 +141,7 @@ def test_init_after_cleanup(self): node.cleanup() node.init().start().execute('select 1') - @unittest.skipUnless(util_exists('pg_resetwal'), 'might be missing') + @unittest.skipUnless(util_exists('pg_resetwal.exe' if os.name == 'nt' else 'pg_resetwal'), 'pgbench might be missing') @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') def test_init_unique_system_id(self): # this function exists in PostgreSQL 9.6+ @@ -254,34 +272,34 @@ def test_psql(self): # check returned values (1 arg) res = node.psql('select 1') - self.assertEqual(res, (0, b'1\n', b'')) + self.assertEqual(rm_carriage_returns(res), (0, b'1\n', b'')) # check returned values (2 args) res = node.psql('postgres', 'select 2') - self.assertEqual(res, (0, b'2\n', b'')) + self.assertEqual(rm_carriage_returns(res), (0, b'2\n', b'')) # check returned values (named) res = node.psql(query='select 3', dbname='postgres') - self.assertEqual(res, (0, b'3\n', b'')) + self.assertEqual(rm_carriage_returns(res), (0, b'3\n', b'')) # check returned values (1 arg) res = node.safe_psql('select 4') - self.assertEqual(res, b'4\n') + self.assertEqual(rm_carriage_returns(res), b'4\n') # check returned values (2 args) res = node.safe_psql('postgres', 'select 5') - self.assertEqual(res, b'5\n') + self.assertEqual(rm_carriage_returns(res), b'5\n') # check returned values (named) res = node.safe_psql(query='select 6', dbname='postgres') - self.assertEqual(res, b'6\n') + self.assertEqual(rm_carriage_returns(res), b'6\n') # check feeding input node.safe_psql('create table horns (w int)') node.safe_psql('copy horns from stdin (format csv)', input=b"1\n2\n3\n\\.\n") _sum = node.safe_psql('select sum(w) from horns') - self.assertEqual(_sum, b'6\n') + self.assertEqual(rm_carriage_returns(_sum), b'6\n') # check psql's default args, fails with self.assertRaises(QueryException): @@ -455,7 +473,7 @@ def test_synchronous_replication(self): master.safe_psql( 'insert into abc select generate_series(1, 1000000)') res = standby1.safe_psql('select count(*) from abc') - self.assertEqual(res, b'1000000\n') + self.assertEqual(rm_carriage_returns(res), b'1000000\n') @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') def test_logical_replication(self): @@ -589,7 +607,7 @@ def test_promotion(self): # make standby becomes writable master replica.safe_psql('insert into abc values (1)') res = replica.safe_psql('select * from abc') - self.assertEqual(res, b'1\n') + self.assertEqual(rm_carriage_returns(res), b'1\n') def test_dump(self): query_create = 'create table test as select generate_series(1, 2) as val' @@ -614,6 +632,7 @@ def test_users(self): with get_new_node().init().start() as node: node.psql('create role test_user login') value = node.safe_psql('select 1', username='test_user') + value = rm_carriage_returns(value) self.assertEqual(value, b'1\n') def test_poll_query_until(self): @@ -728,7 +747,7 @@ def test_logging(self): master.restart() self.assertTrue(master._logger.is_alive()) - @unittest.skipUnless(util_exists('pgbench'), 'might be missing') + @unittest.skipUnless(util_exists('pgbench.exe' if os.name == 'nt' else 'pgbench'), 'pgbench might be missing') def test_pgbench(self): with get_new_node().init().start() as node: @@ -744,6 +763,8 @@ def test_pgbench(self): out, _ = proc.communicate() out = out.decode('utf-8') + proc.stdout.close() + self.assertTrue('tps' in out) def test_pg_config(self): @@ -977,7 +998,9 @@ def test_child_pids(self): def test_child_process_dies(self): # test for FileNotFound exception during child_processes() function - with subprocess.Popen(["sleep", "60"]) as process: + cmd = ["timeout", "60"] if os.name == 'nt' else ["sleep", "60"] + + with subprocess.Popen(cmd, shell=True) as process: # shell=True might be needed on Windows self.assertEqual(process.poll(), None) # collect list of processes currently running children = psutil.Process(os.getpid()).children() diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 1042f3c4..d51820ba 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -52,10 +52,9 @@ from testgres.utils import PgVer from testgres.node import ProcessProxy, ConnectionParams -conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3', - username='dev', - ssh_key=os.getenv( - 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') +conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1', + username=os.getenv('USER'), + ssh_key=os.getenv('RDBMS_TESTPOOL_SSHKEY')) os_ops = RemoteOperations(conn_params) testgres_config.set_os_ops(os_ops=os_ops)