daily/Utils/SqlAlchemyUtils.py

188 lines
5.3 KiB
Python

from datetime import datetime, date
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, and_, asc, desc, func, cast, DATE, or_
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session, DeclarativeMeta, Query
from typing import Literal, List, Any, Optional
from sqlalchemy_utils import database_exists, create_database
Base = declarative_base()
user = "root"
password = "123456"
host = "127.0.0.1"
db = 'daily'
# user = "root"
# password = "jntm2.5"
# host = "192.168.0.89:13306"
# db = 'daily'
# db = 'daily_test'
# host = "139.9.249.34"
# post = "3306"
# user = "fecr"
# password = "fecr1988.wcq"
# db = "daily_test"
def get_engine():
engine = create_engine(
f"mysql+pymysql://{user}:{password}@{host}/{db}?charset=utf8mb4", pool_recycle=3600 * 4)
return engine
def get_db() -> sessionmaker:
try:
engine = get_engine()
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = Session()
yield db
finally:
db.close()
def get_db_i() -> Session:
engine = get_engine()
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = Session()
return db
# def get_db():
# engine = create_engine("sqlite:///./data.db", connect_args={"check_same_thread": False})
# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# db = SessionLocal()
# try:
# yield db
# finally:
# db.close()
def create_db(engine):
if not database_exists(engine.url):
create_database(engine.url)
def init_database():
engine = get_engine()
create_db(engine)
Base.metadata.create_all(bind=engine)
# 通用查询接口
QueryType = Literal['=', "!=", '==', '>', '>=', '<', '<=', 'in', 'like', 'range', 'find_in_set', "date=="]
class QueryParam(BaseModel):
name: str
type: QueryType
value: Any
class OrderParam(BaseModel):
order_by: str
type: Literal['asc', 'desc'] = 'asc'
class QueryParams(BaseModel):
param_list: Optional[List[QueryParam]]
order: Optional[OrderParam]
page: Optional[int]
page_size: Optional[int]
def query_common(db: Session, model, query_params: QueryParams):
"""
通用数据库查询接口
"""
query = db.query(model)
page = query_params.page
page_size = query_params.page_size
# 筛选
if query_params.param_list:
query = query_common_core(model, query, query_params.param_list)
# 排序
if query_params.order:
query = query_order(model, query, query_params.order)
count = query.count()
if page is not None and page_size is not None:
page_size = min(page_size, 100)
query = query.offset((page - 1) * page_size).limit(page_size).all()
return count, query
def query_order(model, query: Query, order_param: OrderParam):
"""
查询结果排序
"""
if order_param:
if order_param.order_by:
column: Column = getattr(model, order_param.order_by)
if order_param.type == 'asc':
query = query.order_by(asc(column))
if order_param.type == 'desc':
query = query.order_by(desc(column))
return query
def query_common_core(model, query: Query, param_list: List[QueryParam]):
for item in param_list:
name = item.name
query_type = item.type
value = item.value
column: Column = getattr(model, name)
if query_type in ['=', '==']:
query = query.filter(column == value)
if query_type in ['date==']:
if type(value) == int:
value = value / 1000
query = query.filter(cast(column, DATE) == date.fromtimestamp(value))
if query_type == ">":
query = query.filter(column > value)
if query_type == "<":
query = query.filter(column < value)
if query_type == ">=":
query = query.filter(column >= value)
if query_type == "<=":
query = query.filter(column <= value)
if query_type == "!=":
query = query.filter(column != value)
if query_type == "in":
query = query.filter(column.in_(value))
if query_type == "find_in_set":
if type(value) == list:
query = query.filter(or_(*[func.find_in_set(item, column) for item in value]))
else:
query = query.filter(func.find_in_set(value, column))
if query_type == "like":
query = query.filter(column.like(f'%{value}%'))
if query_type == "range":
query = query.filter(and_(column >= value[0], column <= value[1]))
return query
def tree_table_to_json(db: Session, model, belong_str='belong', key_str='id'):
"""
树结构数据库存储到json
"""
item_list = db.query(model).all()
node_list = []
for item in item_list:
item_dic = item.to_dict()
item_dic['children'] = []
node_list.append(item_dic)
nodes_dic = {node[key_str]: node for node in node_list}
tree = []
for node in node_list:
belong = node[belong_str]
if belong:
if belong in nodes_dic:
nodes_dic[belong]['children'].append(node)
else:
tree.append(node)
return tree