forked from sqlalchemy/sqlalchemy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpostgis.py
347 lines (265 loc) · 9.49 KB
/
postgis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import binascii
from sqlalchemy import event
from sqlalchemy import Table
from sqlalchemy.sql import expression
from sqlalchemy.sql import type_coerce
from sqlalchemy.types import UserDefinedType
# Python datatypes
class GisElement(object):
"""Represents a geometry value."""
def __str__(self):
return self.desc
def __repr__(self):
return "<%s at 0x%x; %r>" % (
self.__class__.__name__,
id(self),
self.desc,
)
class BinaryGisElement(GisElement, expression.Function):
"""Represents a Geometry value expressed as binary."""
def __init__(self, data):
self.data = data
expression.Function.__init__(
self, "ST_GeomFromEWKB", data, type_=Geometry(coerce_="binary")
)
@property
def desc(self):
return self.as_hex
@property
def as_hex(self):
return binascii.hexlify(self.data)
class TextualGisElement(GisElement, expression.Function):
"""Represents a Geometry value expressed as text."""
def __init__(self, desc, srid=-1):
self.desc = desc
expression.Function.__init__(
self, "ST_GeomFromText", desc, srid, type_=Geometry
)
# SQL datatypes.
class Geometry(UserDefinedType):
"""Base PostGIS Geometry column type."""
name = "GEOMETRY"
def __init__(self, dimension=None, srid=-1, coerce_="text"):
self.dimension = dimension
self.srid = srid
self.coerce = coerce_
class comparator_factory(UserDefinedType.Comparator):
"""Define custom operations for geometry types."""
# override the __eq__() operator
def __eq__(self, other):
return self.op("~=")(other)
# add a custom operator
def intersects(self, other):
return self.op("&&")(other)
# any number of GIS operators can be overridden/added here
# using the techniques above.
def _coerce_compared_value(self, op, value):
return self
def get_col_spec(self):
return self.name
def bind_expression(self, bindvalue):
if self.coerce == "text":
return TextualGisElement(bindvalue)
elif self.coerce == "binary":
return BinaryGisElement(bindvalue)
else:
assert False
def column_expression(self, col):
if self.coerce == "text":
return func.ST_AsText(col, type_=self)
elif self.coerce == "binary":
return func.ST_AsBinary(col, type_=self)
else:
assert False
def bind_processor(self, dialect):
def process(value):
if isinstance(value, GisElement):
return value.desc
else:
return value
return process
def result_processor(self, dialect, coltype):
if self.coerce == "text":
fac = TextualGisElement
elif self.coerce == "binary":
fac = BinaryGisElement
else:
assert False
def process(value):
if value is not None:
return fac(value)
else:
return value
return process
def adapt(self, impltype):
return impltype(
dimension=self.dimension, srid=self.srid, coerce_=self.coerce
)
# other datatypes can be added as needed.
class Point(Geometry):
name = "POINT"
class Curve(Geometry):
name = "CURVE"
class LineString(Curve):
name = "LINESTRING"
# ... etc.
# DDL integration
# PostGIS historically has required AddGeometryColumn/DropGeometryColumn
# and other management methods in order to create PostGIS columns. Newer
# versions don't appear to require these special steps anymore. However,
# here we illustrate how to set up these features in any case.
def setup_ddl_events():
@event.listens_for(Table, "before_create")
def before_create(target, connection, **kw):
dispatch("before-create", target, connection)
@event.listens_for(Table, "after_create")
def after_create(target, connection, **kw):
dispatch("after-create", target, connection)
@event.listens_for(Table, "before_drop")
def before_drop(target, connection, **kw):
dispatch("before-drop", target, connection)
@event.listens_for(Table, "after_drop")
def after_drop(target, connection, **kw):
dispatch("after-drop", target, connection)
def dispatch(event, table, bind):
if event in ("before-create", "before-drop"):
regular_cols = [
c for c in table.c if not isinstance(c.type, Geometry)
]
gis_cols = set(table.c).difference(regular_cols)
table.info["_saved_columns"] = table.c
# temporarily patch a set of columns not including the
# Geometry columns
table.columns = expression.ColumnCollection(*regular_cols)
if event == "before-drop":
for c in gis_cols:
bind.execute(
select(
[
func.DropGeometryColumn(
"public", table.name, c.name
)
],
autocommit=True,
)
)
elif event == "after-create":
table.columns = table.info.pop("_saved_columns")
for c in table.c:
if isinstance(c.type, Geometry):
bind.execute(
select(
[
func.AddGeometryColumn(
table.name,
c.name,
c.type.srid,
c.type.name,
c.type.dimension,
)
],
autocommit=True,
)
)
elif event == "after-drop":
table.columns = table.info.pop("_saved_columns")
setup_ddl_events()
# illustrate usage
if __name__ == "__main__":
from sqlalchemy import (
create_engine,
MetaData,
Column,
Integer,
String,
func,
select,
)
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
engine = create_engine(
"postgresql://scott:tiger@localhost/test", echo=True
)
metadata = MetaData(engine)
Base = declarative_base(metadata=metadata)
class Road(Base):
__tablename__ = "roads"
road_id = Column(Integer, primary_key=True)
road_name = Column(String)
road_geom = Column(Geometry(2))
metadata.drop_all()
metadata.create_all()
session = sessionmaker(bind=engine)()
# Add objects. We can use strings...
session.add_all(
[
Road(
road_name="Jeff Rd",
road_geom="LINESTRING(191232 243118,191108 243242)",
),
Road(
road_name="Geordie Rd",
road_geom="LINESTRING(189141 244158,189265 244817)",
),
Road(
road_name="Paul St",
road_geom="LINESTRING(192783 228138,192612 229814)",
),
Road(
road_name="Graeme Ave",
road_geom="LINESTRING(189412 252431,189631 259122)",
),
Road(
road_name="Phil Tce",
road_geom="LINESTRING(190131 224148,190871 228134)",
),
]
)
# or use an explicit TextualGisElement
# (similar to saying func.GeomFromText())
r = Road(
road_name="Dave Cres",
road_geom=TextualGisElement(
"LINESTRING(198231 263418,198213 268322)", -1
),
)
session.add(r)
# pre flush, the TextualGisElement represents the string we sent.
assert str(r.road_geom) == "LINESTRING(198231 263418,198213 268322)"
session.commit()
# after flush and/or commit, all the TextualGisElements
# become PersistentGisElements.
assert str(r.road_geom) == "LINESTRING(198231 263418,198213 268322)"
r1 = session.query(Road).filter(Road.road_name == "Graeme Ave").one()
# illustrate the overridden __eq__() operator.
# strings come in as TextualGisElements
r2 = (
session.query(Road)
.filter(Road.road_geom == "LINESTRING(189412 252431,189631 259122)")
.one()
)
r3 = session.query(Road).filter(Road.road_geom == r1.road_geom).one()
assert r1 is r2 is r3
# core usage just fine:
road_table = Road.__table__
stmt = select([road_table]).where(
road_table.c.road_geom.intersects(r1.road_geom)
)
print(session.execute(stmt).fetchall())
# TODO: for some reason the auto-generated labels have the internal
# replacement strings exposed, even though PG doesn't complain
# look up the hex binary version, using SQLAlchemy casts
as_binary = session.scalar(
select([type_coerce(r.road_geom, Geometry(coerce_="binary"))])
)
assert as_binary.as_hex == (
"01020000000200000000000000b832084100000000"
"e813104100000000283208410000000088601041"
)
# back again, same method !
as_text = session.scalar(
select([type_coerce(as_binary, Geometry(coerce_="text"))])
)
assert as_text.desc == "LINESTRING(198231 263418,198213 268322)"
session.rollback()
metadata.drop_all()