from typing import TypeVar, Generic, Any, List, get_args, Union from sqlalchemy.orm import Session from Utils.SqlAlchemyUtils import Base, get_db from pydantic import BaseModel from fastapi import APIRouter, Depends from pydantic.generics import GenericModel class QueryBase(BaseModel): page_size: Union[int, None] page: Union[int, None] class ItemId(BaseModel): id: Any ModelType = TypeVar("ModelType", bound=BaseModel) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) QuerySchemaType = TypeVar("QuerySchemaType", bound=QueryBase) # UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) DbModelType = TypeVar("DbModelType", bound=Base) DataType = TypeVar("DataType", BaseModel, Any) class BaseResponseModel(GenericModel, Generic[DataType]): state: int msg: str data: DataType def get_generic_type_arg(cls): t = cls.__orig_bases__[0] return get_args(t)[0] class CRUDBase(Generic[DbModelType, ModelType, CreateSchemaType, UpdateSchemaType, QuerySchemaType]): def __init__(self, db_model: DbModelType, model_type: BaseModel, create_schema_type: CreateSchemaType, update_schema_type: UpdateSchemaType, query_schema_type: QuerySchemaType, name: str, chinese_name: str = ''): self.db_model = db_model self.name = name self.model_type = model_type self.create_schema_type = create_schema_type self.update_schema_type = update_schema_type self.query_schema_type = query_schema_type self.summary_name = chinese_name if chinese_name else self.name def mount(self, router: APIRouter): for key in self.__dir__(): if key.endswith('_create'): getattr(self, key)(router) def add_route_func_create(self, router: APIRouter): create_schema_type = self.create_schema_type @router.post(f'/{self.name}/add', response_model=BaseResponseModel[self.model_type], summary=f"添加{self.summary_name}") def add_route_func(create_model_obj: create_schema_type, db: Session = Depends(get_db)): new_item = self.add(db, create_model_obj) item_info = new_item.to_dict() return BaseResponseModel(state=1, msg="", data=self.model_type(**item_info)) return add_route_func def update_route_func_create(self, router: APIRouter): update_schema_type = self.update_schema_type @router.post(f'/{self.name}/update', response_model=BaseResponseModel[Any], summary=f"修改{self.summary_name}") def update_route_func(update_model_obj: update_schema_type, db: Session = Depends(get_db)): self.update(db, update_model_obj) return BaseResponseModel(state=1, msg="", data={}) return update_route_func def delete_route_func_create(self, router: APIRouter): @router.post(f'/{self.name}/delete', response_model=BaseResponseModel[Any], summary=f"删除{self.summary_name}") def delete_route_func(item_id: ItemId, db: Session = Depends(get_db)): self.delete(db, item_id.id) return BaseResponseModel(state=1, msg="", data={}) return delete_route_func def get_route_func_create(self, router: APIRouter): @router.post(f'/{self.name}/get', response_model=BaseResponseModel[self.model_type], summary=f"获取{self.summary_name}") def get_route_func(item_id: ItemId, db: Session = Depends(get_db)): item = self.get(db, item_id.id) return BaseResponseModel(state=1, msg="", data=item.to_dict()) return get_route_func def query_route_func_create(self, router: APIRouter): model_type = self.model_type query_schema_type = self.query_schema_type class QueryRes(BaseModel, Generic[ModelType]): count: int item_list: List[model_type] @router.post(f'/{self.name}/query', response_model=BaseResponseModel[QueryRes], summary=f"查询{self.summary_name}") def query_route_func(params: query_schema_type, db: Session = Depends(get_db)): count, query = self.query(db, params) item_list = [item.to_dict() for item in query] return BaseResponseModel(state=1, msg="", data=QueryRes(count=count, item_list=item_list)) return query_route_func def add(self, db: Session, create_model_obj: CreateSchemaType) -> ModelType: new_item = self.db_model(**create_model_obj.dict()) db.add(new_item) db.commit() db.refresh(new_item) return new_item def delete(self, db: Session, item_id): if type(item_id) == list: db.query(self.db_model).filter(getattr(self.db_model, 'id').in_(item_id)).delete() else: db.query(self.db_model).filter_by(id=item_id).delete() def update(self, db: Session, update_model_obj: UpdateSchemaType): db.query(self.db_model).filter_by(id=update_model_obj.id).update(update_model_obj.dict()) db.commit() def query(self, db: Session, params: QuerySchemaType) -> [int, List[ModelType]]: params_dict = params.dict() query = db.query(self.db_model) for key, value in params_dict.items(): if key not in ['page', 'page_size'] and value is not None: if type(value) == str: query = query.filter(getattr(self.db_model, key).like(f'%{value}%')) if type(value) in [int, float]: query = query.filter_by(**{key: value}) count = query.count() if params.page is not None and params.page_size is not None: query = query.offset((params.page - 1) * params.page_size).limit(params.page).all() return count, query def get(self, db: Session, item_id) -> ModelType: return db.query(self.db_model).get(item_id)