usermod/Utils/CrudUtils.py

156 lines
6.0 KiB
Python
Raw Normal View History

2023-02-16 14:30:28 +08:00
from typing import TypeVar, Generic, Any, List, get_args, Union
from sqlalchemy.orm import Session
from Utils.SqlAlchemyUtils import Base, get_db
from pydantic import BaseModel
from fastapi import APIRouter, Depends
from pydantic.generics import GenericModel
class QueryBase(BaseModel):
page_size: Union[int, None]
page: Union[int, None]
class ItemId(BaseModel):
id: Any
ModelType = TypeVar("ModelType", bound=BaseModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
QuerySchemaType = TypeVar("QuerySchemaType", bound=QueryBase)
# UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
DbModelType = TypeVar("DbModelType", bound=Base)
DataType = TypeVar("DataType", BaseModel, Any)
class BaseResponseModel(GenericModel, Generic[DataType]):
state: int
msg: str
data: DataType
def get_generic_type_arg(cls):
t = cls.__orig_bases__[0]
return get_args(t)[0]
class CRUDBase(Generic[DbModelType, ModelType, CreateSchemaType, UpdateSchemaType, QuerySchemaType]):
def __init__(self,
db_model: DbModelType,
model_type: BaseModel,
create_schema_type: CreateSchemaType,
update_schema_type: UpdateSchemaType,
query_schema_type: QuerySchemaType, name: str, chinese_name: str = ''):
self.db_model = db_model
self.name = name
self.model_type = model_type
self.create_schema_type = create_schema_type
self.update_schema_type = update_schema_type
self.query_schema_type = query_schema_type
self.summary_name = chinese_name if chinese_name else self.name
def mount(self, router: APIRouter):
for key in self.__dir__():
if key.endswith('_create'):
getattr(self, key)(router)
def add_route_func_create(self, router: APIRouter):
create_schema_type = self.create_schema_type
@router.post(f'/{self.name}/add',
response_model=BaseResponseModel[self.model_type],
summary=f"添加{self.summary_name}")
def add_route_func(create_model_obj: create_schema_type, db: Session = Depends(get_db)):
new_item = self.add(db, create_model_obj)
item_info = new_item.to_dict()
return BaseResponseModel(state=1, msg="", data=self.model_type(**item_info))
return add_route_func
def update_route_func_create(self, router: APIRouter):
update_schema_type = self.update_schema_type
@router.post(f'/{self.name}/update',
response_model=BaseResponseModel[Any],
summary=f"修改{self.summary_name}")
def update_route_func(update_model_obj: update_schema_type, db: Session = Depends(get_db)):
self.update(db, update_model_obj)
return BaseResponseModel(state=1, msg="", data={})
return update_route_func
def delete_route_func_create(self, router: APIRouter):
@router.post(f'/{self.name}/delete',
response_model=BaseResponseModel[Any],
summary=f"删除{self.summary_name}")
def delete_route_func(item_id: ItemId, db: Session = Depends(get_db)):
self.delete(db, item_id.id)
return BaseResponseModel(state=1, msg="", data={})
return delete_route_func
def get_route_func_create(self, router: APIRouter):
@router.post(f'/{self.name}/get',
response_model=BaseResponseModel[self.model_type],
summary=f"获取{self.summary_name}")
def get_route_func(item_id: ItemId, db: Session = Depends(get_db)):
item = self.get(db, item_id.id)
return BaseResponseModel(state=1, msg="", data=item.to_dict())
return get_route_func
def query_route_func_create(self, router: APIRouter):
model_type = self.model_type
query_schema_type = self.query_schema_type
class QueryRes(BaseModel, Generic[ModelType]):
count: int
item_list: List[model_type]
@router.post(f'/{self.name}/query',
response_model=BaseResponseModel[QueryRes],
summary=f"查询{self.summary_name}")
def query_route_func(params: query_schema_type, db: Session = Depends(get_db)):
count, query = self.query(db, params)
item_list = [item.to_dict() for item in query]
return BaseResponseModel(state=1, msg="", data=QueryRes(count=count, item_list=item_list))
return query_route_func
def add(self, db: Session, create_model_obj: CreateSchemaType) -> ModelType:
new_item = self.db_model(**create_model_obj.dict())
db.add(new_item)
db.commit()
db.refresh(new_item)
return new_item
def delete(self, db: Session, item_id):
if type(item_id) == list:
db.query(self.db_model).filter(getattr(self.db_model, 'id').in_(item_id)).delete()
else:
db.query(self.db_model).filter_by(id=item_id).delete()
def update(self, db: Session, update_model_obj: UpdateSchemaType):
db.query(self.db_model).filter_by(id=update_model_obj.id).update(update_model_obj.dict())
db.commit()
def query(self, db: Session, params: QuerySchemaType) -> [int, List[ModelType]]:
params_dict = params.dict()
query = db.query(self.db_model)
for key, value in params_dict.items():
if key not in ['page', 'page_size'] and value is not None:
if type(value) == str:
query = query.filter(getattr(self.db_model, key).like(f'%{value}%'))
if type(value) in [int, float]:
query = query.filter_by(**{key: value})
count = query.count()
if params.page is not None and params.page_size is not None:
query = query.offset((params.page - 1) * params.page_size).limit(params.page).all()
return count, query
def get(self, db: Session, item_id) -> ModelType:
return db.query(self.db_model).get(item_id)