from typing import TypeVar, Generic, Any, List, get_args, Union, Optional from sqlalchemy.orm import Session from Utils.CommonUtils import get_sqlalchemy_model_fields from Utils.SqlAlchemyUtils import Base, get_db from pydantic import BaseModel, create_model from fastapi import APIRouter, Depends from pydantic.generics import GenericModel class QueryBase(BaseModel): page_size: Union[int, None] page: Union[int, None] class ItemId(BaseModel): id: Union[str, int] ModelType = TypeVar("ModelType", bound=BaseModel) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) IdSchemaType = TypeVar("IdSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) QuerySchemaType = TypeVar("QuerySchemaType", bound=QueryBase) DbModelType = TypeVar("DbModelType", bound=Base) DataType = TypeVar("DataType", BaseModel, Any) class BaseResponseModel(GenericModel, Generic[DataType]): state: int msg: str data: DataType class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, UpdateSchemaType, QuerySchemaType]): def __init__(self, db_model: DbModelType, model_type: BaseModel, id_schema_type: IdSchemaType, create_schema_type: CreateSchemaType, update_schema_type: UpdateSchemaType, query_schema_type: QuerySchemaType, name: str, chinese_name: str = ''): self.db_model = db_model self.name = name self.model_type = model_type self.id_schema_type = id_schema_type self.create_schema_type = create_schema_type self.update_schema_type = update_schema_type self.query_schema_type = query_schema_type self.summary_name = chinese_name if chinese_name else self.name def mount(self, router: APIRouter): for key in self.__dir__(): if key.endswith('_create'): getattr(self, key)(router) def add_route_func_create(self, router: APIRouter): create_schema_type = self.create_schema_type @router.post(f'/{self.name}/add', response_model=BaseResponseModel[self.model_type], summary=f"添加{self.summary_name}") def add_route_func(create_model_obj: create_schema_type, db: Session = Depends(get_db)): new_item = self.add(db, create_model_obj) item_info = new_item.to_dict() return BaseResponseModel(state=1, msg="", data=self.model_type(**item_info)) return add_route_func def update_route_func_create(self, router: APIRouter): update_schema_type = self.update_schema_type @router.post(f'/{self.name}/update', response_model=BaseResponseModel[Any], summary=f"修改{self.summary_name}") def update_route_func(update_model_obj: update_schema_type, db: Session = Depends(get_db)): self.update(db, update_model_obj) return BaseResponseModel(state=1, msg="", data={}) return update_route_func def delete_route_func_create(self, router: APIRouter): @router.post(f'/{self.name}/delete', response_model=BaseResponseModel[Any], summary=f"删除{self.summary_name}") def delete_route_func(item_id: ItemId, db: Session = Depends(get_db)): self.delete(db, item_id.id) return BaseResponseModel(state=1, msg="", data={}) return delete_route_func def get_route_func_create(self, router: APIRouter): id_schema_type = self.id_schema_type @router.post(f'/{self.name}/get', response_model=BaseResponseModel[self.model_type], summary=f"获取{self.summary_name}") def get_route_func(item_id: id_schema_type, db: Session = Depends(get_db)): index_value = list(item_id.dict().values())[0] item = self.get(db, index_value) return BaseResponseModel(state=1, msg="", data=self.model_type(**item.to_dict())) return get_route_func def query_route_func_create(self, router: APIRouter): model_type = self.model_type query_schema_type = self.query_schema_type locals_temp = {'model_type': model_type, 'BaseModel': BaseModel, 'ModelType': ModelType, 'Generic': Generic, 'List': List} class_name = f"{self.name}QueryRes" exec(f"class {class_name}(BaseModel, Generic[ModelType]):\n count: int\n item_list: List[model_type]", locals_temp) QueryRes = locals_temp[class_name] @router.post(f'/{self.name}/query', response_model=BaseResponseModel[QueryRes], summary=f"查询{self.summary_name}") def query_route_func(params: query_schema_type, db: Session = Depends(get_db)): count, query = self.query(db, params) item_list = [item.to_dict() for item in query] return BaseResponseModel(state=1, msg="", data=QueryRes(count=count, item_list=item_list)) return query_route_func def add(self, db: Session, create_model_obj: CreateSchemaType) -> ModelType: new_item = self.db_model(**create_model_obj.dict()) db.add(new_item) db.commit() db.refresh(new_item) return new_item def delete(self, db: Session, item_id): if type(item_id) == list: db.query(self.db_model).filter(getattr(self.db_model, 'id').in_(item_id)).delete() else: db.query(self.db_model).filter_by(id=item_id).delete() def update(self, db: Session, update_model_obj: UpdateSchemaType): db.query(self.db_model).filter_by(id=update_model_obj.id).update(update_model_obj.dict()) db.commit() def query(self, db: Session, params: QuerySchemaType) -> [int]: params_dict = params.dict() query = db.query(self.db_model) for key, value in params_dict.items(): if key not in ['page', 'page_size'] and value is not None: if type(value) == str: query = query.filter(getattr(self.db_model, key).like(f'%{value}%')) if type(value) in [int, float]: query = query.filter_by(**{key: value}) count = query.count() if params.page is not None and params.page_size is not None: query = query.offset((params.page - 1) * params.page_size).limit(params.page_size).all() return count, query def get(self, db: Session, item_id: IdSchemaType) -> ModelType: return db.query(self.db_model).get(item_id) def create_crud_model(db_model: Base, name, auto_create_keys=['id'], index_key='id'): """ 创建CRUD所需模型 :param db_model: 数据库模型 :param name: 名称 :param auto_create_keys: 自动生成的keys :param index_key: 索引key :return: """ fields_dict = get_sqlalchemy_model_fields(db_model) model = create_model(f"{name}Model", **fields_dict) class config: arbitrary_types_allowed = False model_id = create_model(f'{name}IdModel', **{index_key: fields_dict[index_key]}) model_create = create_model(f"{name}Create", **{key: item for key, item in fields_dict.items() if key not in auto_create_keys}) model_update = create_model(f"{name}Update", __config__=config, **{key: item if key == index_key else ( eval(f'Optional[{item[0].__name__}]', globals()), item[1]) for key, item in fields_dict.items() if key not in auto_create_keys}) model_query = create_model(f"{name}Query", __config__=config, **{key: (eval(f'Optional[{item[0].__name__}]', globals()), item[1]) for key, item in fields_dict.items()}) return [model, model_id, model_create, model_update, model_query] def auto_create_crud(db_model: Base, name, chinese_name="", auto_create_keys=['id'], index_key='id'): """ 自动创建CRUD类 """ [model, model_id, model_create, model_update, model_query] = create_crud_model(db_model, name, auto_create_keys, index_key) class Crud(CRUDBase[db_model, model, model_id, model_create, model_update, model_query]): pass crud = Crud(db_model, model, model_id, model_create, model_update, model_query, name, chinese_name) return crud