urban-investment-research/Utils/MiddlewareUtils.py

109 lines
3.0 KiB
Python
Raw Permalink Normal View History

2023-03-13 14:22:40 +08:00
from json import JSONDecodeError
from fastapi import Request, status
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel
try:
from Context.common import uvicorn_log
except Exception as e:
pass
uvicorn_log = None
class ReqData(BaseModel):
url: str = ""
method: str = ""
client_ip: str = ""
body: bytes = b''
data: dict = {}
async def set_body(request: Request):
receive_ = await request._receive()
async def receive():
return receive_
request._receive = receive
async def format_request(req: Request) -> ReqData:
url = req.url
method = req.method
client_ip = f"{req.client[0]}:{req.client[1]}"
req_data = ReqData()
req_data.url = str(url)
req_data.method = method
req_data.client_ip = client_ip
2023-03-20 14:11:37 +08:00
body = b''
2023-03-13 14:22:40 +08:00
data = {}
try:
await set_body(req)
body = await req.body()
try:
data = await req.json()
except JSONDecodeError:
data = {}
except RuntimeError:
pass
req_data.body = body
req_data.data = data
return req_data
async def exception_handler(req: Request, exc: Exception):
req_data = await format_request(req)
if uvicorn_log:
uvicorn_log.info("请求错误")
uvicorn_log.info(req_data)
else:
print("请求错误")
print(req_data)
raise exc
async def validation_exception_handler(req: Request, exc: RequestValidationError):
req_data = await format_request(req)
if uvicorn_log:
uvicorn_log.info("请求错误")
uvicorn_log.info(req_data)
else:
print("请求错误")
print(req_data)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(),
"body": exc.body}))
async def logger_request(req: Request, call_next):
req_data = await format_request(req)
response = await call_next(req)
return response
2023-03-20 14:11:37 +08:00
# async def record_middleware(req: Request, call_next):
# """
# 编辑记录的中间件
# """
# req_url_split = req.url.path.split("/")
# if len(req_url_split) > 1 and req_url_split[-1] in ['delete', 'update', 'add']:
# convert_str = "".join([item[0].upper() + item[1:] for item in req_url_split[-2].split("_")])
# if convert_str in recordTypeId._member_names_:
# req_data = await format_request(req)
#
# res = await call_next(req)
# # req_url_split = req.url.path.split("/")
# # if len(req_url_split) > 1 and req_url_split[-1] in ['delete', 'update', 'add']:
# # convert_str = "".join([item[0].upper() + item[1:] for item in req_url_split[-2].split("_")])
# # if convert_str in self.recordTypeId._member_names_:
# # pass
#
# # print(res)
# # req_data=await req.json()
# print("xxxxxxx",req_data)
# return res