You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

aiogram Telegram机器人消息发送速率限制的实现问题及优化方案咨询

aiogram Telegram机器人消息发送速率限制的实现问题及优化方案咨询

我来帮你拆解现有实现的问题,同时给出更可靠、符合aiogram框架习惯的优化方案:

一、现有代码的明显问题

1. 速率限制规则定义完全错误

这是导致限制不生效的核心原因!你需要的是3条/3秒的单聊天限制,但当前代码用了RateLimitItemPerSecond(chat_rate)——这个类的作用是定义「每秒允许的消息数」,所以RateLimitItemPerSecond(3)实际表示3条/秒,和你的需求完全不符。

正确的做法是用RateLimitItem类自定义时间窗口:

from limits import RateLimitItem

# 替换原来的RateLimitItemPerSecond初始化
self.global_limit = RateLimitItem(global_rate, global_per)  # 25条/1秒
self.chat_limit = RateLimitItem(chat_rate, chat_per)        # 3条/3秒

2. Test+Hit的非原子操作引发竞态条件

现有代码中先调用limiter.test再调用limiter.hit,这两个操作不是原子的。在高并发场景下,多个请求可能同时通过test检查,然后都执行hit,最终导致实际发送量超过限制。

应该直接使用limiter.hit的返回值判断是否获取到配额——hit是原子操作,会同时完成检查和计数:

async def wait_for_slot(self, key: str, limit: RateLimitItem):
    while True:
        # 原子性地尝试获取配额,成功则返回
        if await self.limiter.hit(limit, key):
            return
        # 短间隔重试,避免不必要的长等待
        await asyncio.sleep(0.1)

3. 睡眠时间计算逻辑不合理

原来的await asyncio.sleep(limit.get_expiry() / limit.amount)是错误的,它假设配额会平均分配,但实际场景中窗口内的配额可能被快速耗尽,此时应该等待到窗口重置,而不是固定等待平均间隔。短间隔重试的方式更灵活可靠。

4. 测试代码的潜在问题

  • asyncio.gather同时发起所有请求,容易触发竞态条件,无法模拟真实的请求分布
  • 跳过前50个时间戳的操作可能掩盖实际的限制违规情况
  • 没有精确统计每个时间窗口内的消息数量,无法准确验证限制是否生效

二、更符合aiogram习惯的实现方式:使用中间件

继承Bot类只能覆盖单个send_message方法,扩展性很差(比如send_photosend_document等方法不会被限制)。使用aiogram的中间件是框架推荐的全局逻辑处理方式,能自动覆盖所有消息发送类API,扩展性更强。

1. 实现API请求级别的速率限制中间件

from typing import Callable, Dict, Awaitable, Any
from aiogram import BaseMiddleware
from limits import RateLimitItem
from limits.aio.strategies import MovingWindowRateLimiter
from limits.storage import storage_from_string
import asyncio
from aiogram.methods import SendMessage, SendPhoto, SendDocument, SendVideo

class APIRateLimitMiddleware(BaseMiddleware):
    def __init__(
        self,
        redis_url: str,
        global_rate: int = 25,
        global_window: float = 1.0,
        chat_rate: int = 3,
        chat_window: float = 3.0,
    ):
        self.storage = storage_from_string(redis_url)
        self.limiter = MovingWindowRateLimiter(self.storage)
        self.global_limit = RateLimitItem(global_rate, global_window)
        self.chat_limit = RateLimitItem(chat_rate, chat_window)

    async def wait_for_quota(self, key: str, limit: RateLimitItem):
        while True:
            if await self.limiter.hit(limit, key):
                return
            await asyncio.sleep(0.1)

    async def __call__(
        self,
        handler: Callable[[Any, Dict[str, Any]], Awaitable[Any]],
        event: Any,
        data: Dict[str, Any],
    ) -> Any:
        # 只对发送消息类的API方法生效
        if isinstance(event, (SendMessage, SendPhoto, SendDocument, SendVideo)):
            # 应用全局限制
            await self.wait_for_quota("rate:global", self.global_limit)
            # 应用单聊天限制
            await self.wait_for_quota(f"rate:chat:{event.chat_id}", self.chat_limit)
        
        return await handler(event, data)

2. 注册中间件到Bot实例

from aiogram import Bot, Dispatcher
from aiogram.enums import ParseMode

async def main():
    bot = Bot(
        token="你的机器人Token",
        parse_mode=ParseMode.HTML
    )
    # 注册中间件到所有Bot API请求
    bot.middleware.setup(
        APIRateLimitMiddleware(redis_url="redis://localhost:6379/0")
    )
    
    dp = Dispatcher()
    # 注册路由逻辑...
    
    await dp.start_polling(bot)

三、优化后的测试代码

要准确验证速率限制,需要统计每个时间窗口内的消息数量,而不是跳过部分数据:

import asyncio
import time
import pytest
from aiogram import Bot
from unittest.mock import patch
from your_module import RateLimitedBot  # 替换为你的Bot类路径

@pytest.mark.asyncio
async def test_global_rate_limit():
    message_count = 100
    global_rate = 25
    bot = RateLimitedBot(
        token="test_token",
        redis_url="redis://localhost:6379/1",  # 用单独的Redis库避免影响生产数据
        global_rate=global_rate,
        global_per=1.0,
        chat_rate=100,  # 放大单聊天限制避免干扰
        chat_per=3.0
    )

    timestamps = []
    async def mocked_send_message(*args, **kwargs):
        timestamps.append(time.perf_counter())
        return type('Message', (), {'chat': type('Chat', (), {'id': kwargs['chat_id']})()})()

    with patch.object(Bot, "send_message", new=mocked_send_message):
        tasks = [bot.send_message(chat_id=123, text=f"Test {i}") for i in range(message_count)]
        await asyncio.gather(*tasks)

    # 统计每个1秒窗口内的消息数
    window_size = 1.0
    current_window_start = timestamps[0]
    current_count = 0
    max_count = 0
    for ts in timestamps:
        if ts - current_window_start > window_size:
            max_count = max(max_count, current_count)
            current_window_start = ts
            current_count = 1
        else:
            current_count += 1
    max_count = max(max_count, current_count)
    assert max_count <= global_rate, f"全局限制被突破: {max_count} > {global_rate}"

@pytest.mark.asyncio
async def test_chat_rate_limit():
    message_count = 20
    chat_rate = 3
    chat_window = 3.0
    bot = RateLimitedBot(
        token="test_token",
        redis_url="redis://localhost:6379/1",
        global_rate=100,  # 放大全局限制避免干扰
        global_per=1.0,
        chat_rate=chat_rate,
        chat_per=chat_window
    )

    timestamps = []
    async def mocked_send_message(*args, **kwargs):
        timestamps.append(time.perf_counter())
        return type('Message', (), {'chat': type('Chat', (), {'id': kwargs['chat_id']})()})()

    with patch.object(Bot, "send_message", new=mocked_send_message):
        tasks = [bot.send_message(chat_id=123, text=f"Test {i}") for i in range(message_count)]
        await asyncio.gather(*tasks)

    # 统计每个3秒窗口内的消息数
    max_count = 0
    for i, ts in enumerate(timestamps):
        end_ts = ts + chat_window
        window_count = sum(1 for t in timestamps[i:] if t <= end_ts)
        max_count = max(max_count, window_count)
    assert max_count <= chat_rate, f"单聊天限制被突破: {max_count} > {chat_rate}"

内容来源于stack exchange

火山引擎 最新活动