Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package co.nilin.opex.api.app.config

import co.nilin.opex.api.app.service.RateLimitCoordinatorService
import co.nilin.opex.api.core.inout.RateLimitEndpoint
import co.nilin.opex.api.core.spi.RateLimitConfigService
import co.nilin.opex.common.OpexError
import org.slf4j.LoggerFactory
import org.springframework.http.HttpStatus
import org.springframework.security.core.context.ReactiveSecurityContextHolder
import org.springframework.stereotype.Component
Expand All @@ -18,6 +19,7 @@ class RateLimitConfig(
private val coordinator: RateLimitCoordinatorService

) : WebFilter {
private val logger = LoggerFactory.getLogger(RateLimitConfig::class.java)
private val parser = PathPatternParser()

override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> {
Expand All @@ -36,22 +38,28 @@ class RateLimitConfig(
return chain.filter(exchange)
}

return applyRateLimitIfAuthenticated(exchange, chain, endpoint)
return applyRateLimitIfAuthenticated(
exchange,
chain,
endpoint.groupId
)
}


private fun applyRateLimitIfAuthenticated(
exchange: ServerWebExchange,
chain: WebFilterChain,
endpoint: RateLimitEndpoint
groupId: Long
): Mono<Void> {

return ReactiveSecurityContextHolder.getContext()
.mapNotNull { it.authentication }
.filter { it.isAuthenticated }
.flatMap { auth ->
if (auth != null && !auth.name.isNullOrBlank())
applyRateLimit(auth.name, exchange, chain, endpoint)
applyRateLimit(
auth.name, exchange, chain, groupId
)
else
chain.filter(exchange)
}
Expand All @@ -63,23 +71,29 @@ class RateLimitConfig(
identity: String,
exchange: ServerWebExchange,
chain: WebFilterChain,
endpoint: RateLimitEndpoint
groupId: Long
): Mono<Void> {

val group = rateLimitConfig.getGroup(endpoint.groupId)
val group = rateLimitConfig.getGroup(groupId)
?: return chain.filter(exchange)

val result = coordinator.check(
identity = identity,
groupId = endpoint.groupId,
groupId = groupId,
maxRequests = group.requestCount,
windowSeconds = group.requestWindowSeconds,
apiPath = endpoint.url,
apiMethod = endpoint.method
apiPath = exchange.request.uri.path,
apiMethod = exchange.request.method.name()
)

return if (result.blocked) {
tooManyRequests(exchange, identity, endpoint.url, endpoint.method, result.retryAfterSeconds)
tooManyRequests(
exchange,
identity,
exchange.request.uri.path,
exchange.request.method.name(),
result.retryAfterSeconds
)
} else {
chain.filter(exchange)
}
Expand All @@ -93,6 +107,7 @@ class RateLimitConfig(
method: String,
retryAfterSeconds: Int
): Mono<Void> {
logger.info("Rate limit exceeded ($identity) -- $method:$url")
exchange.response.statusCode = HttpStatus.TOO_MANY_REQUESTS
return exchange.response.writeWith(
Mono.just(
Expand Down
Loading