2023-04-04 14:28:53 +08:00
|
|
|
from pydantic import BaseModel
|
|
|
|
from sqlalchemy import create_engine, Column, and_, asc, desc, func
|
2023-02-28 13:52:51 +08:00
|
|
|
from sqlalchemy.ext.declarative import declarative_base
|
2023-04-04 14:28:53 +08:00
|
|
|
from sqlalchemy.orm import sessionmaker, Session, DeclarativeMeta, Query
|
|
|
|
from typing import Literal, List, Any, Optional
|
2023-02-28 13:52:51 +08:00
|
|
|
|
|
|
|
Base = declarative_base()
|
|
|
|
|
2023-02-28 16:28:48 +08:00
|
|
|
user = "root"
|
2023-04-04 20:08:01 +08:00
|
|
|
password = "123456"
|
2023-02-28 16:28:48 +08:00
|
|
|
host = "127.0.0.1"
|
|
|
|
db = 'daily'
|
2023-02-28 13:52:51 +08:00
|
|
|
|
|
|
|
|
|
|
|
def get_engine():
|
|
|
|
engine = create_engine(
|
2023-02-28 16:28:48 +08:00
|
|
|
f"mysql+pymysql://{user}:{password}@{host}/{db}?charset=utf8mb4")
|
2023-02-28 13:52:51 +08:00
|
|
|
return engine
|
|
|
|
|
|
|
|
|
|
|
|
def get_db() -> sessionmaker:
|
|
|
|
try:
|
|
|
|
engine = get_engine()
|
|
|
|
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
db = Session()
|
|
|
|
yield db
|
|
|
|
finally:
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
2023-03-01 16:04:43 +08:00
|
|
|
def get_db_i() -> Session:
|
2023-02-28 16:28:48 +08:00
|
|
|
engine = get_engine()
|
|
|
|
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
db = Session()
|
|
|
|
return db
|
|
|
|
|
|
|
|
|
2023-02-28 13:52:51 +08:00
|
|
|
# def get_db():
|
|
|
|
# engine = create_engine("sqlite:///./data.db", connect_args={"check_same_thread": False})
|
|
|
|
# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
# db = SessionLocal()
|
|
|
|
# try:
|
|
|
|
# yield db
|
|
|
|
# finally:
|
|
|
|
# db.close()
|
|
|
|
|
|
|
|
|
|
|
|
def init_database():
|
|
|
|
engine = get_engine()
|
|
|
|
Base.metadata.create_all(bind=engine)
|
2023-03-01 16:04:43 +08:00
|
|
|
|
2023-04-04 14:28:53 +08:00
|
|
|
|
|
|
|
# 通用查询接口
|
|
|
|
|
|
|
|
QueryType = Literal['=', '==', '>', '>=', '<', '<=', 'in', 'like', 'range', 'find_in_set']
|
|
|
|
|
|
|
|
|
|
|
|
class QueryParam(BaseModel):
|
|
|
|
name: str
|
|
|
|
type: QueryType
|
|
|
|
value: Any
|
|
|
|
|
|
|
|
|
|
|
|
class OrderParam(BaseModel):
|
|
|
|
order_by: str
|
|
|
|
type: Literal['asc', 'desc'] = 'asc'
|
|
|
|
|
|
|
|
|
|
|
|
class QueryParams(BaseModel):
|
|
|
|
param_list: Optional[List[QueryParam]]
|
|
|
|
order: Optional[OrderParam]
|
|
|
|
page: Optional[int]
|
|
|
|
page_size: Optional[int]
|
|
|
|
|
|
|
|
|
|
|
|
def query_common(db: Session, model, query_params: QueryParams):
|
|
|
|
"""
|
|
|
|
通用数据库查询接口
|
|
|
|
"""
|
|
|
|
query = db.query(model)
|
|
|
|
page = query_params.page
|
|
|
|
page_size = query_params.page_size
|
|
|
|
# 筛选
|
|
|
|
if query_params.param_list:
|
|
|
|
query = query_common_core(model, query, query_params.param_list)
|
|
|
|
# 排序
|
|
|
|
if query_params.order:
|
|
|
|
query = query_order(model, query, query_params.order)
|
|
|
|
count = query.count()
|
|
|
|
if page is not None and page_size is not None:
|
|
|
|
page_size = min(page_size, 100)
|
|
|
|
query = query.offset((page - 1) * page_size).limit(page_size).all()
|
|
|
|
return count, query
|
|
|
|
|
|
|
|
|
|
|
|
def query_order(model, query: Query, order_param: OrderParam):
|
|
|
|
"""
|
|
|
|
查询结果排序
|
|
|
|
"""
|
|
|
|
if order_param:
|
|
|
|
if order_param.order_by:
|
|
|
|
column: Column = getattr(model, order_param.order_by)
|
|
|
|
if order_param.type == 'asc':
|
|
|
|
query = query.order_by(asc(column))
|
|
|
|
if order_param.type == 'desc':
|
|
|
|
query = query.order_by(desc(column))
|
|
|
|
return query
|
|
|
|
|
|
|
|
|
|
|
|
def query_common_core(model, query: Query, param_list: List[QueryParam]):
|
|
|
|
for item in param_list:
|
|
|
|
name = item.name
|
|
|
|
query_type = item.type
|
|
|
|
value = item.value
|
|
|
|
column: Column = getattr(model, name)
|
|
|
|
if query_type in ['=', '==']:
|
|
|
|
query = query.filter(column == value)
|
|
|
|
if query_type == ">":
|
|
|
|
query = query.filter(column > value)
|
|
|
|
if query_type == "<":
|
|
|
|
query = query.filter(column < value)
|
|
|
|
if query_type == ">=":
|
|
|
|
query = query.filter(column >= value)
|
|
|
|
if query_type == "<=":
|
|
|
|
query = query.filter(column <= value)
|
|
|
|
if query_type == "in":
|
|
|
|
query = query.filter(column.in_(value))
|
|
|
|
if query_type == "find_in_set":
|
|
|
|
query = query.filter(func.find_in_set(value, column))
|
|
|
|
if query_type == "like":
|
|
|
|
query = query.filter(column.like(f'%{value}%'))
|
|
|
|
if query_type == "range":
|
|
|
|
query = query.filter(and_(column >= value[0], column <= value[1]))
|
|
|
|
return query
|
|
|
|
|
|
|
|
|
|
|
|
def tree_table_to_json(db: Session, model, belong_str='belong', key_str='id'):
|
|
|
|
"""
|
|
|
|
树结构数据库存储到json
|
|
|
|
"""
|
|
|
|
item_list = db.query(model).all()
|
|
|
|
node_list = []
|
|
|
|
for item in item_list:
|
|
|
|
item_dic = item.to_dict()
|
|
|
|
item_dic['children'] = []
|
|
|
|
node_list.append(item_dic)
|
|
|
|
|
|
|
|
nodes_dic = {node[key_str]: node for node in node_list}
|
|
|
|
tree = []
|
|
|
|
for node in node_list:
|
|
|
|
belong = node[belong_str]
|
|
|
|
if belong:
|
|
|
|
if belong in nodes_dic:
|
|
|
|
nodes_dic[belong]['children'].append(node)
|
|
|
|
else:
|
|
|
|
tree.append(node)
|
|
|
|
return tree
|