Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

Commit 1c405ef

Browse files
author
v.shepard
committed
PBCKP-152 add tests for remote_ops.py
1 parent e098b97 commit 1c405ef

File tree

3 files changed

+313
-146
lines changed

3 files changed

+313
-146
lines changed

testgres/node.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,10 @@ def __init__(self, name=None, port=None, base_dir=None,
152152
self.host = host
153153
self.hostname = hostname
154154
self.ssh_key = ssh_key
155-
if hostname == 'localhost' or host == '127.0.0.1':
156-
self.os_ops = LocalOperations(username=username)
157-
else:
155+
if hostname != 'localhost' or host != '127.0.0.1':
158156
self.os_ops = RemoteOperations(host, hostname, ssh_key)
157+
else:
158+
self.os_ops = LocalOperations(username=username)
159159

160160
testgres_config.os_ops = self.os_ops
161161
# defaults for __exit__()

testgres/operations/remote_ops.py

Lines changed: 151 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,98 @@
1-
import io
21
import os
32
import tempfile
4-
from contextlib import contextmanager
3+
from typing import Optional
54

6-
from testgres.logger import log
5+
import paramiko
6+
from paramiko import SSHClient
77

8+
from logger import log
89
from .os_ops import OsOperations
910
from .os_ops import pglib
1011

11-
import paramiko
12+
error_markers = [b'error', b'Permission denied']
1213

1314

1415
class RemoteOperations(OsOperations):
15-
"""
16-
This class specifically supports work with Linux systems. It utilizes the SSH
17-
for making connections and performing various file and directory operations, command executions,
18-
environment setup and management, process control, and database connections.
19-
It uses the Paramiko library for SSH connections and operations.
20-
21-
Some methods are designed to work with specific Linux shell commands, and thus may not work as expected
22-
on other non-Linux systems.
23-
24-
Attributes:
25-
- hostname (str): The remote system's hostname. Default 'localhost'.
26-
- host (str): The remote system's IP address. Default '127.0.0.1'.
27-
- ssh_key (str): Path to the SSH private key for authentication.
28-
- username (str): Username for the remote system.
29-
- ssh (paramiko.SSHClient): SSH connection to the remote system.
30-
"""
31-
32-
def __init__(
33-
self, hostname="localhost", host="127.0.0.1", ssh_key=None, username=None
34-
):
16+
def __init__(self, hostname="localhost", host="127.0.0.1", ssh_key=None, username=None):
3517
super().__init__(username)
36-
self.hostname = hostname
3718
self.host = host
3819
self.ssh_key = ssh_key
3920
self.remote = True
40-
self.ssh = self.connect()
21+
self.ssh = self.ssh_connect()
4122
self.username = username or self.get_user()
4223

4324
def __del__(self):
4425
if self.ssh:
4526
self.ssh.close()
4627

47-
@contextmanager
48-
def ssh_connect(self):
28+
def ssh_connect(self) -> Optional[SSHClient]:
4929
if not self.remote:
50-
yield None
30+
return None
5131
else:
32+
key = self._read_ssh_key()
33+
ssh = paramiko.SSHClient()
34+
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
35+
ssh.connect(self.host, username=self.username, pkey=key)
36+
return ssh
37+
38+
def _read_ssh_key(self):
39+
try:
5240
with open(self.ssh_key, "r") as f:
5341
key_data = f.read()
5442
if "BEGIN OPENSSH PRIVATE KEY" in key_data:
5543
key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key)
5644
else:
5745
key = paramiko.RSAKey.from_private_key_file(self.ssh_key)
46+
return key
47+
except FileNotFoundError:
48+
log.error(f"No such file or directory: '{self.ssh_key}'")
49+
except Exception as e:
50+
log.error(f"An error occurred while reading the ssh key: {e}")
5851

