From 1b10a5499b26cb21d28d2c14feda5d9ce98fe7ce Mon Sep 17 00:00:00 2001 From: vshepard Date: Fri, 1 Nov 2024 23:30:38 +0100 Subject: [PATCH 1/9] Add ability to skip ssl when connect to PostgresNode --- testgres/api.py | 4 +- testgres/node.py | 13 ++--- testgres/operations/local_ops.py | 18 +------ testgres/operations/os_ops.py | 28 ++++++++-- testgres/operations/remote_ops.py | 21 +------- testgres/utils.py | 4 +- tests/test_remote.py | 28 ++++++++-- tests/test_simple_remote.py | 89 ++++++++++++++++--------------- 8 files changed, 106 insertions(+), 99 deletions(-) diff --git a/testgres/api.py b/testgres/api.py index e4b1cdd5..10bfd669 100644 --- a/testgres/api.py +++ b/testgres/api.py @@ -42,7 +42,7 @@ def get_new_node(name=None, base_dir=None, **kwargs): return PostgresNode(name=name, base_dir=base_dir, **kwargs) -def get_remote_node(name=None, conn_params=None): +def get_remote_node(name=None): """ Simply a wrapper around :class:`.PostgresNode` constructor for remote node. See :meth:`.PostgresNode.__init__` for details. @@ -51,4 +51,4 @@ def get_remote_node(name=None, conn_params=None): ssh_key=None, username=default_username()) """ - return get_new_node(name=name, conn_params=conn_params) + return get_new_node(name=name) diff --git a/testgres/node.py b/testgres/node.py index c8c8c087..78ebd87b 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -126,7 +126,8 @@ def __repr__(self): class PostgresNode(object): - def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionParams = ConnectionParams(), bin_dir=None, prefix=None): + def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionParams = ConnectionParams(), + bin_dir=None, prefix=None): """ PostgresNode constructor. @@ -150,13 +151,9 @@ def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionP self.name = name or generate_app_name() if testgres_config.os_ops: self.os_ops = testgres_config.os_ops - elif conn_params.ssh_key: - self.os_ops = RemoteOperations(conn_params) - else: - self.os_ops = LocalOperations(conn_params) self.host = self.os_ops.host - self.port = port or reserve_port() + self.port = port or self.os_ops.port or reserve_port() self.ssh_key = self.os_ops.ssh_key @@ -1005,7 +1002,7 @@ def psql(self, # select query source if query: - if self.os_ops.remote: + if self.os_ops.conn_params.remote: psql_params.extend(("-c", '"{}"'.format(query))) else: psql_params.extend(("-c", query)) @@ -1016,7 +1013,7 @@ def psql(self, # should be the last one psql_params.append(dbname) - if not self.os_ops.remote: + if not self.os_ops.conn_params.remote: # start psql process process = subprocess.Popen(psql_params, stdin=subprocess.PIPE, diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index a0a9926d..796e15c2 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -40,12 +40,7 @@ class LocalOperations(OsOperations): def __init__(self, conn_params=None): if conn_params is None: conn_params = ConnectionParams() - super(LocalOperations, self).__init__(conn_params.username) - self.conn_params = conn_params - self.host = conn_params.host - self.ssh_key = None - self.remote = False - self.username = conn_params.username or getpass.getuser() + super(LocalOperations, self).__init__(conn_params) @staticmethod def _raise_exec_exception(message, command, exit_code, output): @@ -305,14 +300,3 @@ def get_pid(self): def get_process_children(self, pid): return psutil.Process(pid).children() - - # Database control - def db_connect(self, dbname, user, password=None, host="localhost", port=5432): - conn = pglib.connect( - host=host, - port=port, - database=dbname, - user=user, - password=password, - ) - return conn diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index 34242040..f027bfef 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -12,11 +12,16 @@ class ConnectionParams: - def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None): + def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None, remote=False, skip_ssl=False): + """ + skip_ssl: if is True, the connection is established without SSL. + """ + self.remote = remote self.host = host self.port = port self.ssh_key = ssh_key self.username = username + self.skip_ssl = skip_ssl def get_default_encoding(): @@ -26,9 +31,12 @@ def get_default_encoding(): class OsOperations: - def __init__(self, username=None): - self.ssh_key = None - self.username = username or getpass.getuser() + def __init__(self, conn_params=ConnectionParams()): + self.ssh_key = conn_params.ssh_key + self.username = conn_params.username or getpass.getuser() + self.host = conn_params.host + self.port = conn_params.port + self.conn_params = conn_params # Command execution def exec_command(self, cmd, **kwargs): @@ -115,4 +123,14 @@ def get_process_children(self, pid): # Database control def db_connect(self, dbname, user, password=None, host="localhost", port=5432): - raise NotImplementedError() + ssl_options = {"sslmode": "disable"} if self.conn_params.skip_ssl and 'psycopg2' in globals() else {} + conn = pglib.connect( + host=host, + port=port, + database=dbname, + user=user, + password=password, + **({"ssl_context": None} if self.conn_params.skip_ssl and 'pg8000' in globals() else ssl_options) + ) + + return conn diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 20095051..24a8b9fe 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -37,23 +37,17 @@ def cmdline(self): class RemoteOperations(OsOperations): def __init__(self, conn_params: ConnectionParams): - if not platform.system().lower() == "linux": raise EnvironmentError("Remote operations are supported only on Linux!") + super().__init__(conn_params) - super().__init__(conn_params.username) - self.conn_params = conn_params - self.host = conn_params.host - self.port = conn_params.port - self.ssh_key = conn_params.ssh_key self.ssh_args = [] if self.ssh_key: self.ssh_args += ["-i", self.ssh_key] if self.port: self.ssh_args += ["-p", self.port] - self.remote = True self.username = conn_params.username or getpass.getuser() - self.ssh_dest = f"{self.username}@{self.host}" if conn_params.username else self.host + self.ssh_dest = f"{self.username}@{self.host}" if self.username else self.host def __enter__(self): return self @@ -361,17 +355,6 @@ def get_process_children(self, pid): else: raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}") - # Database control - def db_connect(self, dbname, user, password=None, host="localhost", port=5432): - conn = pglib.connect( - host=host, - port=port, - database=dbname, - user=user, - password=password, - ) - return conn - def normalize_error(error): if isinstance(error, bytes): diff --git a/testgres/utils.py b/testgres/utils.py index a4ee7877..aa61d270 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -97,7 +97,7 @@ def get_bin_path(filename): # check if it's already absolute if os.path.isabs(filename): return filename - if tconf.os_ops.remote: + if tconf.os_ops.conn_params.remote: pg_config = os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") else: # try PG_CONFIG - get from local machine @@ -154,7 +154,7 @@ def cache_pg_config_data(cmd): return _pg_config_data # try specified pg_config path or PG_CONFIG - if tconf.os_ops.remote: + if tconf.os_ops.conn_params.remote: pg_config = pg_config_path or os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") else: # try PG_CONFIG - get from local machine diff --git a/tests/test_remote.py b/tests/test_remote.py index e0e4a555..bb13108d 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -2,7 +2,7 @@ import pytest -from testgres import ExecUtilException +from testgres import ExecUtilException, get_remote_node, testgres_config from testgres import RemoteOperations from testgres import ConnectionParams @@ -34,7 +34,7 @@ def test_exec_command_failure(self): exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) except ExecUtilException as e: error = e.message - assert error == b'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n' + assert error == 'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n' def test_is_executable_true(self): """ @@ -87,7 +87,7 @@ def test_makedirs_and_rmdirs_failure(self): exit_status, result, error = self.operations.rmdirs(path, verbose=True) except ExecUtilException as e: error = e.message - assert error == b"Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n" + assert error == "Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n" def test_listdir(self): """ @@ -192,3 +192,25 @@ def test_isfile_false(self): response = self.operations.isfile(filename) assert response is False + + def test_skip_ssl(self): + conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1', + username=os.getenv('USER'), + remote=True, + skip_ssl=True) + os_ops = RemoteOperations(conn_params) + testgres_config.set_os_ops(os_ops=os_ops) + with get_remote_node().init().start() as node: + with node.connect() as con: + con.begin() + con.execute('create table test(val int)') + con.execute('insert into test values (1)') + con.commit() + + con.begin() + con.execute('insert into test values (2)') + res = con.execute('select * from test order by val asc') + if isinstance(res, list): + res.sort() + assert res == [(1,), (2,)] + diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index d51820ba..0d6f3dd5 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -96,16 +96,16 @@ def removing(f): class TestgresRemoteTests(unittest.TestCase): def test_node_repr(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: pattern = r"PostgresNode\(name='.+', port=.+, base_dir='.+'\)" self.assertIsNotNone(re.match(pattern, str(node))) def test_custom_init(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: # enable page checksums node.init(initdb_params=['-k']).start() - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init( allow_streaming=True, initdb_params=['--auth-local=reject', '--auth-host=reject']) @@ -120,13 +120,13 @@ def test_custom_init(self): self.assertFalse(any('trust' in s for s in lines)) def test_double_init(self): - with get_remote_node(conn_params=conn_params).init() as node: + with get_remote_node().init() as node: # can't initialize node more than once with self.assertRaises(InitNodeException): node.init() def test_init_after_cleanup(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init().start().execute('select 1') node.cleanup() node.init().start().execute('select 1') @@ -138,7 +138,7 @@ def test_init_unique_system_id(self): query = 'select system_identifier from pg_control_system()' with scoped_config(cache_initdb=False): - with get_remote_node(conn_params=conn_params).init().start() as node0: + with get_remote_node().init().start() as node0: id0 = node0.execute(query)[0] with scoped_config(cache_initdb=True, @@ -147,8 +147,8 @@ def test_init_unique_system_id(self): self.assertTrue(config.cached_initdb_unique) # spawn two nodes; ids must be different - with get_remote_node(conn_params=conn_params).init().start() as node1, \ - get_remote_node(conn_params=conn_params).init().start() as node2: + with get_remote_node().init().start() as node1, \ + get_remote_node().init().start() as node2: id1 = node1.execute(query)[0] id2 = node2.execute(query)[0] @@ -158,7 +158,7 @@ def test_init_unique_system_id(self): def test_node_exit(self): with self.assertRaises(QueryException): - with get_remote_node(conn_params=conn_params).init() as node: + with get_remote_node().init() as node: base_dir = node.base_dir node.safe_psql('select 1') @@ -166,26 +166,26 @@ def test_node_exit(self): self.assertTrue(os_ops.path_exists(base_dir)) os_ops.rmdirs(base_dir, ignore_errors=True) - with get_remote_node(conn_params=conn_params).init() as node: + with get_remote_node().init() as node: base_dir = node.base_dir # should have been removed by default self.assertFalse(os_ops.path_exists(base_dir)) def test_double_start(self): - with get_remote_node(conn_params=conn_params).init().start() as node: + with get_remote_node().init().start() as node: # can't start node more than once node.start() self.assertTrue(node.is_started) def test_uninitialized_start(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: # node is not initialized yet with self.assertRaises(StartNodeException): node.start() def test_restart(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init().start() # restart, ok @@ -201,7 +201,7 @@ def test_restart(self): node.restart() def test_reload(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init().start() # change client_min_messages and save old value @@ -217,7 +217,7 @@ def test_reload(self): self.assertNotEqual(cmm_old, cmm_new) def test_pg_ctl(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init().start() status = node.pg_ctl(['status']) @@ -229,7 +229,7 @@ def test_status(self): self.assertFalse(NodeStatus.Uninitialized) # check statuses after each operation - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: self.assertEqual(node.pid, 0) self.assertEqual(node.status(), NodeStatus.Uninitialized) @@ -254,7 +254,7 @@ def test_status(self): self.assertEqual(node.status(), NodeStatus.Uninitialized) def test_psql(self): - with get_remote_node(conn_params=conn_params).init().start() as node: + with get_remote_node().init().start() as node: # check returned values (1 arg) res = node.psql('select 1') self.assertEqual(res, (0, b'1\n', b'')) @@ -297,7 +297,7 @@ def test_psql(self): node.safe_psql('select 1') def test_transactions(self): - with get_remote_node(conn_params=conn_params).init().start() as node: + with get_remote_node().init().start() as node: with node.connect() as con: con.begin() con.execute('create table test(val int)') @@ -320,7 +320,7 @@ def test_transactions(self): con.commit() def test_control_data(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: # node is not initialized yet with self.assertRaises(ExecUtilException): node.get_control_data() @@ -333,7 +333,7 @@ def test_control_data(self): self.assertTrue(any('pg_control' in s for s in data.keys())) def test_backup_simple(self): - with get_remote_node(conn_params=conn_params) as master: + with get_remote_node() as master: # enable streaming for backups master.init(allow_streaming=True) @@ -353,7 +353,7 @@ def test_backup_simple(self): self.assertListEqual(res, [(1,), (2,), (3,), (4,)]) def test_backup_multiple(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init(allow_streaming=True).start() with node.backup(xlog_method='fetch') as backup1, \ @@ -366,7 +366,7 @@ def test_backup_multiple(self): self.assertNotEqual(node1.base_dir, node2.base_dir) def test_backup_exhaust(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init(allow_streaming=True).start() with node.backup(xlog_method='fetch') as backup: @@ -379,7 +379,7 @@ def test_backup_exhaust(self): backup.spawn_primary() def test_backup_wrong_xlog_method(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init(allow_streaming=True).start() with self.assertRaises(BackupException, @@ -387,7 +387,7 @@ def test_backup_wrong_xlog_method(self): node.backup(xlog_method='wrong') def test_pg_ctl_wait_option(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init().start(wait=False) while True: try: @@ -399,7 +399,7 @@ def test_pg_ctl_wait_option(self): pass def test_replicate(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init(allow_streaming=True).start() with node.replicate().start() as replica: @@ -415,7 +415,7 @@ def test_replicate(self): @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') def test_synchronous_replication(self): - with get_remote_node(conn_params=conn_params) as master: + with get_remote_node() as master: old_version = not pg_version_ge('9.6') master.init(allow_streaming=True).start() @@ -456,7 +456,7 @@ def test_synchronous_replication(self): @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') def test_logical_replication(self): - with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: + with get_remote_node() as node1, get_remote_node() as node2: node1.init(allow_logical=True) node1.start() node2.init().start() @@ -526,7 +526,7 @@ def test_logical_replication(self): @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') def test_logical_catchup(self): """ Runs catchup for 100 times to be sure that it is consistent """ - with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: + with get_remote_node() as node1, get_remote_node() as node2: node1.init(allow_logical=True) node1.start() node2.init().start() @@ -551,12 +551,12 @@ def test_logical_catchup(self): @unittest.skipIf(pg_version_ge('10'), 'requires <10') def test_logical_replication_fail(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: with self.assertRaises(InitNodeException): node.init(allow_logical=True) def test_replication_slots(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init(allow_streaming=True).start() with node.replicate(slot='slot1').start() as replica: @@ -567,7 +567,7 @@ def test_replication_slots(self): node.replicate(slot='slot1') def test_incorrect_catchup(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init(allow_streaming=True).start() # node has no master, can't catch up @@ -575,7 +575,7 @@ def test_incorrect_catchup(self): node.catchup() def test_promotion(self): - with get_remote_node(conn_params=conn_params) as master: + with get_remote_node() as master: master.init().start() master.safe_psql('create table abc(id serial)') @@ -592,12 +592,12 @@ def test_dump(self): query_create = 'create table test as select generate_series(1, 2) as val' query_select = 'select * from test order by val asc' - with get_remote_node(conn_params=conn_params).init().start() as node1: + with get_remote_node().init().start() as node1: node1.execute(query_create) for format in ['plain', 'custom', 'directory', 'tar']: with removing(node1.dump(format=format)) as dump: - with get_remote_node(conn_params=conn_params).init().start() as node3: + with get_remote_node().init().start() as node3: if format == 'directory': self.assertTrue(os_ops.isdir(dump)) else: @@ -608,13 +608,13 @@ def test_dump(self): self.assertListEqual(res, [(1,), (2,)]) def test_users(self): - with get_remote_node(conn_params=conn_params).init().start() as node: + with get_remote_node().init().start() as node: node.psql('create role test_user login') value = node.safe_psql('select 1', username='test_user') self.assertEqual(b'1\n', value) def test_poll_query_until(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init().start() get_time = 'select extract(epoch from now())' @@ -728,7 +728,7 @@ def test_logging(self): @unittest.skipUnless(util_exists('pgbench'), 'might be missing') def test_pgbench(self): - with get_remote_node(conn_params=conn_params).init().start() as node: + with get_remote_node().init().start() as node: # initialize pgbench DB and run benchmarks node.pgbench_init(scale=2, foreign_keys=True, options=['-q']).pgbench_run(time=2) @@ -796,7 +796,7 @@ def test_config_stack(self): self.assertEqual(TestgresConfig.cached_initdb_dir, d0) def test_unix_sockets(self): - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: node.init(unix_sockets=False, allow_streaming=True) node.start() @@ -812,7 +812,7 @@ def test_unix_sockets(self): self.assertEqual(res_psql, b'1\n') def test_auto_name(self): - with get_remote_node(conn_params=conn_params).init(allow_streaming=True).start() as m: + with get_remote_node().init(allow_streaming=True).start() as m: with m.replicate().start() as r: # check that nodes are running self.assertTrue(m.status()) @@ -849,7 +849,7 @@ def test_file_tail(self): self.assertEqual(lines[0], s3) def test_isolation_levels(self): - with get_remote_node(conn_params=conn_params).init().start() as node: + with get_remote_node().init().start() as node: with node.connect() as con: # string levels con.begin('Read Uncommitted').commit() @@ -871,7 +871,7 @@ def test_ports_management(self): # check that no ports have been bound yet self.assertEqual(len(bound_ports), 0) - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: # check that we've just bound a port self.assertEqual(len(bound_ports), 1) @@ -904,7 +904,7 @@ def test_version_management(self): self.assertTrue(d > f) version = get_pg_version() - with get_remote_node(conn_params=conn_params) as node: + with get_remote_node() as node: self.assertTrue(isinstance(version, six.string_types)) self.assertTrue(isinstance(node.version, PgVer)) self.assertEqual(node.version, PgVer(version)) @@ -922,12 +922,15 @@ def test_child_pids(self): if pg_version_ge('10'): master_processes.append(ProcessType.LogicalReplicationLauncher) + if pg_version_ge('14'): + master_processes.remove(ProcessType.StatsCollector) + repl_processes = [ ProcessType.Startup, ProcessType.WalReceiver, ] - with get_remote_node(conn_params=conn_params).init().start() as master: + with get_remote_node().init().start() as master: # master node doesn't have a source walsender! with self.assertRaises(TestgresException): From 104a127152cc4d1e4aad5013d54050294a29eec5 Mon Sep 17 00:00:00 2001 From: vshepard Date: Fri, 1 Nov 2024 23:45:34 +0100 Subject: [PATCH 2/9] Don't reserve a new port if port was set up --- testgres/node.py | 88 ++++++++++++++++++++++++++++++------------------ 1 file changed, 56 insertions(+), 32 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 78ebd87b..3028e1bc 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -96,7 +96,6 @@ from .operations.os_ops import ConnectionParams from .operations.local_ops import LocalOperations -from .operations.remote_ops import RemoteOperations InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError @@ -487,7 +486,7 @@ def init(self, initdb_params=None, cached=True, **kwargs): os_ops=self.os_ops, params=initdb_params, bin_path=self.bin_dir, - cached=False) + cached=cached) # initialize default config files self.default_conf(**kwargs) @@ -717,9 +716,9 @@ def slow_start(self, replica=False, dbname='template1', username=None, max_attem OperationalError}, max_attempts=max_attempts) - def start(self, params=[], wait=True): + def start(self, params=None, wait: bool = True) -> 'PostgresNode': """ - Starts the PostgreSQL node using pg_ctl if node has not been started. + Starts the PostgreSQL node using pg_ctl if the node has not been started. By default, it waits for the operation to complete before returning. Optionally, it can return immediately without waiting for the start operation to complete by setting the `wait` parameter to False. @@ -731,45 +730,62 @@ def start(self, params=[], wait=True): Returns: This instance of :class:`.PostgresNode`. """ + if params is None: + params = [] if self.is_started: return self _params = [ - self._get_bin_path("pg_ctl"), - "-D", self.data_dir, - "-l", self.pg_log_file, - "-w" if wait else '-W', # --wait or --no-wait - "start" - ] + params # yapf: disable + self._get_bin_path("pg_ctl"), + "-D", self.data_dir, + "-l", self.pg_log_file, + "-w" if wait else '-W', # --wait or --no-wait + "start" + ] + params # yapf: disable - startup_retries = 5 - while True: + max_retries = 5 + sleep_interval = 5 # seconds + + for attempt in range(max_retries): try: exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True) if error and 'does not exist' in error: raise Exception + break # Exit the loop if successful except Exception as e: - files = self._collect_special_files() - if any(len(file) > 1 and 'Is another postmaster already ' - 'running on port' in file[1].decode() for - file in files): - logging.warning("Detected an issue with connecting to port {0}. " - "Trying another port after a 5-second sleep...".format(self.port)) - self.port = reserve_port() - options = {'port': str(self.port)} - self.set_auto_conf(options) - startup_retries -= 1 - time.sleep(5) - continue - - msg = 'Cannot start node' - raise_from(StartNodeException(msg, files), e) - break + if self._handle_port_conflict(): + if attempt < max_retries - 1: + logging.info(f"Retrying start operation (Attempt {attempt + 2}/{max_retries})...") + time.sleep(sleep_interval) + continue + else: + logging.error("Reached maximum retry attempts. Unable to start node.") + raise StartNodeException("Cannot start node after multiple attempts", + self._collect_special_files()) from e + raise StartNodeException("Cannot start node", self._collect_special_files()) from e + self._maybe_start_logger() self.is_started = True return self - def stop(self, params=[], wait=True): + def _handle_port_conflict(self) -> bool: + """ + Checks for a port conflict and attempts to resolve it by changing the port. + Returns True if the port was changed, False otherwise. + """ + files = self._collect_special_files() + if any(len(file) > 1 and 'Is another postmaster already running on port' in file[1].decode() for file in files): + logging.warning(f"Port conflict detected on port {self.port}.") + if self._should_free_port: + logging.warning("Port reservation skipped due to _should_free_port setting.") + return False + self.port = reserve_port() + self.set_auto_conf({'port': str(self.port)}) + logging.info(f"Port changed to {self.port}.") + return True + return False + + def stop(self, params=None, wait=True): """ Stops the PostgreSQL node using pg_ctl if the node has been started. @@ -780,6 +796,8 @@ def stop(self, params=[], wait=True): Returns: This instance of :class:`.PostgresNode`. """ + if params is None: + params = [] if not self.is_started: return self @@ -812,7 +830,7 @@ def kill(self, someone=None): os.kill(self.auxiliary_pids[someone][0], sig) self.is_started = False - def restart(self, params=[]): + def restart(self, params=None): """ Restart this node using pg_ctl. @@ -823,6 +841,8 @@ def restart(self, params=[]): This instance of :class:`.PostgresNode`. """ + if params is None: + params = [] _params = [ self._get_bin_path("pg_ctl"), "-D", self.data_dir, @@ -844,7 +864,7 @@ def restart(self, params=[]): return self - def reload(self, params=[]): + def reload(self, params=None): """ Asynchronously reload config files using pg_ctl. @@ -855,6 +875,8 @@ def reload(self, params=[]): This instance of :class:`.PostgresNode`. """ + if params is None: + params = [] _params = [ self._get_bin_path("pg_ctl"), "-D", self.data_dir, @@ -1587,7 +1609,7 @@ def pgbench_table_checksums(self, dbname="postgres", return {(table, self.table_checksum(table, dbname)) for table in pgbench_tables} - def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): + def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options=None): """ Update or remove configuration options in the specified configuration file, updates the options specified in the options dictionary, removes any options @@ -1603,6 +1625,8 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): Defaults to an empty set. """ # parse postgresql.auto.conf + if rm_options is None: + rm_options = {} path = os.path.join(self.data_dir, config) lines = self.os_ops.readlines(path) From 19ef23f130cb0c534aa4128faf80237c6c0c5666 Mon Sep 17 00:00:00 2001 From: vshepard Date: Mon, 18 Nov 2024 11:39:50 +0100 Subject: [PATCH 3/9] Fix flake8 style --- testgres/node.py | 12 +++++------- testgres/operations/local_ops.py | 3 +-- testgres/operations/remote_ops.py | 9 --------- 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 3028e1bc..d580e49e 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -735,13 +735,11 @@ def start(self, params=None, wait: bool = True) -> 'PostgresNode': if self.is_started: return self - _params = [ - self._get_bin_path("pg_ctl"), - "-D", self.data_dir, - "-l", self.pg_log_file, - "-w" if wait else '-W', # --wait or --no-wait - "start" - ] + params # yapf: disable + _params = [self._get_bin_path("pg_ctl"), + "-D", self.data_dir, + "-l", self.pg_log_file, + "-w" if wait else '-W', # --wait or --no-wait + "start"] + params # yapf: disable max_retries = 5 sleep_interval = 5 # seconds diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index 796e15c2..7d3a99eb 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -1,4 +1,3 @@ -import getpass import logging import os import shutil @@ -10,7 +9,7 @@ import psutil from ..exceptions import ExecUtilException -from .os_ops import ConnectionParams, OsOperations, pglib, get_default_encoding +from .os_ops import ConnectionParams, OsOperations, get_default_encoding try: from shutil import which as find_executable diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 24a8b9fe..ed88d1e4 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -4,15 +4,6 @@ import subprocess import tempfile -# we support both pg8000 and psycopg2 -try: - import psycopg2 as pglib -except ImportError: - try: - import pg8000 as pglib - except ImportError: - raise ImportError("You must have psycopg2 or pg8000 modules installed") - from ..exceptions import ExecUtilException from .os_ops import OsOperations, ConnectionParams, get_default_encoding From 43c91c06b886830ee0d0b1970595e9121c671b5b Mon Sep 17 00:00:00 2001 From: vshepard Date: Mon, 18 Nov 2024 11:40:04 +0100 Subject: [PATCH 4/9] Fix test_the_same_port --- tests/test_remote.py | 1 - tests/test_simple.py | 9 ++++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_remote.py b/tests/test_remote.py index bb13108d..bab08f93 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -213,4 +213,3 @@ def test_skip_ssl(self): if isinstance(res, list): res.sort() assert res == [(1,), (2,)] - diff --git a/tests/test_simple.py b/tests/test_simple.py index 8f85a23b..068dbca2 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -1039,10 +1039,9 @@ def test_parse_pg_version(self): def test_the_same_port(self): with get_new_node() as node: node.init().start() - - with get_new_node() as node2: - node2.port = node.port - node2.init().start() + with get_new_node() as node2: + node2.port = node.port + node2.init().start() def test_make_simple_with_bin_dir(self): with get_new_node() as node: @@ -1059,7 +1058,7 @@ def test_make_simple_with_bin_dir(self): wrong_bin_dir.slow_start() raise RuntimeError("Error was expected.") # We should not reach this except FileNotFoundError: - pass # Expected error + pass # Expected error if __name__ == '__main__': From 1e8d91280bd2a42462b18e69c25ba1db7b984663 Mon Sep 17 00:00:00 2001 From: vshepard Date: Mon, 18 Nov 2024 14:48:08 +0100 Subject: [PATCH 5/9] Fix failed test_ports_management --- testgres/node.py | 5 +++-- tests/test_simple.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index d580e49e..53de3163 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -152,8 +152,6 @@ def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionP self.os_ops = testgres_config.os_ops self.host = self.os_ops.host - self.port = port or self.os_ops.port or reserve_port() - self.ssh_key = self.os_ops.ssh_key # defaults for __exit__() @@ -161,6 +159,8 @@ def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionP self.cleanup_on_bad_exit = testgres_config.node_cleanup_on_bad_exit self.shutdown_max_attempts = 3 + self.port = port or self.os_ops.port or reserve_port() + # NOTE: for compatibility self.utils_log_name = self.utils_log_file self.pg_log_name = self.pg_log_file @@ -810,6 +810,7 @@ def stop(self, params=None, wait=True): self._maybe_stop_logger() self.is_started = False + release_port(self.port) return self def kill(self, someone=None): diff --git a/tests/test_simple.py b/tests/test_simple.py index 068dbca2..b8c07958 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -1052,6 +1052,7 @@ def test_make_simple_with_bin_dir(self): correct_bin_dir = app.make_simple(base_dir=node.base_dir, bin_dir=bin_dir) correct_bin_dir.slow_start() correct_bin_dir.safe_psql("SELECT 1;") + correct_bin_dir.stop() try: wrong_bin_dir = app.make_empty(base_dir=node.base_dir, bin_dir="wrong/path") From f1d28b44952a4aa1cf6eff4b0e1e4761b35e89c9 Mon Sep 17 00:00:00 2001 From: vshepard Date: Mon, 18 Nov 2024 18:14:01 +0100 Subject: [PATCH 6/9] Add env variable TESTGRES_SKIP_SSL --- testgres/operations/os_ops.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index f027bfef..5dfd7841 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -1,5 +1,6 @@ import getpass import locale +import os import sys try: @@ -12,15 +13,17 @@ class ConnectionParams: - def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None, remote=False, skip_ssl=False): + def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None, remote=False, skip_ssl=None): """ - skip_ssl: if is True, the connection is established without SSL. + skip_ssl: if is True, the connection to database is established without SSL. """ self.remote = remote self.host = host self.port = port self.ssh_key = ssh_key self.username = username + if skip_ssl is None: + skip_ssl = os.getenv("TESTGRES_SKIP_SSL", False) self.skip_ssl = skip_ssl From e729c2f2fecc0a63fef96c50cfcf4624e66901ef Mon Sep 17 00:00:00 2001 From: vshepard Date: Mon, 18 Nov 2024 20:00:12 +0100 Subject: [PATCH 7/9] Fix sys.modules instead globals --- testgres/operations/os_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index 5dfd7841..1e1575ce 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -126,14 +126,14 @@ def get_process_children(self, pid): # Database control def db_connect(self, dbname, user, password=None, host="localhost", port=5432): - ssl_options = {"sslmode": "disable"} if self.conn_params.skip_ssl and 'psycopg2' in globals() else {} + ssl_options = {"sslmode": "disable"} if self.conn_params.skip_ssl and 'psycopg2' in sys.modules else {} conn = pglib.connect( host=host, port=port, database=dbname, user=user, password=password, - **({"ssl_context": None} if self.conn_params.skip_ssl and 'pg8000' in globals() else ssl_options) + **({"ssl_context": None} if self.conn_params.skip_ssl and 'pg8000' in sys.modules else ssl_options) ) return conn From dc4b4c3dbf26e1d019b7fe0395a2fb8f933b0c38 Mon Sep 17 00:00:00 2001 From: vshepard Date: Mon, 18 Nov 2024 20:16:26 +0100 Subject: [PATCH 8/9] Move _get_ssl_options in separate function --- testgres/operations/os_ops.py | 15 +++++++++++++-- tests/test_simple.py | 8 +++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index 1e1575ce..b77762df 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -124,16 +124,27 @@ def get_pid(self): def get_process_children(self, pid): raise NotImplementedError() + def _get_ssl_options(self): + """ + Determine the SSL options based on available modules. + """ + if self.conn_params.skip_ssl: + if 'psycopg2' in sys.modules: + return {"sslmode": "disable"} + elif 'pg8000' in sys.modules: + return {"ssl_context": None} + return {} + # Database control def db_connect(self, dbname, user, password=None, host="localhost", port=5432): - ssl_options = {"sslmode": "disable"} if self.conn_params.skip_ssl and 'psycopg2' in sys.modules else {} + ssl_options = self._get_ssl_options() conn = pglib.connect( host=host, port=port, database=dbname, user=user, password=password, - **({"ssl_context": None} if self.conn_params.skip_ssl and 'pg8000' in sys.modules else ssl_options) + **ssl_options ) return conn diff --git a/tests/test_simple.py b/tests/test_simple.py index b8c07958..9cc48e7e 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -1039,9 +1039,11 @@ def test_parse_pg_version(self): def test_the_same_port(self): with get_new_node() as node: node.init().start() - with get_new_node() as node2: - node2.port = node.port - node2.init().start() + with get_new_node() as node2: + node2.port = node.port + # _should_free_port is true if port was set up manually + node2._should_free_port = False + node2.init().start() def test_make_simple_with_bin_dir(self): with get_new_node() as node: From fa6d7519018de0d94a3474d332d5f049cd7889ea Mon Sep 17 00:00:00 2001 From: "d.kovalenko" Date: Tue, 24 Dec 2024 18:21:41 +0300 Subject: [PATCH 9/9] [BUG FIX] TestgresRemoteTests.test_safe_psql__expect_error is corrected get_remote_node() must be called without any parameters. --- tests/test_simple_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 8d700e52..da671a5d 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -297,7 +297,7 @@ def test_psql(self): node.safe_psql('select 1') def test_safe_psql__expect_error(self): - with get_remote_node(conn_params=conn_params).init().start() as node: + with get_remote_node().init().start() as node: err = node.safe_psql('select_or_not_select 1', expect_error=True) self.assertTrue(type(err) == str) # noqa: E721 self.assertIn('select_or_not_select', err)