112 lines
4.1 KiB
Python
112 lines
4.1 KiB
Python
from fastapi import HTTPException
|
|
|
|
from Utils.RedisUtils import RedisPool
|
|
|
|
|
|
class ApiLimit:
|
|
"""
|
|
重要接口使用统计与限制
|
|
"""
|
|
|
|
def __init__(self, redis_pool: RedisPool):
|
|
self.redis_pool = redis_pool
|
|
|
|
class ApiLimitCheck:
|
|
def __init__(self, api_key="", ip="", count=1, time_s=10):
|
|
self.api_key = api_key
|
|
self.ip = ip
|
|
self.count = count
|
|
self.time_s = time_s
|
|
|
|
def get_record_id(self):
|
|
print(f"ApiLimitCheck_{self.api_key}_{self.ip}")
|
|
return f"ApiLimitCheck_{self.api_key}_{self.ip}"
|
|
|
|
def record(self):
|
|
record_id = self.get_record_id()
|
|
client = redis_pool.get_redis_client()
|
|
last_count = client.get(record_id)
|
|
if last_count is None:
|
|
client.set(record_id, '1')
|
|
client.expire(record_id, self.time_s)
|
|
else:
|
|
last_count = int(last_count) + 1
|
|
client.set(record_id, str(last_count))
|
|
|
|
def check(self):
|
|
record_id = self.get_record_id()
|
|
client = redis_pool.get_redis_client()
|
|
last_count = client.get(record_id)
|
|
if last_count is None:
|
|
pass
|
|
else:
|
|
if int(last_count) >= self.count:
|
|
raise HTTPException(detail="接口调用过快", status_code=403)
|
|
|
|
self.ApiLimitCheck = ApiLimitCheck
|
|
|
|
# def limit_by_time_count_check(self, api_key, count, time_s):
|
|
# """
|
|
# 限制接口在一段时间内的调用次数
|
|
# @return:
|
|
# """
|
|
# api_id = f"limit_by_time_count_check_{api_key}"
|
|
# last_count = self.redis_pool.conn.get(api_id)
|
|
# if not last_count:
|
|
# self.redis_pool.conn.set(api_id, 1)
|
|
# self.redis_pool.conn.expire(api_id, time_s)
|
|
# else:
|
|
# last_count = int(last_count)
|
|
# if last_count > count:
|
|
# raise HTTPException(status_code=303, detail=f"{api_key}接口调用次数已达该时段上限")
|
|
# else:
|
|
# self.redis_pool.conn.set(api_id, last_count)
|
|
#
|
|
# def limit_by_time_ip_count_check(self, api_key, count, time_s, ip):
|
|
# """
|
|
# 限制某一ip在一段时间内接口的调用次数
|
|
# @return:
|
|
# """
|
|
# api_id = f"limit_by_time_ip_count_check_{api_key}_{ip}"
|
|
# last_count = self.redis_pool.conn.get(api_id)
|
|
# if not last_count:
|
|
# self.redis_pool.conn.set(api_id, 1)
|
|
# self.redis_pool.conn.expire(api_id, time_s)
|
|
# else:
|
|
# last_count = int(last_count)
|
|
# if last_count > count:
|
|
# raise HTTPException(status_code=303, detail=f"{api_key}接口{ip}调用次数已达该时段上限")
|
|
# else:
|
|
# self.redis_pool.conn.set(api_id, last_count)
|
|
#
|
|
# def limit_by_time_check(self, api_key, time_s: int):
|
|
# """
|
|
# 限制某一接口在一段时间内接口的调用次数
|
|
# @return:
|
|
# """
|
|
# api_id = f"limit_by_time_check_{api_key}"
|
|
# last_count = self.redis_pool.conn.get(api_id)
|
|
# if not last_count:
|
|
# self.redis_pool.conn.set(api_id, 1)
|
|
# self.redis_pool.conn.expire(api_id, time_s)
|
|
# else:
|
|
# raise HTTPException(status_code=303, detail=f"{api_key}接口调用过快")
|
|
#
|
|
# def limit_by_time_ip_check(self, api_key, time_s: int, ip: str):
|
|
# """
|
|
# 限制某一ip对接口在一段时间内的调用
|
|
# @return:
|
|
# """
|
|
# api_id = f"limit_by_time_ip_check_{api_key}_{ip}"
|
|
# last_count = self.redis_pool.conn.get(api_id)
|
|
# if not last_count:
|
|
# self.redis_pool.conn.set(api_id, 1)
|
|
# self.redis_pool.conn.expire(api_id, time_s)
|
|
# else:
|
|
# raise HTTPException(status_code=303, detail=f"{api_key}接口调用过快")
|
|
#
|
|
# def limit_api_record(self):
|
|
# pass
|
|
#
|
|
# def limit_api_check(self, api_key="", ip="", count=1, time_s=60):
|