aiogram Telegram机器人消息发送速率限制的实现问题及优化方案咨询
aiogram Telegram机器人消息发送速率限制的实现问题及优化方案咨询
我来帮你拆解现有实现的问题,同时给出更可靠、符合aiogram框架习惯的优化方案:
一、现有代码的明显问题
1. 速率限制规则定义完全错误
这是导致限制不生效的核心原因!你需要的是3条/3秒的单聊天限制,但当前代码用了RateLimitItemPerSecond(chat_rate)——这个类的作用是定义「每秒允许的消息数」,所以RateLimitItemPerSecond(3)实际表示3条/秒,和你的需求完全不符。
正确的做法是用RateLimitItem类自定义时间窗口:
from limits import RateLimitItem # 替换原来的RateLimitItemPerSecond初始化 self.global_limit = RateLimitItem(global_rate, global_per) # 25条/1秒 self.chat_limit = RateLimitItem(chat_rate, chat_per) # 3条/3秒
2. Test+Hit的非原子操作引发竞态条件
现有代码中先调用limiter.test再调用limiter.hit,这两个操作不是原子的。在高并发场景下,多个请求可能同时通过test检查,然后都执行hit,最终导致实际发送量超过限制。
应该直接使用limiter.hit的返回值判断是否获取到配额——hit是原子操作,会同时完成检查和计数:
async def wait_for_slot(self, key: str, limit: RateLimitItem): while True: # 原子性地尝试获取配额,成功则返回 if await self.limiter.hit(limit, key): return # 短间隔重试,避免不必要的长等待 await asyncio.sleep(0.1)
3. 睡眠时间计算逻辑不合理
原来的await asyncio.sleep(limit.get_expiry() / limit.amount)是错误的,它假设配额会平均分配,但实际场景中窗口内的配额可能被快速耗尽,此时应该等待到窗口重置,而不是固定等待平均间隔。短间隔重试的方式更灵活可靠。
4. 测试代码的潜在问题
- 用
asyncio.gather同时发起所有请求,容易触发竞态条件,无法模拟真实的请求分布 - 跳过前50个时间戳的操作可能掩盖实际的限制违规情况
- 没有精确统计每个时间窗口内的消息数量,无法准确验证限制是否生效
二、更符合aiogram习惯的实现方式:使用中间件
继承Bot类只能覆盖单个send_message方法,扩展性很差(比如send_photo、send_document等方法不会被限制)。使用aiogram的中间件是框架推荐的全局逻辑处理方式,能自动覆盖所有消息发送类API,扩展性更强。
1. 实现API请求级别的速率限制中间件
from typing import Callable, Dict, Awaitable, Any from aiogram import BaseMiddleware from limits import RateLimitItem from limits.aio.strategies import MovingWindowRateLimiter from limits.storage import storage_from_string import asyncio from aiogram.methods import SendMessage, SendPhoto, SendDocument, SendVideo class APIRateLimitMiddleware(BaseMiddleware): def __init__( self, redis_url: str, global_rate: int = 25, global_window: float = 1.0, chat_rate: int = 3, chat_window: float = 3.0, ): self.storage = storage_from_string(redis_url) self.limiter = MovingWindowRateLimiter(self.storage) self.global_limit = RateLimitItem(global_rate, global_window) self.chat_limit = RateLimitItem(chat_rate, chat_window) async def wait_for_quota(self, key: str, limit: RateLimitItem): while True: if await self.limiter.hit(limit, key): return await asyncio.sleep(0.1) async def __call__( self, handler: Callable[[Any, Dict[str, Any]], Awaitable[Any]], event: Any, data: Dict[str, Any], ) -> Any: # 只对发送消息类的API方法生效 if isinstance(event, (SendMessage, SendPhoto, SendDocument, SendVideo)): # 应用全局限制 await self.wait_for_quota("rate:global", self.global_limit) # 应用单聊天限制 await self.wait_for_quota(f"rate:chat:{event.chat_id}", self.chat_limit) return await handler(event, data)
2. 注册中间件到Bot实例
from aiogram import Bot, Dispatcher from aiogram.enums import ParseMode async def main(): bot = Bot( token="你的机器人Token", parse_mode=ParseMode.HTML ) # 注册中间件到所有Bot API请求 bot.middleware.setup( APIRateLimitMiddleware(redis_url="redis://localhost:6379/0") ) dp = Dispatcher() # 注册路由逻辑... await dp.start_polling(bot)
三、优化后的测试代码
要准确验证速率限制,需要统计每个时间窗口内的消息数量,而不是跳过部分数据:
import asyncio import time import pytest from aiogram import Bot from unittest.mock import patch from your_module import RateLimitedBot # 替换为你的Bot类路径 @pytest.mark.asyncio async def test_global_rate_limit(): message_count = 100 global_rate = 25 bot = RateLimitedBot( token="test_token", redis_url="redis://localhost:6379/1", # 用单独的Redis库避免影响生产数据 global_rate=global_rate, global_per=1.0, chat_rate=100, # 放大单聊天限制避免干扰 chat_per=3.0 ) timestamps = [] async def mocked_send_message(*args, **kwargs): timestamps.append(time.perf_counter()) return type('Message', (), {'chat': type('Chat', (), {'id': kwargs['chat_id']})()})() with patch.object(Bot, "send_message", new=mocked_send_message): tasks = [bot.send_message(chat_id=123, text=f"Test {i}") for i in range(message_count)] await asyncio.gather(*tasks) # 统计每个1秒窗口内的消息数 window_size = 1.0 current_window_start = timestamps[0] current_count = 0 max_count = 0 for ts in timestamps: if ts - current_window_start > window_size: max_count = max(max_count, current_count) current_window_start = ts current_count = 1 else: current_count += 1 max_count = max(max_count, current_count) assert max_count <= global_rate, f"全局限制被突破: {max_count} > {global_rate}" @pytest.mark.asyncio async def test_chat_rate_limit(): message_count = 20 chat_rate = 3 chat_window = 3.0 bot = RateLimitedBot( token="test_token", redis_url="redis://localhost:6379/1", global_rate=100, # 放大全局限制避免干扰 global_per=1.0, chat_rate=chat_rate, chat_per=chat_window ) timestamps = [] async def mocked_send_message(*args, **kwargs): timestamps.append(time.perf_counter()) return type('Message', (), {'chat': type('Chat', (), {'id': kwargs['chat_id']})()})() with patch.object(Bot, "send_message", new=mocked_send_message): tasks = [bot.send_message(chat_id=123, text=f"Test {i}") for i in range(message_count)] await asyncio.gather(*tasks) # 统计每个3秒窗口内的消息数 max_count = 0 for i, ts in enumerate(timestamps): end_ts = ts + chat_window window_count = sum(1 for t in timestamps[i:] if t <= end_ts) max_count = max(max_count, window_count) assert max_count <= chat_rate, f"单聊天限制被突破: {max_count} > {chat_rate}"
内容来源于stack exchange




