Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,10 @@

namespace Microsoft.AspNetCore.Routing
{
public class ClientIdentification
{
public string IpAddress { get; set; } = string.Empty;
public string UserAgent { get; set; } = string.Empty;
public string Host { get; set; } = string.Empty;
public string Protocol { get; set; } = string.Empty;
public Dictionary<string, string> AdditionalHeaders { get; set; } = new();
}

public static class IdentityComponentsEndpointRouteBuilderExtensions
{
public delegate Task LoginHandler(string username, string group, IList<string> roles, ClientIdentification clientInfo);
public delegate Task LogoutHandler(string username);
public delegate Task LoginHandler(string username, string group, IList<string> roles, string ipAddress);
public delegate Task LogoutHandler(string username, string ipAddress);

// These endpoints are required by the Identity Razor components defined in the /Components/Account/Pages directory of this project.
public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpointRouteBuilder endpoints, LoginHandler? loginHandler = null, LogoutHandler? logoutHandler = null)
Expand Down Expand Up @@ -61,8 +52,8 @@ public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpoin
if (user != null && loginHandler != null)
{
var roles = await userManager.GetRolesAsync(user);
var clientInfo = GetClientIdentification(context);
await loginHandler(username, user.Group, roles, clientInfo);
string ipAddress = GetClientIpAddress(context);
await loginHandler(username, user.Group, roles, ipAddress);
}

return TypedResults.LocalRedirect(string.IsNullOrEmpty(returnUrl) ? "/" : !returnUrl.StartsWith("/") ? "/" + returnUrl : returnUrl.StartsWith("//") ? "/" + returnUrl.TrimStart('/') : returnUrl);
Expand Down Expand Up @@ -90,7 +81,8 @@ public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpoin

if (!string.IsNullOrEmpty(username) && logoutHandler != null)
{
await logoutHandler(username);
string ipAddress = GetClientIpAddress(context);
await logoutHandler(username, ipAddress);
}

await signInManager.SignOutAsync();
Expand Down Expand Up @@ -131,8 +123,8 @@ public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpoin
// Same user - log out
if (logoutHandler != null)
{
var clientInfo = GetClientIdentification(context);
await logoutHandler(currentUser.UserName!);
string ipAddress = GetClientIpAddress(context);
await logoutHandler(currentUser.UserName!, ipAddress);
}
await signInManager.SignOutAsync();
return TypedResults.LocalRedirect(string.IsNullOrEmpty(returnUrl) ? "/" : !returnUrl.StartsWith("/") ? "/" + returnUrl : returnUrl.StartsWith("//") ? "/" + returnUrl.TrimStart('/') : returnUrl);
Expand All @@ -142,8 +134,8 @@ public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpoin
// Different user - log out current and log in new
if (logoutHandler != null)
{
var clientInfo = GetClientIdentification(context);
await logoutHandler(currentUser.UserName!);
string ipAddress = GetClientIpAddress(context);
await logoutHandler(currentUser.UserName!, ipAddress);
}
await signInManager.SignOutAsync();
}
Expand All @@ -156,8 +148,8 @@ public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpoin
if (loginHandler != null)
{
var roles = await userManager.GetRolesAsync(user);
var clientInfo = GetClientIdentification(context);
await loginHandler(user.UserName!, user.Group, roles, clientInfo);
string ipAddress = GetClientIpAddress(context);
await loginHandler(user.UserName!, user.Group, roles, ipAddress);
}

if (!string.IsNullOrEmpty(returnUrl))
Expand All @@ -183,30 +175,6 @@ public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpoin
return endpoints;
}

private static ClientIdentification GetClientIdentification(HttpContext context)
{
var clientInfo = new ClientIdentification
{
IpAddress = GetClientIpAddress(context),
UserAgent = context.Request.Headers["User-Agent"].FirstOrDefault() ?? "Unknown",
Host = context.Request.Host.ToString(),
Protocol = context.Request.Protocol
};

// Add additional headers that might be useful for client identification
var headersToCapture = new[] { "Referer", "Accept-Language", "X-Requested-With", "Origin" };
foreach (var header in headersToCapture)
{
var value = context.Request.Headers[header].FirstOrDefault();
if (!string.IsNullOrEmpty(value))
{
clientInfo.AdditionalHeaders[header] = value;
}
}

return clientInfo;
}

private static string GetClientIpAddress(HttpContext context)
{
// Try to get IP from X-Forwarded-For header (for reverse proxy scenarios)
Expand Down
Loading