urban-investment-research/Utils/SqlAlchemyUtils.py

124 lines
3.7 KiB
Python
Raw Normal View History

2023-03-22 17:06:48 +08:00
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, and_
2023-03-13 14:22:40 +08:00
from sqlalchemy.ext.declarative import declarative_base
2023-03-22 17:06:48 +08:00
from sqlalchemy.orm import sessionmaker, Session, DeclarativeMeta
from typing import Literal, List, Any
2023-03-13 14:22:40 +08:00
class SqlalchemyConnect:
2023-03-14 10:49:27 +08:00
def __init__(self, Base: declarative_base, host="127.0.0.1", user="", password="", db=""):
self.Base = Base
2023-03-13 14:22:40 +08:00
self.host = host
self.user = user
self.password = password
self.db = db
self.engine = self.init_engine()
def init_engine(self):
engine = create_engine(
f"mysql+pymysql://{self.user}:{self.password}@{self.host}/{self.db}?charset=utf8mb4")
return engine
def get_db(self) -> sessionmaker:
try:
session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
db = session()
yield db
finally:
db.close()
def get_db_commit(self) -> sessionmaker:
try:
session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
db = session()
yield db
db.commit()
finally:
db.close()
def get_db_i(self) -> Session:
engine = create_engine(
f"mysql+pymysql://{self.user}:{self.password}@{self.host}/{self.db}?charset=utf8mb4")
session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = session()
return db
def init_database(self):
2023-03-14 10:49:27 +08:00
self.Base.metadata.create_all(bind=self.engine)
2023-03-22 17:06:48 +08:00
# 通用查询接口
QueryType = Literal['=', '==', '>', '>=', '<', '<=', 'in', 'like', 'range']
class QueryParam(BaseModel):
name: str
type: QueryType
value: Any
class QueryParams(BaseModel):
param_list: List[QueryParam]
page: int
page_size: int
def query_common(db: Session, model, query_params: QueryParams):
query = db.query(model)
page = query_params.page
page_size = query_params.page_size
query = query_common_core(model, query, query_params.param_list)
count = query.count()
page_size = min(page_size, 100)
query = query.offset((page - 1) * page_size).limit(page_size).all()
return count, query
def query_common_core(model, query, param_list: List[QueryParam]):
for item in param_list:
name = item.name
query_type = item.type
value = item.value
column: Column = getattr(model, name)
if query_type in ['=', '==']:
query = query.filter(column == value)
if query_type == ">":
query = query.filter(column > value)
if query_type == "<":
query = query.filter(column < value)
if query_type == ">=":
query = query.filter(column >= value)
if query_type == "<=":
query = query.filter(column <= value)
if query_type == "in":
query = query.filter(column.in_(value))
if query_type == "like":
query = query.filter(column.like(f'%{value}%'))
if query_type == "range":
query = query.filter(and_(column >= value[0], column <= value[1]))
return query
2023-03-23 16:08:39 +08:00
def tree_table_to_json(db: Session, model, belong_str='belong', key_str='id'):
"""
树结构数据库存储到json
"""
item_list = db.query(model).all()
node_list = []
for item in item_list:
item_dic = item.to_dict()
item_dic['children'] = []
node_list.append(item_dic)
nodes_dic = {node[key_str]: node for node in node_list}
tree = []
for node in node_list:
belong = node[belong_str]
if belong:
if belong in nodes_dic:
nodes_dic[belong]['children'].append(node)
else:
tree.append(node)
return tree