How can I perform this Recursive Common Table Expr

2020-07-18 07:34发布

问题:

In this table schema:

class Location(db.Model):
    id = db.Column(db.String(255), primary_key=True)
    parent_id = db.Column(db.String(255), db.ForeignKey('location.id'), nullable=True) 
    type = db.Column(db.String(255))
    name = db.Column(db.String(255))
    options = db.Column(db.Text, nullable=True)

I have a parent, suppose this table:

> select * from location;
+--------------------------------------+--------------------------------------+---------+-----------+---------+
| id                                   | parent_id                            | type    | name      | options |
+--------------------------------------+--------------------------------------+---------+-----------+---------+
| 8821da60-e7d2-11e1-951e-ae16bc2b7368 | 88227a60-e7d2-11e1-951e-ae16bc2b7368 | city    | sengkang  | NULL    |
| 88227a60-e7d2-11e1-951e-ae16bc2b7368 | NULL                                 | country | singapore | NULL    |
+--------------------------------------+--------------------------------------+---------+-----------+---------+

The parent being the country object named singapore. There can be more nested objects, that are children to sengkang just like how sengkang is child to singapore.

So a single hierachy chain might look like a -> b -> sengkang -> singapore

whereby -> means is child of

How can I get all Location objects that have the parent as singapore, including the parent object (singapore)? (in SQLAlchemy please). Thanks!

回答1:

SA Documentation: Query.cte

I have to say the new CTE stuff (since 0.76) is pretty good, better than the old way where I had to rework this.

Anyway, think you're looking for something like this...

included_parts = session.query(Location).\
  filter(Location.name=='singapore').\
  cte(name='included_parts', recursive=True)

incl_alias = aliased(included_parts, name="pr")
parts_alias = aliased(Location, name='p')
included_parts = included_parts.union_all(
  session.query(parts_alias).filter(parts_alias.parent_id==incl_alias.c.id)
)

q = session.query(included_parts).all()