import json from typing import Type, List, Any, Dict from pydantic import BaseModel from sqlalchemy import Column, asc, desc, func, String, and_ from sqlalchemy.orm import DeclarativeMeta, Session, Query from .model_config_utils import SalBase, get_model_col_names from .types.crud_schema import QueryParams, OrderParam, QueryParam from .types.model_config import ModelConfig, TypeInfo class TableModel(DeclarativeMeta, SalBase): pass def query_common(db: Session, model, query_params: QueryParams): """ 通用数据库查询接口 """ query = db.query(model) # if query_params.ex_include: # query = query.options(Load(model).load_only(*query_params.include)) # if query_params.include: # query = query.options(Load(model).noload(*query_params.ex_include)) page = query_params.page page_size = query_params.page_size # 排序 if query_params.order: query = query_order(model, query, query_params.order) if query_params.query: query = query_common_query_core(model, query, query_params.query) # 筛选 if query_params.params: query = query_common_params_core(model, query, query_params.params) count = query.count() return count, query, page, page_size def query_order(model, query: Query, order_param: OrderParam): """ 查询结果排序 """ if order_param: if order_param.order_by: column: Column = getattr(model, order_param.order_by) if order_param.type == 'asc': query = query.order_by(asc(column)) if order_param.type == 'desc': query = query.order_by(desc(column)) return query def query_common_params_core(model, query: Query, params: List[QueryParam]): cols: Dict[str, TypeInfo] = model.model_config.cols for item in params: name = item.name child_name = "" if '.' in name: [name, child_name] = name.split('.') query_type = item.type value = item.value col = cols[name] col_base_type = col.col_base_type column: Column = getattr(model, name) if col_base_type == 'relation' and col.relation: if col.relation.relation_type in ['m2o', 'o2o']: if col.relation.source_foreign_key: query = query.join(col.relation.target_model, getattr(col.relation.target_model, col.relation.target_id_key) == getattr(model, col.relation.source_foreign_key)) else: query = query.join(col.relation.target_model) column = getattr(col.relation.target_model, child_name) else: if col_base_type == 'json': column = column[child_name] if query_type in ['=', '==']: if col_base_type == 'relation' and col.relation and not child_name: relation_type = col.relation.relation_type secondary = col.relation.secondary target_secondary_key = col.relation.target_secondary_key source_secondary_key = col.relation.source_secondary_key if relation_type == 'm2m': query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \ .filter(secondary.columns.get(target_secondary_key) == value) else: if col.col_base_type == 'datetime': query = query.filter(func.date(column) == value) elif col_base_type == 'json': query = query.filter(column == json.dumps(value)) else: query = query.filter(column == value) elif query_type == ">": query = query.filter(column > value) elif query_type == "<": query = query.filter(column < value) elif query_type == ">=": query = query.filter(column >= value) elif query_type == "<=": query = query.filter(column <= value) elif query_type == "in": if col_base_type == 'relation' and col.relation: relation_type = col.relation.relation_type if relation_type == 'm2m': secondary = col.relation.secondary target_secondary_key = col.relation.target_secondary_key source_secondary_key = col.relation.source_secondary_key if value and type(value) == list: query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \ .filter(secondary.columns.get(target_secondary_key).in_(value)) else: query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \ .filter(secondary.columns.get(target_secondary_key) == value) else: query = query.filter(column.in_(value)) elif query_type == "like": query = query.filter(column.cast(String).like(f'%{value}%')) elif query_type == "find_in_set": query = query.filter(func.find_in_set(value, column)) elif query_type == "range": if value: if len(value) == 1 and value[0] is not None: query = query.filter(column >= value[0]) if len(value) == 2: if value[0] is not None and value[1] is None: query = query.filter(column >= value[0]) if value[0] is None and value[1] is not None: query = query.filter(column <= value[1]) if value[0] is not None and value[1] is not None: query = query.filter(and_(column >= value[0], column <= value[1])) return query def query_common_query_core(model, query: Query, query_data: BaseModel): q_data = query_data.dict(exclude_unset=True) cols: Dict[str, TypeInfo] = model.model_config.cols for key, value in q_data: annotation = query_data.__annotations__[key] if key in cols: column: Column = getattr(model, key) col = cols[key] col_base_type = col.col_base_type if col_base_type in ['int', 'float', 'any', 'enum', 'str']: if type(value) == list and value: query = query.filter(column.in_(value)) else: if col_base_type == 'str' and value: query = query.filter(column.like(f'%{value}%')) else: query = query.filter(column == value) elif col_base_type in ['date', 'datetime', 'time']: if value: if type(value) == list: if len(value) == 2: if value[0] is not None: query = query.filter(column >= value) if value[1] is not None: query = query.filter(column <= value) else: if col_base_type == 'datetime': # datetime转date来查 query = query.filter(func.date(column) == value) else: query = query.filter(column == value) elif col_base_type == 'relation' and col.relation: relation_type = col.relation.relation_type if relation_type == 'm2m': secondary = col.relation.secondary target_secondary_key = col.relation.target_secondary_key source_secondary_key = col.relation.source_secondary_key if value and type(value) == list: query = query.join(secondary, model.id == getattr(secondary, source_secondary_key)) \ .filter(getattr(secondary, target_secondary_key).in_(value)) else: query = query.join(secondary, model.id == getattr(secondary, source_secondary_key)) \ .filter(getattr(secondary, target_secondary_key) == value) elif relation_type == 'o2m': target_model = col.relation.target_model target_id_key = col.relation.target_id_key if value and type(value) == list: query = query.join(target_model).filter(getattr(target_model, target_id_key).in_(value)) else: query = query.join(target_model).filter(getattr(target_model, target_id_key == value)) return query class ModelCRUD: def __init__(self, model: Type[TableModel]): self.model = model def update(self, db: Session, data: BaseModel): model = self.model id_key = self.model.model_config.id_key item = db.query(model).filter(getattr(model, id_key) == (getattr(data, id_key))).first() data = data.dict(exclude_unset=True) col_names = get_model_col_names(model) model_config: ModelConfig = model.model_config for k, v in data.items(): if '.' in k: paths = k.split('.') if paths[0] in col_names: col_type_info: TypeInfo = model.model_config.cols[paths[0]] if col_type_info.relation: # 一对一时,{"config.age":1} relation = col_type_info.relation if relation.relation_type in ["o2o"]: col = getattr(item, k) if not col: if v is not None: setattr(item, paths[0], relation.target_model(**{paths[1]: v})) else: setattr(col, paths[1], v) if k in col_names: col_type_info: TypeInfo = model.model_config.cols[k] # col = getattr(model, k) relation = col_type_info.relation if relation: if relation.relation_type in ['m2m']: if v is not None: if v and type(v[0]) == dict: v = [item[model_config.id_key] for item in v] # 多对多更新的值是一个id数组,所以先用id查出对应项,再更新 target_model = relation.target_model target_key = relation.target_id_key update_values = {child for child in db.query(target_model).filter(getattr(target_model, target_key).in_(v))} setattr(item, k, update_values) elif relation.relation_type == 'o2o': if type(v) == dict: col = getattr(item, k) if not col: setattr(item, k, relation.target_model(**v)) else: for v_k, v_v in v.items(): setattr(col, v_k, v_v) else: setattr(item, k, v) db.commit() db.refresh(item) return item def get(self, db: Session, data: BaseModel): item = db.query(self.model).filter(getattr(self.model, self.model.model_config.id_key) == ( getattr(data, self.model.model_config.id_key))).first() return item def add(self, db: Session, data: BaseModel): model = self.model item = model() data = data.dict(exclude_unset=True) col_names = get_model_col_names(model) for k, v in data.items(): if k in col_names: col_type_info: TypeInfo = model.model_config.cols[k] # col = getattr(model, k) relation = col_type_info.relation if relation: if relation.relation_type == 'm2m': if v is not None: target_model = relation.target_model update_values = {child for child in db.query(target_model).filter( getattr(target_model, relation.target_id_key).in_(v))} setattr(item, k, update_values) elif relation.relation_type == 'o2o': if type(v) == dict: setattr(item, k, relation.target_model(**v)) else: setattr(item, k, v) db.add(item) db.commit() db.refresh(item) return item def query(self, db: Session, data: QueryParams) -> (int, List[Any]): return query_common(db, self.model, data) def delete(self, db: Session, data: BaseModel): db.query(self.model).filter(getattr(self.model, self.model.model_config.id_key) == ( getattr(data, self.model.model_config.id_key))).delete() db.commit() return True