diff --git a/setup.py b/setup.py index a5dc600e..5c6f4a07 100755 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ readme = f.read() setup( - version='1.8.5', + version='1.8.6', name='testgres', packages=['testgres'], description='Testing utility for PostgreSQL and its extensions', diff --git a/testgres/__init__.py b/testgres/__init__.py index 9d5e37cf..1b33ba3b 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -32,7 +32,7 @@ ProcessType, \ DumpFormat -from .node import PostgresNode +from .node import PostgresNode, NodeApp from .utils import \ reserve_port, \ @@ -53,7 +53,7 @@ "NodeConnection", "DatabaseError", "InternalError", "ProgrammingError", "OperationalError", "TestgresException", "ExecUtilException", "QueryException", "TimeoutException", "CatchUpException", "StartNodeException", "InitNodeException", "BackupException", "XLogMethod", "IsolationLevel", "NodeStatus", "ProcessType", "DumpFormat", - "PostgresNode", + "PostgresNode", "NodeApp", "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version", "First", "Any", ] diff --git a/testgres/node.py b/testgres/node.py index 378e6803..e6ac44b0 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -2,15 +2,31 @@ import io import os +import random +import shutil +import signal +import threading +from queue import Queue + import psutil import subprocess import time + try: from collections.abc import Iterable except ImportError: from collections import Iterable +# we support both pg8000 and psycopg2 +try: + import psycopg2 as pglib +except ImportError: + try: + import pg8000 as pglib + except ImportError: + raise ImportError("You must have psycopg2 or pg8000 modules installed") + from shutil import rmtree from six import raise_from, iteritems, text_type from tempfile import mkstemp, mkdtemp @@ -86,6 +102,10 @@ from .backup import NodeBackup +InternalError = pglib.InternalError +ProgrammingError = pglib.ProgrammingError +OperationalError = pglib.OperationalError + class ProcessProxy(object): """ @@ -95,6 +115,7 @@ class ProcessProxy(object): process: wrapped psutill.Process object ptype: instance of ProcessType """ + def __init__(self, process, ptype=None): self.process = process self.ptype = ptype or ProcessType.from_process(process) @@ -140,6 +161,9 @@ def __init__(self, name=None, port=None, base_dir=None): self.utils_log_name = self.utils_log_file self.pg_log_name = self.pg_log_file + # Node state + self.is_started = False + def __enter__(self): return self @@ -629,9 +653,39 @@ def get_control_data(self): return out_dict + def slow_start(self, replica=False, dbname='template1', username=default_username()): + """ + Starts the PostgreSQL instance and then polls the instance + until it reaches the expected state (primary or replica). The state is checked + using the pg_is_in_recovery() function. + + Args: + dbname: + username: + replica: If True, waits for the instance to be in recovery (i.e., replica mode). + If False, waits for the instance to be in primary mode. Default is False. + """ + self.start() + + if replica: + query = 'SELECT pg_is_in_recovery()' + else: + query = 'SELECT not pg_is_in_recovery()' + # Call poll_query_until until the expected value is returned + self.poll_query_until(query=query, + dbname=dbname, + username=username, + suppress={InternalError, + QueryException, + ProgrammingError, + OperationalError}) + def start(self, params=[], wait=True): """ - Start this node using pg_ctl. + Starts the PostgreSQL node using pg_ctl if node has not been started. + By default, it waits for the operation to complete before returning. + Optionally, it can return immediately without waiting for the start operation + to complete by setting the `wait` parameter to False. Args: params: additional arguments for pg_ctl. @@ -640,6 +694,8 @@ def start(self, params=[], wait=True): Returns: This instance of :class:`.PostgresNode`. """ + if self.is_started: + return self _params = [ get_bin_path("pg_ctl"), @@ -657,20 +713,22 @@ def start(self, params=[], wait=True): raise_from(StartNodeException(msg, files), e) self._maybe_start_logger() - + self.is_started = True return self def stop(self, params=[], wait=True): """ - Stop this node using pg_ctl. + Stops the PostgreSQL node using pg_ctl if the node has been started. Args: - params: additional arguments for pg_ctl. - wait: wait until operation completes. + params: A list of additional arguments for pg_ctl. Defaults to None. + wait: If True, waits until the operation is complete. Defaults to True. Returns: This instance of :class:`.PostgresNode`. """ + if not self.is_started: + return self _params = [ get_bin_path("pg_ctl"), @@ -682,9 +740,25 @@ def stop(self, params=[], wait=True): execute_utility(_params, self.utils_log_file) self._maybe_stop_logger() - + self.is_started = False return self + def kill(self, someone=None): + """ + Kills the PostgreSQL node or a specified auxiliary process if the node is running. + + Args: + someone: A key to the auxiliary process in the auxiliary_pids dictionary. + If None, the main PostgreSQL node process will be killed. Defaults to None. + """ + if self.is_started: + sig = signal.SIGKILL if os.name != 'nt' else signal.SIGBREAK + if someone is None: + os.kill(self.pid, sig) + else: + os.kill(self.auxiliary_pids[someone][0], sig) + self.is_started = False + def restart(self, params=[]): """ Restart this node using pg_ctl. @@ -894,7 +968,7 @@ def psql(self, return process.returncode, out, err @method_decorator(positional_args_hack(['dbname', 'query'])) - def safe_psql(self, query=None, **kwargs): + def safe_psql(self, query=None, expect_error=False, **kwargs): """ Execute a query using psql. @@ -904,6 +978,8 @@ def safe_psql(self, query=None, **kwargs): dbname: database name to connect to. username: database user name. input: raw input to be passed. + expect_error: if True - fail if we didn't get ret + if False - fail if we got ret **kwargs are passed to psql(). @@ -916,7 +992,12 @@ def safe_psql(self, query=None, **kwargs): ret, out, err = self.psql(query=query, **kwargs) if ret: - raise QueryException((err or b'').decode('utf-8'), query) + if expect_error: + out = (err or b'').decode('utf-8') + else: + raise QueryException((err or b'').decode('utf-8'), query) + elif expect_error: + assert False, f"Exception was expected, but query finished successfully: `{query}` " return out @@ -1359,3 +1440,211 @@ def connect(self, username=username, password=password, autocommit=autocommit) # yapf: disable + + def table_checksum(self, table, dbname="postgres"): + con = self.connect(dbname=dbname) + + curname = "cur_" + str(random.randint(0, 2 ** 48)) + + con.execute(""" + DECLARE %s NO SCROLL CURSOR FOR + SELECT t::text FROM %s as t + """ % (curname, table)) + + que = Queue(maxsize=50) + sum = 0 + + rows = con.execute("FETCH FORWARD 2000 FROM %s" % curname) + if not rows: + return 0 + que.put(rows) + + th = None + if len(rows) == 2000: + def querier(): + try: + while True: + rows = con.execute("FETCH FORWARD 2000 FROM %s" % curname) + if not rows: + break + que.put(rows) + except Exception as e: + que.put(e) + else: + que.put(None) + + th = threading.Thread(target=querier) + th.start() + else: + que.put(None) + + while True: + rows = que.get() + if rows is None: + break + if isinstance(rows, Exception): + raise rows + # hash uses SipHash since Python3.4, therefore it is good enough + for row in rows: + sum += hash(row[0]) + + if th is not None: + th.join() + + con.execute("CLOSE %s; ROLLBACK;" % curname) + + con.close() + return sum + + def pgbench_table_checksums(self, dbname="postgres", + pgbench_tables=('pgbench_branches', + 'pgbench_tellers', + 'pgbench_accounts', + 'pgbench_history') + ): + return {(table, self.table_checksum(table, dbname)) + for table in pgbench_tables} + + def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): + """ + Update or remove configuration options in the specified configuration file, + updates the options specified in the options dictionary, removes any options + specified in the rm_options set, and writes the updated configuration back to + the file. + + Args: + options (dict): A dictionary containing the options to update or add, + with the option names as keys and their values as values. + config (str, optional): The name of the configuration file to update. + Defaults to 'postgresql.auto.conf'. + rm_options (set, optional): A set containing the names of the options to remove. + Defaults to an empty set. + """ + # parse postgresql.auto.conf + path = os.path.join(self.data_dir, config) + + with open(path, 'r') as f: + raw_content = f.read() + + current_options = {} + current_directives = [] + for line in raw_content.splitlines(): + + # ignore comments + if line.startswith('#'): + continue + + if line == '': + continue + + if line.startswith('include'): + current_directives.append(line) + continue + + name, var = line.partition('=')[::2] + name = name.strip() + var = var.strip() + var = var.strip('"') + var = var.strip("'") + + # remove options specified in rm_options list + if name in rm_options: + continue + + current_options[name] = var + + for option in options: + current_options[option] = options[option] + + auto_conf = '' + for option in current_options: + auto_conf += "{0} = '{1}'\n".format( + option, current_options[option]) + + for directive in current_directives: + auto_conf += directive + "\n" + + with open(path, 'wt') as f: + f.write(auto_conf) + + +class NodeApp: + + def __init__(self, test_path, nodes_to_cleanup): + self.test_path = test_path + self.nodes_to_cleanup = nodes_to_cleanup + + def make_empty( + self, + base_dir=None): + real_base_dir = os.path.join(self.test_path, base_dir) + shutil.rmtree(real_base_dir, ignore_errors=True) + os.makedirs(real_base_dir) + + node = PostgresNode(base_dir=real_base_dir) + node.should_rm_dirs = True + self.nodes_to_cleanup.append(node) + + return node + + def make_simple( + self, + base_dir=None, + set_replication=False, + ptrack_enable=False, + initdb_params=[], + pg_options={}): + + node = self.make_empty(base_dir) + node.init( + initdb_params=initdb_params, allow_streaming=set_replication) + + # set major version + with open(os.path.join(node.data_dir, 'PG_VERSION')) as f: + node.major_version_str = str(f.read().rstrip()) + node.major_version = float(node.major_version_str) + + # Sane default parameters + options = {} + options['max_connections'] = 100 + options['shared_buffers'] = '10MB' + options['fsync'] = 'off' + + options['wal_level'] = 'logical' + options['hot_standby'] = 'off' + + options['log_line_prefix'] = '%t [%p]: [%l-1] ' + options['log_statement'] = 'none' + options['log_duration'] = 'on' + options['log_min_duration_statement'] = 0 + options['log_connections'] = 'on' + options['log_disconnections'] = 'on' + options['restart_after_crash'] = 'off' + options['autovacuum'] = 'off' + + # Allow replication in pg_hba.conf + if set_replication: + options['max_wal_senders'] = 10 + + if ptrack_enable: + options['ptrack.map_size'] = '1' + options['shared_preload_libraries'] = 'ptrack' + + if node.major_version >= 13: + options['wal_keep_size'] = '200MB' + else: + options['wal_keep_segments'] = '12' + + # set default values + node.set_auto_conf(options) + + # Apply given parameters + node.set_auto_conf(pg_options) + + # kludge for testgres + # https://github.com/postgrespro/testgres/issues/54 + # for PG >= 13 remove 'wal_keep_segments' parameter + if node.major_version >= 13: + node.set_auto_conf({}, 'postgresql.conf', ['wal_keep_segments']) + + return node diff --git a/tests/test_simple.py b/tests/test_simple.py index d79fa79a..94420b04 100755 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -171,8 +171,8 @@ def test_node_exit(self): def test_double_start(self): with get_new_node().init().start() as node: # can't start node more than once - with self.assertRaises(StartNodeException): - node.start() + node.start() + self.assertTrue(node.is_started) def test_uninitialized_start(self): with get_new_node() as node: