from typing import Literal, Type, Any, Dict, List from sqlalchemy.orm import DeclarativeMeta, InstrumentedAttribute, Relationship from sqlalchemy_utils import get_type, get_primary_keys from .types.model_config import Relation, ModelConfig, TypeInfo, COL_ORG_TYPE_MAP # 一个模型的字段类型 # 基本类型 def get_model_id_key(model: Type[DeclarativeMeta]): id_keys = list(get_primary_keys(model).keys()) print(id_keys) if id_keys: return id_keys[0] def get_model_col_names(model: Type[DeclarativeMeta]): columns = [] for col_name in dir(model): if isinstance(getattr(model, col_name), InstrumentedAttribute): columns.append(col_name) return columns def get_relation_info(col: InstrumentedAttribute) -> Relation: prop = col.property if isinstance(prop, Relationship): r = Relation() source_model = col.property.parent.entity source_id_key = get_model_id_key(source_model) target_model = prop.entity.entity target_id_key = get_model_id_key(target_model) source_uselist = prop.uselist target_uselist = False r.source_model = source_model r.target_model = target_model r.source_id_key = source_id_key r.target_id_key = target_id_key if col.property._user_defined_foreign_keys: r.source_foreign_key = list(col.property._user_defined_foreign_keys)[0].name print(r.source_foreign_key) # 找到另一个表内反向引用的字段id for name in get_model_col_names(target_model): target_col = getattr(target_model, name) col_org_type = get_type(target_col) if source_model == col_org_type: target_uselist = target_col.property.uselist # 寻找是否右表有对应的foreign_key if hasattr(target_col, 'foreign_keys'): for foreign_key in list(target_col.foreign_keys): if foreign_key.column.table.name == source_model.__tablename__: r.source_key = foreign_key.column.key if (source_uselist and target_uselist) or prop.secondary is not None: r.relation_type = 'm2m' columns = prop.secondary.columns foreign_key_0 = list(columns[0].foreign_keys)[0] foreign_key_1 = list(columns[1].foreign_keys)[0] if foreign_key_0.column.table.key == target_model.__tablename__: target_secondary_key = columns[0].key target_key = foreign_key_0.column.key source_secondary_key = columns[1].key source_key = foreign_key_1.column.key else: source_secondary_key = columns[0].key source_key = foreign_key_0.column.key target_secondary_key = columns[1].key target_key = foreign_key_1.column.key r.secondary = prop.secondary r.source_secondary_key = source_secondary_key r.target_secondary_key = target_secondary_key r.source_key = source_key r.target_key = target_key elif source_uselist == True and target_uselist == False: r.relation_type = 'o2m' elif source_uselist == False and target_uselist == False: r.relation_type = 'o2o' elif source_uselist == False and target_uselist == True: r.relation_type = 'm2o' return r else: return None def get_model_config(model: Type[DeclarativeMeta]) -> ModelConfig: """ 这个接口获取一个sql模型的各种信息 :param model: :return: """ # 映射 model_config = ModelConfig() cols = {} for col_name in dir(model): col = getattr(model, col_name) if isinstance(col, InstrumentedAttribute): if hasattr(col, 'primary_key') and col.primary_key: model_config.id_key = col_name type_info = TypeInfo() col_org_type = get_type(col).__class__.__name__ col_base_type = COL_ORG_TYPE_MAP.get(col_org_type) or 'any' relation = get_relation_info(col) type_info.col_org_type = col_org_type type_info.col_base_type = col_base_type type_info.relation = relation cols[col_name] = type_info model_config.cols = cols return model_config def marge_col_names(col_names, include: List[str] = None, ex_include: List[str] = None): if include: col_names = [item for item in include if item in col_names] if ex_include: col_names = [item for item in col_names if item not in ex_include] return set(col_names) class SalBase: __model_config: ModelConfig = None @classmethod @property def model_config(cls) -> ModelConfig: if cls.__model_config is None: cls.__model_config = 'process_ing' cls.__model_config = get_model_config(cls) return cls.__model_config def to_dict(self, include: List[str] = None, ex_include: List[str] = None): col_names = marge_col_names([c.name for c in self.__table__.columns], include, ex_include) data = {name: getattr(self, name) for name in col_names} return data def to_full_dict(self, include: List[str] = None, ex_include: List[str] = None, relation_use_id=False): col_names = marge_col_names(set(self.model_config.cols.keys()), include, ex_include) data = {} for col_name in col_names: col = self.model_config.cols[col_name] col_val = getattr(self, col_name) relation = col.relation if relation: if relation.relation_type == 'm2m': if relation_use_id: col_data = [getattr(item, relation.target_id_key) for item in col_val] else: col_data = [item.to_dict() for item in col_val] data[col_name] = col_data elif relation.relation_type == 'o2m': if relation_use_id: col_data = [getattr(item, relation.target_id_key) for item in col_val] else: col_data = [item.to_dict() for item in col_val] data[col_name] = col_data elif relation.relation_type == 'o2o': data[col_name] = col_val and col_val.to_dict() elif relation.relation_type == 'm2o': data[col_name] = col_val and col_val.to_dict() else: data[col_name] = col_val else: data[col_name] = col_val return data