1
- import io
2
1
import os
3
2
import tempfile
4
- from contextlib import contextmanager
3
+ from typing import Optional
5
4
6
- from testgres .logger import log
5
+ import paramiko
6
+ from paramiko import SSHClient
7
7
8
+ from logger import log
8
9
from .os_ops import OsOperations
9
10
from .os_ops import pglib
10
11
11
- import paramiko
12
+ error_markers = [ b'error' , b'Permission denied' ]
12
13
13
14
14
15
class RemoteOperations (OsOperations ):
15
- """
16
- This class specifically supports work with Linux systems. It utilizes the SSH
17
- for making connections and performing various file and directory operations, command executions,
18
- environment setup and management, process control, and database connections.
19
- It uses the Paramiko library for SSH connections and operations.
20
-
21
- Some methods are designed to work with specific Linux shell commands, and thus may not work as expected
22
- on other non-Linux systems.
23
-
24
- Attributes:
25
- - hostname (str): The remote system's hostname. Default 'localhost'.
26
- - host (str): The remote system's IP address. Default '127.0.0.1'.
27
- - ssh_key (str): Path to the SSH private key for authentication.
28
- - username (str): Username for the remote system.
29
- - ssh (paramiko.SSHClient): SSH connection to the remote system.
30
- """
31
-
32
- def __init__ (
33
- self , hostname = "localhost" , host = "127.0.0.1" , ssh_key = None , username = None
34
- ):
16
+ def __init__ (self , hostname = "localhost" , host = "127.0.0.1" , ssh_key = None , username = None ):
35
17
super ().__init__ (username )
36
- self .hostname = hostname
37
18
self .host = host
38
19
self .ssh_key = ssh_key
39
20
self .remote = True
40
- self .ssh = self .connect ()
21
+ self .ssh = self .ssh_connect ()
41
22
self .username = username or self .get_user ()
42
23
43
24
def __del__ (self ):
44
25
if self .ssh :
45
26
self .ssh .close ()
46
27
47
- @contextmanager
48
- def ssh_connect (self ):
28
+ def ssh_connect (self ) -> Optional [SSHClient ]:
49
29
if not self .remote :
50
- yield None
30
+ return None
51
31
else :
32
+ key = self ._read_ssh_key ()
33
+ ssh = paramiko .SSHClient ()
34
+ ssh .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
35
+ ssh .connect (self .host , username = self .username , pkey = key )
36
+ return ssh
37
+
38
+ def _read_ssh_key (self ):
39
+ try :
52
40
with open (self .ssh_key , "r" ) as f :
53
41
key_data = f .read ()
54
42
if "BEGIN OPENSSH PRIVATE KEY" in key_data :
55
43
key = paramiko .Ed25519Key .from_private_key_file (self .ssh_key )
56
44
else :
57
45
key = paramiko .RSAKey .from_private_key_file (self .ssh_key )
46
+ return key
47
+ except FileNotFoundError :
48
+ log .error (f"No such file or directory: '{ self .ssh_key } '" )
49
+ except Exception as e :
50
+ log .error (f"An error occurred while reading the ssh key: { e } " )
58
51
59
- with paramiko .SSHClient () as ssh :
60
- ssh .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
61
- ssh .connect (self .host , username = self .username , pkey = key )
62
- yield ssh
63
-
64
- def connect (self ):
65
- with self .ssh_connect () as ssh :
66
- return ssh
67
-
68
- # Command execution
69
- def exec_command (self , cmd , wait_exit = False , verbose = False ,
70
- expect_error = False , encoding = None , shell = True , text = False ,
71
- input = None , stdout = None , stderr = None , proc = None ):
52
+ def exec_command (self , cmd : str , wait_exit = False , verbose = False , expect_error = False ,
53
+ encoding = None , shell = True , text = False , input = None , stdout = None ,
54
+ stderr = None , proc = None ):
55
+ """
56
+ Execute a command in the SSH session.
57
+ Args:
58
+ - cmd (str): The command to be executed.
59
+ """
72
60
if isinstance (cmd , list ):
73
61
cmd = " " .join (cmd )
74
- log .debug (f"os_ops.exec_command: `{ cmd } `; remote={ self .remote } " )
75
- # Source global profile file + execute command
76
62
try :
77
- cmd = f"source /etc/profile.d/custom.sh; { cmd } "
78
- with self .ssh_connect () as ssh :
79
- if input :
80
- # encode input and feed it to stdin
81
- stdin , stdout , stderr = ssh .exec_command (cmd )
82
- stdin .write (input )
83
- stdin .flush ()
84
- else :
85
- stdin , stdout , stderr = ssh .exec_command (cmd )
86
- exit_status = 0
87
- if wait_exit :
88
- exit_status = stdout .channel .recv_exit_status ()
89
- if encoding :
90
- result = stdout .read ().decode (encoding )
91
- error = stderr .read ().decode (encoding )
92
- else :
93
- # Save as binary string
94
- result = io .BytesIO (stdout .read ()).getvalue ()
95
- error = io .BytesIO (stderr .read ()).getvalue ()
96
- error_str = stderr .read ()
63
+ if input :
64
+ stdin , stdout , stderr = self .ssh .exec_command (cmd )
65
+ stdin .write (input .encode ("utf-8" ))
66
+ stdin .flush ()
67
+ else :
68
+ stdin , stdout , stderr = self .ssh .exec_command (cmd )
69
+ exit_status = 0
70
+ if wait_exit :
71
+ exit_status = stdout .channel .recv_exit_status ()
72
+
73
+ if encoding :
74
+ result = stdout .read ().decode (encoding )
75
+ error = stderr .read ().decode (encoding )
76
+ else :
77
+ result = stdout .read ()
78
+ error = stderr .read ()
97
79
98
80
if expect_error :
99
81
raise Exception (result , error )
100
- if exit_status != 0 or 'error' in error_str :
82
+
83
+ if encoding :
84
+ error_found = exit_status != 0 or any (
85
+ marker .decode (encoding ) in error for marker in error_markers )
86
+ else :
87
+ error_found = exit_status != 0 or any (
88
+ marker in error for marker in error_markers )
89
+
90
+ if error_found :
101
91
log .error (
102
92
f"Problem in executing command: `{ cmd } `\n error: { error } \n exit_code: { exit_status } "
103
93
)
94
+ if exit_status == 0 :
95
+ exit_status = 1
104
96
105
97
if verbose :
106
98
return exit_status , result , error
@@ -112,7 +104,12 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
112
104
return None
113
105
114
106
# Environment setup
115
- def environ (self , var_name ):
107
+ def environ (self , var_name : str ) -> str :
108
+ """
109
+ Get the value of an environment variable.
110
+ Args:
111
+ - var_name (str): The name of the environment variable.
112
+ """
116
113
cmd = f"echo ${ var_name } "
117
114
return self .exec_command (cmd ).strip ()
118
115
@@ -131,7 +128,8 @@ def find_executable(self, executable):
131
128
132
129
def is_executable (self , file ):
133
130
# Check if the file is executable
134
- return self .exec_command (f"test -x { file } && echo OK" ) == "OK\n "
131
+ is_exec = self .exec_command (f"test -x { file } && echo OK" )
132
+ return is_exec == b"OK\n "
135
133
136
134
def add_to_path (self , new_path ):
137
135
pathsep = self .pathsep
@@ -144,8 +142,13 @@ def add_to_path(self, new_path):
144
142
os .environ ["PATH" ] = f"{ new_path } { pathsep } { path } "
145
143
return pathsep
146
144
147
- def set_env (self , var_name , var_val ):
148
- # Check if the directory is already in PATH
145
+ def set_env (self , var_name : str , var_val : str ) -> None :
146
+ """
147
+ Set the value of an environment variable.
148
+ Args:
149
+ - var_name (str): The name of the environment variable.
150
+ - var_val (str): The value to be set for the environment variable.
151
+ """
149
152
return self .exec_command (f"export { var_name } ={ var_val } " )
150
153
151
154
# Get environment variables
@@ -158,22 +161,47 @@ def get_name(self):
158
161
159
162
# Work with dirs
160
163
def makedirs (self , path , remove_existing = False ):
164
+ """
165
+ Create a directory in the remote server.
166
+ Args:
167
+ - path (str): The path to the directory to be created.
168
+ - remove_existing (bool): If True, the existing directory at the path will be removed.
169
+ """
161
170
if remove_existing :
162
171
cmd = f"rm -rf { path } && mkdir -p { path } "
163
172
else :
164
173
cmd = f"mkdir -p { path } "
165
- return self .exec_command (cmd )
174
+ exit_status , result , error = self .exec_command (cmd , verbose = True )
175
+ if exit_status != 0 :
176
+ raise Exception (f"Couldn't create dir { path } because of error { error } " )
177
+ return result
166
178
167
- def rmdirs (self , path , ignore_errors = True ):
179
+ def rmdirs (self , path , verbose = False , ignore_errors = True ):
180
+ """
181
+ Remove a directory in the remote server.
182
+ Args:
183
+ - path (str): The path to the directory to be removed.
184
+ - verbose (bool): If True, return exit status, result, and error.
185
+ - ignore_errors (bool): If True, do not raise error if directory does not exist.
186
+ """
168
187
cmd = f"rm -rf { path } "
169
- return self .exec_command (cmd )
188
+ exit_status , result , error = self .exec_command (cmd , verbose = True )
189
+ if verbose :
190
+ return exit_status , result , error
191
+ else :
192
+ return result
170
193
171
194
def listdir (self , path ):
195
+ """
196
+ List all files and directories in a directory.
197
+ Args:
198
+ path (str): The path to the directory.
199
+ """
172
200
result = self .exec_command (f"ls { path } " )
173
201
return result .splitlines ()
174
202
175
203
def path_exists (self , path ):
176
- result = self .exec_command (f"test -e { path } ; echo $?" )
204
+ result = self .exec_command (f"test -e { path } ; echo $?" , encoding = 'utf-8' )
177
205
return int (result .strip ()) == 0
178
206
179
207
@property
@@ -188,7 +216,12 @@ def pathsep(self):
188
216
return pathsep
189
217
190
218
def mkdtemp (self , prefix = None ):
191
- temp_dir = self .exec_command (f"mkdtemp -d { prefix } " )
219
+ """
220
+ Creates a temporary directory in the remote server.
221
+ Args:
222
+ prefix (str): The prefix of the temporary directory name.
223
+ """
224
+ temp_dir = self .exec_command (f"mkdtemp -d { prefix } " , encoding = 'utf-8' )
192
225
return temp_dir .strip ()
193
226
194
227
def mkstemp (self , prefix = None ):
@@ -200,18 +233,19 @@ def copytree(self, src, dst):
200
233
return self .exec_command (f"cp -r { src } { dst } " )
201
234
202
235
# Work with files
203
- def write (self , filename , data , truncate = False , binary = False , read_and_write = False ):
236
+ def write (self , filename , data , truncate = False , binary = False , read_and_write = False , encoding = 'utf-8' ):
204
237
"""
205
238
Write data to a file on a remote host
239
+
206
240
Args:
207
- filename: The file path where the data will be written.
208
- data : The data to be written to the file.
209
- truncate: If True, the file will be truncated before writing ('w' or 'wb' option);
210
- if False (default), data will be appended ('a' or 'ab' option).
211
- binary: If True, the data will be written in binary mode ('wb' or 'ab' option);
212
- if False (default), the data will be written in text mode ('w' or 'a' option).
213
- read_and_write: If True, the file will be opened with read and write permissions ('r+' option);
214
- if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option)
241
+ - filename (str) : The file path where the data will be written.
242
+ - data (bytes or str) : The data to be written to the file.
243
+ - truncate (bool) : If True, the file will be truncated before writing ('w' or 'wb' option);
244
+ if False (default), data will be appended ('a' or 'ab' option).
245
+ - binary (bool) : If True, the data will be written in binary mode ('wb' or 'ab' option);
246
+ if False (default), the data will be written in text mode ('w' or 'a' option).
247
+ - read_and_write (bool) : If True, the file will be opened with read and write permissions ('r+' option);
248
+ if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option).
215
249
"""
216
250
mode = "wb" if binary else "w"
217
251
if not truncate :
@@ -220,15 +254,18 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
220
254
mode = "r+b" if binary else "r+"
221
255
222
256
with tempfile .NamedTemporaryFile (mode = mode ) as tmp_file :
223
- if isinstance (data , list ):
224
- tmp_file .writelines (data )
225
- else :
226
- tmp_file .write (data )
257
+ if isinstance (data , bytes ) and not binary :
258
+ data = data .decode (encoding )
259
+ elif isinstance (data , str ) and binary :
260
+ data = data .encode (encoding )
261
+
262
+ tmp_file .write (data )
227
263
tmp_file .flush ()
228
264
229
- sftp = self .ssh .open_sftp ()
230
- sftp .put (tmp_file .name , filename )
231
- sftp .close ()
265
+ with self .ssh_connect () as ssh :
266
+ sftp = ssh .open_sftp ()
267
+ sftp .put (tmp_file .name , filename )
268
+ sftp .close ()
232
269
233
270
def touch (self , filename ):
234
271
"""
@@ -281,8 +318,29 @@ def get_pid(self):
281
318
return self .exec_command ("echo $$" )
282
319
283
320
# Database control
284
- def db_connect (self , dbname , user , password = None , host = "localhost" , port = 5432 ):
285
- local_port = self .ssh .forward_remote_port (host , port )
321
+ def db_connect (self , dbname , user , password = None , host = "127.0.0.1" , hostname = "localhost" , port = 5432 ):
322
+ """
323
+ Connects to a PostgreSQL database on the remote system.
324
+ Args:
325
+ - dbname (str): The name of the database to connect to.
326
+ - user (str): The username for the database connection.
327
+ - password (str, optional): The password for the database connection. Defaults to None.
328
+ - host (str, optional): The IP address of the remote system. Defaults to "127.0.0.1".
329
+ - hostname (str, optional): The hostname of the remote system. Defaults to "localhost".
330
+ - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432.
331
+
332
+ This function establishes a connection to a PostgreSQL database on the remote system using the specified
333
+ parameters. It returns a connection object that can be used to interact with the database.
334
+ """
335
+ transport = self .ssh .get_transport ()
336
+ local_port = 9090 # or any other available port
337
+
338
+ transport .open_channel (
339
+ 'direct-tcpip' ,
340
+ (hostname , port ),
341
+ (host , local_port )
342
+ )
343
+
286
344
conn = pglib .connect (
287
345
host = host ,
288
346
port = local_port ,
@@ -291,3 +349,4 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
291
349
password = password ,
292
350
)
293
351
return conn
352
+
0 commit comments