136 lines
4.4 KiB
C#
136 lines
4.4 KiB
C#
using System.Collections.Concurrent;
|
|
|
|
namespace WebApp.Authentication;
|
|
|
|
public class LoginRateLimitService : IHostedService, IDisposable
|
|
{
|
|
private readonly ConcurrentDictionary<string, LoginAttemptTracker> _attempts = new();
|
|
private readonly ILogger<LoginRateLimitService> _logger;
|
|
private Timer? _cleanupTimer;
|
|
private const int MaxAttempts = 5;
|
|
private static readonly TimeSpan LockoutDuration = TimeSpan.FromMinutes(15);
|
|
private static readonly TimeSpan CleanupInterval = TimeSpan.FromHours(1);
|
|
|
|
public class LoginAttemptTracker
|
|
{
|
|
public int FailedAttempts { get; set; }
|
|
public DateTime? LockoutUntil { get; set; }
|
|
public DateTime LastAttemptTime { get; set; } = DateTime.UtcNow;
|
|
}
|
|
|
|
public LoginRateLimitService(ILogger<LoginRateLimitService> logger)
|
|
{
|
|
_logger = logger;
|
|
}
|
|
|
|
public bool IsLockedOut(string ipAddress)
|
|
{
|
|
if (!_attempts.TryGetValue(ipAddress, out var tracker))
|
|
return false;
|
|
|
|
if (tracker.LockoutUntil.HasValue && tracker.LockoutUntil > DateTime.UtcNow)
|
|
return true;
|
|
|
|
// Lockout expired, reset
|
|
if (tracker.LockoutUntil.HasValue)
|
|
{
|
|
tracker.FailedAttempts = 0;
|
|
tracker.LockoutUntil = null;
|
|
_logger.LogInformation("Lockout expired for IP: {IpAddress}", ipAddress);
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
public void RecordFailedAttempt(string ipAddress)
|
|
{
|
|
var tracker = _attempts.GetOrAdd(ipAddress, _ => new LoginAttemptTracker());
|
|
|
|
tracker.FailedAttempts++;
|
|
tracker.LastAttemptTime = DateTime.UtcNow;
|
|
|
|
if (tracker.FailedAttempts >= MaxAttempts)
|
|
{
|
|
tracker.LockoutUntil = DateTime.UtcNow.Add(LockoutDuration);
|
|
_logger.LogWarning(
|
|
"IP address locked out due to {Attempts} failed login attempts: {IpAddress}. Lockout until: {LockoutUntil}",
|
|
tracker.FailedAttempts, ipAddress, tracker.LockoutUntil);
|
|
}
|
|
else
|
|
{
|
|
_logger.LogInformation(
|
|
"Failed login attempt {Attempt}/{MaxAttempts} for IP: {IpAddress}",
|
|
tracker.FailedAttempts, MaxAttempts, ipAddress);
|
|
}
|
|
}
|
|
|
|
public void RecordSuccessfulLogin(string ipAddress)
|
|
{
|
|
if (_attempts.TryRemove(ipAddress, out _))
|
|
{
|
|
_logger.LogDebug("Cleared rate limit tracking for IP: {IpAddress}", ipAddress);
|
|
}
|
|
}
|
|
|
|
public TimeSpan? GetRemainingLockoutTime(string ipAddress)
|
|
{
|
|
if (!_attempts.TryGetValue(ipAddress, out var tracker) ||
|
|
!tracker.LockoutUntil.HasValue)
|
|
return null;
|
|
|
|
var remaining = tracker.LockoutUntil.Value - DateTime.UtcNow;
|
|
return remaining > TimeSpan.Zero ? remaining : null;
|
|
}
|
|
|
|
// Background cleanup to prevent memory leaks
|
|
private void CleanupExpiredEntries(object? state)
|
|
{
|
|
try
|
|
{
|
|
var cutoffTime = DateTime.UtcNow.AddHours(-24);
|
|
var expiredKeys = _attempts
|
|
.Where(kvp => kvp.Value.LastAttemptTime < cutoffTime &&
|
|
(!kvp.Value.LockoutUntil.HasValue || kvp.Value.LockoutUntil < DateTime.UtcNow))
|
|
.Select(kvp => kvp.Key)
|
|
.ToList();
|
|
|
|
foreach (var key in expiredKeys)
|
|
{
|
|
if (_attempts.TryRemove(key, out _))
|
|
{
|
|
_logger.LogDebug("Removed expired rate limit entry for IP: {IpAddress}", key);
|
|
}
|
|
}
|
|
|
|
if (expiredKeys.Any())
|
|
{
|
|
_logger.LogInformation("Cleaned up {Count} expired rate limit entries", expiredKeys.Count);
|
|
}
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "Error during rate limit cleanup");
|
|
}
|
|
}
|
|
|
|
// IHostedService implementation
|
|
public Task StartAsync(CancellationToken cancellationToken)
|
|
{
|
|
_logger.LogInformation("Login rate limiting service started");
|
|
_cleanupTimer = new Timer(CleanupExpiredEntries, null, CleanupInterval, CleanupInterval);
|
|
return Task.CompletedTask;
|
|
}
|
|
|
|
public Task StopAsync(CancellationToken cancellationToken)
|
|
{
|
|
_logger.LogInformation("Login rate limiting service stopped");
|
|
_cleanupTimer?.Change(Timeout.Infinite, 0);
|
|
return Task.CompletedTask;
|
|
}
|
|
|
|
public void Dispose()
|
|
{
|
|
_cleanupTimer?.Dispose();
|
|
}
|
|
}
|