daily/Utils/SqlAlchemyUtils.py

158 lines
4.4 KiB
Python

from pydantic import BaseModel
from sqlalchemy import create_engine, Column, and_, asc, desc, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session, DeclarativeMeta, Query
from typing import Literal, List, Any, Optional
Base = declarative_base()
user = "root"
password = "123456"
host = "127.0.0.1"
db = 'daily'
def get_engine():
engine = create_engine(
f"mysql+pymysql://{user}:{password}@{host}/{db}?charset=utf8mb4",pool_recycle=3600*4)
return engine
def get_db() -> sessionmaker:
try:
engine = get_engine()
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = Session()
yield db
finally:
db.close()
def get_db_i() -> Session:
engine = get_engine()
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = Session()
return db
# def get_db():
# engine = create_engine("sqlite:///./data.db", connect_args={"check_same_thread": False})
# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# db = SessionLocal()
# try:
# yield db
# finally:
# db.close()
def init_database():
engine = get_engine()
Base.metadata.create_all(bind=engine)
# 通用查询接口
QueryType = Literal['=', '==', '>', '>=', '<', '<=', 'in', 'like', 'range', 'find_in_set']
class QueryParam(BaseModel):
name: str
type: QueryType
value: Any
class OrderParam(BaseModel):
order_by: str
type: Literal['asc', 'desc'] = 'asc'
class QueryParams(BaseModel):
param_list: Optional[List[QueryParam]]
order: Optional[OrderParam]
page: Optional[int]
page_size: Optional[int]
def query_common(db: Session, model, query_params: QueryParams):
"""
通用数据库查询接口
"""
query = db.query(model)
page = query_params.page
page_size = query_params.page_size
# 筛选
if query_params.param_list:
query = query_common_core(model, query, query_params.param_list)
# 排序
if query_params.order:
query = query_order(model, query, query_params.order)
count = query.count()
if page is not None and page_size is not None:
page_size = min(page_size, 100)
query = query.offset((page - 1) * page_size).limit(page_size).all()
return count, query
def query_order(model, query: Query, order_param: OrderParam):
"""
查询结果排序
"""
if order_param:
if order_param.order_by:
column: Column = getattr(model, order_param.order_by)
if order_param.type == 'asc':
query = query.order_by(asc(column))
if order_param.type == 'desc':
query = query.order_by(desc(column))
return query
def query_common_core(model, query: Query, param_list: List[QueryParam]):
for item in param_list:
name = item.name
query_type = item.type
value = item.value
column: Column = getattr(model, name)
if query_type in ['=', '==']:
query = query.filter(column == value)
if query_type == ">":
query = query.filter(column > value)
if query_type == "<":
query = query.filter(column < value)
if query_type == ">=":
query = query.filter(column >= value)
if query_type == "<=":
query = query.filter(column <= value)
if query_type == "in":
query = query.filter(column.in_(value))
if query_type == "find_in_set":
query = query.filter(func.find_in_set(value, column))
if query_type == "like":
query = query.filter(column.like(f'%{value}%'))
if query_type == "range":
query = query.filter(and_(column >= value[0], column <= value[1]))
return query
def tree_table_to_json(db: Session, model, belong_str='belong', key_str='id'):
"""
树结构数据库存储到json
"""
item_list = db.query(model).all()
node_list = []
for item in item_list:
item_dic = item.to_dict()
item_dic['children'] = []
node_list.append(item_dic)
nodes_dic = {node[key_str]: node for node in node_list}
tree = []
for node in node_list:
belong = node[belong_str]
if belong:
if belong in nodes_dic:
nodes_dic[belong]['children'].append(node)
else:
tree.append(node)
return tree