diff --git a/src/main/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealer.java b/src/main/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealer.java index 38759d002..ae36f776f 100644 --- a/src/main/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealer.java +++ b/src/main/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealer.java @@ -83,16 +83,21 @@ private Timer timer() { /** * DO NOT RETURN THE TOKEN IN ANY API ENDPOINT.
This function utilizes Playwright in order to get an * authentication key from Leetcode. That code is stored in the database and can then be used to run authenticated - * queries such as used to retrieve code from our user submissions. + * queries such as being used to retrieve code from our user submissions. */ @Scheduled(initialDelay = 0, fixedDelay = 1, timeUnit = TimeUnit.HOURS) public void stealAuthCookie() { timer().record(() -> { - LOCK.writeLock().lock(); + boolean acquired = LOCK.writeLock().tryLock(); + if (!acquired) { + log.info("Lock failed to be acquired, bouncing..."); + return; + } + try { Auth mostRecentAuth = authRepository.getMostRecentAuth(); - // The auth token should be refreshed every day. + // The auth token should be refreshed every 4 hours. if (mostRecentAuth != null && mostRecentAuth .getCreatedAt() @@ -100,28 +105,22 @@ public void stealAuthCookie() { log.info("Auth token already exists, using token from database."); cookie = mostRecentAuth.getToken(); csrf = mostRecentAuth.getCsrf(); - if (env.isCi()) { - log.info("in ci, stealing token and putting it in cache for 1 day"); - redisClient.setAuth(cookie, 4, ChronoUnit.HOURS); // 4 hours. - } return; } - if (env.isCi()) { - log.info("in ci env, checking redis client..."); - Optional authToken = redisClient.getAuth(); - - log.info("auth token in redis = {}", authToken.isPresent()); + log.info("falling back to checking redis client..."); + Optional authToken = redisClient.getAuth(); - if (authToken.isPresent()) { - log.info("auth token found in redis client"); - cookie = authToken.get(); - csrf = null; // don't care in ci. - return; - } + log.info("auth token in redis = {}", authToken.isPresent()); - log.info("auth token not found in redis client"); + if (authToken.isPresent()) { + log.info("auth token found in redis client"); + cookie = authToken.get(); + csrf = null; // don't care in ci. + return; } + + log.info("auth token not found in redis client"); log.info("Auth token is missing/expired. Attempting to receive token..."); stealCookieImpl(); @@ -139,7 +138,19 @@ public void stealAuthCookie() { */ @Async public CompletableFuture> reloadCookie() { - return timer().record(() -> CompletableFuture.completedFuture(Optional.ofNullable(stealCookieImpl()))); + return timer().record(() -> { + boolean acquired = LOCK.writeLock().tryLock(); + if (!acquired) { + log.info("Lock failed to be acquired, bouncing..."); + return CompletableFuture.completedFuture(Optional.empty()); + } + + try { + return CompletableFuture.completedFuture(Optional.ofNullable(stealCookieImpl())); + } finally { + LOCK.writeLock().unlock(); + } + }); } public String getCookie() { @@ -177,26 +188,19 @@ public String getCsrf() { String stealCookieImpl() { return timer().record(() -> { - LOCK.writeLock().lock(); - try { - Optional auth = playwrightClient.getLeetcodeCookie(githubUsername, githubPassword); - if (auth.isPresent()) { - var a = auth.get(); - this.csrf = a.getCsrf(); - this.cookie = a.getToken(); - if (env.isCi()) { - log.info("in ci, stored in redis as well"); - redisClient.setAuth(a.getToken(), 4, ChronoUnit.HOURS); // 4 hours. - } - this.authRepository.createAuth(Auth.builder() - .csrf(a.getCsrf()) - .token(a.getToken()) - .createdAt(StandardizedOffsetDateTime.now()) - .build()); - return cookie; - } - } finally { - LOCK.writeLock().unlock(); + Optional auth = playwrightClient.getLeetcodeCookie(githubUsername, githubPassword); + if (auth.isPresent()) { + var a = auth.get(); + this.csrf = a.getCsrf(); + this.cookie = a.getToken(); + redisClient.setAuth(a.getToken(), 4, ChronoUnit.HOURS); + log.info("auth token stored in redis"); + this.authRepository.createAuth(Auth.builder() + .csrf(a.getCsrf()) + .token(a.getToken()) + .createdAt(StandardizedOffsetDateTime.now()) + .build()); + return cookie; } return null; }); diff --git a/src/test/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealerTest.java b/src/test/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealerTest.java index c160d4cb6..9e2158446 100644 --- a/src/test/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealerTest.java +++ b/src/test/java/org/patinanetwork/codebloom/scheduled/auth/LeetcodeAuthStealerTest.java @@ -3,6 +3,10 @@ import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; +import ch.qos.logback.classic.Level; +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import java.util.Optional; @@ -12,6 +16,8 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -24,6 +30,7 @@ import org.patinanetwork.codebloom.common.reporter.Reporter; import org.patinanetwork.codebloom.common.time.StandardizedOffsetDateTime; import org.patinanetwork.codebloom.playwright.PlaywrightClient; +import org.slf4j.LoggerFactory; public class LeetcodeAuthStealerTest { private LeetcodeAuthStealer leetcodeAuthStealer; @@ -35,6 +42,8 @@ public class LeetcodeAuthStealerTest { private MeterRegistry meterRegistry; private PlaywrightClient playwrightClient; + private ListAppender logWatcher; + public LeetcodeAuthStealerTest() { redisClient = mock(RedisClient.class); authRepository = mock(AuthRepository.class); @@ -42,17 +51,26 @@ public LeetcodeAuthStealerTest() { env = mock(Env.class); meterRegistry = new SimpleMeterRegistry(); playwrightClient = mock(PlaywrightClient.class); - - leetcodeAuthStealer = spy( - new LeetcodeAuthStealer(redisClient, authRepository, reporter, env, meterRegistry, playwrightClient)); } @BeforeEach void setup() { + leetcodeAuthStealer = spy( + new LeetcodeAuthStealer(redisClient, authRepository, reporter, env, meterRegistry, playwrightClient)); + + logWatcher = new ListAppender<>(); + logWatcher.start(); + ((Logger) LoggerFactory.getLogger(leetcodeAuthStealer.getClass())).addAppender(logWatcher); + when(env.isCi()).thenReturn(false); playwrightClientResolvesSlowly(Auth.builder().build()); } + @AfterEach + void teardown() { + ((Logger) LoggerFactory.getLogger(leetcodeAuthStealer.getClass())).detachAndStopAllAppenders(); + } + private void playwrightClientResolvesSlowly(Auth authToReturn) { when(playwrightClient.getLeetcodeCookie(any(), any())).thenAnswer(invocation -> { FakeLag.sleep(1000); @@ -591,4 +609,67 @@ void testReloadCookieReadWriteLockInteractionDifferentThreadPools() throws Inter readBlockedByWrite.get(), "Read operations from different thread pool should wait for write lock to be released"); } + + @Test + @Timeout(10) + @DisplayName("reloadCookie - If one thread is stealing cookie, other thread will bounce") + void testReloadCookieIfOneThreadIsStealingCookieOtherThreadWillBounce() throws InterruptedException { + ExecutorService writePool = Executors.newFixedThreadPool(2); + CountDownLatch latch = new CountDownLatch(1); + + writePool.execute(() -> { + leetcodeAuthStealer.reloadCookie(); + }); + + AtomicReference> ref = new AtomicReference<>(); + writePool.execute(() -> { + try { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + ref.set(leetcodeAuthStealer.reloadCookie().join()); + } finally { + latch.countDown(); + } + }); + + latch.await(2, TimeUnit.SECONDS); + + assertTrue(ref.get().isEmpty()); + } + + @Test + @Timeout(10) + @DisplayName("stealAuthCookie - If one thread is stealing cookie, other thread will bounce") + void testStealAuthCookieIfOneThreadIsStealingCookieOtherThreadWillBounce() throws InterruptedException { + ExecutorService writePool = Executors.newFixedThreadPool(2); + CountDownLatch latch = new CountDownLatch(1); + + writePool.execute(() -> { + leetcodeAuthStealer.stealAuthCookie(); + }); + + writePool.execute(() -> { + try { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + leetcodeAuthStealer.stealAuthCookie(); + } finally { + latch.countDown(); + } + }); + + latch.await(2, TimeUnit.SECONDS); + + assertTrue(logWatcher.list.stream() + .anyMatch(log -> log.getLevel().equals(Level.INFO) + && log.getFormattedMessage().contains("Lock failed to be acquired, bouncing..."))); + } }