wd-smebiz-client/utils/sal_utils/crud_utils.py

297 lines
13 KiB
Python
Raw Normal View History

2023-09-11 10:37:07 +08:00
import json
from typing import Type, List, Any, Dict
from pydantic import BaseModel
from sqlalchemy import Column, asc, desc, func, String, and_
from sqlalchemy.orm import DeclarativeMeta, Session, Query
from .model_config_utils import SalBase, get_model_col_names
from .types.crud_schema import QueryParams, OrderParam, QueryParam
from .types.model_config import ModelConfig, TypeInfo
class TableModel(DeclarativeMeta, SalBase):
pass
def query_common(db: Session, model, query_params: QueryParams):
"""
通用数据库查询接口
"""
query = db.query(model)
# if query_params.ex_include:
# query = query.options(Load(model).load_only(*query_params.include))
# if query_params.include:
# query = query.options(Load(model).noload(*query_params.ex_include))
page = query_params.page
page_size = query_params.page_size
# 排序
if query_params.order:
query = query_order(model, query, query_params.order)
if query_params.query:
query = query_common_query_core(model, query, query_params.query)
# 筛选
if query_params.params:
query = query_common_params_core(model, query, query_params.params)
count = query.count()
return count, query, page, page_size
def query_order(model, query: Query, order_param: OrderParam):
"""
查询结果排序
"""
if order_param:
if order_param.order_by:
column: Column = getattr(model, order_param.order_by)
if order_param.type == 'asc':
query = query.order_by(asc(column))
if order_param.type == 'desc':
query = query.order_by(desc(column))
return query
def query_common_params_core(model, query: Query, params: List[QueryParam]):
cols: Dict[str, TypeInfo] = model.model_config.cols
for item in params:
name = item.name
child_name = ""
if '.' in name:
[name, child_name] = name.split('.')
query_type = item.type
value = item.value
col = cols[name]
col_base_type = col.col_base_type
column: Column = getattr(model, name)
if col_base_type == 'relation' and col.relation:
if col.relation.relation_type in ['m2o', 'o2o']:
if col.relation.source_foreign_key:
query = query.join(col.relation.target_model,
getattr(col.relation.target_model, col.relation.target_id_key) == getattr(model,
col.relation.source_foreign_key))
else:
query = query.join(col.relation.target_model)
column = getattr(col.relation.target_model, child_name)
else:
if col_base_type == 'json':
column = column[child_name]
if query_type in ['=', '==']:
if col_base_type == 'relation' and col.relation and not child_name:
relation_type = col.relation.relation_type
secondary = col.relation.secondary
target_secondary_key = col.relation.target_secondary_key
source_secondary_key = col.relation.source_secondary_key
if relation_type == 'm2m':
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
.filter(secondary.columns.get(target_secondary_key) == value)
else:
if col.col_base_type == 'datetime':
query = query.filter(func.date(column) == value)
elif col_base_type == 'json':
query = query.filter(column == json.dumps(value))
else:
query = query.filter(column == value)
elif query_type == ">":
query = query.filter(column > value)
elif query_type == "<":
query = query.filter(column < value)
elif query_type == ">=":
query = query.filter(column >= value)
elif query_type == "<=":
query = query.filter(column <= value)
elif query_type == "in":
if col_base_type == 'relation' and col.relation:
relation_type = col.relation.relation_type
if relation_type == 'm2m':
secondary = col.relation.secondary
target_secondary_key = col.relation.target_secondary_key
source_secondary_key = col.relation.source_secondary_key
if value and type(value) == list:
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
.filter(secondary.columns.get(target_secondary_key).in_(value))
else:
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
.filter(secondary.columns.get(target_secondary_key) == value)
else:
query = query.filter(column.in_(value))
elif query_type == "like":
query = query.filter(column.cast(String).like(f'%{value}%'))
elif query_type == "find_in_set":
query = query.filter(func.find_in_set(value, column))
elif query_type == "range":
if value:
if len(value) == 1 and value[0] is not None:
query = query.filter(column >= value[0])
if len(value) == 2:
if value[0] is not None and value[1] is None:
query = query.filter(column >= value[0])
if value[0] is None and value[1] is not None:
query = query.filter(column <= value[1])
if value[0] is not None and value[1] is not None:
query = query.filter(and_(column >= value[0], column <= value[1]))
return query
def query_common_query_core(model, query: Query, query_data: BaseModel):
q_data = query_data.dict(exclude_unset=True)
cols: Dict[str, TypeInfo] = model.model_config.cols
for key, value in q_data:
annotation = query_data.__annotations__[key]
if key in cols:
column: Column = getattr(model, key)
col = cols[key]
col_base_type = col.col_base_type
if col_base_type in ['int', 'float', 'any', 'enum', 'str']:
if type(value) == list and value:
query = query.filter(column.in_(value))
else:
if col_base_type == 'str' and value:
query = query.filter(column.like(f'%{value}%'))
else:
query = query.filter(column == value)
elif col_base_type in ['date', 'datetime', 'time']:
if value:
if type(value) == list:
if len(value) == 2:
if value[0] is not None:
query = query.filter(column >= value)
if value[1] is not None:
query = query.filter(column <= value)
else:
if col_base_type == 'datetime':
# datetime转date来查
query = query.filter(func.date(column) == value)
else:
query = query.filter(column == value)
elif col_base_type == 'relation' and col.relation:
relation_type = col.relation.relation_type
if relation_type == 'm2m':
secondary = col.relation.secondary
target_secondary_key = col.relation.target_secondary_key
source_secondary_key = col.relation.source_secondary_key
if value and type(value) == list:
query = query.join(secondary, model.id == getattr(secondary, source_secondary_key)) \
.filter(getattr(secondary, target_secondary_key).in_(value))
else:
query = query.join(secondary, model.id == getattr(secondary, source_secondary_key)) \
.filter(getattr(secondary, target_secondary_key) == value)
elif relation_type == 'o2m':
target_model = col.relation.target_model
target_id_key = col.relation.target_id_key
if value and type(value) == list:
query = query.join(target_model).filter(getattr(target_model, target_id_key).in_(value))
else:
query = query.join(target_model).filter(getattr(target_model, target_id_key == value))
return query
class ModelCRUD:
def __init__(self, model: Type[TableModel]):
self.model = model
def update(self, db: Session, data: BaseModel):
model = self.model
id_key = self.model.model_config.id_key
item = db.query(model).filter(getattr(model, id_key) == (getattr(data, id_key))).first()
data = data.dict(exclude_unset=True)
col_names = get_model_col_names(model)
model_config: ModelConfig = model.model_config
for k, v in data.items():
if '.' in k:
paths = k.split('.')
if paths[0] in col_names:
col_type_info: TypeInfo = model.model_config.cols[paths[0]]
if col_type_info.relation:
# 一对一时,{"config.age":1}
relation = col_type_info.relation
if relation.relation_type in ["o2o"]:
col = getattr(item, k)
if not col:
if v is not None:
setattr(item, paths[0], relation.target_model(**{paths[1]: v}))
else:
setattr(col, paths[1], v)
if k in col_names:
col_type_info: TypeInfo = model.model_config.cols[k]
# col = getattr(model, k)
relation = col_type_info.relation
if relation:
if relation.relation_type in ['m2m']:
if v is not None:
if v and type(v[0]) == dict:
v = [item[model_config.id_key] for item in v]
# 多对多更新的值是一个id数组所以先用id查出对应项再更新
target_model = relation.target_model
target_key = relation.target_id_key
update_values = {child for child in
db.query(target_model).filter(getattr(target_model, target_key).in_(v))}
setattr(item, k, update_values)
elif relation.relation_type == 'o2o':
if type(v) == dict:
col = getattr(item, k)
if not col:
setattr(item, k, relation.target_model(**v))
else:
for v_k, v_v in v.items():
setattr(col, v_k, v_v)
else:
setattr(item, k, v)
db.commit()
db.refresh(item)
return item
def get(self, db: Session, data: BaseModel):
item = db.query(self.model).filter(getattr(self.model, self.model.model_config.id_key) == (
getattr(data, self.model.model_config.id_key))).first()
return item
def add(self, db: Session, data: BaseModel):
model = self.model
item = model()
data = data.dict(exclude_unset=True)
col_names = get_model_col_names(model)
for k, v in data.items():
if k in col_names:
col_type_info: TypeInfo = model.model_config.cols[k]
# col = getattr(model, k)
relation = col_type_info.relation
if relation:
if relation.relation_type == 'm2m':
if v is not None:
target_model = relation.target_model
update_values = {child for child in
db.query(target_model).filter(
getattr(target_model, relation.target_id_key).in_(v))}
setattr(item, k, update_values)
elif relation.relation_type == 'o2o':
if type(v) == dict:
setattr(item, k, relation.target_model(**v))
else:
setattr(item, k, v)
db.add(item)
db.commit()
db.refresh(item)
return item
def query(self, db: Session, data: QueryParams) -> (int, List[Any]):
return query_common(db, self.model, data)
def delete(self, db: Session, data: BaseModel):
db.query(self.model).filter(getattr(self.model, self.model.model_config.id_key) == (
getattr(data, self.model.model_config.id_key))).delete()
db.commit()
return True