from typing import TypeVar, Generic, Any, List, get_args, Union, Optional, Tuple from sqlalchemy import func from sqlalchemy.orm import Session from datetime import datetime, date from Schemas.AuthSchemas import DefaultAuthConfigTypeEnum from Utils.CommonUtils import get_sqlalchemy_model_fields from Utils.SqlAlchemyUtils import Base, get_db from pydantic import BaseModel, create_model from fastapi import APIRouter, Depends, HTTPException from pydantic.generics import GenericModel from Schemas.DailySchemas import DailyTypeEnum from Schemas.UserSchemas import DepartmentTypeEnum class QueryBase(BaseModel): page_size: Union[int, None] page: Union[int, None] class ItemId(BaseModel): id: Union[str, int] ModelType = TypeVar("ModelType", bound=BaseModel) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) IdSchemaType = TypeVar("IdSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) QuerySchemaType = TypeVar("QuerySchemaType", bound=QueryBase) DbModelType = TypeVar("DbModelType", bound=Base) DataType = TypeVar("DataType", BaseModel, Any) class BaseResponseModel(GenericModel, Generic[DataType]): state: int msg: str data: DataType class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, UpdateSchemaType, QuerySchemaType]): def __init__(self, db_model: DbModelType, model_type: BaseModel, id_schema_type: IdSchemaType, create_schema_type: CreateSchemaType, update_schema_type: UpdateSchemaType, query_schema_type: QuerySchemaType, name: str, chinese_name: str = '', array_keys=[], query_dependencies=[], tags=[] ): self.db_model = db_model self.name = name self.model_type = model_type self.id_schema_type = id_schema_type self.create_schema_type = create_schema_type self.update_schema_type = update_schema_type self.query_schema_type = query_schema_type self.summary_name = chinese_name if chinese_name else self.name self.query_func = None self.array_keys = array_keys self.query_dependencies = query_dependencies self.tags = tags def mount(self, router: APIRouter): for key in self.__dir__(): if key.endswith('_create'): getattr(self, key)(router) def add_route_func_create(self, router: APIRouter): create_schema_type = self.create_schema_type @router.post(f'/{self.name}/add', response_model=BaseResponseModel[self.model_type], tags=self.tags, summary=f"添加{self.summary_name}") def add_route_func(create_model_obj: create_schema_type, db: Session = Depends(get_db)): new_item = self.add(db, create_model_obj) item_info = new_item.to_dict() return BaseResponseModel(state=1, msg="", data=self.model_type(**item_info)) return add_route_func def update_route_func_create(self, router: APIRouter): update_schema_type = self.update_schema_type @router.post(f'/{self.name}/update', tags=self.tags, response_model=BaseResponseModel[Any], summary=f"修改{self.summary_name}") def update_route_func(update_model_obj: update_schema_type, db: Session = Depends(get_db)): self.update(db, update_model_obj) return BaseResponseModel(state=1, msg="", data={}) return update_route_func def delete_route_func_create(self, router: APIRouter): @router.post(f'/{self.name}/delete', tags=self.tags, response_model=BaseResponseModel[Any], summary=f"删除{self.summary_name}") def delete_route_func(item_id: ItemId, db: Session = Depends(get_db)): self.delete(db, item_id.id) return BaseResponseModel(state=1, msg="", data={}) return delete_route_func def get_route_func_create(self, router: APIRouter): id_schema_type = self.id_schema_type @router.post(f'/{self.name}/get', tags=self.tags, response_model=BaseResponseModel[self.model_type], summary=f"获取{self.summary_name}") def get_route_func(item_id: id_schema_type, db: Session = Depends(get_db)): index_value = list(item_id.dict().values())[0] item = self.get(db, index_value) if not item: raise HTTPException(detail="未查询到数据", status_code=303) return BaseResponseModel(state=1, msg="", data=self.model_type(**item.to_dict())) return get_route_func def query_route_func_create(self, router: APIRouter): model_type = self.model_type query_schema_type = self.query_schema_type locals_temp = {'model_type': model_type, 'BaseModel': BaseModel, 'ModelType': ModelType, 'Generic': Generic, 'List': List} class_name = f"{self.name}QueryRes" exec(f"class {class_name}(BaseModel, Generic[ModelType]):\n count: int\n item_list: List[model_type]", locals_temp) QueryRes = locals_temp[class_name] @router.post(f'/{self.name}/query', response_model=BaseResponseModel[QueryRes], tags=self.tags, summary=f"查询{self.summary_name}", dependencies=self.query_dependencies) def query_route_func(params: query_schema_type, db: Session = Depends(get_db)): count, query = self.query(db, params) item_list = [item.to_dict() for item in query] return BaseResponseModel(state=1, msg="", data=QueryRes(count=count, item_list=item_list)) return query_route_func def add(self, db: Session, create_model_obj: CreateSchemaType) -> ModelType: new_item = self.db_model(**create_model_obj.dict()) db.add(new_item) db.commit() db.refresh(new_item) return new_item def delete(self, db: Session, item_id): if type(item_id) == list: db.query(self.db_model).filter(getattr(self.db_model, 'id').in_(item_id)).delete() else: db.query(self.db_model).filter_by(id=item_id).delete() db.commit() def update(self, db: Session, update_model_obj: UpdateSchemaType): db.query(self.db_model).filter_by(id=update_model_obj.id).update(update_model_obj.dict()) db.commit() def query(self, db: Session, params: QuerySchemaType) -> [int]: params_dict = params.dict() query = db.query(self.db_model) for key, value in params_dict.items(): if key not in ['page', 'page_size'] and value is not None: # 在存储的数组值内查询 如存的 1,2,3,4 查询时则使用的 [1,2]这样的数据来查 if key in self.array_keys: for item in value: query = query.filter(func.find_in_set(str(item), getattr(self.db_model, key))) continue # 如果执行query_func后有返回值即使用了key,value则跳过后面操作 if self.query_func: query_temp = self.query_func(query, key, value) if query_temp: query = query_temp continue if type(value) == str: query = query.filter(getattr(self.db_model, key).like(f'%{value}%')) if type(value) in [int, float, bool]: query = query.filter_by(**{key: value}) # 日期范围查询 if type(value) in [list, tuple]: if value[0] is not None: query = query.filter(getattr(self.db_model, key) >= datetime.fromtimestamp(value[0])) if value[1] is not None: query = query.filter(getattr(self.db_model, key) <= datetime.fromtimestamp(value[1])) count = query.count() page = None page_size = None if 'page' in params_dict: page = params_dict['page'] if 'page_size' in params_dict: page_size = params_dict['page_size'] if page is not None and page_size is not None: query = query.offset((page - 1) * page_size).limit(page_size).all() return count, query def get(self, db: Session, item_id: IdSchemaType) -> ModelType: return db.query(self.db_model).get(item_id) def create_crud_model(db_model: Base, name, auto_create_keys=['id'], index_key='id', array_keys=[]): """ 创建CRUD所需模型 :param db_model: 数据库模型 :param name: 名称 :param auto_create_keys: 自动生成的keys :param index_key: 索引key :param array_keys: 用id数组存储数据的key的列表 :return: """ fields_dict = get_sqlalchemy_model_fields(db_model) model = create_model(f"{name}Model", **fields_dict) class config: arbitrary_types_allowed = False model_id = create_model(f'{name}IdModel', **{index_key: fields_dict[index_key]}) model_create = create_model(f"{name}Create", **{key: item for key, item in fields_dict.items() if key not in auto_create_keys}) model_update = create_model(f"{name}Update", __config__=config, **{key: item if key == index_key else ( eval(f'Optional[{item[0].__name__}]', globals()), item[1]) for key, item in fields_dict.items() if key not in auto_create_keys}) query_fields_dic = {} for key, item in fields_dict.items(): if key in array_keys: query_fields_dic[key] = ( eval(f'Optional[List[Union[int,str]]]', globals(), locals()), None) else: if item[0].__name__ == 'datetime': query_fields_dic[key] = ( eval(f'Optional[List[Optional[int]]]', globals(), locals()), None) else: query_fields_dic[key] = (eval(f'Optional[{item[0].__name__}]', globals()), None) model_query = create_model(f"{name}Query", __config__=config, **query_fields_dic) return [model, model_id, model_create, model_update, model_query] def auto_create_crud(db_model: Base, name, chinese_name="", auto_create_keys=['id'], index_key='id', array_keys=[], query_dependencies=[], tags=[]): """ 自动创建CRUD类 :param query_dependencies: :param db_model: :param name: :param chinese_name: :param auto_create_keys: 自动生成的字段key :param index_key: :param array_keys: 由id数组组成的字符串key列表 :param tags: :return: """ [model, model_id, model_create, model_update, model_query] = create_crud_model(db_model, name, auto_create_keys, index_key, array_keys=array_keys) class Crud(CRUDBase[db_model, model, model_id, model_create, model_update, model_query]): pass crud = Crud(db_model, model, model_id, model_create, model_update, model_query, name, chinese_name, array_keys=array_keys, query_dependencies=query_dependencies, tags=tags) return crud