This commit is contained in:
P3ngSaM 2023-02-24 15:27:30 +08:00
commit c1efed6d33
1 changed files with 44 additions and 8 deletions

View File

@ -1,4 +1,6 @@
from typing import TypeVar, Generic, Any, List, get_args, Union, Optional, Tuple from typing import TypeVar, Generic, Any, List, get_args, Union, Optional, Tuple
from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from datetime import datetime, date from datetime import datetime, date
from Utils.CommonUtils import get_sqlalchemy_model_fields from Utils.CommonUtils import get_sqlalchemy_model_fields
@ -40,7 +42,11 @@ class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, U
id_schema_type: IdSchemaType, id_schema_type: IdSchemaType,
create_schema_type: CreateSchemaType, create_schema_type: CreateSchemaType,
update_schema_type: UpdateSchemaType, update_schema_type: UpdateSchemaType,
query_schema_type: QuerySchemaType, name: str, chinese_name: str = ''): query_schema_type: QuerySchemaType,
name: str,
chinese_name: str = '',
array_keys=[],
):
self.db_model = db_model self.db_model = db_model
self.name = name self.name = name
self.model_type = model_type self.model_type = model_type
@ -49,6 +55,8 @@ class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, U
self.update_schema_type = update_schema_type self.update_schema_type = update_schema_type
self.query_schema_type = query_schema_type self.query_schema_type = query_schema_type
self.summary_name = chinese_name if chinese_name else self.name self.summary_name = chinese_name if chinese_name else self.name
self.query_func = None
self.array_keys = array_keys
def mount(self, router: APIRouter): def mount(self, router: APIRouter):
for key in self.__dir__(): for key in self.__dir__():
@ -146,6 +154,17 @@ class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, U
query = db.query(self.db_model) query = db.query(self.db_model)
for key, value in params_dict.items(): for key, value in params_dict.items():
if key not in ['page', 'page_size'] and value is not None: if key not in ['page', 'page_size'] and value is not None:
# 在存储的数组值内查询 如存的 1,2,3,4 查询时则使用的 [1,2]这样的数据来查
if key in self.array_keys:
for item in value:
query = query.filter(func.find_in_set(str(item), getattr(self.db_model, key)))
continue
# 如果执行query_func后有返回值即使用了key,value则跳过后面操作
if self.query_func:
query_temp = self.query_func(query, key, value)
if query_temp:
query = query_temp
continue
if type(value) == str: if type(value) == str:
query = query.filter(getattr(self.db_model, key).like(f'%{value}%')) query = query.filter(getattr(self.db_model, key).like(f'%{value}%'))
if type(value) in [int, float, bool]: if type(value) in [int, float, bool]:
@ -172,13 +191,15 @@ class CRUDBase(Generic[DbModelType, ModelType, IdSchemaType, CreateSchemaType, U
return db.query(self.db_model).get(item_id) return db.query(self.db_model).get(item_id)
def create_crud_model(db_model: Base, name, auto_create_keys=['id'], index_key='id'): def create_crud_model(db_model: Base, name, auto_create_keys=['id'], index_key='id',
array_keys=[]):
""" """
创建CRUD所需模型 创建CRUD所需模型
:param db_model: 数据库模型 :param db_model: 数据库模型
:param name: 名称 :param name: 名称
:param auto_create_keys: 自动生成的keys :param auto_create_keys: 自动生成的keys
:param index_key: 索引key :param index_key: 索引key
:param array_keys: 用id数组存储数据的key的列表
:return: :return:
""" """
fields_dict = get_sqlalchemy_model_fields(db_model) fields_dict = get_sqlalchemy_model_fields(db_model)
@ -195,26 +216,41 @@ def create_crud_model(db_model: Base, name, auto_create_keys=['id'], index_key='
eval(f'Optional[{item[0].__name__}]', globals()), item[1]) for key, item in eval(f'Optional[{item[0].__name__}]', globals()), item[1]) for key, item in
fields_dict.items() if key not in auto_create_keys}) fields_dict.items() if key not in auto_create_keys})
query_fields_dic = {} query_fields_dic = {}
for key, item in fields_dict.items(): for key, item in fields_dict.items():
if item[0].__name__ == 'datetime': if key in array_keys:
query_fields_dic[key] = ( query_fields_dic[key] = (
eval(f'Optional[List[Optional[int]]]', globals(), locals()), None) eval(f'Optional[List[Union[int,str]]]', globals(), locals()), None)
else: else:
query_fields_dic[key] = (eval(f'Optional[{item[0].__name__}]', globals()), None) if item[0].__name__ == 'datetime':
query_fields_dic[key] = (
eval(f'Optional[List[Optional[int]]]', globals(), locals()), None)
else:
query_fields_dic[key] = (eval(f'Optional[{item[0].__name__}]', globals()), None)
model_query = create_model(f"{name}Query", __config__=config, model_query = create_model(f"{name}Query", __config__=config,
**query_fields_dic) **query_fields_dic)
return [model, model_id, model_create, model_update, model_query] return [model, model_id, model_create, model_update, model_query]
def auto_create_crud(db_model: Base, name, chinese_name="", auto_create_keys=['id'], index_key='id'): def auto_create_crud(db_model: Base, name, chinese_name="", auto_create_keys=['id'], index_key='id',
array_keys=[]):
""" """
自动创建CRUD类 自动创建CRUD类
:param db_model:
:param name:
:param chinese_name:
:param auto_create_keys: 自动生成的字段key
:param index_key:
:param array_keys: 由id数组组成的字符串key列表
:return:
""" """
[model, model_id, model_create, model_update, model_query] = create_crud_model(db_model, name, [model, model_id, model_create, model_update, model_query] = create_crud_model(db_model, name,
auto_create_keys, index_key) auto_create_keys, index_key,
array_keys=array_keys)
class Crud(CRUDBase[db_model, model, model_id, model_create, model_update, model_query]): class Crud(CRUDBase[db_model, model, model_id, model_create, model_update, model_query]):
pass pass
crud = Crud(db_model, model, model_id, model_create, model_update, model_query, name, chinese_name) crud = Crud(db_model, model, model_id, model_create, model_update, model_query, name, chinese_name,
array_keys=array_keys)
return crud return crud