File size: 3,813 Bytes
8b7b267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
"""
Simple Rate Limiter for API Endpoints
"""

import time
from collections import defaultdict
from typing import Dict, Tuple
import logging

logger = logging.getLogger(__name__)


class SimpleRateLimiter:
    """
    Simple in-memory rate limiter
    """
    
    def __init__(self):
        # Store: {client_id: [(timestamp, count)]}
        self.requests: Dict[str, list] = defaultdict(list)
        
        # Rate limit configurations (requests per minute)
        self.limits = {
            "default": 60,  # 60 requests per minute
            "sentiment": 30,  # 30 sentiment requests per minute
            "model_loading": 5,  # 5 model loads per minute
            "dataset_loading": 5,  # 5 dataset loads per minute
            "external_api": 100  # 100 external API calls per minute
        }
        
        # Time windows in seconds
        self.window = 60  # 1 minute
    
    def is_allowed(
        self,
        client_id: str,
        endpoint_type: str = "default"
    ) -> Tuple[bool, Dict]:
        """
        Check if request is allowed based on rate limit
        
        Args:
            client_id: Client identifier (IP, API key, etc.)
            endpoint_type: Type of endpoint (default, sentiment, model_loading, etc.)
            
        Returns:
            Tuple of (is_allowed, info_dict)
        """
        current_time = time.time()
        limit = self.limits.get(endpoint_type, self.limits["default"])
        
        # Clean old requests outside the window
        self.requests[client_id] = [
            ts for ts in self.requests[client_id]
            if current_time - ts < self.window
        ]
        
        # Count requests in current window
        request_count = len(self.requests[client_id])
        
        # Check if allowed
        if request_count < limit:
            # Allow request and record it
            self.requests[client_id].append(current_time)
            
            return True, {
                "allowed": True,
                "requests_remaining": limit - request_count - 1,
                "limit": limit,
                "window_seconds": self.window,
                "reset_at": current_time + self.window
            }
        else:
            # Deny request
            oldest_request = min(self.requests[client_id])
            reset_at = oldest_request + self.window
            
            return False, {
                "allowed": False,
                "requests_remaining": 0,
                "limit": limit,
                "window_seconds": self.window,
                "reset_at": reset_at,
                "retry_after": reset_at - current_time
            }
    
    def reset_client(self, client_id: str):
        """Reset rate limit for a specific client"""
        if client_id in self.requests:
            del self.requests[client_id]
            logger.info(f"Rate limit reset for client: {client_id}")
    
    def get_stats(self) -> Dict:
        """Get rate limiter statistics"""
        current_time = time.time()
        
        active_clients = 0
        total_requests = 0
        
        for client_id, timestamps in self.requests.items():
            # Count only recent requests
            recent_requests = [
                ts for ts in timestamps
                if current_time - ts < self.window
            ]
            if recent_requests:
                active_clients += 1
                total_requests += len(recent_requests)
        
        return {
            "active_clients": active_clients,
            "total_recent_requests": total_requests,
            "window_seconds": self.window,
            "limits": self.limits
        }


# Global instance
rate_limiter = SimpleRateLimiter()


# Export
__all__ = ["SimpleRateLimiter", "rate_limiter"]