wd-rating/utils/sal_utils/model_config_utils.py

166 lines
6.6 KiB
Python

from typing import Literal, Type, Any, Dict, List
from sqlalchemy.orm import DeclarativeMeta, InstrumentedAttribute, Relationship
from sqlalchemy_utils import get_type, get_primary_keys
from .types.model_config import Relation, ModelConfig, TypeInfo, COL_ORG_TYPE_MAP
# 一个模型的字段类型
# 基本类型
def get_model_id_key(model: Type[DeclarativeMeta]):
id_keys = list(get_primary_keys(model).keys())
print(id_keys)
if id_keys:
return id_keys[0]
def get_model_col_names(model: Type[DeclarativeMeta]):
columns = []
for col_name in dir(model):
if isinstance(getattr(model, col_name), InstrumentedAttribute):
columns.append(col_name)
return columns
def get_relation_info(col: InstrumentedAttribute) -> Relation:
prop = col.property
if isinstance(prop, Relationship):
r = Relation()
source_model = col.property.parent.entity
source_id_key = get_model_id_key(source_model)
target_model = prop.entity.entity
target_id_key = get_model_id_key(target_model)
source_uselist = prop.uselist
target_uselist = False
r.source_model = source_model
r.target_model = target_model
r.source_id_key = source_id_key
r.target_id_key = target_id_key
if col.property._user_defined_foreign_keys:
r.source_foreign_key = list(col.property._user_defined_foreign_keys)[0].name
print(r.source_foreign_key)
# 找到另一个表内反向引用的字段id
for name in get_model_col_names(target_model):
target_col = getattr(target_model, name)
col_org_type = get_type(target_col)
if source_model == col_org_type:
target_uselist = target_col.property.uselist
# 寻找是否右表有对应的foreign_key
if hasattr(target_col, 'foreign_keys'):
for foreign_key in list(target_col.foreign_keys):
if foreign_key.column.table.name == source_model.__tablename__:
r.source_key = foreign_key.column.key
if (source_uselist and target_uselist) or prop.secondary is not None:
r.relation_type = 'm2m'
columns = prop.secondary.columns
foreign_key_0 = list(columns[0].foreign_keys)[0]
foreign_key_1 = list(columns[1].foreign_keys)[0]
if foreign_key_0.column.table.key == target_model.__tablename__:
target_secondary_key = columns[0].key
target_key = foreign_key_0.column.key
source_secondary_key = columns[1].key
source_key = foreign_key_1.column.key
else:
source_secondary_key = columns[0].key
source_key = foreign_key_0.column.key
target_secondary_key = columns[1].key
target_key = foreign_key_1.column.key
r.secondary = prop.secondary
r.source_secondary_key = source_secondary_key
r.target_secondary_key = target_secondary_key
r.source_key = source_key
r.target_key = target_key
elif source_uselist == True and target_uselist == False:
r.relation_type = 'o2m'
elif source_uselist == False and target_uselist == False:
r.relation_type = 'o2o'
elif source_uselist == False and target_uselist == True:
r.relation_type = 'm2o'
return r
else:
return None
def get_model_config(model: Type[DeclarativeMeta]) -> ModelConfig:
"""
这个接口获取一个sql模型的各种信息
:param model:
:return:
"""
# 映射
model_config = ModelConfig()
cols = {}
for col_name in dir(model):
col = getattr(model, col_name)
if isinstance(col, InstrumentedAttribute):
if hasattr(col, 'primary_key') and col.primary_key:
model_config.id_key = col_name
type_info = TypeInfo()
col_org_type = get_type(col).__class__.__name__
col_base_type = COL_ORG_TYPE_MAP.get(col_org_type) or 'any'
relation = get_relation_info(col)
type_info.col_org_type = col_org_type
type_info.col_base_type = col_base_type
type_info.relation = relation
cols[col_name] = type_info
model_config.cols = cols
return model_config
def marge_col_names(col_names, include: List[str] = None, ex_include: List[str] = None):
if include:
col_names = [item for item in include if item in col_names]
if ex_include:
col_names = [item for item in col_names if item not in ex_include]
return set(col_names)
class SalBase:
__model_config: ModelConfig = None
@classmethod
@property
def model_config(cls) -> ModelConfig:
if cls.__model_config is None:
cls.__model_config = 'process_ing'
cls.__model_config = get_model_config(cls)
return cls.__model_config
def to_dict(self, include: List[str] = None, ex_include: List[str] = None):
col_names = marge_col_names([c.name for c in self.__table__.columns], include, ex_include)
data = {name: getattr(self, name) for name in col_names}
return data
def to_full_dict(self, include: List[str] = None, ex_include: List[str] = None, relation_use_id=False):
col_names = marge_col_names(set(self.model_config.cols.keys()), include, ex_include)
data = {}
for col_name in col_names:
col = self.model_config.cols[col_name]
col_val = getattr(self, col_name)
relation = col.relation
if relation:
if relation.relation_type == 'm2m':
if relation_use_id:
col_data = [getattr(item, relation.target_id_key) for item in col_val]
else:
col_data = [item.to_dict() for item in col_val]
data[col_name] = col_data
elif relation.relation_type == 'o2m':
if relation_use_id:
col_data = [getattr(item, relation.target_id_key) for item in col_val]
else:
col_data = [item.to_dict() for item in col_val]
data[col_name] = col_data
elif relation.relation_type == 'o2o':
data[col_name] = col_val and col_val.to_dict()
elif relation.relation_type == 'm2o':
data[col_name] = col_val and col_val.to_dict()
else:
data[col_name] = col_val
else:
data[col_name] = col_val
return data