109 lines
3.0 KiB
Python
109 lines
3.0 KiB
Python
from json import JSONDecodeError
|
|
from fastapi import Request, status
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.encoders import jsonable_encoder
|
|
from fastapi.exceptions import RequestValidationError
|
|
from pydantic import BaseModel
|
|
|
|
try:
|
|
from context.common import uvicorn_log
|
|
except Exception as e:
|
|
pass
|
|
uvicorn_log = None
|
|
|
|
|
|
class ReqData(BaseModel):
|
|
url: str = ""
|
|
method: str = ""
|
|
client_ip: str = ""
|
|
body: bytes = b''
|
|
data: dict = {}
|
|
|
|
|
|
async def set_body(request: Request):
|
|
receive_ = await request._receive()
|
|
|
|
async def receive():
|
|
return receive_
|
|
|
|
request._receive = receive
|
|
|
|
|
|
async def format_request(req: Request) -> ReqData:
|
|
url = req.url
|
|
method = req.method
|
|
client_ip = f"{req.client[0]}:{req.client[1]}"
|
|
|
|
req_data = ReqData()
|
|
req_data.url = str(url)
|
|
req_data.method = method
|
|
req_data.client_ip = client_ip
|
|
body = b''
|
|
data = {}
|
|
try:
|
|
await set_body(req)
|
|
body = await req.body()
|
|
try:
|
|
data = await req.json()
|
|
except JSONDecodeError:
|
|
data = {}
|
|
except RuntimeError:
|
|
pass
|
|
req_data.body = body
|
|
req_data.data = data
|
|
return req_data
|
|
|
|
|
|
async def exception_handler(req: Request, exc: Exception):
|
|
req_data = await format_request(req)
|
|
if uvicorn_log:
|
|
uvicorn_log.info("请求错误")
|
|
uvicorn_log.info(req_data)
|
|
else:
|
|
print("请求错误")
|
|
print(req_data)
|
|
raise exc
|
|
|
|
|
|
async def validation_exception_handler(req: Request, exc: RequestValidationError):
|
|
req_data = await format_request(req)
|
|
if uvicorn_log:
|
|
uvicorn_log.info("请求错误")
|
|
uvicorn_log.info(req_data)
|
|
else:
|
|
print("请求错误")
|
|
print(req_data)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
content=jsonable_encoder({"detail": exc.errors(),
|
|
"body": exc.body}))
|
|
|
|
|
|
async def logger_request(req: Request, call_next):
|
|
req_data = await format_request(req)
|
|
response = await call_next(req)
|
|
return response
|
|
|
|
|
|
# async def record_middleware(req: Request, call_next):
|
|
# """
|
|
# 编辑记录的中间件
|
|
# """
|
|
# req_url_split = req.url.path.split("/")
|
|
# if len(req_url_split) > 1 and req_url_split[-1] in ['delete', 'update', 'add']:
|
|
# convert_str = "".join([item[0].upper() + item[1:] for item in req_url_split[-2].split("_")])
|
|
# if convert_str in recordTypeId._member_names_:
|
|
# req_data = await format_request(req)
|
|
#
|
|
# res = await call_next(req)
|
|
# # req_url_split = req.url.path.split("/")
|
|
# # if len(req_url_split) > 1 and req_url_split[-1] in ['delete', 'update', 'add']:
|
|
# # convert_str = "".join([item[0].upper() + item[1:] for item in req_url_split[-2].split("_")])
|
|
# # if convert_str in self.recordTypeId._member_names_:
|
|
# # pass
|
|
#
|
|
# # print(res)
|
|
# # req_data=await req.json()
|
|
# print("xxxxxxx",req_data)
|
|
# return res
|