diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..359bb53 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml diff --git a/CrudModel/ItemCrudModel.py b/CrudModel/ItemCrudModel.py new file mode 100644 index 0000000..2899b35 --- /dev/null +++ b/CrudModel/ItemCrudModel.py @@ -0,0 +1,20 @@ +from sqlalchemy import Column, String, Float, Integer, DateTime, func +from Utils.CrudUtils import auto_create_crud +from Utils.SqlAlchemyUtils import Base + + +class ItemModel(Base): + __tablename__ = 'items' + id = Column(Integer, primary_key=True, index=True) + thickness = Column(Float) + mass = Column(Float) + color = Column(String(32)) + type = Column(String(32)) + create_time = Column(DateTime, server_default=func.now(), comment='创建时间') + update_time = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment='修改时间') + + def to_dict(self): + return {c.name: getattr(self, c.name) for c in self.__table__.columns} + + +item_crud = auto_create_crud(ItemModel, 'item', '条目') diff --git a/Router/ProjectRouter.py b/Router/ProjectRouter.py new file mode 100644 index 0000000..4c6d480 --- /dev/null +++ b/Router/ProjectRouter.py @@ -0,0 +1,6 @@ +from fastapi import APIRouter + +router = APIRouter( + tags=["项目管理"], + prefix="/api/project", +) diff --git a/Utils/CommonUtils.py b/Utils/CommonUtils.py new file mode 100644 index 0000000..c1b8960 --- /dev/null +++ b/Utils/CommonUtils.py @@ -0,0 +1,69 @@ +from typing import Container, Optional, Type + +from pydantic import BaseConfig, BaseModel, create_model +from sqlalchemy.inspection import inspect +from sqlalchemy.orm.properties import ColumnProperty +from hashlib import md5 + + +# 文件md +def file_md5(body): + md = md5() + md.update(body) + return md.hexdigest() + + + +class OrmConfig(BaseConfig): + orm_mode = True + + +def sqlalchemy_to_pydantic( + db_model: Type, *, config: Type = OrmConfig, exclude: Container[str] = [] +) -> Type[BaseModel]: + mapper = inspect(db_model) + fields = {} + for attr in mapper.attrs: + if isinstance(attr, ColumnProperty): + if attr.columns: + name = attr.key + if name in exclude: + continue + column = attr.columns[0] + python_type: Optional[type] = None + if hasattr(column.type, "impl"): + if hasattr(column.type.impl, "python_type"): + python_type = column.type.impl.python_type + elif hasattr(column.type, "python_type"): + python_type = column.type.python_type + assert python_type, f"Could not infer python_type for {column}" + default = None + if column.default is None and not column.nullable: + default = ... + fields[name] = (python_type, default) + pydantic_model = create_model( + db_model.__name__, __config__=config, **fields # type: ignore + ) + return pydantic_model + + +def get_sqlalchemy_model_fields(db_model): + mapper = inspect(db_model) + fields = {} + for attr in mapper.attrs: + if isinstance(attr, ColumnProperty): + if attr.columns: + name = attr.key + column = attr.columns[0] + python_type: Optional[type] = None + if hasattr(column.type, "impl"): + if hasattr(column.type.impl, "python_type"): + python_type = column.type.impl.python_type + elif hasattr(column.type, "python_type"): + python_type = column.type.python_type + assert python_type, f"Could not infer python_type for {column}" + default = None + if column.default is None and not column.nullable: + default = ... + fields[name] = (python_type, default) + return fields diff --git a/Utils/CrudUtils.py b/Utils/CrudUtils.py new file mode 100644 index 0000000..4dae233 --- /dev/null +++ b/Utils/CrudUtils.py @@ -0,0 +1,219 @@ +from typing import TypeVar, Generic, Any, List, get_args, Union, Optional, Tuple +from sqlalchemy.orm import Session +from datetime import datetime +from Utils.CommonUtils import get_sqlalchemy_model_fields +from Utils.SqlAlchemyUtils import Base, get_db +from pydantic import BaseModel, create_model +from fastapi import APIRouter, Depends +from pydantic.generics import GenericModel + + +class QueryBase(BaseModel): + page_size: Union[int, None] + page: Union[int, None] + + +class ItemId(BaseModel): + id: Union[str, int] + + +ModelType = TypeVar("ModelType", bound=BaseModel) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +IdSchemaType = TypeVar("IdSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) +QuerySchemaType = TypeVar("QuerySchemaType", bound=QueryBase) +DbModelType = TypeVar("DbModelType", bound=Base) +DataType = TypeVar("DataType", BaseModel, Any) + + +class BaseResponseModel(GenericModel, Generic[DataType]): + state: int + msg: str + data: DataType + + +class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, UpdateSchemaType, QuerySchemaType]): + def __init__(self, + db_model: DbModelType, + model_type: BaseModel, + id_schema_type: IdSchemaType, + create_schema_type: CreateSchemaType, + update_schema_type: UpdateSchemaType, + query_schema_type: QuerySchemaType, name: str, chinese_name: str = ''): + self.db_model = db_model + self.name = name + self.model_type = model_type + self.id_schema_type = id_schema_type + self.create_schema_type = create_schema_type + self.update_schema_type = update_schema_type + self.query_schema_type = query_schema_type + self.summary_name = chinese_name if chinese_name else self.name + + def mount(self, router: APIRouter): + for key in self.__dir__(): + if key.endswith('_create'): + getattr(self, key)(router) + + def add_route_func_create(self, router: APIRouter): + create_schema_type = self.create_schema_type + + @router.post(f'/{self.name}/add', + response_model=BaseResponseModel[self.model_type], + summary=f"添加{self.summary_name}") + def add_route_func(create_model_obj: create_schema_type, db: Session = Depends(get_db)): + new_item = self.add(db, create_model_obj) + item_info = new_item.to_dict() + return BaseResponseModel(state=1, msg="", data=self.model_type(**item_info)) + + return add_route_func + + def update_route_func_create(self, router: APIRouter): + update_schema_type = self.update_schema_type + + @router.post(f'/{self.name}/update', + response_model=BaseResponseModel[Any], + summary=f"修改{self.summary_name}") + def update_route_func(update_model_obj: update_schema_type, db: Session = Depends(get_db)): + self.update(db, update_model_obj) + return BaseResponseModel(state=1, msg="", data={}) + + return update_route_func + + def delete_route_func_create(self, router: APIRouter): + @router.post(f'/{self.name}/delete', + response_model=BaseResponseModel[Any], + summary=f"删除{self.summary_name}") + def delete_route_func(item_id: ItemId, db: Session = Depends(get_db)): + self.delete(db, item_id.id) + return BaseResponseModel(state=1, msg="", data={}) + + return delete_route_func + + def get_route_func_create(self, router: APIRouter): + id_schema_type = self.id_schema_type + + @router.post(f'/{self.name}/get', + response_model=BaseResponseModel[self.model_type], + summary=f"获取{self.summary_name}") + def get_route_func(item_id: id_schema_type, db: Session = Depends(get_db)): + index_value = list(item_id.dict().values())[0] + item = self.get(db, index_value) + return BaseResponseModel(state=1, msg="", data=self.model_type(**item.to_dict())) + + return get_route_func + + def query_route_func_create(self, router: APIRouter): + model_type = self.model_type + query_schema_type = self.query_schema_type + locals_temp = {'model_type': model_type, 'BaseModel': BaseModel, 'ModelType': ModelType, 'Generic': Generic, + 'List': List} + class_name = f"{self.name}QueryRes" + exec(f"class {class_name}(BaseModel, Generic[ModelType]):\n count: int\n item_list: List[model_type]", + locals_temp) + QueryRes = locals_temp[class_name] + + @router.post(f'/{self.name}/query', + response_model=BaseResponseModel[QueryRes], + summary=f"查询{self.summary_name}") + def query_route_func(params: query_schema_type, db: Session = Depends(get_db)): + count, query = self.query(db, params) + item_list = [item.to_dict() for item in query] + return BaseResponseModel(state=1, msg="", data=QueryRes(count=count, item_list=item_list)) + + return query_route_func + + def add(self, db: Session, create_model_obj: CreateSchemaType) -> ModelType: + new_item = self.db_model(**create_model_obj.dict()) + db.add(new_item) + db.commit() + db.refresh(new_item) + return new_item + + def delete(self, db: Session, item_id): + if type(item_id) == list: + db.query(self.db_model).filter(getattr(self.db_model, 'id').in_(item_id)).delete() + else: + db.query(self.db_model).filter_by(id=item_id).delete() + db.commit() + + def update(self, db: Session, update_model_obj: UpdateSchemaType): + db.query(self.db_model).filter_by(id=update_model_obj.id).update(update_model_obj.dict()) + db.commit() + + def query(self, db: Session, params: QuerySchemaType) -> [int]: + params_dict = params.dict() + query = db.query(self.db_model) + for key, value in params_dict.items(): + if key not in ['page', 'page_size'] and value is not None: + if type(value) == str: + query = query.filter(getattr(self.db_model, key).like(f'%{value}%')) + if type(value) in [int, float, bool]: + query = query.filter_by(**{key: value}) + # 日期范围查询 + if type(value) in [list, tuple]: + if value[0] is not None: + query = query.filter(getattr(self.db_model, key) >= datetime.fromtimestamp(value[0])) + if value[1] is not None: + query = query.filter(getattr(self.db_model, key) <= datetime.fromtimestamp(value[1])) + + count = query.count() + page=None + page_size=None + if 'page' in params_dict: + page=params_dict['page'] + if 'page_size' in params_dict: + page_size = params_dict['page_size'] + if page is not None and page_size is not None: + query = query.offset((page - 1) * page_size).limit(page_size).all() + return count, query + + def get(self, db: Session, item_id: IdSchemaType) -> ModelType: + return db.query(self.db_model).get(item_id) + + +def create_crud_model(db_model: Base, name, auto_create_keys=['id'], index_key='id'): + """ + 创建CRUD所需模型 + :param db_model: 数据库模型 + :param name: 名称 + :param auto_create_keys: 自动生成的keys + :param index_key: 索引key + :return: + """ + fields_dict = get_sqlalchemy_model_fields(db_model) + model = create_model(f"{name}Model", **fields_dict) + + class config: + arbitrary_types_allowed = False + + model_id = create_model(f'{name}IdModel', **{index_key: fields_dict[index_key]}) + model_create = create_model(f"{name}Create", + **{key: item for key, item in fields_dict.items() if key not in auto_create_keys}) + model_update = create_model(f"{name}Update", __config__=config, + **{key: item if key == index_key else ( + eval(f'Optional[{item[0].__name__}]', globals()), item[1]) for key, item in + fields_dict.items() if key not in auto_create_keys}) + query_fields_dic = {} + for key, item in fields_dict.items(): + if item[0].__name__ == 'datetime': + query_fields_dic[key] = ( + eval(f'Optional[List[Optional[int]]]', globals(), locals()), None) + else: + query_fields_dic[key] = (eval(f'Optional[{item[0].__name__}]', globals()), None) + model_query = create_model(f"{name}Query", __config__=config, + **query_fields_dic) + return [model, model_id, model_create, model_update, model_query] + + +def auto_create_crud(db_model: Base, name, chinese_name="", auto_create_keys=['id'], index_key='id'): + """ + 自动创建CRUD类 + """ + [model, model_id, model_create, model_update, model_query] = create_crud_model(db_model, name, + auto_create_keys, index_key) + + class Crud(CRUDBase[db_model, model, model_id, model_create, model_update, model_query]): + pass + + crud = Crud(db_model, model, model_id, model_create, model_update, model_query, name, chinese_name) + return crud diff --git a/Utils/SqlAlchemyUtils.py b/Utils/SqlAlchemyUtils.py new file mode 100644 index 0000000..e5c7d01 --- /dev/null +++ b/Utils/SqlAlchemyUtils.py @@ -0,0 +1,30 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() + +USER="root" +PASSWORD="12345" +HOST="127.0.0.1" +DB="user" +def get_engine(): + engine = create_engine( + f"mysql+pymysql://{USER}:{PASSWORD}@{HOST}/{DB}?charset=utf8mb4") + return engine + + +def get_db() -> sessionmaker: + try: + engine = get_engine() + Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = Session() + yield db + finally: + db.close() + + + +def init_database(): + engine = get_engine() + Base.metadata.create_all(bind=engine) diff --git a/main.py b/main.py new file mode 100644 index 0000000..260610d --- /dev/null +++ b/main.py @@ -0,0 +1,24 @@ +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from Router import ProjectRouter + +from CrudModel.ItemCrudModel import item_crud + +item_crud.mount(ProjectRouter.router) +app = FastAPI( + title="驾驶舱", + description="", + version="v1.0.0" +) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(ProjectRouter.router) + +uvicorn.run(app=app, port=8001) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a66c6f3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +fastapi==0.89.1 +uvicorn==0.20.0 +python-multipart==0.0.5 +SQLAlchemy==2.0.0 +passlib==1.7.4 +bcrypt==4.0.1 +python-jose==3.3.0 +cryptography==39.0.0 +requests==2.28.2 +captcha==0.4 +pymysql==1.0.2 +fastapi-crudrouter==0.8.6 +redis==4.5.1 \ No newline at end of file