diff --git a/src/main/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealer.java b/src/main/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealer.java
index 38759d002..ae36f776f 100644
--- a/src/main/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealer.java
+++ b/src/main/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealer.java
@@ -83,16 +83,21 @@ private Timer timer() {
/**
* DO NOT RETURN THE TOKEN IN ANY API ENDPOINT.
This function utilizes Playwright in order to get an
* authentication key from Leetcode. That code is stored in the database and can then be used to run authenticated
- * queries such as used to retrieve code from our user submissions.
+ * queries such as being used to retrieve code from our user submissions.
*/
@Scheduled(initialDelay = 0, fixedDelay = 1, timeUnit = TimeUnit.HOURS)
public void stealAuthCookie() {
timer().record(() -> {
- LOCK.writeLock().lock();
+ boolean acquired = LOCK.writeLock().tryLock();
+ if (!acquired) {
+ log.info("Lock failed to be acquired, bouncing...");
+ return;
+ }
+
try {
Auth mostRecentAuth = authRepository.getMostRecentAuth();
- // The auth token should be refreshed every day.
+ // The auth token should be refreshed every 4 hours.
if (mostRecentAuth != null
&& mostRecentAuth
.getCreatedAt()
@@ -100,28 +105,22 @@ public void stealAuthCookie() {
log.info("Auth token already exists, using token from database.");
cookie = mostRecentAuth.getToken();
csrf = mostRecentAuth.getCsrf();
- if (env.isCi()) {
- log.info("in ci, stealing token and putting it in cache for 1 day");
- redisClient.setAuth(cookie, 4, ChronoUnit.HOURS); // 4 hours.
- }
return;
}
- if (env.isCi()) {
- log.info("in ci env, checking redis client...");
- Optional authToken = redisClient.getAuth();
-
- log.info("auth token in redis = {}", authToken.isPresent());
+ log.info("falling back to checking redis client...");
+ Optional authToken = redisClient.getAuth();
- if (authToken.isPresent()) {
- log.info("auth token found in redis client");
- cookie = authToken.get();
- csrf = null; // don't care in ci.
- return;
- }
+ log.info("auth token in redis = {}", authToken.isPresent());
- log.info("auth token not found in redis client");
+ if (authToken.isPresent()) {
+ log.info("auth token found in redis client");
+ cookie = authToken.get();
+ csrf = null; // don't care in ci.
+ return;
}
+
+ log.info("auth token not found in redis client");
log.info("Auth token is missing/expired. Attempting to receive token...");
stealCookieImpl();
@@ -139,7 +138,19 @@ public void stealAuthCookie() {
*/
@Async
public CompletableFuture> reloadCookie() {
- return timer().record(() -> CompletableFuture.completedFuture(Optional.ofNullable(stealCookieImpl())));
+ return timer().record(() -> {
+ boolean acquired = LOCK.writeLock().tryLock();
+ if (!acquired) {
+ log.info("Lock failed to be acquired, bouncing...");
+ return CompletableFuture.completedFuture(Optional.empty());
+ }
+
+ try {
+ return CompletableFuture.completedFuture(Optional.ofNullable(stealCookieImpl()));
+ } finally {
+ LOCK.writeLock().unlock();
+ }
+ });
}
public String getCookie() {
@@ -177,26 +188,19 @@ public String getCsrf() {
String stealCookieImpl() {
return timer().record(() -> {
- LOCK.writeLock().lock();
- try {
- Optional auth = playwrightClient.getLeetcodeCookie(githubUsername, githubPassword);
- if (auth.isPresent()) {
- var a = auth.get();
- this.csrf = a.getCsrf();
- this.cookie = a.getToken();
- if (env.isCi()) {
- log.info("in ci, stored in redis as well");
- redisClient.setAuth(a.getToken(), 4, ChronoUnit.HOURS); // 4 hours.
- }
- this.authRepository.createAuth(Auth.builder()
- .csrf(a.getCsrf())
- .token(a.getToken())
- .createdAt(StandardizedOffsetDateTime.now())
- .build());
- return cookie;
- }
- } finally {
- LOCK.writeLock().unlock();
+ Optional auth = playwrightClient.getLeetcodeCookie(githubUsername, githubPassword);
+ if (auth.isPresent()) {
+ var a = auth.get();
+ this.csrf = a.getCsrf();
+ this.cookie = a.getToken();
+ redisClient.setAuth(a.getToken(), 4, ChronoUnit.HOURS);
+ log.info("auth token stored in redis");
+ this.authRepository.createAuth(Auth.builder()
+ .csrf(a.getCsrf())
+ .token(a.getToken())
+ .createdAt(StandardizedOffsetDateTime.now())
+ .build());
+ return cookie;
}
return null;
});
diff --git a/src/test/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealerTest.java b/src/test/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealerTest.java
index c160d4cb6..9e2158446 100644
--- a/src/test/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealerTest.java
+++ b/src/test/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealerTest.java
@@ -3,6 +3,10 @@
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
+import ch.qos.logback.classic.Level;
+import ch.qos.logback.classic.Logger;
+import ch.qos.logback.classic.spi.ILoggingEvent;
+import ch.qos.logback.core.read.ListAppender;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import java.util.Optional;
@@ -12,6 +16,8 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
@@ -24,6 +30,7 @@
import org.patinanetwork.codebloom.common.reporter.Reporter;
import org.patinanetwork.codebloom.common.time.StandardizedOffsetDateTime;
import org.patinanetwork.codebloom.playwright.PlaywrightClient;
+import org.slf4j.LoggerFactory;
public class LeetcodeAuthStealerTest {
private LeetcodeAuthStealer leetcodeAuthStealer;
@@ -35,6 +42,8 @@ public class LeetcodeAuthStealerTest {
private MeterRegistry meterRegistry;
private PlaywrightClient playwrightClient;
+ private ListAppender logWatcher;
+
public LeetcodeAuthStealerTest() {
redisClient = mock(RedisClient.class);
authRepository = mock(AuthRepository.class);
@@ -42,17 +51,26 @@ public LeetcodeAuthStealerTest() {
env = mock(Env.class);
meterRegistry = new SimpleMeterRegistry();
playwrightClient = mock(PlaywrightClient.class);
-
- leetcodeAuthStealer = spy(
- new LeetcodeAuthStealer(redisClient, authRepository, reporter, env, meterRegistry, playwrightClient));
}
@BeforeEach
void setup() {
+ leetcodeAuthStealer = spy(
+ new LeetcodeAuthStealer(redisClient, authRepository, reporter, env, meterRegistry, playwrightClient));
+
+ logWatcher = new ListAppender<>();
+ logWatcher.start();
+ ((Logger) LoggerFactory.getLogger(leetcodeAuthStealer.getClass())).addAppender(logWatcher);
+
when(env.isCi()).thenReturn(false);
playwrightClientResolvesSlowly(Auth.builder().build());
}
+ @AfterEach
+ void teardown() {
+ ((Logger) LoggerFactory.getLogger(leetcodeAuthStealer.getClass())).detachAndStopAllAppenders();
+ }
+
private void playwrightClientResolvesSlowly(Auth authToReturn) {
when(playwrightClient.getLeetcodeCookie(any(), any())).thenAnswer(invocation -> {
FakeLag.sleep(1000);
@@ -591,4 +609,67 @@ void testReloadCookieReadWriteLockInteractionDifferentThreadPools() throws Inter
readBlockedByWrite.get(),
"Read operations from different thread pool should wait for write lock to be released");
}
+
+ @Test
+ @Timeout(10)
+ @DisplayName("reloadCookie - If one thread is stealing cookie, other thread will bounce")
+ void testReloadCookieIfOneThreadIsStealingCookieOtherThreadWillBounce() throws InterruptedException {
+ ExecutorService writePool = Executors.newFixedThreadPool(2);
+ CountDownLatch latch = new CountDownLatch(1);
+
+ writePool.execute(() -> {
+ leetcodeAuthStealer.reloadCookie();
+ });
+
+ AtomicReference> ref = new AtomicReference<>();
+ writePool.execute(() -> {
+ try {
+ try {
+ Thread.sleep(100);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+
+ ref.set(leetcodeAuthStealer.reloadCookie().join());
+ } finally {
+ latch.countDown();
+ }
+ });
+
+ latch.await(2, TimeUnit.SECONDS);
+
+ assertTrue(ref.get().isEmpty());
+ }
+
+ @Test
+ @Timeout(10)
+ @DisplayName("stealAuthCookie - If one thread is stealing cookie, other thread will bounce")
+ void testStealAuthCookieIfOneThreadIsStealingCookieOtherThreadWillBounce() throws InterruptedException {
+ ExecutorService writePool = Executors.newFixedThreadPool(2);
+ CountDownLatch latch = new CountDownLatch(1);
+
+ writePool.execute(() -> {
+ leetcodeAuthStealer.stealAuthCookie();
+ });
+
+ writePool.execute(() -> {
+ try {
+ try {
+ Thread.sleep(100);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+
+ leetcodeAuthStealer.stealAuthCookie();
+ } finally {
+ latch.countDown();
+ }
+ });
+
+ latch.await(2, TimeUnit.SECONDS);
+
+ assertTrue(logWatcher.list.stream()
+ .anyMatch(log -> log.getLevel().equals(Level.INFO)
+ && log.getFormattedMessage().contains("Lock failed to be acquired, bouncing...")));
+ }
}