urban-investment-research/Utils/SqlAlchemyUtils.py

101 lines
3.1 KiB
Python

from pydantic import BaseModel
from sqlalchemy import create_engine, Column, and_
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session, DeclarativeMeta
from typing import Literal, List, Any
class SqlalchemyConnect:
def __init__(self, Base: declarative_base, host="127.0.0.1", user="", password="", db=""):
self.Base = Base
self.host = host
self.user = user
self.password = password
self.db = db
self.engine = self.init_engine()
def init_engine(self):
engine = create_engine(
f"mysql+pymysql://{self.user}:{self.password}@{self.host}/{self.db}?charset=utf8mb4")
return engine
def get_db(self) -> sessionmaker:
try:
session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
db = session()
yield db
finally:
db.close()
def get_db_commit(self) -> sessionmaker:
try:
session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
db = session()
yield db
db.commit()
finally:
db.close()
def get_db_i(self) -> Session:
engine = create_engine(
f"mysql+pymysql://{self.user}:{self.password}@{self.host}/{self.db}?charset=utf8mb4")
session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = session()
return db
def init_database(self):
self.Base.metadata.create_all(bind=self.engine)
# 通用查询接口
QueryType = Literal['=', '==', '>', '>=', '<', '<=', 'in', 'like', 'range']
class QueryParam(BaseModel):
name: str
type: QueryType
value: Any
class QueryParams(BaseModel):
param_list: List[QueryParam]
page: int
page_size: int
def query_common(db: Session, model, query_params: QueryParams):
query = db.query(model)
page = query_params.page
page_size = query_params.page_size
query = query_common_core(model, query, query_params.param_list)
count = query.count()
page_size = min(page_size, 100)
query = query.offset((page - 1) * page_size).limit(page_size).all()
return count, query
def query_common_core(model, query, param_list: List[QueryParam]):
for item in param_list:
name = item.name
query_type = item.type
value = item.value
column: Column = getattr(model, name)
if query_type in ['=', '==']:
query = query.filter(column == value)
if query_type == ">":
query = query.filter(column > value)
if query_type == "<":
query = query.filter(column < value)
if query_type == ">=":
query = query.filter(column >= value)
if query_type == "<=":
query = query.filter(column <= value)
if query_type == "in":
query = query.filter(column.in_(value))
if query_type == "like":
query = query.filter(column.like(f'%{value}%'))
if query_type == "range":
query = query.filter(and_(column >= value[0], column <= value[1]))
return query