59-
with paramiko.SSHClient() as ssh:
60-
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
61-
ssh.connect(self.host, username=self.username, pkey=key)
62-
yield ssh
63-
64-
def connect(self):
65-
with self.ssh_connect() as ssh:
66-
return ssh
67-
68-
# Command execution
69-
def exec_command(self, cmd, wait_exit=False, verbose=False,
70-
expect_error=False, encoding=None, shell=True, text=False,
71-
input=None, stdout=None, stderr=None, proc=None):
52+
def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False,
53+
encoding=None, shell=True, text=False, input=None, stdout=None,
54+
stderr=None, proc=None):
55+
"""
56+
Execute a command in the SSH session.
57+
Args:
58+
- cmd (str): The command to be executed.
59+
"""
7260
if isinstance(cmd, list):
7361
cmd = " ".join(cmd)
74-
log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}")
75-
# Source global profile file + execute command
7662
try:
77-
cmd = f"source /etc/profile.d/custom.sh; {cmd}"
78-
with self.ssh_connect() as ssh:
79-
if input:
80-
# encode input and feed it to stdin
81-
stdin, stdout, stderr = ssh.exec_command(cmd)
82-
stdin.write(input)
83-
stdin.flush()
84-
else:
85-
stdin, stdout, stderr = ssh.exec_command(cmd)
86-
exit_status = 0
87-
if wait_exit:
88-
exit_status = stdout.channel.recv_exit_status()
89-
if encoding:
90-
result = stdout.read().decode(encoding)
91-
error = stderr.read().decode(encoding)
92-
else:
93-
# Save as binary string
94-
result = io.BytesIO(stdout.read()).getvalue()
95-
error = io.BytesIO(stderr.read()).getvalue()
96-
error_str = stderr.read()
63+
if input:
64+
stdin, stdout, stderr = self.ssh.exec_command(cmd)
65+
stdin.write(input.encode("utf-8"))
66+
stdin.flush()
67+
else:
68+
stdin, stdout, stderr = self.ssh.exec_command(cmd)
69+
exit_status = 0
70+
if wait_exit:
71+
exit_status = stdout.channel.recv_exit_status()
72+
73+
if encoding:
74+
result = stdout.read().decode(encoding)
75+
error = stderr.read().decode(encoding)
76+
else:
77+
result = stdout.read()
78+
error = stderr.read()
9779

9880
if expect_error:
9981
raise Exception(result, error)
100-
if exit_status != 0 or 'error' in error_str:
82+
83+
if encoding:
84+
error_found = exit_status != 0 or any(
85+
marker.decode(encoding) in error for marker in error_markers)
86+
else:
87+
error_found = exit_status != 0 or any(
88+
marker in error for marker in error_markers)
89+
90+
if error_found:
10191
log.error(
10292
f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}"
10393
)
94+
if exit_status == 0:
95+
exit_status = 1
10496

10597
if verbose:
10698
return exit_status, result, error
@@ -112,7 +104,12 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
112104
return None
113105

114106
# Environment setup
115-
def environ(self, var_name):
107+
def environ(self, var_name: str) -> str:
108+
"""
109+
Get the value of an environment variable.
110+
Args:
111+
- var_name (str): The name of the environment variable.
112+
"""
116113
cmd = f"echo ${var_name}"
117114
return self.exec_command(cmd).strip()
118115

@@ -131,7 +128,8 @@ def find_executable(self, executable):
131128

132129
def is_executable(self, file):
133130
# Check if the file is executable
134-
return self.exec_command(f"test -x {file} && echo OK") == "OK\n"
131+
is_exec = self.exec_command(f"test -x {file} && echo OK")
132+
return is_exec == b"OK\n"
135133

136134
def add_to_path(self, new_path):
137135
pathsep = self.pathsep
@@ -144,8 +142,13 @@ def add_to_path(self, new_path):
144142
os.environ["PATH"] = f"{new_path}{pathsep}{path}"
145143
return pathsep
146144

147-
def set_env(self, var_name, var_val):
148-
# Check if the directory is already in PATH
145+
def set_env(self, var_name: str, var_val: str) -> None:
146+
"""
147+
Set the value of an environment variable.
148+
Args:
149+
- var_name (str): The name of the environment variable.
150+
- var_val (str): The value to be set for the environment variable.
151+
"""
149152
return self.exec_command(f"export {var_name}={var_val}")
150153

151154
# Get environment variables
@@ -158,22 +161,47 @@ def get_name(self):
158161

159162
# Work with dirs
160163
def makedirs(self, path, remove_existing=False):
164+
"""
165+
Create a directory in the remote server.
166+
Args:
167+
- path (str): The path to the directory to be created.
168+
- remove_existing (bool): If True, the existing directory at the path will be removed.
169+
"""
161170
if remove_existing:
162171
cmd = f"rm -rf {path} && mkdir -p {path}"
163172
else:
164173
cmd = f"mkdir -p {path}"
165-
return self.exec_command(cmd)
174+
exit_status, result, error = self.exec_command(cmd, verbose=True)
175+
if exit_status != 0:
176+
raise Exception(f"Couldn't create dir {path} because of error {error}")
177+
return result
166178

167-
def rmdirs(self, path, ignore_errors=True):
179+
def rmdirs(self, path, verbose=False, ignore_errors=True):
180+
"""
181+
Remove a directory in the remote server.
182+
Args:
183+
- path (str): The path to the directory to be removed.
184+
- verbose (bool): If True, return exit status, result, and error.
185+
- ignore_errors (bool): If True, do not raise error if directory does not exist.
186+
"""
168187
cmd = f"rm -rf {path}"
169-
return self.exec_command(cmd)
188+
exit_status, result, error = self.exec_command(cmd, verbose=True)
189+
if verbose:
190+
return exit_status, result, error
191+
else:
192+
return result
170193

