Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

Commit 3453c52

Browse files
committed
Add support for asynchronous working
1 parent 1444759 commit 3453c52

File tree

2 files changed

+335
-293
lines changed

2 files changed

+335
-293
lines changed

testgres/testgres.py

Lines changed: 100 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,10 @@
4545
from enum import Enum
4646
from distutils.version import LooseVersion
4747

48-
# Try to use psycopg2 by default. If psycopg2 isn't available then use
49-
# pg8000 which is slower but much more portable because uses only
50-
# pure-Python code
5148
try:
52-
import psycopg2 as pglib
49+
import asyncpg as pglib
5350
except ImportError:
54-
try:
55-
import pg8000 as pglib
56-
except ImportError:
57-
raise ImportError("You must have psycopg2 or pg8000 modules installed")
51+
raise ImportError("You must have asyncpg module installed")
5852

5953
# ports used by nodes
6054
bound_ports = set()
@@ -193,26 +187,34 @@ def __init__(self,
193187
password=None):
194188

195189
# Use default user if not specified
196-
username = username or default_username()
197-
190+
self.username = username or default_username()
191+
self.dbname = dbname
192+
self.host = host
193+
self.password = password
198194
self.parent_node = parent_node
195+
self.connection = None
196+
self.current_transaction = None
199197

200-
self.connection = pglib.connect(
201-
database=dbname,
202-
user=username,
203-
port=parent_node.port,
204-
host=host,
205-
password=password)
198+
async def init_connection(self):
199+
if self.connection:
200+
return
206201

207-
self.cursor = self.connection.cursor()
202+
self.connection = await pglib.connect(
203+
database=self.dbname,
204+
user=self.username,
205+
port=self.parent_node.port,
206+
host=self.host,
207+
password=self.password)
208208

209-
def __enter__(self):
209+
async def __aenter__(self):
210210
return self
211211

212-
def __exit__(self, type, value, traceback):
213-
self.close()
212+
async def __aexit__(self, type, value, traceback):
213+
await self.close()
214+
215+
async def begin(self, isolation_level=IsolationLevel.ReadCommitted):
216+
await self.init_connection()
214217

