from pydantic import BaseModel from sqlalchemy import create_engine, Column, and_, asc, desc, func from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, Session, DeclarativeMeta, Query from typing import Literal, List, Any, Optional class SqlalchemyConnect: def __init__(self, Base: declarative_base, host="127.0.0.1", user="", password="", db=""): self.Base = Base self.host = host self.user = user self.password = password self.db = db self.engine = self.init_engine() def init_engine(self): engine = create_engine( f"mysql+pymysql://{self.user}:{self.password}@{self.host}/{self.db}?charset=utf8mb4",pool_recycle=3600*4) return engine def get_db(self) -> sessionmaker: try: session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) db = session() yield db finally: db.close() def get_db_commit(self) -> sessionmaker: try: session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) db = session() yield db db.commit() finally: db.close() def get_db_i(self) -> Session: engine = create_engine( f"mysql+pymysql://{self.user}:{self.password}@{self.host}/{self.db}?charset=utf8mb4") session = sessionmaker(autocommit=False, autoflush=False, bind=engine) db = session() return db def init_database(self): self.Base.metadata.create_all(bind=self.engine) # 通用查询接口 QueryType = Literal['=', '==', '>', '>=', '<', '<=', 'in', 'like', 'range', 'find_in_set'] class QueryParam(BaseModel): name: str type: QueryType value: Any class OrderParam(BaseModel): order_by: str type: Literal['asc', 'desc'] = 'asc' class QueryParams(BaseModel): param_list: Optional[List[QueryParam]] order: Optional[OrderParam] page: int page_size: int def query_common(db: Session, model, query_params: QueryParams): """ 通用数据库查询接口 """ query = db.query(model) page = query_params.page page_size = query_params.page_size # 筛选 if query_params.param_list: query = query_common_core(model, query, query_params.param_list) # 排序 if query_params.order: query = query_order(model, query, query_params.order) count = query.count() page_size = min(page_size, 100) query = query.offset((page - 1) * page_size).limit(page_size).all() return count, query def query_order(model, query: Query, order_param: OrderParam): """ 查询结果排序 """ if order_param: if order_param.order_by: column: Column = getattr(model, order_param.order_by) if order_param.type == 'asc': query = query.order_by(asc(column)) if order_param.type == 'desc': query = query.order_by(desc(column)) return query def query_common_core(model, query: Query, param_list: List[QueryParam]): for item in param_list: name = item.name query_type = item.type value = item.value column: Column = getattr(model, name) if query_type in ['=', '==']: query = query.filter(column == value) if query_type == ">": query = query.filter(column > value) if query_type == "<": query = query.filter(column < value) if query_type == ">=": query = query.filter(column >= value) if query_type == "<=": query = query.filter(column <= value) if query_type == "in": query = query.filter(column.in_(value)) if query_type == "find_in_set": query = query.filter(func.find_in_set(value, column)) if query_type == "like": query = query.filter(column.like(f'%{value}%')) if query_type == "range": query = query.filter(and_(column >= value[0], column <= value[1])) return query def tree_table_to_json(db: Session, model, belong_str='belong', key_str='id'): """ 树结构数据库存储到json """ item_list = db.query(model).all() node_list = [] for item in item_list: item_dic = item.to_dict() item_dic['children'] = [] node_list.append(item_dic) nodes_dic = {node[key_str]: node for node in node_list} tree = [] for node in node_list: belong = node[belong_str] if belong: if belong in nodes_dic: nodes_dic[belong]['children'].append(node) else: tree.append(node) return tree