171194
def listdir(self, path):
195+
"""
196+
List all files and directories in a directory.
197+
Args:
198+
path (str): The path to the directory.
199+
"""
172200
result = self.exec_command(f"ls {path}")
173201
return result.splitlines()
174202

175203
def path_exists(self, path):
176-
result = self.exec_command(f"test -e {path}; echo $?")
204+
result = self.exec_command(f"test -e {path}; echo $?", encoding='utf-8')
177205
return int(result.strip()) == 0
178206

179207
@property
@@ -188,7 +216,12 @@ def pathsep(self):
188216
return pathsep
189217

190218
def mkdtemp(self, prefix=None):
191-
temp_dir = self.exec_command(f"mkdtemp -d {prefix}")
219+
"""
220+
Creates a temporary directory in the remote server.
221+
Args:
222+
prefix (str): The prefix of the temporary directory name.
223+
"""
224+
temp_dir = self.exec_command(f"mkdtemp -d {prefix}", encoding='utf-8')
192225
return temp_dir.strip()
193226

194227
def mkstemp(self, prefix=None):
@@ -200,18 +233,19 @@ def copytree(self, src, dst):
200233
return self.exec_command(f"cp -r {src} {dst}")
201234

202235
# Work with files
203-
def write(self, filename, data, truncate=False, binary=False, read_and_write=False):
236+
def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding='utf-8'):
204237
"""
205238
Write data to a file on a remote host
239+
206240
Args:
207-
filename: The file path where the data will be written.
208-
data: The data to be written to the file.
209-
truncate: If True, the file will be truncated before writing ('w' or 'wb' option);
210-
if False (default), data will be appended ('a' or 'ab' option).
211-
binary: If True, the data will be written in binary mode ('wb' or 'ab' option);
212-
if False (default), the data will be written in text mode ('w' or 'a' option).
213-
read_and_write: If True, the file will be opened with read and write permissions ('r+' option);
214-
if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option)
241+
- filename (str): The file path where the data will be written.
242+
- data (bytes or str): The data to be written to the file.
243+
- truncate (bool): If True, the file will be truncated before writing ('w' or 'wb' option);
244+
if False (default), data will be appended ('a' or 'ab' option).
245+
- binary (bool): If True, the data will be written in binary mode ('wb' or 'ab' option);
246+
if False (default), the data will be written in text mode ('w' or 'a' option).
247+
- read_and_write (bool): If True, the file will be opened with read and write permissions ('r+' option);
248+
if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option).
215249
"""
216250
mode = "wb" if binary else "w"
217251
if not truncate:
@@ -220,15 +254,18 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
220254
mode = "r+b" if binary else "r+"
221255

222256
with tempfile.NamedTemporaryFile(mode=mode) as tmp_file:
223-
if isinstance(data, list):
224-
tmp_file.writelines(data)
225-
else:
226-
tmp_file.write(data)
257+
if isinstance(data, bytes) and not binary:
258+
data = data.decode(encoding)
259+
elif isinstance(data, str) and binary:
260+
data = data.encode(encoding)
261+
262+
tmp_file.write(data)
227263
tmp_file.flush()
228264

229-
sftp = self.ssh.open_sftp()
230-
sftp.put(tmp_file.name, filename)
231-
sftp.close()
265+
with self.ssh_connect() as ssh:
266+
sftp = ssh.open_sftp()
267+
sftp.put(tmp_file.name, filename)
268+
sftp.close()
232269

233270
def touch(self, filename):
234271
"""
@@ -281,8 +318,29 @@ def get_pid(self):
281318
return self.exec_command("echo $$")
282319

283320
# Database control
284-
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
285-
local_port = self.ssh.forward_remote_port(host, port)
321+
def db_connect(self, dbname, user, password=None, host="127.0.0.1", hostname="localhost", port=5432):
322+
"""
323+
Connects to a PostgreSQL database on the remote system.
324+
Args:
325+
- dbname (str): The name of the database to connect to.
326+
- user (str): The username for the database connection.
327+
- password (str, optional): The password for the database connection. Defaults to None.
328+
- host (str, optional): The IP address of the remote system. Defaults to "127.0.0.1".
329+
- hostname (str, optional): The hostname of the remote system. Defaults to "localhost".
330+
- port (int, optional): The port number of the PostgreSQL service. Defaults to 5432.
331+
332+
This function establishes a connection to a PostgreSQL database on the remote system using the specified
333+
parameters. It returns a connection object that can be used to interact with the database.
334+
"""
335+
transport = self.ssh.get_transport()
336+
local_port = 9090 # or any other available port
337+
338+
transport.open_channel(
339+
'direct-tcpip',
340+
(hostname, port),
341+
(host, local_port)
342+
)
343+
286344
conn = pglib.connect(
287345
host=host,
288346
port=local_port,
@@ -291,3 +349,4 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
291349
password=password,
292350
)
293351
return conn
352+

0 commit comments

Comments
 (0)