From 823ac4bc6f6622b9ceb1fc4556839f488e7196a6 Mon Sep 17 00:00:00 2001 From: wcq <744800102@qq.com> Date: Fri, 24 Feb 2023 10:49:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AF=B9crud=E6=A8=A1=E5=9E=8B=E4=B8=AD?= =?UTF-8?q?=E5=AD=98=E5=82=A8id=E6=95=B0=E7=BB=84=E7=9A=84key=E7=9A=84?= =?UTF-8?q?=E7=AD=9B=E9=80=89=E5=8A=9F=E8=83=BD=E7=BC=96=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Utils/CrudUtils.py | 52 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/Utils/CrudUtils.py b/Utils/CrudUtils.py index d8ac238..b8f7e43 100644 --- a/Utils/CrudUtils.py +++ b/Utils/CrudUtils.py @@ -1,4 +1,6 @@ from typing import TypeVar, Generic, Any, List, get_args, Union, Optional, Tuple + +from sqlalchemy import func from sqlalchemy.orm import Session from datetime import datetime, date from Utils.CommonUtils import get_sqlalchemy_model_fields @@ -40,7 +42,11 @@ class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, U id_schema_type: IdSchemaType, create_schema_type: CreateSchemaType, update_schema_type: UpdateSchemaType, - query_schema_type: QuerySchemaType, name: str, chinese_name: str = ''): + query_schema_type: QuerySchemaType, + name: str, + chinese_name: str = '', + array_keys=[], + ): self.db_model = db_model self.name = name self.model_type = model_type @@ -49,6 +55,8 @@ class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, U self.update_schema_type = update_schema_type self.query_schema_type = query_schema_type self.summary_name = chinese_name if chinese_name else self.name + self.query_func = None + self.array_keys = array_keys def mount(self, router: APIRouter): for key in self.__dir__(): @@ -146,6 +154,17 @@ class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, U query = db.query(self.db_model) for key, value in params_dict.items(): if key not in ['page', 'page_size'] and value is not None: + # 在存储的数组值内查询 如存的 1,2,3,4 查询时则使用的 [1,2]这样的数据来查 + if key in self.array_keys: + for item in value: + query = query.filter(func.find_in_set(str(item), getattr(self.db_model, key))) + continue + # 如果执行query_func后有返回值即使用了key,value则跳过后面操作 + if self.query_func: + query_temp = self.query_func(query, key, value) + if query_temp: + query = query_temp + continue if type(value) == str: query = query.filter(getattr(self.db_model, key).like(f'%{value}%')) if type(value) in [int, float, bool]: @@ -172,13 +191,15 @@ class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, U return db.query(self.db_model).get(item_id) -def create_crud_model(db_model: Base, name, auto_create_keys=['id'], index_key='id'): +def create_crud_model(db_model: Base, name, auto_create_keys=['id'], index_key='id', + array_keys=[]): """ 创建CRUD所需模型 :param db_model: 数据库模型 :param name: 名称 :param auto_create_keys: 自动生成的keys :param index_key: 索引key + :param array_keys: 用id数组存储数据的key的列表 :return: """ fields_dict = get_sqlalchemy_model_fields(db_model) @@ -195,26 +216,41 @@ def create_crud_model(db_model: Base, name, auto_create_keys=['id'], index_key=' eval(f'Optional[{item[0].__name__}]', globals()), item[1]) for key, item in fields_dict.items() if key not in auto_create_keys}) query_fields_dic = {} + for key, item in fields_dict.items(): - if item[0].__name__ == 'datetime': + if key in array_keys: query_fields_dic[key] = ( - eval(f'Optional[List[Optional[int]]]', globals(), locals()), None) + eval(f'Optional[List[Union[int,str]]]', globals(), locals()), None) else: - query_fields_dic[key] = (eval(f'Optional[{item[0].__name__}]', globals()), None) + if item[0].__name__ == 'datetime': + query_fields_dic[key] = ( + eval(f'Optional[List[Optional[int]]]', globals(), locals()), None) + else: + query_fields_dic[key] = (eval(f'Optional[{item[0].__name__}]', globals()), None) model_query = create_model(f"{name}Query", __config__=config, **query_fields_dic) return [model, model_id, model_create, model_update, model_query] -def auto_create_crud(db_model: Base, name, chinese_name="", auto_create_keys=['id'], index_key='id'): +def auto_create_crud(db_model: Base, name, chinese_name="", auto_create_keys=['id'], index_key='id', + array_keys=[]): """ 自动创建CRUD类 + :param db_model: + :param name: + :param chinese_name: + :param auto_create_keys: 自动生成的字段key + :param index_key: + :param array_keys: 由id数组组成的字符串key列表 + :return: """ [model, model_id, model_create, model_update, model_query] = create_crud_model(db_model, name, - auto_create_keys, index_key) + auto_create_keys, index_key, + array_keys=array_keys) class Crud(CRUDBase[db_model, model, model_id, model_create, model_update, model_query]): pass - crud = Crud(db_model, model, model_id, model_create, model_update, model_query, name, chinese_name) + crud = Crud(db_model, model, model_id, model_create, model_update, model_query, name, chinese_name, + array_keys=array_keys) return crud