wd-smebiz-client/utils/sal_utils/crud_utils.py

297 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from typing import Type, List, Any, Dict
from pydantic import BaseModel
from sqlalchemy import Column, asc, desc, func, String, and_
from sqlalchemy.orm import DeclarativeMeta, Session, Query
from .model_config_utils import SalBase, get_model_col_names
from .types.crud_schema import QueryParams, OrderParam, QueryParam
from .types.model_config import ModelConfig, TypeInfo
class TableModel(DeclarativeMeta, SalBase):
pass
def query_common(db: Session, model, query_params: QueryParams):
"""
通用数据库查询接口
"""
query = db.query(model)
# if query_params.ex_include:
# query = query.options(Load(model).load_only(*query_params.include))
# if query_params.include:
# query = query.options(Load(model).noload(*query_params.ex_include))
page = query_params.page
page_size = query_params.page_size
# 排序
if query_params.order:
query = query_order(model, query, query_params.order)
if query_params.query:
query = query_common_query_core(model, query, query_params.query)
# 筛选
if query_params.params:
query = query_common_params_core(model, query, query_params.params)
count = query.count()
return count, query, page, page_size
def query_order(model, query: Query, order_param: OrderParam):
"""
查询结果排序
"""
if order_param:
if order_param.order_by:
column: Column = getattr(model, order_param.order_by)
if order_param.type == 'asc':
query = query.order_by(asc(column))
if order_param.type == 'desc':
query = query.order_by(desc(column))
return query
def query_common_params_core(model, query: Query, params: List[QueryParam]):
cols: Dict[str, TypeInfo] = model.model_config.cols
for item in params:
name = item.name
child_name = ""
if '.' in name:
[name, child_name] = name.split('.')
query_type = item.type
value = item.value
col = cols[name]
col_base_type = col.col_base_type
column: Column = getattr(model, name)
if col_base_type == 'relation' and col.relation:
if col.relation.relation_type in ['m2o', 'o2o']:
if col.relation.source_foreign_key:
query = query.join(col.relation.target_model,
getattr(col.relation.target_model, col.relation.target_id_key) == getattr(model,
col.relation.source_foreign_key))
else:
query = query.join(col.relation.target_model)
column = getattr(col.relation.target_model, child_name)
else:
if col_base_type == 'json':
column = column[child_name]
if query_type in ['=', '==']:
if col_base_type == 'relation' and col.relation and not child_name:
relation_type = col.relation.relation_type
secondary = col.relation.secondary
target_secondary_key = col.relation.target_secondary_key
source_secondary_key = col.relation.source_secondary_key
if relation_type == 'm2m':
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
.filter(secondary.columns.get(target_secondary_key) == value)
else:
if col.col_base_type == 'datetime':
query = query.filter(func.date(column) == value)
elif col_base_type == 'json':
query = query.filter(column == json.dumps(value))
else:
query = query.filter(column == value)
elif query_type == ">":
query = query.filter(column > value)
elif query_type == "<":
query = query.filter(column < value)
elif query_type == ">=":
query = query.filter(column >= value)
elif query_type == "<=":
query = query.filter(column <= value)
elif query_type == "in":
if col_base_type == 'relation' and col.relation:
relation_type = col.relation.relation_type
if relation_type == 'm2m':
secondary = col.relation.secondary
target_secondary_key = col.relation.target_secondary_key
source_secondary_key = col.relation.source_secondary_key
if value and type(value) == list:
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
.filter(secondary.columns.get(target_secondary_key).in_(value))
else:
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
.filter(secondary.columns.get(target_secondary_key) == value)
else:
query = query.filter(column.in_(value))
elif query_type == "like":
query = query.filter(column.cast(String).like(f'%{value}%'))
elif query_type == "find_in_set":
query = query.filter(func.find_in_set(value, column))
elif query_type == "range":
if value:
if len(value) == 1 and value[0] is not None:
query = query.filter(column >= value[0])
if len(value) == 2:
if value[0] is not None and value[1] is None:
query = query.filter(column >= value[0])
if value[0] is None and value[1] is not None:
query = query.filter(column <= value[1])
if value[0] is not None and value[1] is not None:
query = query.filter(and_(column >= value[0], column <= value[1]))
return query
def query_common_query_core(model, query: Query, query_data: BaseModel):
q_data = query_data.dict(exclude_unset=True)
cols: Dict[str, TypeInfo] = model.model_config.cols
for key, value in q_data:
annotation = query_data.__annotations__[key]
if key in cols:
column: Column = getattr(model, key)
col = cols[key]
col_base_type = col.col_base_type
if col_base_type in ['int', 'float', 'any', 'enum', 'str']:
if type(value) == list and value:
query = query.filter(column.in_(value))
else:
if col_base_type == 'str' and value:
query = query.filter(column.like(f'%{value}%'))
else:
query = query.filter(column == value)
elif col_base_type in ['date', 'datetime', 'time']:
if value:
if type(value) == list:
if len(value) == 2:
if value[0] is not None:
query = query.filter(column >= value)
if value[1] is not None:
query = query.filter(column <= value)
else:
if col_base_type == 'datetime':
# datetime转date来查
query = query.filter(func.date(column) == value)
else:
query = query.filter(column == value)
elif col_base_type == 'relation' and col.relation:
relation_type = col.relation.relation_type
if relation_type == 'm2m':
secondary = col.relation.secondary
target_secondary_key = col.relation.target_secondary_key
source_secondary_key = col.relation.source_secondary_key
if value and type(value) == list:
query = query.join(secondary, model.id == getattr(secondary, source_secondary_key)) \
.filter(getattr(secondary, target_secondary_key).in_(value))
else:
query = query.join(secondary, model.id == getattr(secondary, source_secondary_key)) \
.filter(getattr(secondary, target_secondary_key) == value)
elif relation_type == 'o2m':
target_model = col.relation.target_model
target_id_key = col.relation.target_id_key
if value and type(value) == list:
query = query.join(target_model).filter(getattr(target_model, target_id_key).in_(value))
else:
query = query.join(target_model).filter(getattr(target_model, target_id_key == value))
return query
class ModelCRUD:
def __init__(self, model: Type[TableModel]):
self.model = model
def update(self, db: Session, data: BaseModel):
model = self.model
id_key = self.model.model_config.id_key
item = db.query(model).filter(getattr(model, id_key) == (getattr(data, id_key))).first()
data = data.dict(exclude_unset=True)
col_names = get_model_col_names(model)
model_config: ModelConfig = model.model_config
for k, v in data.items():
if '.' in k:
paths = k.split('.')
if paths[0] in col_names:
col_type_info: TypeInfo = model.model_config.cols[paths[0]]
if col_type_info.relation:
# 一对一时,{"config.age":1}
relation = col_type_info.relation
if relation.relation_type in ["o2o"]:
col = getattr(item, k)
if not col:
if v is not None:
setattr(item, paths[0], relation.target_model(**{paths[1]: v}))
else:
setattr(col, paths[1], v)
if k in col_names:
col_type_info: TypeInfo = model.model_config.cols[k]
# col = getattr(model, k)
relation = col_type_info.relation
if relation:
if relation.relation_type in ['m2m']:
if v is not None:
if v and type(v[0]) == dict:
v = [item[model_config.id_key] for item in v]
# 多对多更新的值是一个id数组所以先用id查出对应项再更新
target_model = relation.target_model
target_key = relation.target_id_key
update_values = {child for child in
db.query(target_model).filter(getattr(target_model, target_key).in_(v))}
setattr(item, k, update_values)
elif relation.relation_type == 'o2o':
if type(v) == dict:
col = getattr(item, k)
if not col:
setattr(item, k, relation.target_model(**v))
else:
for v_k, v_v in v.items():
setattr(col, v_k, v_v)
else:
setattr(item, k, v)
db.commit()
db.refresh(item)
return item
def get(self, db: Session, data: BaseModel):
item = db.query(self.model).filter(getattr(self.model, self.model.model_config.id_key) == (
getattr(data, self.model.model_config.id_key))).first()
return item
def add(self, db: Session, data: BaseModel):
model = self.model
item = model()
data = data.dict(exclude_unset=True)
col_names = get_model_col_names(model)
for k, v in data.items():
if k in col_names:
col_type_info: TypeInfo = model.model_config.cols[k]
# col = getattr(model, k)
relation = col_type_info.relation
if relation:
if relation.relation_type == 'm2m':
if v is not None:
target_model = relation.target_model
update_values = {child for child in
db.query(target_model).filter(
getattr(target_model, relation.target_id_key).in_(v))}
setattr(item, k, update_values)
elif relation.relation_type == 'o2o':
if type(v) == dict:
setattr(item, k, relation.target_model(**v))
else:
setattr(item, k, v)
db.add(item)
db.commit()
db.refresh(item)
return item
def query(self, db: Session, data: QueryParams) -> (int, List[Any]):
return query_common(db, self.model, data)
def delete(self, db: Session, data: BaseModel):
db.query(self.model).filter(getattr(self.model, self.model.model_config.id_key) == (
getattr(data, self.model.model_config.id_key))).delete()
db.commit()
return True