Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

Commit bdc264e

Browse files
author
vshepard
committed
Fix initdb error on Windows
1 parent 846c05f commit bdc264e

File tree

5 files changed

+170
-37
lines changed

5 files changed

+170
-37
lines changed

testgres/operations/local_ops.py

Lines changed: 84 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import psutil
99

1010
from ..exceptions import ExecUtilException
11-
from .os_ops import ConnectionParams, OsOperations
12-
from .os_ops import pglib
11+
from .os_ops import ConnectionParams, OsOperations, pglib, get_default_encoding
1312

1413
try:
1514
from shutil import which as find_executable
@@ -22,6 +21,12 @@
2221
error_markers = [b'error', b'Permission denied', b'fatal']
2322

2423

24+
def has_errors(output):
25+
if isinstance(output, str):
26+
output = output.encode(get_default_encoding())
27+
return any(marker in output for marker in error_markers)
28+
29+
2530
class LocalOperations(OsOperations):
2631
def __init__(self, conn_params=None):
2732
if conn_params is None:
@@ -33,7 +38,38 @@ def __init__(self, conn_params=None):
3338
self.remote = False
3439
self.username = conn_params.username or self.get_user()
3540

36-
# Command execution
41+
@staticmethod
42+
def _run_command(cmd, shell, input, timeout, encoding, temp_file=None):
43+
"""Execute a command and return the process."""
44+
if temp_file is not None:
45+
stdout = temp_file
46+
stderr = subprocess.STDOUT
47+
else:
48+
stdout = subprocess.PIPE
49+
stderr = subprocess.PIPE
50+
51+
process = subprocess.Popen(
52+
cmd,
53+
shell=shell,
54+
stdin=subprocess.PIPE if input is not None else None,
55+
stdout=stdout,
56+
stderr=stderr,
57+
)
58+
59+
try:
60+
return process.communicate(input=input.encode(encoding) if input else None, timeout=timeout), process
61+
except subprocess.TimeoutExpired:
62+
process.kill()
63+
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
64+
65+
@staticmethod
66+
def _raise_exec_exception(message, command, exit_code, output):
67+
"""Raise an ExecUtilException."""
68+
raise ExecUtilException(message=message.format(output),
69+
command=command,
70+
exit_code=exit_code,
71+
out=output)
72+
3773
def exec_command(self, cmd, wait_exit=False, verbose=False,
3874
expect_error=False, encoding=None, shell=False, text=False,
3975
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
@@ -56,16 +92,15 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
5692
:return: The output of the subprocess.
5793
"""
5894
if os.name == 'nt':
59-
with tempfile.NamedTemporaryFile() as buf:
60-
process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT)
61-
process.communicate()
62-
buf.seek(0)
63-
result = buf.read().decode(encoding)
64-
return result
95+
self._exec_command_windows(cmd, wait_exit=wait_exit, verbose=verbose,
96+
expect_error=expect_error, encoding=encoding, shell=shell, text=text,
97+
input=input, stdin=stdin, stdout=stdout, stderr=stderr,
98+
get_process=get_process, timeout=timeout)
6599
else:
66100
process = subprocess.Popen(
67101
cmd,
68102
shell=shell,
103+
stdin=stdin,
69104
stdout=stdout,
70105
stderr=stderr,
71106
)
@@ -79,7 +114,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
79114
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
80115
exit_status = process.returncode
81116

82-
error_found = exit_status != 0 or any(marker in error for marker in error_markers)
117+
error_found = exit_status != 0 or has_errors(error)
83118

84119
if encoding:
85120
result = result.decode(encoding)
@@ -91,15 +126,49 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
91126
if exit_status != 0 or error_found:
92127
if exit_status == 0:
93128
exit_status = 1
94-
raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error),
95-
command=cmd,
96-
exit_code=exit_status,
97-
out=result)
129+
self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, exit_status, result)
98130
if verbose:
99131
return exit_status, result, error
100132
else:
101133
return result
102134

135+
@staticmethod
136+
def _process_output(process, encoding, temp_file=None):
137+
"""Process the output of a command."""
138+
if temp_file is not None:
139+
temp_file.seek(0)
140+
output = temp_file.read()
141+
else:
142+
output = process.stdout.read()
143+
144+
if encoding:
145+
output = output.decode(encoding)
146+
147+
return output
148+
def _exec_command_windows(self, cmd, wait_exit=False, verbose=False,
149+
expect_error=False, encoding=None, shell=False, text=False,
150+
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
151+
get_process=None, timeout=None):
152+
with tempfile.NamedTemporaryFile(mode='w+b') as temp_file:
153+
_, process = self._run_command(cmd, shell, input, timeout, encoding, temp_file)
154+
if get_process:
155+
return process
156+
output = self._process_output(process, encoding, temp_file)
157+
158+
if process.returncode != 0 or has_errors(output):
159+
if process.returncode == 0:
160+
process.returncode = 1
161+
if expect_error:
162+
if verbose:
163+
return process.returncode, output, output
164+
else:
165+
return output
166+
else:
167+
self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, process.returncode,
168+
output)
169+
170+
return (process.returncode, output, output) if verbose else output
171+
103172
# Environment setup
104173
def environ(self, var_name):
105174
return os.environ.get(var_name)
@@ -210,7 +279,7 @@ def read(self, filename, encoding=None, binary=False):
210279
if binary:
211280
return content
212281
if isinstance(content, bytes):
213-
return content.decode(encoding or 'utf-8')
282+
return content.decode(encoding or get_default_encoding())
214283
return content
215284

216285
def readlines(self, filename, num_lines=0, binary=False, encoding=None):

testgres/operations/os_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import locale
2+
13
try:
24
import psycopg2 as pglib # noqa: F401
35
except ImportError:
@@ -14,6 +16,10 @@ def __init__(self, host='127.0.0.1', ssh_key=None, username=None):
1416
self.username = username
1517

1618

19+
def get_default_encoding():
20+
return locale.getdefaultlocale()[1] or 'UTF-8'
21+
22+
1723
class OsOperations:
1824
def __init__(self, username=None):
1925
self.ssh_key = None

testgres/operations/remote_ops.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import locale
21
import logging
32
import os
43
import subprocess
@@ -15,12 +14,7 @@
1514
raise ImportError("You must have psycopg2 or pg8000 modules installed")
1615

1716
from ..exceptions import ExecUtilException
18-
19-
from .os_ops import OsOperations, ConnectionParams
20-
21-
ConsoleEncoding = locale.getdefaultlocale()[1]
22-
if not ConsoleEncoding:
23-
ConsoleEncoding = 'UTF-8'
17+
from .os_ops import OsOperations, ConnectionParams, get_default_encoding
2418

2519
error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory']
2620

@@ -283,7 +277,9 @@ def copytree(self, src, dst):
283277
return self.exec_command("cp -r {} {}".format(src, dst))
284278

285279
# Work with files
286-
def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=ConsoleEncoding):
280+
def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=None):
281+
if not encoding:
282+
encoding = get_default_encoding()
287283
mode = "wb" if binary else "w"
288284
if not truncate:
289285
mode = "ab" if binary else "a"

testgres/utils.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33
from __future__ import division
44
from __future__ import print_function
55

6+
import locale
67
import os
7-
import port_for
8+
import random
9+
import socket
10+
811
import sys
912

1013
from contextlib import contextmanager
1114
from packaging.version import Version, InvalidVersion
1215
import re
1316

17+
from port_for import PortForException
1418
from six import iteritems
1519

1620
from .exceptions import ExecUtilException
@@ -37,13 +41,49 @@ def reserve_port():
3741
"""
3842
Generate a new port and add it to 'bound_ports'.
3943
"""
40-
41-
port = port_for.select_random(exclude_ports=bound_ports)
44+
port = select_random(exclude_ports=bound_ports)
4245
bound_ports.add(port)
4346

4447
return port
4548

4649

50+
def select_random(
51+
ports=None,
52+
exclude_ports=None,
53+
) -> int:
54+
"""
55+
Return random unused port number.
56+
Standard function from port_for does not work on Windows
57+
- an error 'port_for.exceptions.PortForException: Can't select a port'
58+
We should update it.
59+
"""
60+
if ports is None:
61+
ports = set(range(1024, 65535))
62+
63+
if exclude_ports is None:
64+
exclude_ports = set()
65+
66+
ports.difference_update(set(exclude_ports))
67+
68+
sampled_ports = random.sample(tuple(ports), min(len(ports), 100))
69+
70+
for port in sampled_ports:
71+
if is_port_free(port):
72+
return port
73+
74+
raise PortForException("Can't select a port")
75+
76+
77+
def is_port_free(port: int) -> bool:
78+
"""Check if a port is free to use."""
79+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
80+
try:
81+
s.bind(("", port))
82+
return True
83+
except OSError:
84+
return False
85+
86+
4787
def release_port(port):
4888
"""
4989
Free port provided by reserve_port().
@@ -80,7 +120,8 @@ def execute_utility(args, logfile=None, verbose=False):
80120
lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n']
81121
tconf.os_ops.write(filename=logfile, data=lines)
82122
except IOError:
83-
raise ExecUtilException("Problem with writing to logfile `{}` during run command `{}`".format(logfile, args))
123+
raise ExecUtilException(
124+
"Problem with writing to logfile `{}` during run command `{}`".format(logfile, args))
84125
if verbose:
85126
return exit_status, out, error
86127
else:

tests/test_simple.py

100755100644
Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,24 @@ def good_properties(f):
7474
return True
7575

7676

77+
def rm_carriage_returns(out):
78+
"""
79+
In Windows we have additional '\r' symbols in output.
80+
Let's get rid of them.
81+
"""
82+
if os.name == 'nt':
83+
if isinstance(out, (int, float, complex)):
84+
return out
85+
elif isinstance(out, tuple):
86+
return tuple(rm_carriage_returns(item) for item in out)
87+
elif isinstance(out, bytes):
88+
return out.replace(b'\r', b'')
89+
else:
90+
return out.replace('\r', '')
91+
else:
92+
return out
93+
94+
7795
@contextmanager
7896
def removing(f):
7997
try:
@@ -254,34 +272,34 @@ def test_psql(self):
254272

255273
# check returned values (1 arg)
256274
res = node.psql('select 1')
257-
self.assertEqual(res, (0, b'1\n', b''))
275+
self.assertEqual(rm_carriage_returns(res), (0, b'1\n', b''))
258276

259277
# check returned values (2 args)
260278
res = node.psql('postgres', 'select 2')
261-
self.assertEqual(res, (0, b'2\n', b''))
279+
self.assertEqual(rm_carriage_returns(res), (0, b'2\n', b''))
262280

263281
# check returned values (named)
264282
res = node.psql(query='select 3', dbname='postgres')
265-
self.assertEqual(res, (0, b'3\n', b''))
283+
self.assertEqual(rm_carriage_returns(res), (0, b'3\n', b''))
266284

267285
# check returned values (1 arg)
268286
res = node.safe_psql('select 4')
269-
self.assertEqual(res, b'4\n')
287+
self.assertEqual(rm_carriage_returns(res), b'4\n')
270288

271289
# check returned values (2 args)
272290
res = node.safe_psql('postgres', 'select 5')
273-
self.assertEqual(res, b'5\n')
291+
self.assertEqual(rm_carriage_returns(res), b'5\n')
274292

275293
# check returned values (named)
276294
res = node.safe_psql(query='select 6', dbname='postgres')
277-
self.assertEqual(res, b'6\n')
295+
self.assertEqual(rm_carriage_returns(res), b'6\n')
278296

279297
# check feeding input
280298
node.safe_psql('create table horns (w int)')
281299
node.safe_psql('copy horns from stdin (format csv)',
282300
input=b"1\n2\n3\n\\.\n")
283301
_sum = node.safe_psql('select sum(w) from horns')
284-
self.assertEqual(_sum, b'6\n')
302+
self.assertEqual(rm_carriage_returns(_sum), b'6\n')
285303

286304
# check psql's default args, fails
287305
with self.assertRaises(QueryException):
@@ -455,7 +473,7 @@ def test_synchronous_replication(self):
455473
master.safe_psql(
456474
'insert into abc select generate_series(1, 1000000)')
457475
res = standby1.safe_psql('select count(*) from abc')
458-
self.assertEqual(res, b'1000000\n')
476+
self.assertEqual(rm_carriage_returns(res), b'1000000\n')
459477

460478
@unittest.skipUnless(pg_version_ge('10'), 'requires 10+')
461479
def test_logical_replication(self):
@@ -589,7 +607,7 @@ def test_promotion(self):
589607
# make standby becomes writable master
590608
replica.safe_psql('insert into abc values (1)')
591609
res = replica.safe_psql('select * from abc')
592-
self.assertEqual(res, b'1\n')
610+
self.assertEqual(rm_carriage_returns(res), b'1\n')
593611

594612
def test_dump(self):
595613
query_create = 'create table test as select generate_series(1, 2) as val'
@@ -614,6 +632,7 @@ def test_users(self):
614632
with get_new_node().init().start() as node:
615633
node.psql('create role test_user login')
616634
value = node.safe_psql('select 1', username='test_user')
635+
value = rm_carriage_returns(value)
617636
self.assertEqual(value, b'1\n')
618637

619638
def test_poll_query_until(self):
@@ -977,7 +996,9 @@ def test_child_pids(self):
977996

978997
def test_child_process_dies(self):
979998
# test for FileNotFound exception during child_processes() function
980-
with subprocess.Popen(["sleep", "60"]) as process:
999+
cmd = ["timeout", "60"] if os.name == 'nt' else ["sleep", "60"]
1000+
1001+
with subprocess.Popen(cmd, shell=True) as process: # shell=True might be needed on Windows
9811002
self.assertEqual(process.poll(), None)
9821003
# collect list of processes currently running
9831004
children = psutil.Process(os.getpid()).children()

0 commit comments

Comments
 (0)