45
45
from enum import Enum
46
46
from distutils .version import LooseVersion
47
47
48
- # Try to use psycopg2 by default. If psycopg2 isn't available then use
49
- # pg8000 which is slower but much more portable because uses only
50
- # pure-Python code
51
48
try :
52
- import psycopg2 as pglib
49
+ import asyncpg as pglib
53
50
except ImportError :
54
- try :
55
- import pg8000 as pglib
56
- except ImportError :
57
- raise ImportError ("You must have psycopg2 or pg8000 modules installed" )
51
+ raise ImportError ("You must have asyncpg module installed" )
58
52
59
53
# ports used by nodes
60
54
bound_ports = set ()
@@ -193,26 +187,34 @@ def __init__(self,
193
187
password = None ):
194
188
195
189
# Use default user if not specified
196
- username = username or default_username ()
197
-
190
+ self .username = username or default_username ()
191
+ self .dbname = dbname
192
+ self .host = host
193
+ self .password = password
198
194
self .parent_node = parent_node
195
+ self .connection = None
196
+ self .current_transaction = None
199
197
200
- self .connection = pglib .connect (
201
- database = dbname ,
202
- user = username ,
203
- port = parent_node .port ,
204
- host = host ,
205
- password = password )
198
+ async def init_connection (self ):
199
+ if self .connection :
200
+ return
206
201
207
- self .cursor = self .connection .cursor ()
202
+ self .connection = await pglib .connect (
203
+ database = self .dbname ,
204
+ user = self .username ,
205
+ port = self .parent_node .port ,
206
+ host = self .host ,
207
+ password = self .password )
208
208
209
- def __enter__ (self ):
209
+ async def __aenter__ (self ):
210
210
return self
211
211
212
- def __exit__ (self , type , value , traceback ):
213
- self .close ()
212
+ async def __aexit__ (self , type , value , traceback ):
213
+ await self .close ()
214
+
215
+ async def begin (self , isolation_level = IsolationLevel .ReadCommitted ):
216
+ await self .init_connection ()
214
217
215
- def begin (self , isolation_level = IsolationLevel .ReadCommitted ):
216
218
# yapf: disable
217
219
levels = [
218
220
'read uncommitted' ,
@@ -245,37 +247,45 @@ def begin(self, isolation_level=IsolationLevel.ReadCommitted):
245
247
246
248
# Set isolation level
247
249
cmd = 'SET TRANSACTION ISOLATION LEVEL {}'
248
- self .cursor .execute (cmd .format (isolation_level ))
250
+ self .current_transaction = self .connection .transaction ()
251
+ await self .current_transaction .start ()
252
+ await self .connection .execute (cmd .format (isolation_level ))
249
253
250
254
return self
251
255
252
- def commit (self ):
253
- self .connection .commit ()
256
+ async def commit (self ):
257
+ if not self .current_transaction :
258
+ raise QueryException ("transaction is not started" )
254
259
255
- return self
256
-
257
- def rollback (self ):
258
- self .connection .rollback ()
260
+ await self .current_transaction .commit ()
261
+ self .current_transaction = None
259
262
260
- return self
263
+ async def rollback (self ):
264
+ if not self .current_transaction :
265
+ raise QueryException ("transaction is not started" )
261
266
262
- def execute ( self , query , * args ):
263
- self .cursor . execute ( query , args )
267
+ await self . current_transaction . rollback ()
268
+ self .current_transaction = None
264
269
265
- try :
266
- res = self .cursor .fetchall ()
267
-
268
- # pg8000 might return tuples
269
- if isinstance (res , tuple ):
270
- res = [tuple (t ) for t in res ]
270
+ async def execute (self , query , * args ):
271
+ await self .init_connection ()
272
+ if self .current_transaction :
273
+ return await self .connection .execute (query , * args )
274
+ else :
275
+ async with self .connection .transaction ():
276
+ return await self .connection .execute (query , * args )
271
277
272
- return res
273
- except Exception :
274
- return None
278
+ async def fetch (self , query , * args ):
279
+ await self .init_connection ()
280
+ if self .current_transaction :
281
+ return await self .connection .fetch (query , * args )
282
+ else :
283
+ async with self .connection .transaction ():
284
+ return await self .connection .fetch (query , * args )
275
285
276
- def close (self ):
277
- self .cursor . close ()
278
- self .connection .close ()
286
+ async def close (self ):
287
+ if self .connection :
288
+ await self .connection .close ()
279
289
280
290
281
291
class NodeBackup (object ):
@@ -943,7 +953,7 @@ def restore(self, dbname, filename, username=None):
943
953
944
954
self .psql (dbname = dbname , filename = filename , username = username )
945
955
946
- def poll_query_until (self ,
956
+ async def poll_query_until (self ,
947
957
dbname ,
948
958
query ,
949
959
username = None ,
@@ -973,41 +983,54 @@ def poll_query_until(self,
973
983
974
984
attempts = 0
975
985
while max_attempts == 0 or attempts < max_attempts :
976
- try :
977
- res = self .execute (dbname = dbname ,
978
- query = query ,
979
- username = username ,
980
- commit = True )
981
-
982
- if expected is None and res is None :
983
- return # done
986
+ res = await self .fetch (dbname = dbname ,
987
+ query = query ,
988
+ username = username ,
989
+ commit = True )
984
990
985
- if res is None :
986
- raise QueryException ( 'Query returned None' )
991
+ if expected is None and res is None :
992
+ return # done
987
993
988
- if len ( res ) == 0 :
989
- raise QueryException ('Query returned 0 rows ' )
994
+ if res is None :
995
+ raise QueryException ('Query returned None ' )
990
996
991
- if len (res [ 0 ] ) == 0 :
992
- raise QueryException ('Query returned 0 columns ' )
997
+ if len (res ) == 0 :
998
+ raise QueryException ('Query returned 0 rows ' )
993
999
994
- if res [0 ][ 0 ] :
995
- return # done
1000
+ if len ( res [0 ]) == 0 :
1001
+ raise QueryException ( 'Query returned 0 columns' )
996
1002
997
- except pglib .ProgrammingError as e :
998
- if raise_programming_error :
999
- raise e
1000
-
1001
- except pglib .InternalError as e :
1002
- if raise_internal_error :
1003
- raise e
1003
+ if res [0 ][0 ]:
1004
+ return # done
1004
1005
1005
1006
time .sleep (sleep_time )
1006
1007
attempts += 1
1007
1008
1008
1009
raise TimeoutException ('Query timeout' )
1009
1010
1010
- def execute (self , dbname , query , username = None , commit = True ):
1011
+ async def execute (self , dbname , query , username = None , commit = True ):
1012
+ """
1013
+ Execute a query
1014
+
1015
+ Args:
1016
+ dbname: database name to connect to.
1017
+ query: query to be executed.
1018
+ username: database user name.
1019
+ commit: should we commit this query?
1020
+
1021
+ Returns:
1022
+ A list of tuples representing rows.
1023
+ """
1024
+
1025
+ async with self .connect (dbname , username ) as node_con :
1026
+ if commit :
1027
+ await node_con .begin ()
1028
+
1029
+ await node_con .execute (query )
1030
+ if commit :
1031
+ await node_con .commit ()
1032
+
1033
+ async def fetch (self , dbname , query , username = None , commit = True ):
1011
1034
"""
1012
1035
Execute a query and return all rows as list.
1013
1036
@@ -1021,10 +1044,13 @@ def execute(self, dbname, query, username=None, commit=True):
1021
1044
A list of tuples representing rows.
1022
1045
"""
1023
1046
1024
- with self .connect (dbname , username ) as node_con :
1025
- res = node_con .execute (query )
1047
+ async with self .connect (dbname , username ) as node_con :
1048
+ if commit :
1049
+ await node_con .begin ()
1050
+
1051
+ res = await node_con .fetch (query )
1026
1052
if commit :
1027
- node_con .commit ()
1053
+ await node_con .commit ()
1028
1054
return res
1029
1055
1030
1056
def backup (self , username = None , xlog_method = DEFAULT_XLOG_METHOD ):
@@ -1059,7 +1085,7 @@ def replicate(self, name, username=None,
1059
1085
backup = self .backup (username = username , xlog_method = xlog_method )
1060
1086
return backup .spawn_replica (name , use_logging = use_logging )
1061
1087
1062
- def catchup (self , username = None ):
1088
+ async def catchup (self , username = None ):
1063
1089
"""
1064
1090
Wait until async replica catches up with its master.
1065
1091
"""
@@ -1080,8 +1106,8 @@ def catchup(self, username=None):
1080
1106
raise CatchUpException ("Master node is not specified" )
1081
1107
1082
1108
try :
1083
- lsn = master .execute ('postgres' , poll_lsn )[0 ][0 ]
1084
- self .poll_query_until (dbname = 'postgres' ,
1109
+ lsn = ( await master .fetch ('postgres' , poll_lsn ) )[0 ][0 ]
1110
+ await self .poll_query_until (dbname = 'postgres' ,
1085
1111
username = username ,
1086
1112
query = wait_lsn .format (lsn ),
1087
1113
max_attempts = 0 ) # infinite
0 commit comments