From 2e04a17b5b459b32e05f493832486562ccbe29e8 Mon Sep 17 00:00:00 2001 From: Ulrich Diedrichsen Date: Fri, 30 Jan 2026 10:51:44 +0100 Subject: [PATCH] test(security): add comprehensive unit tests for Phase 1 --- src/security/intrusion-detector.test.ts | 404 +++++++++++++++++++ src/security/ip-manager.test.ts | 408 +++++++++++++++++++ src/security/rate-limiter.test.ts | 298 ++++++++++++++ src/security/shield.test.ts | 507 ++++++++++++++++++++++++ src/security/token-bucket.test.ts | 157 ++++++++ 5 files changed, 1774 insertions(+) create mode 100644 src/security/intrusion-detector.test.ts create mode 100644 src/security/ip-manager.test.ts create mode 100644 src/security/rate-limiter.test.ts create mode 100644 src/security/shield.test.ts create mode 100644 src/security/token-bucket.test.ts diff --git a/src/security/intrusion-detector.test.ts b/src/security/intrusion-detector.test.ts new file mode 100644 index 000000000..30fd1ce93 --- /dev/null +++ b/src/security/intrusion-detector.test.ts @@ -0,0 +1,404 @@ +import { describe, expect, it, beforeEach, vi, afterEach } from "vitest"; +import { IntrusionDetector } from "./intrusion-detector.js"; +import { SecurityActions, AttackPatterns, type SecurityEvent } from "./events/schema.js"; +import { ipManager } from "./ip-manager.js"; + +vi.mock("./ip-manager.js", () => ({ + ipManager: { + blockIp: vi.fn(), + }, +})); + +describe("IntrusionDetector", () => { + let detector: IntrusionDetector; + + beforeEach(() => { + vi.clearAllMocks(); + detector = new IntrusionDetector({ + enabled: true, + patterns: { + bruteForce: { threshold: 10, windowMs: 600_000 }, + ssrfBypass: { threshold: 3, windowMs: 300_000 }, + pathTraversal: { threshold: 5, windowMs: 300_000 }, + portScanning: { threshold: 20, windowMs: 10_000 }, + }, + anomalyDetection: { + enabled: false, + learningPeriodMs: 86_400_000, + sensitivityScore: 0.95, + }, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + const createTestEvent = (action: string): SecurityEvent => ({ + timestamp: new Date().toISOString(), + eventId: `event-${Math.random()}`, + severity: "warn", + category: "authentication", + ip: "192.168.1.100", + action, + resource: "test_resource", + outcome: "deny", + details: {}, + }); + + describe("checkBruteForce", () => { + it("should detect brute force after threshold", () => { + const ip = "192.168.1.100"; + + // Submit 9 failed auth attempts (below threshold) + for (let i = 0; i < 9; i++) { + const result = detector.checkBruteForce({ + ip, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + expect(result.detected).toBe(false); + } + + // 10th attempt should trigger detection + const result = detector.checkBruteForce({ + ip, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + + expect(result.detected).toBe(true); + expect(result.pattern).toBe(AttackPatterns.BRUTE_FORCE); + expect(result.count).toBe(10); + expect(result.threshold).toBe(10); + expect(ipManager.blockIp).toHaveBeenCalledWith({ + ip, + reason: AttackPatterns.BRUTE_FORCE, + durationMs: 86_400_000, + source: "auto", + }); + }); + + it("should track different IPs independently", () => { + const ip1 = "192.168.1.1"; + const ip2 = "192.168.1.2"; + + // IP1: 5 attempts + for (let i = 0; i < 5; i++) { + detector.checkBruteForce({ + ip: ip1, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + } + + // IP2: 5 attempts + for (let i = 0; i < 5; i++) { + detector.checkBruteForce({ + ip: ip2, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + } + + // Neither should trigger (both under threshold) + const result1 = detector.checkBruteForce({ + ip: ip1, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + const result2 = detector.checkBruteForce({ + ip: ip2, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + + expect(result1.detected).toBe(false); + expect(result2.detected).toBe(false); + }); + + it("should not detect when disabled", () => { + const disabledDetector = new IntrusionDetector({ enabled: false }); + const ip = "192.168.1.100"; + + // Submit 20 attempts (well over threshold) + for (let i = 0; i < 20; i++) { + const result = disabledDetector.checkBruteForce({ + ip, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + expect(result.detected).toBe(false); + } + + expect(ipManager.blockIp).not.toHaveBeenCalled(); + }); + }); + + describe("checkSsrfBypass", () => { + it("should detect SSRF bypass after threshold", () => { + const ip = "192.168.1.100"; + + // Submit 2 SSRF attempts (below threshold) + for (let i = 0; i < 2; i++) { + const result = detector.checkSsrfBypass({ + ip, + event: createTestEvent(SecurityActions.SSRF_BYPASS_ATTEMPT), + }); + expect(result.detected).toBe(false); + } + + // 3rd attempt should trigger detection + const result = detector.checkSsrfBypass({ + ip, + event: createTestEvent(SecurityActions.SSRF_BYPASS_ATTEMPT), + }); + + expect(result.detected).toBe(true); + expect(result.pattern).toBe(AttackPatterns.SSRF_BYPASS); + expect(result.count).toBe(3); + expect(ipManager.blockIp).toHaveBeenCalledWith({ + ip, + reason: AttackPatterns.SSRF_BYPASS, + durationMs: 86_400_000, + source: "auto", + }); + }); + + it("should handle lower threshold than brute force", () => { + const ip = "192.168.1.100"; + + // SSRF has lower threshold (3) than brute force (10) + for (let i = 0; i < 3; i++) { + detector.checkSsrfBypass({ + ip, + event: createTestEvent(SecurityActions.SSRF_BYPASS_ATTEMPT), + }); + } + + // Should detect with fewer attempts + expect(ipManager.blockIp).toHaveBeenCalled(); + }); + }); + + describe("checkPathTraversal", () => { + it("should detect path traversal after threshold", () => { + const ip = "192.168.1.100"; + + // Submit 4 attempts (below threshold) + for (let i = 0; i < 4; i++) { + const result = detector.checkPathTraversal({ + ip, + event: createTestEvent(SecurityActions.PATH_TRAVERSAL_ATTEMPT), + }); + expect(result.detected).toBe(false); + } + + // 5th attempt should trigger detection + const result = detector.checkPathTraversal({ + ip, + event: createTestEvent(SecurityActions.PATH_TRAVERSAL_ATTEMPT), + }); + + expect(result.detected).toBe(true); + expect(result.pattern).toBe(AttackPatterns.PATH_TRAVERSAL); + expect(result.count).toBe(5); + expect(ipManager.blockIp).toHaveBeenCalledWith({ + ip, + reason: AttackPatterns.PATH_TRAVERSAL, + durationMs: 86_400_000, + source: "auto", + }); + }); + }); + + describe("checkPortScanning", () => { + it("should detect port scanning after threshold", () => { + const ip = "192.168.1.100"; + + // Submit 19 connection attempts (below threshold) + for (let i = 0; i < 19; i++) { + const result = detector.checkPortScanning({ + ip, + event: createTestEvent(SecurityActions.CONNECTION_LIMIT_EXCEEDED), + }); + expect(result.detected).toBe(false); + } + + // 20th attempt should trigger detection + const result = detector.checkPortScanning({ + ip, + event: createTestEvent(SecurityActions.CONNECTION_LIMIT_EXCEEDED), + }); + + expect(result.detected).toBe(true); + expect(result.pattern).toBe(AttackPatterns.PORT_SCANNING); + expect(result.count).toBe(20); + expect(ipManager.blockIp).toHaveBeenCalledWith({ + ip, + reason: AttackPatterns.PORT_SCANNING, + durationMs: 86_400_000, + source: "auto", + }); + }); + + it("should handle rapid connection attempts", () => { + const ip = "192.168.1.100"; + + // Rapid-fire 25 connection attempts + for (let i = 0; i < 25; i++) { + detector.checkPortScanning({ + ip, + event: createTestEvent(SecurityActions.CONNECTION_LIMIT_EXCEEDED), + }); + } + + // Should auto-block + expect(ipManager.blockIp).toHaveBeenCalled(); + }); + }); + + describe("time window behavior", () => { + it("should reset detection after time window", () => { + vi.useFakeTimers(); + const ip = "192.168.1.100"; + + // Submit 9 attempts + for (let i = 0; i < 9; i++) { + detector.checkBruteForce({ + ip, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + } + + // Advance past window (10 minutes) + vi.advanceTimersByTime(601_000); + + // Submit 9 more attempts (should not trigger, old attempts expired) + for (let i = 0; i < 9; i++) { + const result = detector.checkBruteForce({ + ip, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + expect(result.detected).toBe(false); + } + + vi.useRealTimers(); + }); + }); + + describe("custom configuration", () => { + it("should respect custom thresholds", () => { + const customDetector = new IntrusionDetector({ + enabled: true, + patterns: { + bruteForce: { threshold: 3, windowMs: 60_000 }, + ssrfBypass: { threshold: 1, windowMs: 60_000 }, + pathTraversal: { threshold: 2, windowMs: 60_000 }, + portScanning: { threshold: 5, windowMs: 10_000 }, + }, + anomalyDetection: { + enabled: false, + learningPeriodMs: 86_400_000, + sensitivityScore: 0.95, + }, + }); + + const ip = "192.168.1.100"; + + // Should trigger with custom threshold (3) + for (let i = 0; i < 3; i++) { + customDetector.checkBruteForce({ + ip, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + } + + expect(ipManager.blockIp).toHaveBeenCalled(); + }); + + it("should respect custom time windows", () => { + vi.useFakeTimers(); + + const customDetector = new IntrusionDetector({ + enabled: true, + patterns: { + bruteForce: { threshold: 5, windowMs: 10_000 }, // 10 seconds + ssrfBypass: { threshold: 3, windowMs: 300_000 }, + pathTraversal: { threshold: 5, windowMs: 300_000 }, + portScanning: { threshold: 20, windowMs: 10_000 }, + }, + anomalyDetection: { + enabled: false, + learningPeriodMs: 86_400_000, + sensitivityScore: 0.95, + }, + }); + + const ip = "192.168.1.100"; + + // Submit 4 attempts + for (let i = 0; i < 4; i++) { + customDetector.checkBruteForce({ + ip, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + } + + // Advance past short window + vi.advanceTimersByTime(11_000); + + // Submit 4 more attempts (should not trigger, old attempts expired) + for (let i = 0; i < 4; i++) { + const result = customDetector.checkBruteForce({ + ip, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + expect(result.detected).toBe(false); + } + + vi.useRealTimers(); + }); + }); + + describe("integration scenarios", () => { + it("should detect multiple attack patterns from same IP", () => { + const ip = "192.168.1.100"; + + // Trigger brute force + for (let i = 0; i < 10; i++) { + detector.checkBruteForce({ + ip, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + } + + // Trigger SSRF bypass + for (let i = 0; i < 3; i++) { + detector.checkSsrfBypass({ + ip, + event: createTestEvent(SecurityActions.SSRF_BYPASS_ATTEMPT), + }); + } + + // Should auto-block for both patterns + expect(ipManager.blockIp).toHaveBeenCalledTimes(2); + expect(ipManager.blockIp).toHaveBeenCalledWith( + expect.objectContaining({ reason: AttackPatterns.BRUTE_FORCE }), + ); + expect(ipManager.blockIp).toHaveBeenCalledWith( + expect.objectContaining({ reason: AttackPatterns.SSRF_BYPASS }), + ); + }); + + it("should handle coordinated attack from multiple IPs", () => { + // Simulate distributed brute force attack + const ips = ["192.168.1.1", "192.168.1.2", "192.168.1.3"]; + + ips.forEach((ip) => { + for (let i = 0; i < 10; i++) { + detector.checkBruteForce({ + ip, + event: createTestEvent(SecurityActions.AUTH_FAILED), + }); + } + }); + + // Should block all attacking IPs + expect(ipManager.blockIp).toHaveBeenCalledTimes(3); + }); + }); +}); diff --git a/src/security/ip-manager.test.ts b/src/security/ip-manager.test.ts new file mode 100644 index 000000000..618722f32 --- /dev/null +++ b/src/security/ip-manager.test.ts @@ -0,0 +1,408 @@ +import { describe, expect, it, beforeEach, vi, afterEach } from "vitest"; +import { IpManager } from "./ip-manager.js"; +import fs from "node:fs"; +import os from "node:os"; +import path from "node:path"; + +vi.mock("node:fs", () => ({ + default: { + promises: { + mkdir: vi.fn().mockResolvedValue(undefined), + writeFile: vi.fn().mockResolvedValue(undefined), + readFile: vi.fn().mockResolvedValue("{}"), + unlink: vi.fn().mockResolvedValue(undefined), + }, + }, +})); + +describe("IpManager", () => { + let manager: IpManager; + + beforeEach(() => { + vi.clearAllMocks(); + manager = new IpManager(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("blockIp", () => { + it("should block an IP address", () => { + manager.blockIp({ + ip: "192.168.1.100", + reason: "brute_force", + durationMs: 86400000, + }); + + const blockReason = manager.isBlocked("192.168.1.100"); + expect(blockReason).toBe("brute_force"); + }); + + it("should block with auto source by default", () => { + manager.blockIp({ + ip: "192.168.1.100", + reason: "test", + durationMs: 86400000, + }); + + expect(manager.isBlocked("192.168.1.100")).toBe("test"); + }); + + it("should block with manual source", () => { + manager.blockIp({ + ip: "192.168.1.100", + reason: "manual_block", + durationMs: 86400000, + source: "manual", + }); + + expect(manager.isBlocked("192.168.1.100")).toBe("manual_block"); + }); + + it("should handle IPv6 addresses", () => { + manager.blockIp({ + ip: "2001:db8::1", + reason: "test", + durationMs: 86400000, + }); + + expect(manager.isBlocked("2001:db8::1")).toBe("test"); + }); + + it("should update existing block", () => { + manager.blockIp({ + ip: "192.168.1.100", + reason: "first_reason", + durationMs: 86400000, + }); + + manager.blockIp({ + ip: "192.168.1.100", + reason: "second_reason", + durationMs: 172800000, + }); + + expect(manager.isBlocked("192.168.1.100")).toBe("second_reason"); + }); + }); + + describe("unblockIp", () => { + it("should unblock a blocked IP", () => { + manager.blockIp({ + ip: "192.168.1.100", + reason: "test", + durationMs: 86400000, + }); + + expect(manager.isBlocked("192.168.1.100")).toBe("test"); + + manager.unblockIp("192.168.1.100"); + + expect(manager.isBlocked("192.168.1.100")).toBeNull(); + }); + + it("should handle unblocking non-existent IP", () => { + expect(() => manager.unblockIp("192.168.1.100")).not.toThrow(); + }); + }); + + describe("allowIp", () => { + it("should add IP to allowlist", () => { + manager.allowIp({ + ip: "192.168.1.200", + reason: "trusted", + }); + + expect(manager.isAllowed("192.168.1.200")).toBe(true); + }); + + it("should add CIDR range to allowlist", () => { + manager.allowIp({ + ip: "10.0.0.0/8", + reason: "internal_network", + }); + + expect(manager.isAllowed("10.5.10.20")).toBe(true); + expect(manager.isAllowed("11.0.0.1")).toBe(false); + }); + + it("should handle Tailscale CGNAT range", () => { + manager.allowIp({ + ip: "100.64.0.0/10", + reason: "tailscale", + }); + + expect(manager.isAllowed("100.64.0.1")).toBe(true); + expect(manager.isAllowed("100.127.255.254")).toBe(true); + expect(manager.isAllowed("100.128.0.1")).toBe(false); + }); + }); + + describe("removeFromAllowlist", () => { + it("should remove IP from allowlist", () => { + manager.allowIp({ + ip: "192.168.1.200", + reason: "trusted", + }); + + expect(manager.isAllowed("192.168.1.200")).toBe(true); + + manager.removeFromAllowlist("192.168.1.200"); + + expect(manager.isAllowed("192.168.1.200")).toBe(false); + }); + + it("should remove CIDR range from allowlist", () => { + manager.allowIp({ + ip: "10.0.0.0/8", + reason: "internal", + }); + + manager.removeFromAllowlist("10.0.0.0/8"); + + expect(manager.isAllowed("10.5.10.20")).toBe(false); + }); + }); + + describe("isBlocked", () => { + it("should return null for non-blocked IP", () => { + expect(manager.isBlocked("192.168.1.100")).toBeNull(); + }); + + it("should return block reason for blocked IP", () => { + manager.blockIp({ + ip: "192.168.1.100", + reason: "brute_force", + durationMs: 86400000, + }); + + expect(manager.isBlocked("192.168.1.100")).toBe("brute_force"); + }); + + it("should return null for expired blocks", () => { + vi.useFakeTimers(); + const now = Date.now(); + vi.setSystemTime(now); + + manager.blockIp({ + ip: "192.168.1.100", + reason: "test", + durationMs: 60000, // 1 minute + }); + + expect(manager.isBlocked("192.168.1.100")).toBe("test"); + + // Advance past expiration + vi.advanceTimersByTime(61000); + + expect(manager.isBlocked("192.168.1.100")).toBeNull(); + + vi.useRealTimers(); + }); + + it("should prioritize allowlist over blocklist", () => { + manager.blockIp({ + ip: "192.168.1.100", + reason: "test", + durationMs: 86400000, + }); + + manager.allowIp({ + ip: "192.168.1.100", + reason: "override", + }); + + expect(manager.isBlocked("192.168.1.100")).toBeNull(); + }); + }); + + describe("isAllowed", () => { + it("should return false for non-allowlisted IP", () => { + expect(manager.isAllowed("192.168.1.100")).toBe(false); + }); + + it("should return true for allowlisted IP", () => { + manager.allowIp({ + ip: "192.168.1.100", + reason: "trusted", + }); + + expect(manager.isAllowed("192.168.1.100")).toBe(true); + }); + + it("should match IP in CIDR range", () => { + manager.allowIp({ + ip: "192.168.0.0/16", + reason: "local_network", + }); + + expect(manager.isAllowed("192.168.1.100")).toBe(true); + expect(manager.isAllowed("192.168.255.255")).toBe(true); + expect(manager.isAllowed("192.169.0.1")).toBe(false); + }); + + it("should match localhost variations", () => { + manager.allowIp({ + ip: "127.0.0.0/8", + reason: "localhost", + }); + + expect(manager.isAllowed("127.0.0.1")).toBe(true); + expect(manager.isAllowed("127.0.0.2")).toBe(true); + expect(manager.isAllowed("127.255.255.255")).toBe(true); + }); + }); + + describe("getBlocklist", () => { + it("should return all blocked IPs", () => { + manager.blockIp({ + ip: "192.168.1.1", + reason: "test1", + durationMs: 86400000, + }); + + manager.blockIp({ + ip: "192.168.1.2", + reason: "test2", + durationMs: 86400000, + }); + + const blocklist = manager.getBlocklist(); + expect(blocklist).toHaveLength(2); + expect(blocklist.map((b) => b.ip)).toContain("192.168.1.1"); + expect(blocklist.map((b) => b.ip)).toContain("192.168.1.2"); + }); + + it("should include expiration timestamps", () => { + const now = new Date(); + manager.blockIp({ + ip: "192.168.1.1", + reason: "test", + durationMs: 86400000, + }); + + const blocklist = manager.getBlocklist(); + expect(blocklist[0]?.expiresAt).toBeDefined(); + expect(new Date(blocklist[0]!.expiresAt).getTime()).toBeGreaterThan(now.getTime()); + }); + }); + + describe("getAllowlist", () => { + it("should return all allowed IPs", () => { + manager.allowIp({ + ip: "192.168.1.100", + reason: "trusted1", + }); + + manager.allowIp({ + ip: "10.0.0.0/8", + reason: "trusted2", + }); + + const allowlist = manager.getAllowlist(); + expect(allowlist).toHaveLength(2); + expect(allowlist.map((a) => a.ip)).toContain("192.168.1.100"); + expect(allowlist.map((a) => a.ip)).toContain("10.0.0.0/8"); + }); + }); + + describe("CIDR matching", () => { + it("should match /24 network", () => { + manager.allowIp({ + ip: "192.168.1.0/24", + reason: "test", + }); + + expect(manager.isAllowed("192.168.1.0")).toBe(true); + expect(manager.isAllowed("192.168.1.100")).toBe(true); + expect(manager.isAllowed("192.168.1.255")).toBe(true); + expect(manager.isAllowed("192.168.2.1")).toBe(false); + }); + + it("should match /16 network", () => { + manager.allowIp({ + ip: "10.20.0.0/16", + reason: "test", + }); + + expect(manager.isAllowed("10.20.0.1")).toBe(true); + expect(manager.isAllowed("10.20.255.254")).toBe(true); + expect(manager.isAllowed("10.21.0.1")).toBe(false); + }); + + it("should match /8 network", () => { + manager.allowIp({ + ip: "172.0.0.0/8", + reason: "test", + }); + + expect(manager.isAllowed("172.16.0.1")).toBe(true); + expect(manager.isAllowed("172.255.255.254")).toBe(true); + expect(manager.isAllowed("173.0.0.1")).toBe(false); + }); + + it("should handle /32 single IP", () => { + manager.allowIp({ + ip: "192.168.1.100/32", + reason: "test", + }); + + expect(manager.isAllowed("192.168.1.100")).toBe(true); + expect(manager.isAllowed("192.168.1.101")).toBe(false); + }); + }); + + describe("integration scenarios", () => { + it("should handle mixed blocklist and allowlist", () => { + // Block entire subnet + manager.blockIp({ + ip: "192.168.1.0/24", + reason: "suspicious_network", + durationMs: 86400000, + }); + + // Allow specific IP from that subnet + manager.allowIp({ + ip: "192.168.1.100", + reason: "known_good", + }); + + // Blocked IP from subnet + expect(manager.isBlocked("192.168.1.50")).toBe("suspicious_network"); + + // Allowed IP overrides block + expect(manager.isBlocked("192.168.1.100")).toBeNull(); + }); + + it("should handle automatic cleanup of expired blocks", () => { + vi.useFakeTimers(); + + manager.blockIp({ + ip: "192.168.1.1", + reason: "short_block", + durationMs: 60000, + }); + + manager.blockIp({ + ip: "192.168.1.2", + reason: "long_block", + durationMs: 86400000, + }); + + // Both blocked initially + expect(manager.isBlocked("192.168.1.1")).toBe("short_block"); + expect(manager.isBlocked("192.168.1.2")).toBe("long_block"); + + // Advance past short block expiration + vi.advanceTimersByTime(61000); + + // Short block expired + expect(manager.isBlocked("192.168.1.1")).toBeNull(); + // Long block still active + expect(manager.isBlocked("192.168.1.2")).toBe("long_block"); + + vi.useRealTimers(); + }); + }); +}); diff --git a/src/security/rate-limiter.test.ts b/src/security/rate-limiter.test.ts new file mode 100644 index 000000000..a7b031647 --- /dev/null +++ b/src/security/rate-limiter.test.ts @@ -0,0 +1,298 @@ +import { describe, expect, it, beforeEach, vi, afterEach } from "vitest"; +import { RateLimiter, RateLimitKeys } from "./rate-limiter.js"; + +describe("RateLimiter", () => { + let limiter: RateLimiter; + + beforeEach(() => { + vi.useFakeTimers(); + limiter = new RateLimiter(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + limiter.resetAll(); + }); + + describe("check", () => { + it("should allow requests within rate limit", () => { + const limit = { max: 5, windowMs: 60000 }; + const key = "test:key"; + + for (let i = 0; i < 5; i++) { + const result = limiter.check(key, limit); + expect(result.allowed).toBe(true); + expect(result.remaining).toBeGreaterThanOrEqual(0); + } + }); + + it("should deny requests exceeding rate limit", () => { + const limit = { max: 3, windowMs: 60000 }; + const key = "test:key"; + + // Consume all tokens + limiter.check(key, limit); + limiter.check(key, limit); + limiter.check(key, limit); + + // Should be rate limited + const result = limiter.check(key, limit); + expect(result.allowed).toBe(false); + expect(result.retryAfterMs).toBeGreaterThan(0); + }); + + it("should track separate keys independently", () => { + const limit = { max: 2, windowMs: 60000 }; + + const result1 = limiter.check("key1", limit); + const result2 = limiter.check("key2", limit); + + expect(result1.allowed).toBe(true); + expect(result2.allowed).toBe(true); + }); + + it("should refill tokens after time window", () => { + const limit = { max: 5, windowMs: 10000 }; // 5 requests per 10 seconds + const key = "test:key"; + + // Consume all tokens + for (let i = 0; i < 5; i++) { + limiter.check(key, limit); + } + + // Should be rate limited + expect(limiter.check(key, limit).allowed).toBe(false); + + // Advance time to allow refill + vi.advanceTimersByTime(10000); + + // Should allow requests again + const result = limiter.check(key, limit); + expect(result.allowed).toBe(true); + }); + + it("should provide resetAt timestamp", () => { + const limit = { max: 5, windowMs: 60000 }; + const key = "test:key"; + + const now = Date.now(); + const result = limiter.check(key, limit); + + expect(result.resetAt).toBeInstanceOf(Date); + expect(result.resetAt.getTime()).toBeGreaterThanOrEqual(now); + }); + }); + + describe("peek", () => { + it("should check limit without consuming tokens", () => { + const limit = { max: 5, windowMs: 60000 }; + const key = "test:key"; + + // Peek multiple times + const result1 = limiter.peek(key, limit); + const result2 = limiter.peek(key, limit); + const result3 = limiter.peek(key, limit); + + expect(result1.allowed).toBe(true); + expect(result2.allowed).toBe(true); + expect(result3.allowed).toBe(true); + expect(result1.remaining).toBe(result2.remaining); + expect(result2.remaining).toBe(result3.remaining); + }); + + it("should reflect consumed tokens from check", () => { + const limit = { max: 5, windowMs: 60000 }; + const key = "test:key"; + + limiter.check(key, limit); // Consume 1 + limiter.check(key, limit); // Consume 1 + + const result = limiter.peek(key, limit); + expect(result.remaining).toBe(3); + }); + }); + + describe("reset", () => { + it("should reset specific key", () => { + const limit = { max: 3, windowMs: 60000 }; + const key = "test:key"; + + // Consume all tokens + limiter.check(key, limit); + limiter.check(key, limit); + limiter.check(key, limit); + + expect(limiter.check(key, limit).allowed).toBe(false); + + // Reset + limiter.reset(key); + + // Should allow requests again + const result = limiter.check(key, limit); + expect(result.allowed).toBe(true); + }); + + it("should not affect other keys", () => { + const limit = { max: 2, windowMs: 60000 }; + + limiter.check("key1", limit); + limiter.check("key2", limit); + + limiter.reset("key1"); + + // key1 should be reset + expect(limiter.peek("key1", limit).remaining).toBe(2); + // key2 should still have consumed token + expect(limiter.peek("key2", limit).remaining).toBe(1); + }); + }); + + describe("resetAll", () => { + it("should reset all keys", () => { + const limit = { max: 3, windowMs: 60000 }; + + limiter.check("key1", limit); + limiter.check("key2", limit); + limiter.check("key3", limit); + + limiter.resetAll(); + + expect(limiter.peek("key1", limit).remaining).toBe(3); + expect(limiter.peek("key2", limit).remaining).toBe(3); + expect(limiter.peek("key3", limit).remaining).toBe(3); + }); + }); + + describe("LRU cache behavior", () => { + it("should evict least recently used entries when cache is full", () => { + const smallLimiter = new RateLimiter({ maxSize: 3 }); + const limit = { max: 5, windowMs: 60000 }; + + // Add 3 entries + smallLimiter.check("key1", limit); + smallLimiter.check("key2", limit); + smallLimiter.check("key3", limit); + + // Add 4th entry, should evict key1 + smallLimiter.check("key4", limit); + + // key1 should be evicted (fresh entry) + expect(smallLimiter.peek("key1", limit).remaining).toBe(5); + // key2, key3, key4 should have consumed tokens + expect(smallLimiter.peek("key2", limit).remaining).toBe(4); + expect(smallLimiter.peek("key3", limit).remaining).toBe(4); + expect(smallLimiter.peek("key4", limit).remaining).toBe(4); + }); + }); + + describe("cleanup", () => { + it("should clean up stale entries", () => { + const limit = { max: 5, windowMs: 10000 }; + const key = "test:key"; + + limiter.check(key, limit); + + // Advance past cleanup interval + TTL + vi.advanceTimersByTime(180000); // 3 minutes (cleanup runs every 60s, TTL is 2min) + + // Trigger cleanup by checking + limiter.check("trigger:cleanup", limit); + + // Original entry should be cleaned up (fresh entry) + expect(limiter.peek(key, limit).remaining).toBe(5); + }); + }); + + describe("RateLimitKeys", () => { + it("should generate unique keys for auth attempts", () => { + const key1 = RateLimitKeys.authAttempt("192.168.1.1"); + const key2 = RateLimitKeys.authAttempt("192.168.1.2"); + + expect(key1).toBe("auth:192.168.1.1"); + expect(key2).toBe("auth:192.168.1.2"); + expect(key1).not.toBe(key2); + }); + + it("should generate unique keys for device auth attempts", () => { + const key1 = RateLimitKeys.authAttemptDevice("device-123"); + const key2 = RateLimitKeys.authAttemptDevice("device-456"); + + expect(key1).toBe("auth:device:device-123"); + expect(key2).toBe("auth:device:device-456"); + expect(key1).not.toBe(key2); + }); + + it("should generate unique keys for connections", () => { + const key = RateLimitKeys.connection("192.168.1.1"); + expect(key).toBe("conn:192.168.1.1"); + }); + + it("should generate unique keys for requests", () => { + const key = RateLimitKeys.request("192.168.1.1"); + expect(key).toBe("req:192.168.1.1"); + }); + + it("should generate unique keys for pairing requests", () => { + const key = RateLimitKeys.pairingRequest("telegram", "user123"); + expect(key).toBe("pair:telegram:user123"); + }); + + it("should generate unique keys for webhook tokens", () => { + const key = RateLimitKeys.webhookToken("token-abc"); + expect(key).toBe("hook:token:token-abc"); + }); + + it("should generate unique keys for webhook paths", () => { + const key = RateLimitKeys.webhookPath("/api/webhook"); + expect(key).toBe("hook:path:/api/webhook"); + }); + }); + + describe("integration scenarios", () => { + it("should handle burst traffic pattern", () => { + const limit = { max: 10, windowMs: 60000 }; + const key = "burst:test"; + + // Burst of 10 requests + for (let i = 0; i < 10; i++) { + expect(limiter.check(key, limit).allowed).toBe(true); + } + + // 11th request should be rate limited + expect(limiter.check(key, limit).allowed).toBe(false); + }); + + it("should handle sustained traffic under limit", () => { + const limit = { max: 100, windowMs: 60000 }; // 100 req/min + const key = "sustained:test"; + + // 50 requests should all pass + for (let i = 0; i < 50; i++) { + expect(limiter.check(key, limit).allowed).toBe(true); + } + + const result = limiter.peek(key, limit); + expect(result.remaining).toBe(50); + }); + + it("should handle multiple IPs with different patterns", () => { + const limit = { max: 5, windowMs: 60000 }; + + // IP1: consume 3 tokens + for (let i = 0; i < 3; i++) { + limiter.check(RateLimitKeys.authAttempt("192.168.1.1"), limit); + } + + // IP2: consume 5 tokens (rate limited) + for (let i = 0; i < 5; i++) { + limiter.check(RateLimitKeys.authAttempt("192.168.1.2"), limit); + } + + // IP1 should still have capacity + expect(limiter.check(RateLimitKeys.authAttempt("192.168.1.1"), limit).allowed).toBe(true); + + // IP2 should be rate limited + expect(limiter.check(RateLimitKeys.authAttempt("192.168.1.2"), limit).allowed).toBe(false); + }); + }); +}); diff --git a/src/security/shield.test.ts b/src/security/shield.test.ts new file mode 100644 index 000000000..fa984320a --- /dev/null +++ b/src/security/shield.test.ts @@ -0,0 +1,507 @@ +import { describe, expect, it, beforeEach, vi, afterEach } from "vitest"; +import { SecurityShield, type SecurityContext } from "./shield.js"; +import { rateLimiter } from "./rate-limiter.js"; +import { ipManager } from "./ip-manager.js"; +import type { IncomingMessage } from "node:http"; + +vi.mock("./rate-limiter.js", () => ({ + rateLimiter: { + check: vi.fn(), + }, + RateLimitKeys: { + authAttempt: (ip: string) => `auth:${ip}`, + authAttemptDevice: (deviceId: string) => `auth:device:${deviceId}`, + connection: (ip: string) => `conn:${ip}`, + request: (ip: string) => `req:${ip}`, + pairingRequest: (channel: string, sender: string) => `pair:${channel}:${sender}`, + webhookToken: (token: string) => `hook:token:${token}`, + webhookPath: (path: string) => `hook:path:${path}`, + }, +})); + +vi.mock("./ip-manager.js", () => ({ + ipManager: { + isBlocked: vi.fn(), + }, +})); + +describe("SecurityShield", () => { + let shield: SecurityShield; + + beforeEach(() => { + vi.clearAllMocks(); + shield = new SecurityShield({ + enabled: true, + rateLimiting: { + enabled: true, + perIp: { + authAttempts: { max: 5, windowMs: 300_000 }, + connections: { max: 10, windowMs: 60_000 }, + requests: { max: 100, windowMs: 60_000 }, + }, + perDevice: { + authAttempts: { max: 10, windowMs: 900_000 }, + }, + perSender: { + pairingRequests: { max: 3, windowMs: 3_600_000 }, + }, + webhook: { + perToken: { max: 200, windowMs: 60_000 }, + perPath: { max: 50, windowMs: 60_000 }, + }, + }, + intrusionDetection: { + enabled: true, + patterns: { + bruteForce: { threshold: 10, windowMs: 600_000 }, + ssrfBypass: { threshold: 3, windowMs: 300_000 }, + pathTraversal: { threshold: 5, windowMs: 300_000 }, + portScanning: { threshold: 20, windowMs: 10_000 }, + }, + anomalyDetection: { + enabled: false, + learningPeriodMs: 86_400_000, + sensitivityScore: 0.95, + }, + }, + ipManagement: { + autoBlock: { + enabled: true, + durationMs: 86_400_000, + }, + allowlist: ["100.64.0.0/10"], + firewall: { + enabled: true, + backend: "iptables", + }, + }, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + const createContext = (ip: string, deviceId?: string): SecurityContext => ({ + ip, + deviceId, + userAgent: "test-agent", + requestId: "test-request-id", + }); + + describe("isEnabled", () => { + it("should return true when enabled", () => { + expect(shield.isEnabled()).toBe(true); + }); + + it("should return false when disabled", () => { + const disabledShield = new SecurityShield({ enabled: false }); + expect(disabledShield.isEnabled()).toBe(false); + }); + }); + + describe("isIpBlocked", () => { + it("should return true for blocked IP", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue("test_reason"); + expect(shield.isIpBlocked("192.168.1.100")).toBe(true); + }); + + it("should return false for non-blocked IP", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + expect(shield.isIpBlocked("192.168.1.100")).toBe(false); + }); + + it("should return false when shield disabled", () => { + const disabledShield = new SecurityShield({ enabled: false }); + expect(disabledShield.isIpBlocked("192.168.1.100")).toBe(false); + }); + }); + + describe("checkAuthAttempt", () => { + it("should allow auth when under rate limit", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check).mockReturnValue({ + allowed: true, + remaining: 4, + resetAt: new Date(), + }); + + const result = shield.checkAuthAttempt(createContext("192.168.1.100")); + + expect(result.allowed).toBe(true); + expect(result.reason).toBeUndefined(); + }); + + it("should deny auth for blocked IP", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue("brute_force"); + + const result = shield.checkAuthAttempt(createContext("192.168.1.100")); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("IP blocked: brute_force"); + }); + + it("should deny auth when per-IP rate limit exceeded", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check).mockReturnValue({ + allowed: false, + retryAfterMs: 60000, + remaining: 0, + resetAt: new Date(), + }); + + const result = shield.checkAuthAttempt(createContext("192.168.1.100")); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("Rate limit exceeded"); + expect(result.rateLimitInfo?.retryAfterMs).toBe(60000); + }); + + it("should deny auth when per-device rate limit exceeded", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check) + .mockReturnValueOnce({ + allowed: true, + remaining: 4, + resetAt: new Date(), + }) + .mockReturnValueOnce({ + allowed: false, + retryAfterMs: 120000, + remaining: 0, + resetAt: new Date(), + }); + + const result = shield.checkAuthAttempt(createContext("192.168.1.100", "device-123")); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("Rate limit exceeded (device)"); + }); + + it("should allow auth when shield disabled", () => { + const disabledShield = new SecurityShield({ enabled: false }); + const result = disabledShield.checkAuthAttempt(createContext("192.168.1.100")); + + expect(result.allowed).toBe(true); + expect(ipManager.isBlocked).not.toHaveBeenCalled(); + expect(rateLimiter.check).not.toHaveBeenCalled(); + }); + }); + + describe("checkConnection", () => { + it("should allow connection when under rate limit", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check).mockReturnValue({ + allowed: true, + remaining: 9, + resetAt: new Date(), + }); + + const result = shield.checkConnection(createContext("192.168.1.100")); + + expect(result.allowed).toBe(true); + }); + + it("should deny connection for blocked IP", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue("port_scanning"); + + const result = shield.checkConnection(createContext("192.168.1.100")); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("IP blocked: port_scanning"); + }); + + it("should deny connection when rate limit exceeded", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check).mockReturnValue({ + allowed: false, + retryAfterMs: 30000, + remaining: 0, + resetAt: new Date(), + }); + + const result = shield.checkConnection(createContext("192.168.1.100")); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("Connection rate limit exceeded"); + }); + }); + + describe("checkRequest", () => { + it("should allow request when under rate limit", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check).mockReturnValue({ + allowed: true, + remaining: 99, + resetAt: new Date(), + }); + + const result = shield.checkRequest(createContext("192.168.1.100")); + + expect(result.allowed).toBe(true); + }); + + it("should deny request for blocked IP", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue("malicious"); + + const result = shield.checkRequest(createContext("192.168.1.100")); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("IP blocked: malicious"); + }); + + it("should deny request when rate limit exceeded", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check).mockReturnValue({ + allowed: false, + retryAfterMs: 10000, + remaining: 0, + resetAt: new Date(), + }); + + const result = shield.checkRequest(createContext("192.168.1.100")); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("Request rate limit exceeded"); + }); + }); + + describe("checkPairingRequest", () => { + it("should allow pairing when under rate limit", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check).mockReturnValue({ + allowed: true, + remaining: 2, + resetAt: new Date(), + }); + + const result = shield.checkPairingRequest({ + channel: "telegram", + sender: "user123", + ip: "192.168.1.100", + }); + + expect(result.allowed).toBe(true); + }); + + it("should deny pairing for blocked IP", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue("spam"); + + const result = shield.checkPairingRequest({ + channel: "telegram", + sender: "user123", + ip: "192.168.1.100", + }); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("IP blocked: spam"); + }); + + it("should deny pairing when rate limit exceeded", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check).mockReturnValue({ + allowed: false, + retryAfterMs: 3600000, + remaining: 0, + resetAt: new Date(), + }); + + const result = shield.checkPairingRequest({ + channel: "telegram", + sender: "user123", + ip: "192.168.1.100", + }); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("Pairing rate limit exceeded"); + }); + }); + + describe("checkWebhook", () => { + it("should allow webhook when under rate limit", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check).mockReturnValue({ + allowed: true, + remaining: 199, + resetAt: new Date(), + }); + + const result = shield.checkWebhook({ + token: "webhook-token", + path: "/api/webhook", + ip: "192.168.1.100", + }); + + expect(result.allowed).toBe(true); + }); + + it("should deny webhook for blocked IP", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue("abuse"); + + const result = shield.checkWebhook({ + token: "webhook-token", + path: "/api/webhook", + ip: "192.168.1.100", + }); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("IP blocked: abuse"); + }); + + it("should deny webhook when per-token rate limit exceeded", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check).mockReturnValueOnce({ + allowed: false, + retryAfterMs: 5000, + remaining: 0, + resetAt: new Date(), + }); + + const result = shield.checkWebhook({ + token: "webhook-token", + path: "/api/webhook", + ip: "192.168.1.100", + }); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("Webhook rate limit exceeded (token)"); + }); + + it("should deny webhook when per-path rate limit exceeded", () => { + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + vi.mocked(rateLimiter.check) + .mockReturnValueOnce({ + allowed: true, + remaining: 199, + resetAt: new Date(), + }) + .mockReturnValueOnce({ + allowed: false, + retryAfterMs: 10000, + remaining: 0, + resetAt: new Date(), + }); + + const result = shield.checkWebhook({ + token: "webhook-token", + path: "/api/webhook", + ip: "192.168.1.100", + }); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("Webhook rate limit exceeded (path)"); + }); + }); + + describe("extractIp", () => { + it("should extract IP from X-Forwarded-For header", () => { + const req = { + headers: { + "x-forwarded-for": "203.0.113.1, 198.51.100.1", + }, + socket: { + remoteAddress: "192.168.1.1", + }, + } as unknown as IncomingMessage; + + const ip = SecurityShield.extractIp(req); + expect(ip).toBe("203.0.113.1"); + }); + + it("should extract IP from X-Real-IP header when X-Forwarded-For absent", () => { + const req = { + headers: { + "x-real-ip": "203.0.113.5", + }, + socket: { + remoteAddress: "192.168.1.1", + }, + } as unknown as IncomingMessage; + + const ip = SecurityShield.extractIp(req); + expect(ip).toBe("203.0.113.5"); + }); + + it("should fall back to socket remote address", () => { + const req = { + headers: {}, + socket: { + remoteAddress: "192.168.1.100", + }, + } as unknown as IncomingMessage; + + const ip = SecurityShield.extractIp(req); + expect(ip).toBe("192.168.1.100"); + }); + + it("should handle missing socket", () => { + const req = { + headers: {}, + } as unknown as IncomingMessage; + + const ip = SecurityShield.extractIp(req); + expect(ip).toBe("unknown"); + }); + + it("should handle array X-Forwarded-For", () => { + const req = { + headers: { + "x-forwarded-for": ["203.0.113.1, 198.51.100.1", "192.0.2.1"], + }, + socket: { + remoteAddress: "192.168.1.1", + }, + } as unknown as IncomingMessage; + + const ip = SecurityShield.extractIp(req); + expect(ip).toBe("203.0.113.1"); + }); + }); + + describe("integration scenarios", () => { + it("should coordinate IP blocklist and rate limiting", () => { + // First check: allow + vi.mocked(ipManager.isBlocked).mockReturnValueOnce(null); + vi.mocked(rateLimiter.check).mockReturnValueOnce({ + allowed: true, + remaining: 4, + resetAt: new Date(), + }); + + const result1 = shield.checkAuthAttempt(createContext("192.168.1.100")); + expect(result1.allowed).toBe(true); + + // Second check: IP now blocked + vi.mocked(ipManager.isBlocked).mockReturnValueOnce("brute_force"); + + const result2 = shield.checkAuthAttempt(createContext("192.168.1.100")); + expect(result2.allowed).toBe(false); + expect(result2.reason).toBe("IP blocked: brute_force"); + }); + + it("should handle per-IP and per-device limits together", () => { + const ctx = createContext("192.168.1.100", "device-123"); + + vi.mocked(ipManager.isBlocked).mockReturnValue(null); + + // Per-IP limit OK, per-device limit exceeded + vi.mocked(rateLimiter.check) + .mockReturnValueOnce({ + allowed: true, + remaining: 3, + resetAt: new Date(), + }) + .mockReturnValueOnce({ + allowed: false, + retryAfterMs: 60000, + remaining: 0, + resetAt: new Date(), + }); + + const result = shield.checkAuthAttempt(ctx); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("Rate limit exceeded (device)"); + }); + }); +}); diff --git a/src/security/token-bucket.test.ts b/src/security/token-bucket.test.ts new file mode 100644 index 000000000..9e6cfdcc6 --- /dev/null +++ b/src/security/token-bucket.test.ts @@ -0,0 +1,157 @@ +import { describe, expect, it, beforeEach, vi, afterEach } from "vitest"; +import { TokenBucket } from "./token-bucket.js"; + +describe("TokenBucket", () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("constructor", () => { + it("should initialize with full tokens", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 1 }); + expect(bucket.getTokens()).toBe(10); + }); + + it("should throw error for invalid max", () => { + expect(() => new TokenBucket({ max: 0, refillRate: 1 })).toThrow("max must be positive"); + expect(() => new TokenBucket({ max: -1, refillRate: 1 })).toThrow("max must be positive"); + }); + + it("should throw error for invalid refillRate", () => { + expect(() => new TokenBucket({ max: 10, refillRate: 0 })).toThrow( + "refillRate must be positive", + ); + expect(() => new TokenBucket({ max: 10, refillRate: -1 })).toThrow( + "refillRate must be positive", + ); + }); + }); + + describe("consume", () => { + it("should consume tokens successfully when available", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 1 }); + expect(bucket.consume(3)).toBe(true); + expect(bucket.getTokens()).toBe(7); + }); + + it("should reject consumption when insufficient tokens", () => { + const bucket = new TokenBucket({ max: 5, refillRate: 1 }); + expect(bucket.consume(10)).toBe(false); + expect(bucket.getTokens()).toBe(5); + }); + + it("should consume exactly available tokens", () => { + const bucket = new TokenBucket({ max: 5, refillRate: 1 }); + expect(bucket.consume(5)).toBe(true); + expect(bucket.getTokens()).toBe(0); + }); + + it("should reject consumption when count is zero", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 1 }); + expect(bucket.consume(0)).toBe(false); + }); + + it("should reject consumption when count is negative", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 1 }); + expect(bucket.consume(-1)).toBe(false); + }); + }); + + describe("refill", () => { + it("should refill tokens based on elapsed time", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 2 }); // 2 tokens/sec + bucket.consume(10); // Empty the bucket + expect(bucket.getTokens()).toBe(0); + + vi.advanceTimersByTime(1000); // Advance 1 second + expect(bucket.consume(1)).toBe(true); // Should refill 2 tokens + expect(bucket.getTokens()).toBeCloseTo(1, 1); + }); + + it("should not exceed max tokens on refill", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 5 }); + bucket.consume(5); // 5 tokens left + + vi.advanceTimersByTime(10000); // Advance 10 seconds (should refill 50 tokens) + expect(bucket.getTokens()).toBe(10); // Capped at max + }); + + it("should handle partial second refills", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 1 }); + bucket.consume(10); // Empty the bucket + + vi.advanceTimersByTime(500); // Advance 0.5 seconds + expect(bucket.getTokens()).toBeCloseTo(0.5, 1); + }); + }); + + describe("getRetryAfterMs", () => { + it("should return 0 when enough tokens available", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 1 }); + expect(bucket.getRetryAfterMs(5)).toBe(0); + }); + + it("should calculate retry time for insufficient tokens", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 2 }); // 2 tokens/sec + bucket.consume(10); // Empty the bucket + + // Need 5 tokens, refill rate is 2/sec, so need 2.5 seconds + const retryAfter = bucket.getRetryAfterMs(5); + expect(retryAfter).toBeGreaterThanOrEqual(2400); + expect(retryAfter).toBeLessThanOrEqual(2600); + }); + + it("should return Infinity when count exceeds max", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 1 }); + expect(bucket.getRetryAfterMs(15)).toBe(Infinity); + }); + }); + + describe("reset", () => { + it("should restore bucket to full capacity", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 1 }); + bucket.consume(8); + expect(bucket.getTokens()).toBe(2); + + bucket.reset(); + expect(bucket.getTokens()).toBe(10); + }); + }); + + describe("integration scenarios", () => { + it("should handle burst followed by gradual refill", () => { + const bucket = new TokenBucket({ max: 5, refillRate: 1 }); + + // Burst: consume all tokens + expect(bucket.consume(1)).toBe(true); + expect(bucket.consume(1)).toBe(true); + expect(bucket.consume(1)).toBe(true); + expect(bucket.consume(1)).toBe(true); + expect(bucket.consume(1)).toBe(true); + expect(bucket.consume(1)).toBe(false); // Depleted + + // Wait and refill + vi.advanceTimersByTime(2000); // 2 seconds = 2 tokens + expect(bucket.consume(1)).toBe(true); + expect(bucket.consume(1)).toBe(true); + expect(bucket.consume(1)).toBe(false); // Not enough yet + }); + + it("should maintain capacity during continuous consumption", () => { + const bucket = new TokenBucket({ max: 10, refillRate: 5 }); // 5 tokens/sec + + // Consume 5 tokens per second (sustainable rate) + for (let i = 0; i < 5; i++) { + expect(bucket.consume(1)).toBe(true); + vi.advanceTimersByTime(200); // 0.2 seconds = 1 token refill + } + + // Should still have tokens available + expect(bucket.getTokens()).toBeGreaterThan(0); + }); + }); +});