wd-smebiz/utils/sqlalchemy_common_utils.py

732 lines
28 KiB
Python
Raw Normal View History

2023-08-02 14:24:28 +08:00
import json
from typing import Literal, List, Any, Optional, Union, Generic, TypeVar, Dict, Type
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, and_, asc, desc, func, String
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session, Query, Load
from sqlalchemy_utils import database_exists, create_database, get_columns, get_column_key, get_type, get_primary_keys
from pydantic import BaseModel
from sqlalchemy import Column
from sqlalchemy.orm import InstrumentedAttribute, ColumnProperty, Relationship, DeclarativeMeta, Session
from sqlalchemy_utils import get_columns, get_column_key
class SqlalchemyConnect:
def __init__(self, Base: declarative_base, host="127.0.0.1", user="", password="", db="",
db_type: Literal['mysql', 'postgresql'] = 'mysql'):
self.Base = Base
self.host = host
self.user = user
self.password = password
self.db = db
if db_type == 'mysql':
self.engine = self.init_engine()
if db_type == 'postgresql':
self.engine = self.init_postgresql_engine()
# self.async_engine = self.init_async_engine()
def init_engine(self):
engine = create_engine(
f"mysql+pymysql://{self.user}:{self.password}@{self.host}/{self.db}?charset=utf8mb4", pool_recycle=3600 * 4)
return engine
def init_postgresql_engine(self):
engine = create_engine(f'postgresql+psycopg2://{self.user}:{self.password}@{self.host}/{self.db}',
pool_recycle=3600 * 4, echo=False)
return engine
def create_db(self):
if not database_exists(self.engine.url):
create_database(self.engine.url)
def get_db(self) -> sessionmaker:
try:
session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
db = session()
yield db
finally:
db.close()
def get_db_commit(self) -> sessionmaker:
try:
session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
db = session()
yield db
db.commit()
finally:
db.close()
def get_db_i(self) -> Session:
# engine = create_engine(
# f"mysql+pymysql://{self.user}:{self.password}@{self.host}/{self.db}?charset=utf8mb4")
session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
db = session()
return db
def init_database(self, create_db=False):
if create_db:
self.create_db()
self.Base.metadata.create_all(bind=self.engine)
def init_async_engine(self):
async_engine = create_async_engine(
f"mysql+aiomysql://{self.user}:{self.password}@{self.host}/{self.db}?charset=utf8mb4", pool_recycle=3600 * 4
)
return async_engine
def get_model_column_keys(model):
return [get_column_key(model, item) for item in get_columns(model)]
def marge_col_names(col_names, include: List[str] = None, ex_include: List[str] = None):
if include:
col_names = [item for item in include if item in col_names]
if ex_include:
col_names = [item for item in col_names if item not in ex_include]
return set(col_names)
# 通用查询接口
QueryType = Literal['=', '==', '>', '>=', '<', '<=', 'in', 'like', 'range', "find_in_set"]
class QueryParam(BaseModel):
name: str
type: QueryType
value: Any
class OrderParam(BaseModel):
order_by: str
type: Literal['asc', 'desc'] = 'asc'
ModelInfo = TypeVar('ModelInfo', bound=BaseModel)
class QueryInclude(BaseModel):
include: Optional[List[str]] = None
ex_include: Optional[List[str]] = None
2023-08-02 14:24:28 +08:00
relation_use_id: bool = False # 为true的话m2m关系返回的是id数组
class QueryParams(QueryInclude, Generic[ModelInfo]):
params: Optional[List[QueryParam]] = None
order: Optional[OrderParam] = None
query: Optional[ModelInfo] = None
2023-08-02 14:24:28 +08:00
page: int
page_size: int
def query_common(db: Session, model, query_params: QueryParams):
"""
通用数据库查询接口
"""
query = db.query(model)
# if query_params.ex_include:
# query = query.options(Load(model).load_only(*query_params.include))
# if query_params.include:
# query = query.options(Load(model).noload(*query_params.ex_include))
page = query_params.page
page_size = query_params.page_size
# 排序
if query_params.order:
query = query_order(model, query, query_params.order)
if query_params.query:
query = query_common_query_core(model, query, query_params.query)
# 筛选
if query_params.params:
query = query_common_params_core(model, query, query_params.params)
count = query.count()
return count, query, page, page_size
2023-08-02 16:32:36 +08:00
def query_common_with_page(db: Session, model, query_params: QueryParams):
count, query, page, page_size = query_common(db, model, query_params)
query = query.offset((page - 1) * page_size).limit(page_size)
return count, query
2023-08-02 16:32:36 +08:00
2023-08-02 14:24:28 +08:00
def query_order(model, query: Query, order_param: OrderParam):
"""
查询结果排序
"""
if order_param:
if order_param.order_by:
column: Column = getattr(model, order_param.order_by)
if order_param.type == 'asc':
query = query.order_by(asc(column))
if order_param.type == 'desc':
query = query.order_by(desc(column))
return query
def query_common_params_core(model, query: Query, params: List[QueryParam]):
cols: Dict[str, TypeInfo] = model.model_config.cols
for item in params:
name = item.name
child_name = ""
if '.' in name:
[name, child_name] = name.split('.')
query_type = item.type
value = item.value
col = cols[name]
col_base_type = col.col_base_type
column: Column = getattr(model, name)
if col_base_type == 'relation' and col.relation:
if col.relation.relation_type in ['m2o', 'o2o']:
if col.relation.source_foreign_key:
query = query.join(col.relation.target_model,
getattr(col.relation.target_model, col.relation.target_id_key) == getattr(model,
col.relation.source_foreign_key))
else:
query = query.join(col.relation.target_model)
column = getattr(col.relation.target_model, child_name)
else:
if col_base_type == 'json':
column = column[child_name]
if query_type in ['=', '==']:
if col_base_type == 'relation' and col.relation and not child_name:
relation_type = col.relation.relation_type
secondary = col.relation.secondary
target_secondary_key = col.relation.target_secondary_key
source_secondary_key = col.relation.source_secondary_key
if relation_type == 'm2m':
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
.filter(secondary.columns.get(target_secondary_key) == value)
else:
if col.col_base_type == 'datetime':
query = query.filter(func.date(column) == value)
elif col_base_type == 'json':
query = query.filter(column == json.dumps(value))
else:
query = query.filter(column == value)
elif query_type == ">":
query = query.filter(column > value)
elif query_type == "<":
query = query.filter(column < value)
elif query_type == ">=":
query = query.filter(column >= value)
elif query_type == "<=":
query = query.filter(column <= value)
elif query_type == "in":
if col_base_type == 'relation' and col.relation:
relation_type = col.relation.relation_type
if relation_type == 'm2m':
secondary = col.relation.secondary
target_secondary_key = col.relation.target_secondary_key
source_secondary_key = col.relation.source_secondary_key
if value and type(value) == list:
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
.filter(secondary.columns.get(target_secondary_key).in_(value))
else:
query = query.join(secondary, model.id == secondary.columns.get(source_secondary_key)) \
.filter(secondary.columns.get(target_secondary_key) == value)
else:
query = query.filter(column.in_(value))
elif query_type == "like":
query = query.filter(column.cast(String).like(f'%{value}%'))
elif query_type == "find_in_set":
query = query.filter(func.find_in_set(value, column))
elif query_type == "range":
if value:
if len(value) == 1 and value[0] is not None:
query = query.filter(column >= value[0])
if len(value) == 2:
if value[0] is not None and value[1] is None:
query = query.filter(column >= value[0])
if value[0] is None and value[1] is not None:
query = query.filter(column <= value[1])
if value[0] is not None and value[1] is not None:
query = query.filter(and_(column >= value[0], column <= value[1]))
return query
def query_common_query_core(model, query: Query, query_data: BaseModel):
q_data = query_data.dict(exclude_unset=True)
cols: Dict[str, TypeInfo] = model.model_config.cols
for key, value in q_data:
annotation = query_data.__annotations__[key]
if key in cols:
column: Column = getattr(model, key)
col = cols[key]
col_base_type = col.col_base_type
if col_base_type in ['int', 'float', 'any', 'enum', 'str']:
if type(value) == list and value:
query = query.filter(column.in_(value))
else:
if col_base_type == 'str' and value:
query = query.filter(column.like(f'%{value}%'))
else:
query = query.filter(column == value)
elif col_base_type in ['date', 'datetime', 'time']:
if value:
if type(value) == list:
if len(value) == 2:
if value[0] is not None:
query = query.filter(column >= value)
if value[1] is not None:
query = query.filter(column <= value)
else:
if col_base_type == 'datetime':
# datetime转date来查
query = query.filter(func.date(column) == value)
else:
query = query.filter(column == value)
elif col_base_type == 'relation' and col.relation:
relation_type = col.relation.relation_type
if relation_type == 'm2m':
secondary = col.relation.secondary
target_secondary_key = col.relation.target_secondary_key
source_secondary_key = col.relation.source_secondary_key
if value and type(value) == list:
query = query.join(secondary, model.id == getattr(secondary, source_secondary_key)) \
.filter(getattr(secondary, target_secondary_key).in_(value))
else:
query = query.join(secondary, model.id == getattr(secondary, source_secondary_key)) \
.filter(getattr(secondary, target_secondary_key) == value)
elif relation_type == 'o2m':
target_model = col.relation.target_model
target_id_key = col.relation.target_id_key
if value and type(value) == list:
query = query.join(target_model).filter(getattr(target_model, target_id_key).in_(value))
else:
query = query.join(target_model).filter(getattr(target_model, target_id_key == value))
return query
def tree_table_to_json(db: Session, model, belong_str='belong', key_str='id'):
"""
树结构数据库存储到json
"""
item_list = db.query(model).all()
node_list = []
for item in item_list:
item_dic = item.to_dict()
item_dic['children'] = []
node_list.append(item_dic)
nodes_dic = {node[key_str]: node for node in node_list}
tree = []
for node in node_list:
belong = node[belong_str]
if belong:
if belong in nodes_dic:
nodes_dic[belong]['children'].append(node)
else:
tree.append(node)
return tree
def get_relation(col: InstrumentedAttribute):
prop = col.property
if isinstance(prop, Relationship):
model = prop.entity.entity
if prop.secondary is not None:
columns = prop.secondary.columns
foreign_key_0 = list(columns[0].foreign_keys)[0]
foreign_key_1 = list(columns[1].foreign_keys)[0]
if foreign_key_0.column.table.key == model.__tablename__:
target_secondary_key = columns[0].key
target_key = foreign_key_0.column.key
source_secondary_key = columns[1].key
source_key = foreign_key_1.column.key
else:
source_secondary_key = columns[0].key
source_key = foreign_key_0.column.key
target_secondary_key = columns[1].key
target_key = foreign_key_1.column.key
return {'relation': 'm2m',
'model': model,
'secondary': prop.secondary,
'source_secondary_key': source_secondary_key,
'source_key': source_key,
'target_secondary_key': target_secondary_key,
'target_key': target_key,
}
else:
return {'relation': 'm2o', 'model': model}
else:
return False
def get_relation_col_dic(model: Type[DeclarativeMeta]):
relation_col_dic = {
}
for col_name in dir(model):
col = getattr(model, col_name)
if isinstance(col, InstrumentedAttribute):
relation = get_relation(col)
if relation:
relation_col_dic[col_name] = relation
return relation_col_dic
# 一个模型的字段类型
# 基本类型
RelationType = Literal['m2m', 'm2o', 'o2m', 'o2o']
BaseType = Literal[
'int', 'float', 'bool', 'str', 'date', 'datetime', 'list', 'time', 'any', 'enum', 'set', 'relation', "json", "jsonb"]
class Relation:
relation_type: RelationType = None
source_model: Type[DeclarativeMeta] = None # 左表
target_model: Type[DeclarativeMeta] = None # 右表
secondary: Any = None # 中间表
source_secondary_key: str = None # 中间表左表字段
source_key: str = None # 左表在中间表使用的key
target_secondary_key: str = None # 中间表右表字段
target_key: str = None # 右表在中间表使用的key
source_id_key: str = None # 左表id
target_id_key: str = None # 右表id
source_foreign_key: str = None # 左表里字段对应的foreign_key字段
ColOrgTypeEnum = Literal[
"BigInteger",
"Boolean",
"Date",
"DateTime",
"Enum",
"Double",
"Float",
"Integer",
"Interval",
"LargeBinary",
"MatchType",
"Numeric",
"PickleType",
"SchemaType",
"SmallInteger",
"String",
"Text",
"Time",
"Unicode",
"UnicodeText",
"Uuid",
"JSON",
"JSONB",
'DeclarativeMeta',
]
class TypeInfo:
col_org_type: ColOrgTypeEnum = None
col_base_type: BaseType = None
relation: Relation = None
class ModelConfig:
cols: Dict[str, TypeInfo] = {}
id_key: str = 'id'
COL_ORG_TYPE_MAP = {
"BigInteger": 'int',
"Boolean": 'bool',
"Date": 'date',
"DateTime": 'datetime',
"Enum": 'enum',
"Double": 'float',
"Float": 'float',
"Integer": 'int',
"Interval": 'any',
"LargeBinary": 'any',
"MatchType": 'any',
"Numeric": 'any',
"PickleType": 'any',
"SchemaType": 'any',
"SmallInteger": 'int',
"String": 'str',
"Text": 'str',
"Time": 'time',
"Unicode": 'str',
"UnicodeText": 'str',
"Uuid": 'str',
"JSON": "json",
"JSONB": "jsonb",
'DeclarativeMeta': 'relation'
}
def get_model_id_key(model: Type[DeclarativeMeta]):
id_keys = list(get_primary_keys(model).keys())
print(id_keys)
if id_keys:
return id_keys[0]
def get_relation_info(col: InstrumentedAttribute) -> Relation:
prop = col.property
if isinstance(prop, Relationship):
r = Relation()
source_model = col.property.parent.entity
source_id_key = get_model_id_key(source_model)
target_model = prop.entity.entity
target_id_key = get_model_id_key(target_model)
source_uselist = prop.uselist
target_uselist = False
r.source_model = source_model
r.target_model = target_model
r.source_id_key = source_id_key
r.target_id_key = target_id_key
if col.property._user_defined_foreign_keys:
r.source_foreign_key = list(col.property._user_defined_foreign_keys)[0].name
print(r.source_foreign_key)
# 找到另一个表内反向引用的字段id
for name in get_model_col_names(target_model):
target_col = getattr(target_model, name)
col_org_type = get_type(target_col)
if source_model == col_org_type:
target_uselist = target_col.property.uselist
# 寻找是否右表有对应的foreign_key
if hasattr(target_col, 'foreign_keys'):
for foreign_key in list(target_col.foreign_keys):
if foreign_key.column.table.name == source_model.__tablename__:
r.source_key = foreign_key.column.key
# 找到本表内反向引用的字段id
# for name in get_model_col_names(source_model):
# source_col = getattr(target_model, name)
# col_org_type = get_type(source_col)
# if target_model == col_org_type:
# source_uselist = source_col.property.uselist
# # 寻找是否右表有对应的foreign_key
# if hasattr(source_col,'foreign_keys'):
# for foreign_key in List[source_col.foreign_keys]:
# if foreign_key.column.table.name==target_model.__tablename__:
# r.target_key=foreign_key.column.key
if (source_uselist and target_uselist) or prop.secondary is not None:
r.relation_type = 'm2m'
columns = prop.secondary.columns
foreign_key_0 = list(columns[0].foreign_keys)[0]
foreign_key_1 = list(columns[1].foreign_keys)[0]
if foreign_key_0.column.table.key == target_model.__tablename__:
target_secondary_key = columns[0].key
target_key = foreign_key_0.column.key
source_secondary_key = columns[1].key
source_key = foreign_key_1.column.key
else:
source_secondary_key = columns[0].key
source_key = foreign_key_0.column.key
target_secondary_key = columns[1].key
target_key = foreign_key_1.column.key
r.secondary = prop.secondary
r.source_secondary_key = source_secondary_key
r.target_secondary_key = target_secondary_key
r.source_key = source_key
r.target_key = target_key
elif source_uselist == True and target_uselist == False:
r.relation_type = 'o2m'
elif source_uselist == False and target_uselist == False:
r.relation_type = 'o2o'
elif source_uselist == False and target_uselist == True:
r.relation_type = 'm2o'
return r
else:
return None
def get_model_config(model: Type[DeclarativeMeta]) -> ModelConfig:
"""
这个接口获取一个sql模型的各种信息
:param model:
:return:
"""
# 映射
model_config = ModelConfig()
cols = {}
for col_name in dir(model):
col = getattr(model, col_name)
if isinstance(col, InstrumentedAttribute):
if hasattr(col, 'primary_key') and col.primary_key:
model_config.id_key = col_name
type_info = TypeInfo()
col_org_type = get_type(col).__class__.__name__
col_base_type = COL_ORG_TYPE_MAP.get(col_org_type) or 'any'
relation = get_relation_info(col)
type_info.col_org_type = col_org_type
type_info.col_base_type = col_base_type
type_info.relation = relation
cols[col_name] = type_info
model_config.cols = cols
return model_config
def get_model_col_names(model: Type[DeclarativeMeta]):
columns = []
for col_name in dir(model):
if isinstance(getattr(model, col_name), InstrumentedAttribute):
columns.append(col_name)
return columns
class SalBase:
__model_config: ModelConfig = None
@classmethod
@property
def model_config(cls) -> ModelConfig:
if cls.__model_config is None:
cls.__model_config = 'process_ing'
cls.__model_config = get_model_config(cls)
return cls.__model_config
def to_dict(self, include: List[str] = None, ex_include: List[str] = None):
col_names = marge_col_names([c.name for c in self.__table__.columns], include, ex_include)
data = {name: getattr(self, name) for name in col_names}
return data
def to_full_dict(self, include: List[str] = None, ex_include: List[str] = None, relation_use_id=False):
col_names = marge_col_names(set(self.model_config.cols.keys()), include, ex_include)
data = {}
for col_name in col_names:
col = self.model_config.cols[col_name]
col_val = getattr(self, col_name)
relation = col.relation
if relation:
if relation.relation_type == 'm2m':
if relation_use_id:
col_data = [getattr(item, relation.target_id_key) for item in col_val]
else:
col_data = [item.to_dict() for item in col_val]
data[col_name] = col_data
elif relation.relation_type == 'o2m':
if relation_use_id:
col_data = [getattr(item, relation.target_id_key) for item in col_val]
else:
col_data = [item.to_dict() for item in col_val]
data[col_name] = col_data
elif relation.relation_type == 'o2o':
data[col_name] = col_val and col_val.to_dict()
elif relation.relation_type == 'm2o':
data[col_name] = col_val and col_val.to_dict()
else:
data[col_name] = col_val
else:
data[col_name] = col_val
return data
class TableModel(DeclarativeMeta, SalBase):
pass
class ModelCRUD:
def __init__(self, model: Type[TableModel]):
self.model = model
def update(self, db: Session, data: BaseModel):
model = self.model
id_key = self.model.model_config.id_key
item = db.query(model).filter(getattr(model, id_key) == (getattr(data, id_key))).first()
data = data.dict(exclude_unset=True)
col_names = get_model_col_names(model)
model_config: ModelConfig = model.model_config
for k, v in data.items():
if '.' in k:
paths = k.split('.')
if paths[0] in col_names:
col_type_info: TypeInfo = model.model_config.cols[paths[0]]
if col_type_info.relation:
# 一对一时,{"config.age":1}
relation = col_type_info.relation
if relation.relation_type in ["o2o"]:
col = getattr(item, k)
if not col:
if v is not None:
setattr(item, paths[0], relation.target_model(**{paths[1]: v}))
else:
setattr(col, paths[1], v)
if k in col_names:
col_type_info: TypeInfo = model.model_config.cols[k]
# col = getattr(model, k)
relation = col_type_info.relation
if relation:
if relation.relation_type in ['m2m']:
if v is not None:
if v and type(v[0]) == dict:
v = [item[model_config.id_key] for item in v]
# 多对多更新的值是一个id数组所以先用id查出对应项再更新
target_model = relation.target_model
target_key = relation.target_id_key
update_values = {child for child in
db.query(target_model).filter(getattr(target_model, target_key).in_(v))}
setattr(item, k, update_values)
elif relation.relation_type == 'o2o':
if type(v) == dict:
col = getattr(item, k)
if not col:
setattr(item, k, relation.target_model(**v))
else:
for v_k, v_v in v.items():
setattr(col, v_k, v_v)
else:
setattr(item, k, v)
db.commit()
db.refresh(item)
return item
def get(self, db: Session, data: BaseModel):
2023-08-02 16:32:36 +08:00
item = db.query(self.model).filter(getattr(self.model, self.model.model_config.id_key) == (
getattr(data, self.model.model_config.id_key))).first()
2023-08-02 14:24:28 +08:00
return item
def add(self, db: Session, data: BaseModel):
model = self.model
item = model()
data = data.dict(exclude_unset=True)
col_names = get_model_col_names(model)
for k, v in data.items():
if k in col_names:
col_type_info: TypeInfo = model.model_config.cols[k]
# col = getattr(model, k)
relation = col_type_info.relation
if relation:
if relation.relation_type == 'm2m':
if v is not None:
target_model = relation.target_model
update_values = {child for child in
db.query(target_model).filter(
getattr(target_model, relation.target_id_key).in_(v))}
setattr(item, k, update_values)
elif relation.relation_type == 'o2o':
if type(v) == dict:
setattr(item, k, relation.target_model(**v))
else:
setattr(item, k, v)
db.add(item)
db.commit()
db.refresh(item)
return item
def query(self, db: Session, data: QueryParams) -> (int, List[Any]):
return query_common(db, self.model, data)
def delete(self, db: Session, data: BaseModel):
2023-08-02 16:32:36 +08:00
db.query(self.model).filter(getattr(self.model, self.model.model_config.id_key) == (
getattr(data, self.model.model_config.id_key))).delete()
2023-08-02 14:24:28 +08:00
db.commit()
return True