297 lines
13 KiB
Python
297 lines
13 KiB
Python
import json
|
||
from typing import Type, List, Any, Dict
|
||
|
||
from pydantic import BaseModel
|
||
from sqlalchemy import Column, asc, desc, func, String, and_
|
||
from sqlalchemy.orm import DeclarativeMeta, Session, Query
|
||
|
||
from .model_config_utils import SalBase, get_model_col_names
|
||
from .types.crud_schema import QueryParams, OrderParam, QueryParam
|
||
from .types.model_config import ModelConfig, TypeInfo
|
||
|
||
|
||
class TableModel(DeclarativeMeta, SalBase):
|
||
pass
|
||
|
||
|
||
def query_common(db: Session, model, query_params: QueryParams):
|
||
"""
|
||
通用数据库查询接口
|
||
"""
|
||
query = db.query(model)
|
||
# if query_params.ex_include:
|
||
# query = query.options(Load(model).load_only(*query_params.include))
|
||
# if query_params.include:
|
||
# query = query.options(Load(model).noload(*query_params.ex_include))
|
||
page = query_params.page
|
||
page_size = query_params.page_size
|
||
# 排序
|
||
if query_params.order:
|
||
query = query_order(model, query, query_params.order)
|
||
|
||
if query_params.query:
|
||
query = query_common_query_core(model, query, query_params.query)
|
||
# 筛选
|
||
if query_params.params:
|
||
query = query_common_params_core(model, query, query_params.params)
|
||
|
||
count = query.count()
|
||
|
||
return count, query, page, page_size
|
||
|
||
|
||
def query_order(model, query: Query, order_param: OrderParam):
|
||
"""
|
||
查询结果排序
|
||
"""
|
||
if order_param:
|
||
if order_param.order_by:
|
||
column: Column = getattr(model, order_param.order_by)
|
||
if order_param.type == 'asc':
|
||
query = query.order_by(asc(column))
|
||
if order_param.type == 'desc':
|
||
query = query.order_by(desc(column))
|
||
return query
|
||
|
||
|
||
def query_common_params_core(model, query: Query, params: List[QueryParam]):
|
||
cols: Dict[str, TypeInfo] = model.model_config.cols
|
||
for item in params:
|
||
name = item.name
|
||
child_name = ""
|
||
if '.' in name:
|
||
[name, child_name] = name.split('.')
|
||
query_type = item.type
|
||
value = item.value
|
||
col = cols[name]
|
||
col_base_type = col.col_base_type
|
||
column: Column = getattr(model, name)
|
||
if col_base_type == 'relation' and col.relation:
|
||
if col.relation.relation_type in ['m2o', 'o2o']:
|
||
if col.relation.source_foreign_key:
|
||
query = query.join(col.relation.target_model,
|
||
getattr(col.relation.target_model, col.relation.target_id_key) == getattr(model,
|
||
col.relation.source_foreign_key))
|
||
else:
|
||
query = query.join(col.relation.target_model)
|
||
column = getattr(col.relation.target_model, child_name)
|
||
else:
|
||
if col_base_type == 'json':
|
||
column = column[child_name]
|
||
if query_type in ['=', '==']:
|
||
if col_base_type == 'relation' and col.relation and not child_name:
|
||
relation_type = col.relation.relation_type
|
||
secondary = col.relation.secondary
|
||
target_secondary_key = col.relation.target_secondary_key
|
||
source_secondary_key = col.relation.source_secondary_key
|
||
if relation_type == 'm2m':
|
||
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
|
||
.filter(secondary.columns.get(target_secondary_key) == value)
|
||
else:
|
||
if col.col_base_type == 'datetime':
|
||
query = query.filter(func.date(column) == value)
|
||
elif col_base_type == 'json':
|
||
query = query.filter(column == json.dumps(value))
|
||
else:
|
||
query = query.filter(column == value)
|
||
|
||
elif query_type == ">":
|
||
query = query.filter(column > value)
|
||
elif query_type == "<":
|
||
query = query.filter(column < value)
|
||
elif query_type == ">=":
|
||
query = query.filter(column >= value)
|
||
elif query_type == "<=":
|
||
query = query.filter(column <= value)
|
||
elif query_type == "in":
|
||
if col_base_type == 'relation' and col.relation:
|
||
relation_type = col.relation.relation_type
|
||
if relation_type == 'm2m':
|
||
secondary = col.relation.secondary
|
||
target_secondary_key = col.relation.target_secondary_key
|
||
source_secondary_key = col.relation.source_secondary_key
|
||
if value and type(value) == list:
|
||
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
|
||
.filter(secondary.columns.get(target_secondary_key).in_(value))
|
||
else:
|
||
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
|
||
.filter(secondary.columns.get(target_secondary_key) == value)
|
||
|
||
else:
|
||
query = query.filter(column.in_(value))
|
||
elif query_type == "like":
|
||
query = query.filter(column.cast(String).like(f'%{value}%'))
|
||
elif query_type == "find_in_set":
|
||
query = query.filter(func.find_in_set(value, column))
|
||
elif query_type == "range":
|
||
if value:
|
||
if len(value) == 1 and value[0] is not None:
|
||
query = query.filter(column >= value[0])
|
||
if len(value) == 2:
|
||
if value[0] is not None and value[1] is None:
|
||
query = query.filter(column >= value[0])
|
||
if value[0] is None and value[1] is not None:
|
||
query = query.filter(column <= value[1])
|
||
if value[0] is not None and value[1] is not None:
|
||
query = query.filter(and_(column >= value[0], column <= value[1]))
|
||
return query
|
||
|
||
|
||
def query_common_query_core(model, query: Query, query_data: BaseModel):
|
||
q_data = query_data.dict(exclude_unset=True)
|
||
cols: Dict[str, TypeInfo] = model.model_config.cols
|
||
for key, value in q_data:
|
||
annotation = query_data.__annotations__[key]
|
||
if key in cols:
|
||
column: Column = getattr(model, key)
|
||
col = cols[key]
|
||
col_base_type = col.col_base_type
|
||
if col_base_type in ['int', 'float', 'any', 'enum', 'str']:
|
||
if type(value) == list and value:
|
||
query = query.filter(column.in_(value))
|
||
else:
|
||
if col_base_type == 'str' and value:
|
||
query = query.filter(column.like(f'%{value}%'))
|
||
else:
|
||
query = query.filter(column == value)
|
||
|
||
elif col_base_type in ['date', 'datetime', 'time']:
|
||
if value:
|
||
if type(value) == list:
|
||
if len(value) == 2:
|
||
if value[0] is not None:
|
||
query = query.filter(column >= value)
|
||
if value[1] is not None:
|
||
query = query.filter(column <= value)
|
||
|
||
else:
|
||
if col_base_type == 'datetime':
|
||
# datetime转date来查
|
||
query = query.filter(func.date(column) == value)
|
||
else:
|
||
query = query.filter(column == value)
|
||
elif col_base_type == 'relation' and col.relation:
|
||
relation_type = col.relation.relation_type
|
||
if relation_type == 'm2m':
|
||
secondary = col.relation.secondary
|
||
target_secondary_key = col.relation.target_secondary_key
|
||
source_secondary_key = col.relation.source_secondary_key
|
||
|
||
if value and type(value) == list:
|
||
query = query.join(secondary, model.id == getattr(secondary, source_secondary_key)) \
|
||
.filter(getattr(secondary, target_secondary_key).in_(value))
|
||
else:
|
||
query = query.join(secondary, model.id == getattr(secondary, source_secondary_key)) \
|
||
.filter(getattr(secondary, target_secondary_key) == value)
|
||
elif relation_type == 'o2m':
|
||
target_model = col.relation.target_model
|
||
target_id_key = col.relation.target_id_key
|
||
|
||
if value and type(value) == list:
|
||
query = query.join(target_model).filter(getattr(target_model, target_id_key).in_(value))
|
||
else:
|
||
query = query.join(target_model).filter(getattr(target_model, target_id_key == value))
|
||
|
||
return query
|
||
|
||
|
||
class ModelCRUD:
|
||
|
||
def __init__(self, model: Type[TableModel]):
|
||
self.model = model
|
||
|
||
def update(self, db: Session, data: BaseModel):
|
||
model = self.model
|
||
id_key = self.model.model_config.id_key
|
||
item = db.query(model).filter(getattr(model, id_key) == (getattr(data, id_key))).first()
|
||
data = data.dict(exclude_unset=True)
|
||
col_names = get_model_col_names(model)
|
||
model_config: ModelConfig = model.model_config
|
||
for k, v in data.items():
|
||
if '.' in k:
|
||
paths = k.split('.')
|
||
if paths[0] in col_names:
|
||
col_type_info: TypeInfo = model.model_config.cols[paths[0]]
|
||
if col_type_info.relation:
|
||
# 一对一时,{"config.age":1}
|
||
relation = col_type_info.relation
|
||
if relation.relation_type in ["o2o"]:
|
||
col = getattr(item, k)
|
||
if not col:
|
||
if v is not None:
|
||
setattr(item, paths[0], relation.target_model(**{paths[1]: v}))
|
||
else:
|
||
setattr(col, paths[1], v)
|
||
if k in col_names:
|
||
col_type_info: TypeInfo = model.model_config.cols[k]
|
||
# col = getattr(model, k)
|
||
relation = col_type_info.relation
|
||
if relation:
|
||
if relation.relation_type in ['m2m']:
|
||
if v is not None:
|
||
if v and type(v[0]) == dict:
|
||
v = [item[model_config.id_key] for item in v]
|
||
# 多对多更新的值是一个id数组,所以先用id查出对应项,再更新
|
||
target_model = relation.target_model
|
||
target_key = relation.target_id_key
|
||
update_values = {child for child in
|
||
db.query(target_model).filter(getattr(target_model, target_key).in_(v))}
|
||
setattr(item, k, update_values)
|
||
elif relation.relation_type == 'o2o':
|
||
if type(v) == dict:
|
||
col = getattr(item, k)
|
||
if not col:
|
||
setattr(item, k, relation.target_model(**v))
|
||
else:
|
||
for v_k, v_v in v.items():
|
||
setattr(col, v_k, v_v)
|
||
|
||
else:
|
||
setattr(item, k, v)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
def get(self, db: Session, data: BaseModel):
|
||
item = db.query(self.model).filter(getattr(self.model, self.model.model_config.id_key) == (
|
||
getattr(data, self.model.model_config.id_key))).first()
|
||
return item
|
||
|
||
def add(self, db: Session, data: BaseModel):
|
||
|
||
model = self.model
|
||
item = model()
|
||
data = data.dict(exclude_unset=True)
|
||
col_names = get_model_col_names(model)
|
||
for k, v in data.items():
|
||
if k in col_names:
|
||
col_type_info: TypeInfo = model.model_config.cols[k]
|
||
# col = getattr(model, k)
|
||
relation = col_type_info.relation
|
||
if relation:
|
||
if relation.relation_type == 'm2m':
|
||
if v is not None:
|
||
target_model = relation.target_model
|
||
update_values = {child for child in
|
||
db.query(target_model).filter(
|
||
getattr(target_model, relation.target_id_key).in_(v))}
|
||
setattr(item, k, update_values)
|
||
elif relation.relation_type == 'o2o':
|
||
if type(v) == dict:
|
||
setattr(item, k, relation.target_model(**v))
|
||
else:
|
||
setattr(item, k, v)
|
||
db.add(item)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
def query(self, db: Session, data: QueryParams) -> (int, List[Any]):
|
||
return query_common(db, self.model, data)
|
||
|
||
def delete(self, db: Session, data: BaseModel):
|
||
db.query(self.model).filter(getattr(self.model, self.model.model_config.id_key) == (
|
||
getattr(data, self.model.model_config.id_key))).delete()
|
||
db.commit()
|
||
return True
|