2023-02-17 10:55:02 +08:00
|
|
|
from typing import TypeVar, Generic, Any, List, get_args, Union, Optional
|
2023-02-16 14:30:28 +08:00
|
|
|
from sqlalchemy.orm import Session
|
2023-02-17 10:55:02 +08:00
|
|
|
|
|
|
|
from Utils.CommonUtils import get_sqlalchemy_model_fields
|
2023-02-16 14:30:28 +08:00
|
|
|
from Utils.SqlAlchemyUtils import Base, get_db
|
2023-02-17 10:55:02 +08:00
|
|
|
from pydantic import BaseModel, create_model
|
2023-02-16 14:30:28 +08:00
|
|
|
from fastapi import APIRouter, Depends
|
|
|
|
from pydantic.generics import GenericModel
|
|
|
|
|
|
|
|
|
|
|
|
class QueryBase(BaseModel):
|
|
|
|
page_size: Union[int, None]
|
|
|
|
page: Union[int, None]
|
|
|
|
|
|
|
|
|
|
|
|
class ItemId(BaseModel):
|
2023-02-17 10:55:02 +08:00
|
|
|
id: Union[str, int]
|
2023-02-16 14:30:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
ModelType = TypeVar("ModelType", bound=BaseModel)
|
|
|
|
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
2023-02-17 10:55:02 +08:00
|
|
|
IdSchemaType = TypeVar("IdSchemaType", bound=BaseModel)
|
2023-02-16 14:30:28 +08:00
|
|
|
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
|
|
|
QuerySchemaType = TypeVar("QuerySchemaType", bound=QueryBase)
|
|
|
|
DbModelType = TypeVar("DbModelType", bound=Base)
|
|
|
|
DataType = TypeVar("DataType", BaseModel, Any)
|
|
|
|
|
|
|
|
|
|
|
|
class BaseResponseModel(GenericModel, Generic[DataType]):
|
|
|
|
state: int
|
|
|
|
msg: str
|
|
|
|
data: DataType
|
|
|
|
|
|
|
|
|
2023-02-17 10:55:02 +08:00
|
|
|
class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, UpdateSchemaType, QuerySchemaType]):
|
2023-02-16 14:30:28 +08:00
|
|
|
def __init__(self,
|
|
|
|
db_model: DbModelType,
|
|
|
|
model_type: BaseModel,
|
2023-02-17 10:55:02 +08:00
|
|
|
id_schema_type: IdSchemaType,
|
2023-02-16 14:30:28 +08:00
|
|
|
create_schema_type: CreateSchemaType,
|
|
|
|
update_schema_type: UpdateSchemaType,
|
|
|
|
query_schema_type: QuerySchemaType, name: str, chinese_name: str = ''):
|
|
|
|
self.db_model = db_model
|
|
|
|
self.name = name
|
|
|
|
self.model_type = model_type
|
2023-02-17 10:55:02 +08:00
|
|
|
self.id_schema_type = id_schema_type
|
2023-02-16 14:30:28 +08:00
|
|
|
self.create_schema_type = create_schema_type
|
|
|
|
self.update_schema_type = update_schema_type
|
|
|
|
self.query_schema_type = query_schema_type
|
|
|
|
self.summary_name = chinese_name if chinese_name else self.name
|
|
|
|
|
|
|
|
def mount(self, router: APIRouter):
|
|
|
|
for key in self.__dir__():
|
|
|
|
if key.endswith('_create'):
|
|
|
|
getattr(self, key)(router)
|
|
|
|
|
|
|
|
def add_route_func_create(self, router: APIRouter):
|
|
|
|
create_schema_type = self.create_schema_type
|
|
|
|
|
|
|
|
@router.post(f'/{self.name}/add',
|
|
|
|
response_model=BaseResponseModel[self.model_type],
|
|
|
|
summary=f"添加{self.summary_name}")
|
|
|
|
def add_route_func(create_model_obj: create_schema_type, db: Session = Depends(get_db)):
|
|
|
|
new_item = self.add(db, create_model_obj)
|
|
|
|
item_info = new_item.to_dict()
|
|
|
|
return BaseResponseModel(state=1, msg="", data=self.model_type(**item_info))
|
|
|
|
|
|
|
|
return add_route_func
|
|
|
|
|
|
|
|
def update_route_func_create(self, router: APIRouter):
|
|
|
|
update_schema_type = self.update_schema_type
|
|
|
|
|
|
|
|
@router.post(f'/{self.name}/update',
|
|
|
|
response_model=BaseResponseModel[Any],
|
|
|
|
summary=f"修改{self.summary_name}")
|
|
|
|
def update_route_func(update_model_obj: update_schema_type, db: Session = Depends(get_db)):
|
|
|
|
self.update(db, update_model_obj)
|
|
|
|
return BaseResponseModel(state=1, msg="", data={})
|
|
|
|
|
|
|
|
return update_route_func
|
|
|
|
|
|
|
|
def delete_route_func_create(self, router: APIRouter):
|
|
|
|
@router.post(f'/{self.name}/delete',
|
|
|
|
response_model=BaseResponseModel[Any],
|
|
|
|
summary=f"删除{self.summary_name}")
|
|
|
|
def delete_route_func(item_id: ItemId, db: Session = Depends(get_db)):
|
|
|
|
self.delete(db, item_id.id)
|
|
|
|
return BaseResponseModel(state=1, msg="", data={})
|
|
|
|
|
|
|
|
return delete_route_func
|
|
|
|
|
|
|
|
def get_route_func_create(self, router: APIRouter):
|
2023-02-17 10:55:02 +08:00
|
|
|
id_schema_type = self.id_schema_type
|
|
|
|
|
2023-02-16 14:30:28 +08:00
|
|
|
@router.post(f'/{self.name}/get',
|
|
|
|
response_model=BaseResponseModel[self.model_type],
|
|
|
|
summary=f"获取{self.summary_name}")
|
2023-02-17 10:55:02 +08:00
|
|
|
def get_route_func(item_id: id_schema_type, db: Session = Depends(get_db)):
|
|
|
|
index_value = list(item_id.dict().values())[0]
|
|
|
|
item = self.get(db, index_value)
|
|
|
|
return BaseResponseModel(state=1, msg="", data=self.model_type(**item.to_dict()))
|
2023-02-16 14:30:28 +08:00
|
|
|
|
|
|
|
return get_route_func
|
|
|
|
|
|
|
|
def query_route_func_create(self, router: APIRouter):
|
|
|
|
model_type = self.model_type
|
|
|
|
query_schema_type = self.query_schema_type
|
2023-02-17 10:55:02 +08:00
|
|
|
locals_temp = {'model_type': model_type, 'BaseModel': BaseModel, 'ModelType': ModelType, 'Generic': Generic,
|
|
|
|
'List': List}
|
|
|
|
class_name = f"{self.name}QueryRes"
|
|
|
|
exec(f"class {class_name}(BaseModel, Generic[ModelType]):\n count: int\n item_list: List[model_type]",
|
|
|
|
locals_temp)
|
|
|
|
QueryRes = locals_temp[class_name]
|
2023-02-16 14:30:28 +08:00
|
|
|
|
|
|
|
@router.post(f'/{self.name}/query',
|
|
|
|
response_model=BaseResponseModel[QueryRes],
|
|
|
|
summary=f"查询{self.summary_name}")
|
|
|
|
def query_route_func(params: query_schema_type, db: Session = Depends(get_db)):
|
|
|
|
count, query = self.query(db, params)
|
|
|
|
item_list = [item.to_dict() for item in query]
|
|
|
|
return BaseResponseModel(state=1, msg="", data=QueryRes(count=count, item_list=item_list))
|
|
|
|
|
|
|
|
return query_route_func
|
|
|
|
|
|
|
|
def add(self, db: Session, create_model_obj: CreateSchemaType) -> ModelType:
|
|
|
|
new_item = self.db_model(**create_model_obj.dict())
|
|
|
|
db.add(new_item)
|
|
|
|
db.commit()
|
|
|
|
db.refresh(new_item)
|
|
|
|
return new_item
|
|
|
|
|
|
|
|
def delete(self, db: Session, item_id):
|
|
|
|
if type(item_id) == list:
|
|
|
|
db.query(self.db_model).filter(getattr(self.db_model, 'id').in_(item_id)).delete()
|
|
|
|
else:
|
|
|
|
db.query(self.db_model).filter_by(id=item_id).delete()
|
|
|
|
|
|
|
|
def update(self, db: Session, update_model_obj: UpdateSchemaType):
|
|
|
|
db.query(self.db_model).filter_by(id=update_model_obj.id).update(update_model_obj.dict())
|
|
|
|
db.commit()
|
|
|
|
|
|
|
|
def query(self, db: Session, params: QuerySchemaType) -> [int, List[ModelType]]:
|
|
|
|
params_dict = params.dict()
|
|
|
|
query = db.query(self.db_model)
|
|
|
|
for key, value in params_dict.items():
|
|
|
|
if key not in ['page', 'page_size'] and value is not None:
|
|
|
|
if type(value) == str:
|
|
|
|
query = query.filter(getattr(self.db_model, key).like(f'%{value}%'))
|
|
|
|
if type(value) in [int, float]:
|
|
|
|
query = query.filter_by(**{key: value})
|
|
|
|
count = query.count()
|
|
|
|
if params.page is not None and params.page_size is not None:
|
2023-02-17 10:55:02 +08:00
|
|
|
query = query.offset((params.page - 1) * params.page_size).limit(params.page_size).all()
|
2023-02-16 14:30:28 +08:00
|
|
|
return count, query
|
|
|
|
|
2023-02-17 10:55:02 +08:00
|
|
|
def get(self, db: Session, item_id: IdSchemaType) -> ModelType:
|
2023-02-16 14:30:28 +08:00
|
|
|
return db.query(self.db_model).get(item_id)
|
2023-02-17 10:55:02 +08:00
|
|
|
|
|
|
|
|
|
|
|
def create_crud_model(db_model: Base, name, auto_create_keys=['id'], index_key='id'):
|
|
|
|
"""
|
|
|
|
创建CRUD所需模型
|
|
|
|
:param db_model: 数据库模型
|
|
|
|
:param name: 名称
|
|
|
|
:param auto_create_keys: 自动生成的keys
|
|
|
|
:param index_key: 索引key
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
fields_dict = get_sqlalchemy_model_fields(db_model)
|
|
|
|
model = create_model(f"{name}Model", **fields_dict)
|
|
|
|
|
|
|
|
class config:
|
|
|
|
arbitrary_types_allowed = False
|
|
|
|
|
|
|
|
model_id = create_model(f'{name}IdModel', **{index_key: fields_dict[index_key]})
|
|
|
|
model_create = create_model(f"{name}Create",
|
|
|
|
**{key: item for key, item in fields_dict.items() if key not in auto_create_keys})
|
|
|
|
model_update = create_model(f"{name}Update", __config__=config,
|
|
|
|
**{key: item if key == index_key else (
|
|
|
|
eval(f'Optional[{item[0].__name__}]', globals()), item[1]) for key, item in
|
|
|
|
fields_dict.items() if key not in auto_create_keys})
|
|
|
|
model_query = create_model(f"{name}Query", __config__=config,
|
|
|
|
**{key: (eval(f'Optional[{item[0].__name__}]', globals()), item[1]) for key, item in
|
|
|
|
fields_dict.items()})
|
|
|
|
return [model, model_id, model_create, model_update, model_query]
|
|
|
|
|
|
|
|
|
|
|
|
def auto_create_crud(db_model: Base, name, chinese_name="", auto_create_keys=['id'], index_key='id'):
|
|
|
|
"""
|
|
|
|
自动创建CRUD类
|
|
|
|
:param db_model:
|
|
|
|
:param name:
|
|
|
|
:param chinese_name:
|
|
|
|
:param auto_create_keys:
|
|
|
|
:param index_key:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
[model, model_id, model_create, model_update, model_query] = create_crud_model(db_model, name,
|
|
|
|
auto_create_keys, index_key)
|
|
|
|
|
|
|
|
class Crud(CRUDBase[db_model, model, model_id, model_create, model_update, model_query]):
|
|
|
|
pass
|
|
|
|
|
|
|
|
crud = Crud(db_model, model, model_id, model_create, model_update, model_query, name, chinese_name)
|
|
|
|
return crud
|