154 lines
4.6 KiB
Python
154 lines
4.6 KiB
Python
from pydantic import BaseModel
|
|
from sqlalchemy import create_engine, Column, and_, asc, desc, func
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker, Session, DeclarativeMeta, Query
|
|
from typing import Literal, List, Any, Optional
|
|
|
|
|
|
class SqlalchemyConnect:
|
|
def __init__(self, Base: declarative_base, host="127.0.0.1", user="", password="", db=""):
|
|
self.Base = Base
|
|
self.host = host
|
|
self.user = user
|
|
self.password = password
|
|
self.db = db
|
|
self.engine = self.init_engine()
|
|
|
|
def init_engine(self):
|
|
engine = create_engine(
|
|
f"mysql+pymysql://{self.user}:{self.password}@{self.host}/{self.db}?charset=utf8mb4")
|
|
return engine
|
|
|
|
def get_db(self) -> sessionmaker:
|
|
try:
|
|
session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
|
|
db = session()
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
def get_db_commit(self) -> sessionmaker:
|
|
try:
|
|
session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
|
|
db = session()
|
|
yield db
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
def get_db_i(self) -> Session:
|
|
engine = create_engine(
|
|
f"mysql+pymysql://{self.user}:{self.password}@{self.host}/{self.db}?charset=utf8mb4")
|
|
session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
db = session()
|
|
return db
|
|
|
|
def init_database(self):
|
|
self.Base.metadata.create_all(bind=self.engine)
|
|
|
|
|
|
# 通用查询接口
|
|
|
|
QueryType = Literal['=', '==', '>', '>=', '<', '<=', 'in', 'like', 'range', 'find_in_set']
|
|
|
|
|
|
class QueryParam(BaseModel):
|
|
name: str
|
|
type: QueryType
|
|
value: Any
|
|
|
|
|
|
class OrderParam(BaseModel):
|
|
order_by: str
|
|
type: Literal['asc', 'desc'] = 'asc'
|
|
|
|
|
|
class QueryParams(BaseModel):
|
|
param_list: Optional[List[QueryParam]]
|
|
order: Optional[OrderParam]
|
|
page: int
|
|
page_size: int
|
|
|
|
|
|
def query_common(db: Session, model, query_params: QueryParams):
|
|
"""
|
|
通用数据库查询接口
|
|
"""
|
|
query = db.query(model)
|
|
page = query_params.page
|
|
page_size = query_params.page_size
|
|
# 筛选
|
|
if query_params.param_list:
|
|
query = query_common_core(model, query, query_params.param_list)
|
|
# 排序
|
|
if query_params.order:
|
|
query = query_order(model, query, query_params.order)
|
|
count = query.count()
|
|
page_size = min(page_size, 100)
|
|
query = query.offset((page - 1) * page_size).limit(page_size).all()
|
|
return count, query
|
|
|
|
|
|
def query_order(model, query: Query, order_param: OrderParam):
|
|
"""
|
|
查询结果排序
|
|
"""
|
|
if order_param:
|
|
if order_param.order_by:
|
|
column: Column = getattr(model, order_param.order_by)
|
|
if order_param.type == 'asc':
|
|
query = query.order_by(asc(column))
|
|
if order_param.type == 'desc':
|
|
query = query.order_by(desc(column))
|
|
return query
|
|
|
|
|
|
def query_common_core(model, query: Query, param_list: List[QueryParam]):
|
|
for item in param_list:
|
|
name = item.name
|
|
query_type = item.type
|
|
value = item.value
|
|
column: Column = getattr(model, name)
|
|
if query_type in ['=', '==']:
|
|
query = query.filter(column == value)
|
|
if query_type == ">":
|
|
query = query.filter(column > value)
|
|
if query_type == "<":
|
|
query = query.filter(column < value)
|
|
if query_type == ">=":
|
|
query = query.filter(column >= value)
|
|
if query_type == "<=":
|
|
query = query.filter(column <= value)
|
|
if query_type == "in":
|
|
query = query.filter(column.in_(value))
|
|
if query_type == "find_in_set":
|
|
query = query.filter(func.find_in_set(value, column))
|
|
if query_type == "like":
|
|
query = query.filter(column.like(f'%{value}%'))
|
|
if query_type == "range":
|
|
query = query.filter(and_(column >= value[0], column <= value[1]))
|
|
return query
|
|
|
|
|
|
def tree_table_to_json(db: Session, model, belong_str='belong', key_str='id'):
|
|
"""
|
|
树结构数据库存储到json
|
|
"""
|
|
item_list = db.query(model).all()
|
|
node_list = []
|
|
for item in item_list:
|
|
item_dic = item.to_dict()
|
|
item_dic['children'] = []
|
|
node_list.append(item_dic)
|
|
|
|
nodes_dic = {node[key_str]: node for node in node_list}
|
|
tree = []
|
|
for node in node_list:
|
|
belong = node[belong_str]
|
|
if belong:
|
|
if belong in nodes_dic:
|
|
nodes_dic[belong]['children'].append(node)
|
|
else:
|
|
tree.append(node)
|
|
return tree
|