215-
def begin(self, isolation_level=IsolationLevel.ReadCommitted):
216218
# yapf: disable
217219
levels = [
218220
'read uncommitted',
@@ -245,37 +247,45 @@ def begin(self, isolation_level=IsolationLevel.ReadCommitted):
245247

246248
# Set isolation level
247249
cmd = 'SET TRANSACTION ISOLATION LEVEL {}'
248-
self.cursor.execute(cmd.format(isolation_level))
250+
self.current_transaction = self.connection.transaction()
251+
await self.current_transaction.start()
252+
await self.connection.execute(cmd.format(isolation_level))
249253

250254
return self
251255

252-
def commit(self):
253-
self.connection.commit()
256+
async def commit(self):
257+
if not self.current_transaction:
258+
raise QueryException("transaction is not started")
254259

255-
return self
256-
257-
def rollback(self):
258-
self.connection.rollback()
260+
await self.current_transaction.commit()
261+
self.current_transaction = None
259262

260-
return self
263+
async def rollback(self):
264+
if not self.current_transaction:
265+
raise QueryException("transaction is not started")
261266

262-
def execute(self, query, *args):
263-
self.cursor.execute(query, args)
267+
await self.current_transaction.rollback()
268+
self.current_transaction = None
264269

265-
try:
266-
res = self.cursor.fetchall()
267-
268-
# pg8000 might return tuples
269-
if isinstance(res, tuple):
270-
res = [tuple(t) for t in res]
270+
async def execute(self, query, *args):
271+
await self.init_connection()
272+
if self.current_transaction:
273+
return await self.connection.execute(query, *args)
274+
else:
275+
async with self.connection.transaction():
276+
return await self.connection.execute(query, *args)
271277

272-
return res
273-
except Exception:
274-
return None
278+
async def fetch(self, query, *args):
279+
await self.init_connection()
280+
if self.current_transaction:
281+
return await self.connection.fetch(query, *args)
282+
else:
283+
async with self.connection.transaction():
284+
return await self.connection.fetch(query, *args)
275285

276-
def close(self):
277-
self.cursor.close()
278-
self.connection.close()
286+
async def close(self):
287+
if self.connection:
288+
await self.connection.close()
279289

280290

281291
class NodeBackup(object):
@@ -943,7 +953,7 @@ def restore(self, dbname, filename, username=None):
943953

944954
self.psql(dbname=dbname, filename=filename, username=username)
945955

946-
def poll_query_until(self,
956+
async def poll_query_until(self,
947957
dbname,
948958
query,
949959
username=None,
@@ -973,41 +983,54 @@ def poll_query_until(self,
973983

974984
attempts = 0
975985
while max_attempts == 0 or attempts < max_attempts:
976-
try:
977-
res = self.execute(dbname=dbname,
978-
query=query,
979-
username=username,
980-
commit=True)
981-
982-
if expected is None and res is None:
983-
return # done
986+
res = await self.fetch(dbname=dbname,
987+
query=query,
988+
username=username,
989+
commit=True)
984990

985-
if res is None:
986-
raise QueryException('Query returned None')
991+
if expected is None and res is None:
992+
return # done
987993

988-
if len(res) == 0:
989-
raise QueryException('Query returned 0 rows')
994+
if res is None:
995+
raise QueryException('Query returned None')
990996

991-
if len(res[0]) == 0:
992-
raise QueryException('Query returned 0 columns')
997+
if len(res) == 0:
998+
raise QueryException('Query returned 0 rows')
993999

994-
if res[0][0]:
995-
return # done
1000+
if len(res[0]) == 0:
1001+
raise QueryException('Query returned 0 columns')
9961002

997-
except pglib.ProgrammingError as e:
998-
if raise_programming_error:
999-
raise e
1000-
1001-
except pglib.InternalError as e:
1002-
if raise_internal_error:
1003-
raise e
1003+
if res[0][0]:
1004+
return # done
10041005

10051006
time.sleep(sleep_time)
10061007
attempts += 1
10071008

10081009
raise TimeoutException('Query timeout')
10091010

1010-
def execute(self, dbname, query, username=None, commit=True):
1011+
async def execute(self, dbname, query, username=None, commit=True):
1012+
"""
1013+
Execute a query
1014+
1015+
Args:
1016+
dbname: database name to connect to.
1017+
query: query to be executed.
1018+
username: database user name.
1019+
commit: should we commit this query?
1020+
1021+
Returns:
1022+
A list of tuples representing rows.
1023+
"""
1024+
1025+
async with self.connect(dbname, username) as node_con:
1026+
if commit:
1027+
await node_con.begin()
1028+
1029+
await node_con.execute(query)
1030+
if commit:
1031+
await node_con.commit()
1032+
1033+
async def fetch(self, dbname, query, username=None, commit=True):
10111034
"""
10121035
Execute a query and return all rows as list.
10131036
@@ -1021,10 +1044,13 @@ def execute(self, dbname, query, username=None, commit=True):
10211044
A list of tuples representing rows.
10221045
"""
10231046

1024-
with self.connect(dbname, username) as node_con:
1025-
res = node_con.execute(query)
1047+
async with self.connect(dbname, username) as node_con:
1048+
if commit:
1049+
await node_con.begin()
1050+
1051+
res = await node_con.fetch(query)
10261052
if commit:
1027-
node_con.commit()
1053+
await node_con.commit()
10281054
return res
10291055

10301056
def backup(self, username=None, xlog_method=DEFAULT_XLOG_METHOD):
@@ -1059,7 +1085,7 @@ def replicate(self, name, username=None,
10591085
backup = self.backup(username=username, xlog_method=xlog_method)
10601086
return backup.spawn_replica(name, use_logging=use_logging)
10611087

1062-
def catchup(self, username=None):
1088+
async def catchup(self, username=None):
10631089
"""
10641090
Wait until async replica catches up with its master.
10651091
"""
@@ -1080,8 +1106,8 @@ def catchup(self, username=None):
10801106
raise CatchUpException("Master node is not specified")
10811107

10821108
try:
1083-
lsn = master.execute('postgres', poll_lsn)[0][0]
1084-
self.poll_query_until(dbname='postgres',
1109+
lsn = (await master.fetch('postgres', poll_lsn))[0][0]
1110+
await self.poll_query_until(dbname='postgres',
10851111
username=username,
10861112
query=wait_lsn.format(lsn),
10871113
max_attempts=0) # infinite

0 commit comments

Comments
 (0)