diff --git a/clients/src/main/java/org/apache/kafka/common/network/NetworkReceive.java b/clients/src/main/java/org/apache/kafka/common/network/NetworkReceive.java index 3a67ce077a..60451f3a8d 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/NetworkReceive.java +++ b/clients/src/main/java/org/apache/kafka/common/network/NetworkReceive.java @@ -27,30 +27,36 @@ import java.nio.channels.ScatteringByteChannel; /** - * A size delimited Receive that consists of a 4 byte network-ordered size N followed by N bytes of content + * A size delimited Receive that consists of a 4 byte network-ordered size N + * followed by N bytes of content. */ public class NetworkReceive implements Receive { public static final String UNKNOWN_SOURCE = ""; public static final int UNLIMITED = -1; + private static final Logger log = LoggerFactory.getLogger(NetworkReceive.class); private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); private final String source; - private final ByteBuffer size; private final int maxSize; private final MemoryPool memoryPool; + + private final ByteBuffer size; private int requestedBufferSize = -1; private ByteBuffer buffer; - - public NetworkReceive(String source, ByteBuffer buffer) { - this(UNLIMITED, source); - this.buffer = buffer; + public NetworkReceive() { + this(UNKNOWN_SOURCE); } public NetworkReceive(String source) { - this(UNLIMITED, source); + this(UNLIMITED, source, MemoryPool.NONE); + } + + public NetworkReceive(String source, ByteBuffer buffer) { + this(UNLIMITED, source, MemoryPool.NONE); + this.buffer = buffer; } public NetworkReceive(int maxSize, String source) { @@ -59,14 +65,9 @@ public NetworkReceive(int maxSize, String source) { public NetworkReceive(int maxSize, String source, MemoryPool memoryPool) { this.source = source; - this.size = ByteBuffer.allocate(4); - this.buffer = null; this.maxSize = maxSize; this.memoryPool = memoryPool; - } - - public NetworkReceive() { - this(UNKNOWN_SOURCE); + this.size = ByteBuffer.allocate(4); } @Override @@ -79,35 +80,60 @@ public boolean complete() { return !size.hasRemaining() && buffer != null && !buffer.hasRemaining(); } + @Override public long readFrom(ScatteringByteChannel channel) throws IOException { int read = 0; + + // Step 1: Read the 4-byte size prefix if (size.hasRemaining()) { int bytesRead = channel.read(size); if (bytesRead < 0) throw new EOFException(); + read += bytesRead; + if (!size.hasRemaining()) { size.rewind(); int receiveSize = size.getInt(); + if (receiveSize < 0) - throw new InvalidReceiveException("Invalid receive (size = " + receiveSize + ")"); + throw new InvalidReceiveException( + "Invalid receive (size = " + receiveSize + ")" + ); + if (maxSize != UNLIMITED && receiveSize > maxSize) - throw new InvalidReceiveException("Invalid receive (size = " + receiveSize + " larger than " + maxSize + ")"); - requestedBufferSize = receiveSize; // may be 0 for some payloads (SASL) + throw new InvalidReceiveException( + "Invalid receive (size = " + receiveSize + + " larger than " + maxSize + ")" + ); + + requestedBufferSize = receiveSize; + if (receiveSize == 0) { buffer = EMPTY_BUFFER; } } } - if (buffer == null && requestedBufferSize != -1) { // we know the size we want but haven't been able to allocate it yet + + // Step 2: Allocate payload buffer via MemoryPool + if (buffer == null && requestedBufferSize != -1) { buffer = memoryPool.tryAllocate(requestedBufferSize); - if (buffer == null) - log.trace("Broker low on memory - could not allocate buffer of size {} for source {}", requestedBufferSize, source); + if (buffer == null) { + log.trace( + "Broker low on memory - could not allocate buffer of size {} for source {}", + requestedBufferSize, + source + ); + return read; + } } + + // Step 3: Read payload if (buffer != null) { int bytesRead = channel.read(buffer); if (bytesRead < 0) throw new EOFException(); + read += bytesRead; } @@ -124,31 +150,29 @@ public boolean memoryAllocated() { return buffer != null; } - @Override public void close() throws IOException { - if (buffer != null && buffer != EMPTY_BUFFER) { + if (buffer != null) { memoryPool.release(buffer); buffer = null; } } public ByteBuffer payload() { - return this.buffer; + return buffer; } public int bytesRead() { if (buffer == null) return size.position(); - return buffer.position() + size.position(); + return size.position() + buffer.position(); } /** * Returns the total size of the receive including payload and size buffer - * for use in metrics. This is consistent with {@link NetworkSend#size()} + * for use in metrics. This is consistent with {@link NetworkSend#size()}. */ public int size() { return payload().limit() + size.limit(); } - } diff --git a/clients/src/test/java/org/apache/kafka/common/network/NetworkReceiveTest.java b/clients/src/test/java/org/apache/kafka/common/network/NetworkReceiveTest.java index 80e7e9ce10..8d3f4e411b 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/NetworkReceiveTest.java +++ b/clients/src/test/java/org/apache/kafka/common/network/NetworkReceiveTest.java @@ -16,6 +16,7 @@ */ package org.apache.kafka.common.network; +import org.apache.kafka.common.memory.MemoryPool; import org.apache.kafka.test.TestUtils; import org.junit.jupiter.api.Test; @@ -27,7 +28,6 @@ import java.nio.channels.ScatteringByteChannel; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; public class NetworkReceiveTest { @@ -35,39 +35,83 @@ public class NetworkReceiveTest { @Test public void testBytesRead() throws IOException { NetworkReceive receive = new NetworkReceive(128, "0"); - assertEquals(0, receive.bytesRead()); ScatteringByteChannel channel = Mockito.mock(ScatteringByteChannel.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuffer.class); + + Mockito.when(channel.read(captor.capture())) + .thenAnswer(invocation -> { + ByteBuffer buf = captor.getValue(); + if (buf.remaining() == 4) { + buf.putInt(128); + return 4; + } + int toWrite = Math.min(buf.remaining(), 64); + buf.put(TestUtils.randomBytes(toWrite)); + return toWrite; + }); + + int totalRead = 0; + while (!receive.complete()) { + totalRead += receive.readFrom(channel); + } - ArgumentCaptor bufferCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - Mockito.when(channel.read(bufferCaptor.capture())).thenAnswer(invocation -> { - bufferCaptor.getValue().putInt(128); - return 4; - }).thenReturn(0); - - assertEquals(4, receive.readFrom(channel)); - assertEquals(4, receive.bytesRead()); - assertFalse(receive.complete()); - - Mockito.reset(channel); - Mockito.when(channel.read(bufferCaptor.capture())).thenAnswer(invocation -> { - bufferCaptor.getValue().put(TestUtils.randomBytes(64)); - return 64; - }); - - assertEquals(64, receive.readFrom(channel)); - assertEquals(68, receive.bytesRead()); - assertFalse(receive.complete()); - - Mockito.reset(channel); - Mockito.when(channel.read(bufferCaptor.capture())).thenAnswer(invocation -> { - bufferCaptor.getValue().put(TestUtils.randomBytes(64)); - return 64; - }); - - assertEquals(64, receive.readFrom(channel)); assertEquals(132, receive.bytesRead()); + assertEquals(132, totalRead); assertTrue(receive.complete()); + + receive.close(); } + @Test + public void testZeroSizePayload() throws IOException { + NetworkReceive receive = new NetworkReceive("zero"); + + ScatteringByteChannel channel = Mockito.mock(ScatteringByteChannel.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuffer.class); + + Mockito.when(channel.read(captor.capture())) + .thenAnswer(invocation -> { + ByteBuffer buf = captor.getValue(); + if (buf.remaining() == 4) { + buf.putInt(0); + return 4; + } + return 0; + }); + + receive.readFrom(channel); + + assertTrue(receive.complete()); + assertEquals(0, receive.payload().remaining()); + + receive.close(); + } + + @Test + public void testWithMemoryPoolNone() throws IOException { + NetworkReceive receive = + new NetworkReceive(1024, "none", MemoryPool.NONE); + + ScatteringByteChannel channel = Mockito.mock(ScatteringByteChannel.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuffer.class); + + Mockito.when(channel.read(captor.capture())) + .thenAnswer(invocation -> { + ByteBuffer buf = captor.getValue(); + if (buf.remaining() == 4) { + buf.putInt(256); + return 4; + } + buf.put(TestUtils.randomBytes(buf.remaining())); + return buf.remaining(); + }); + + while (!receive.complete()) { + receive.readFrom(channel); + } + + assertTrue(receive.complete()); + receive.close(); + } }