From 95e03a5035157e475ecb8df2a98ccad7a350cbf1 Mon Sep 17 00:00:00 2001 From: s0m01n3 Date: Fri, 20 Dec 2019 15:56:03 -0800 Subject: [PATCH] Fixed BulkGetFuture.get() API to honor operationTimeout instead of waiting forever. This also makes it consistent with single operation GetFuture class - both will use the same operation timeout by default. This commit fixes the problem in the enterprise environments when asyncBulkGet() API is called: client.asyncGetBulk(keys).get(); When multiple shards are used and one of them goes down without closing the connection gracefully, the above call will wait forever on the failed shard. The fix is to modify default BulkFuture.get() behavior to use operationTimeout instead of eternal wait. If a longer wait is required for bulk requests, the timeout can still be provided in the BulkFuture.get(timeout, timeoutUnit) API call. --- .../net/spy/memcached/MemcachedClient.java | 2 +- .../spy/memcached/internal/BulkGetFuture.java | 8 ++-- .../java/net/spy/memcached/TimeoutTest.java | 37 +++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/main/java/net/spy/memcached/MemcachedClient.java b/src/main/java/net/spy/memcached/MemcachedClient.java index a04644b9e..8ccb8710b 100644 --- a/src/main/java/net/spy/memcached/MemcachedClient.java +++ b/src/main/java/net/spy/memcached/MemcachedClient.java @@ -1319,7 +1319,7 @@ public BulkFuture> asyncGetBulk(Iterator keyIter, int initialLatchCount = chunks.isEmpty() ? 0 : 1; final CountDownLatch latch = new CountDownLatch(initialLatchCount); final Collection ops = new ArrayList(chunks.size()); - final BulkGetFuture rv = new BulkGetFuture(m, ops, latch, executorService); + final BulkGetFuture rv = new BulkGetFuture(m, ops, latch, operationTimeout, executorService); GetOperation.Callback cb = new GetOperation.Callback() { @Override diff --git a/src/main/java/net/spy/memcached/internal/BulkGetFuture.java b/src/main/java/net/spy/memcached/internal/BulkGetFuture.java index 0f956c3a1..48ad9e7ae 100644 --- a/src/main/java/net/spy/memcached/internal/BulkGetFuture.java +++ b/src/main/java/net/spy/memcached/internal/BulkGetFuture.java @@ -56,16 +56,18 @@ public class BulkGetFuture private final Map> rvMap; private final Collection ops; private final CountDownLatch latch; + private final long defaultTimeoutMillis; private OperationStatus status; private boolean cancelled = false; private boolean timeout = false; public BulkGetFuture(Map> m, Collection getOps, - CountDownLatch l, ExecutorService service) { + CountDownLatch l, long defaultTimeoutMillis, ExecutorService service) { super(service); rvMap = m; ops = getOps; latch = l; + this.defaultTimeoutMillis = defaultTimeoutMillis; status = null; } @@ -86,9 +88,9 @@ public boolean cancel(boolean ign) { public Map get() throws InterruptedException, ExecutionException { try { - return get(Long.MAX_VALUE, TimeUnit.MILLISECONDS); + return get(defaultTimeoutMillis, TimeUnit.MILLISECONDS); } catch (TimeoutException e) { - throw new RuntimeException("Timed out waiting forever", e); + throw new ExecutionException("Bulk operation timed out after " + defaultTimeoutMillis + " millis", e); } } diff --git a/src/test/java/net/spy/memcached/TimeoutTest.java b/src/test/java/net/spy/memcached/TimeoutTest.java index 5a4508f0e..a3056ea14 100644 --- a/src/test/java/net/spy/memcached/TimeoutTest.java +++ b/src/test/java/net/spy/memcached/TimeoutTest.java @@ -23,6 +23,10 @@ package net.spy.memcached; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; + /** * A TimeoutTest. */ @@ -90,6 +94,39 @@ public void run() { }); } + public void testAsyncGetBulkCustomTimeout() { + tryTimeout("asyncGetBulk", new Runnable() { + public void run() { + try { + client.asyncGetBulk("k", "k2").get(500, TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + throw new OperationTimeoutException("Bulk op timed out - custom timeout", e); + } catch (Exception e) { + throw new RuntimeException("Unexpected exception in bulk op", e); + } + } + }); + } + + public void testAsyncGetBulkDefaultTimeout() { + tryTimeout("asyncGetBulk", new Runnable() { + public void run() { + try { + client.asyncGetBulk("k", "k2").get(); + } catch (ExecutionException e) { + if (e.getCause() instanceof TimeoutException) { + throw new OperationTimeoutException("Bulk op timed out - default timeout", e.getCause()); + } + else { + throw new RuntimeException("Unexpected execution exception in bulk op", e); + } + } catch (Exception e) { + throw new RuntimeException("Unexpected exception in bulk op", e); + } + } + }); + } + public void testIncrTimeout() { tryTimeout("incr", new Runnable() { public void run() {