198 lines
7.1 KiB
Python
198 lines
7.1 KiB
Python
from pymongo import MongoClient, errors
|
||
from pydantic import BaseModel, ValidationError, parse_obj_as
|
||
from typing import Any, Dict, Optional, Union, Type
|
||
import hashlib
|
||
import json
|
||
from datetime import datetime
|
||
from typing import Literal
|
||
from jsonschema import validate
|
||
from pathlib import Path
|
||
from tempfile import TemporaryDirectory
|
||
from datamodel_code_generator import InputFileType, generate
|
||
from datamodel_code_generator import DataModelType
|
||
|
||
"""
|
||
{
|
||
name:"",
|
||
schema:{},
|
||
from :"man"|"api",
|
||
apiConfig:{
|
||
"url":"",
|
||
"params":{},
|
||
"paging":False,
|
||
"method":GET|POST
|
||
}
|
||
check_repeated:False,True,"normal","latest"
|
||
}
|
||
|
||
"""
|
||
|
||
|
||
def json_schema_to_pydantic(name, schema) -> Type[BaseModel]:
|
||
try:
|
||
json_schema: str = json.dumps(schema)
|
||
names = {name: None}
|
||
with TemporaryDirectory() as temporary_directory_name:
|
||
temporary_directory = Path(temporary_directory_name)
|
||
output = Path(temporary_directory / f'{name}.py')
|
||
generate(
|
||
json_schema,
|
||
input_file_type=InputFileType.JsonSchema,
|
||
input_filename=f"{name}.json",
|
||
output=output,
|
||
class_name=name,
|
||
# set up the output model types
|
||
output_model_type=DataModelType.PydanticBaseModel,
|
||
encoding="utf-8"
|
||
)
|
||
model: str = output.read_text(encoding="utf-8")
|
||
exec(model, names, names)
|
||
return names[name]
|
||
except Exception as e:
|
||
print(f"jsonschema转pydantic出现错误:\n{name}:\n{schema}\n{e}")
|
||
|
||
|
||
class JsonDataManage:
|
||
|
||
def __init__(self, mongo_client: MongoClient, db_name: str = "json数据"):
|
||
self.mongo_client = mongo_client[db_name]
|
||
self.json_schema_dict = {}
|
||
self.models = {}
|
||
self.schema_collection = self.mongo_client['schema']
|
||
self.schema_collection.create_index('name', unique=True)
|
||
self.load_schemas()
|
||
|
||
def load_schemas(self):
|
||
# 从MongoDB加载schemas
|
||
try:
|
||
for document in self.schema_collection.find():
|
||
name = document['name']
|
||
schema = document['schema']
|
||
self.json_schema_dict[name] = schema
|
||
model = json_schema_to_pydantic(name, schema)
|
||
if model:
|
||
self.models[name] = model
|
||
except errors.PyMongoError as e:
|
||
print(f"Error loading schemas: {e}")
|
||
|
||
def add_schema(self, name: str, schema: Dict[str, Any]):
|
||
# 添加一个新的schema
|
||
try:
|
||
self.schema_collection.insert_one({'name': name, 'schema': schema})
|
||
self.json_schema_dict[name] = schema
|
||
self.models[name] = BaseModel.parse_obj(schema)
|
||
except errors.DuplicateKeyError:
|
||
print(f"Error: a schema with the name {name} already exists.")
|
||
except errors.PyMongoError as e:
|
||
print(f"Error adding schema: {e}")
|
||
|
||
def remove_schema(self, name: str):
|
||
# 删除一个schema
|
||
try:
|
||
self.schema_collection.delete_one({'name': name})
|
||
del self.json_schema_dict[name]
|
||
del self.models[name]
|
||
except KeyError:
|
||
print(f"Error: no schema with the name {name} exists.")
|
||
except errors.PyMongoError as e:
|
||
print(f"Error removing schema: {e}")
|
||
|
||
def update_schema(self, name: str, schema: Dict[str, Any]):
|
||
# 更新一个schema
|
||
try:
|
||
self.schema_collection.update_one({'name': name}, {'$set': {'schema': schema}})
|
||
self.json_schema_dict[name] = schema
|
||
self.models[name] = BaseModel.parse_obj(schema)
|
||
except KeyError:
|
||
print(f"Error: no schema with the name {name} exists.")
|
||
except errors.PyMongoError as e:
|
||
print(f"Error updating schema: {e}")
|
||
|
||
@staticmethod
|
||
def validation(model: Type[BaseModel], data: Dict[str, Any]):
|
||
# 使用json_schema验证data
|
||
try:
|
||
return model.parse_obj(data).dict()
|
||
except ValidationError as e:
|
||
raise Exception(f"数据校验失败:{e}")
|
||
|
||
def insert_data(self, collection_name: str, data: Dict[str, Any], index: Dict[str, Any] = {},
|
||
check_repeated: Union[Literal["normal", "latest"], bool] = False):
|
||
validation_data = self.validation(self.models[collection_name], data)
|
||
# 计算数据的哈希值
|
||
data_hash = self.get_data_hash(validation_data)
|
||
if check_repeated:
|
||
if check_repeated is True:
|
||
check_repeated_model = "normal"
|
||
else:
|
||
check_repeated_model = check_repeated
|
||
if self.check_repeated(collection_name, validation_data, index=index, model=check_repeated_model):
|
||
raise Exception(f"数据已经存在")
|
||
# 创建一个新的文档,包含数据和哈希值
|
||
document = {
|
||
"index": index,
|
||
"data": validation_data,
|
||
"create_time": datetime.now(),
|
||
"update_time": datetime.now(),
|
||
"hash": data_hash
|
||
}
|
||
|
||
# 插入文档
|
||
collection = self.mongo_client[collection_name]
|
||
try:
|
||
result = collection.insert_one(document)
|
||
return result.inserted_id
|
||
except errors.PyMongoError as e:
|
||
print(f"Error inserting data: {e}")
|
||
|
||
def check_repeated(self, collection_name: str, data: Dict[str, Any], index={},
|
||
model: Literal["normal", "latest"] = "normal") -> bool:
|
||
# 计算数据的哈希值
|
||
data_hash = self.get_data_hash(data)
|
||
|
||
# 判断数据是否已存在
|
||
collection = self.mongo_client[collection_name]
|
||
try:
|
||
if model == 'normal':
|
||
if collection.find_one({"hash": data_hash, "index": index}):
|
||
return True
|
||
else:
|
||
return False
|
||
if model == 'latest':
|
||
# latest模式的话数据跟最后一条相同的话就不更新了
|
||
if collection.find_one(sort=[('_id', -1)], filter={"hash": data_hash, "index": index}):
|
||
return True
|
||
else:
|
||
return False
|
||
|
||
except errors.PyMongoError as e:
|
||
print(f"Error checking for repeated data: {e}")
|
||
|
||
def query(self, collection_name: str, query: Dict[str, Any]):
|
||
"""
|
||
根据query查询返回查询结果
|
||
:param collection_name:
|
||
:param query: 查询条件
|
||
:return:
|
||
"""
|
||
collection = self.mongo_client[collection_name]
|
||
try:
|
||
return collection.find(query)
|
||
except errors.PyMongoError as e:
|
||
print(f"Error executing query: {e}")
|
||
|
||
def get_data_hash(self, data: dict) -> str:
|
||
"""
|
||
获取数据的hash值
|
||
:param data:
|
||
:return:
|
||
"""
|
||
# 将数据转换为JSON字符串,然后编码为字节串
|
||
data_bytes = json.dumps(data, sort_keys=True).encode()
|
||
|
||
# 使用SHA256算法计算哈希值
|
||
hash_object = hashlib.sha256(data_bytes)
|
||
|
||
# 返回十六进制表示的哈希值
|
||
return hash_object.hexdigest()
|