Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,36 @@
import java.nio.channels.ScatteringByteChannel;

/**
* A size delimited Receive that consists of a 4 byte network-ordered size N followed by N bytes of content
* A size delimited Receive that consists of a 4 byte network-ordered size N
* followed by N bytes of content.
*/
public class NetworkReceive implements Receive {

public static final String UNKNOWN_SOURCE = "";
public static final int UNLIMITED = -1;

private static final Logger log = LoggerFactory.getLogger(NetworkReceive.class);
private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0);

private final String source;
private final ByteBuffer size;
private final int maxSize;
private final MemoryPool memoryPool;

private final ByteBuffer size;
private int requestedBufferSize = -1;
private ByteBuffer buffer;


public NetworkReceive(String source, ByteBuffer buffer) {
this(UNLIMITED, source);
this.buffer = buffer;
public NetworkReceive() {
this(UNKNOWN_SOURCE);
}

public NetworkReceive(String source) {
this(UNLIMITED, source);
this(UNLIMITED, source, MemoryPool.NONE);
}

public NetworkReceive(String source, ByteBuffer buffer) {
this(UNLIMITED, source, MemoryPool.NONE);
this.buffer = buffer;
}

public NetworkReceive(int maxSize, String source) {
Expand All @@ -59,14 +65,9 @@ public NetworkReceive(int maxSize, String source) {

public NetworkReceive(int maxSize, String source, MemoryPool memoryPool) {
this.source = source;
this.size = ByteBuffer.allocate(4);
this.buffer = null;
this.maxSize = maxSize;
this.memoryPool = memoryPool;
}

public NetworkReceive() {
this(UNKNOWN_SOURCE);
this.size = ByteBuffer.allocate(4);
}

@Override
Expand All @@ -79,35 +80,60 @@ public boolean complete() {
return !size.hasRemaining() && buffer != null && !buffer.hasRemaining();
}

@Override
public long readFrom(ScatteringByteChannel channel) throws IOException {
int read = 0;

// Step 1: Read the 4-byte size prefix
if (size.hasRemaining()) {
int bytesRead = channel.read(size);
if (bytesRead < 0)
throw new EOFException();

read += bytesRead;

if (!size.hasRemaining()) {
size.rewind();
int receiveSize = size.getInt();

if (receiveSize < 0)
throw new InvalidReceiveException("Invalid receive (size = " + receiveSize + ")");
throw new InvalidReceiveException(
"Invalid receive (size = " + receiveSize + ")"
);

if (maxSize != UNLIMITED && receiveSize > maxSize)
throw new InvalidReceiveException("Invalid receive (size = " + receiveSize + " larger than " + maxSize + ")");
requestedBufferSize = receiveSize; // may be 0 for some payloads (SASL)
throw new InvalidReceiveException(
"Invalid receive (size = " + receiveSize +
" larger than " + maxSize + ")"
);

requestedBufferSize = receiveSize;

if (receiveSize == 0) {
buffer = EMPTY_BUFFER;
}
}
}
if (buffer == null && requestedBufferSize != -1) { // we know the size we want but haven't been able to allocate it yet

// Step 2: Allocate payload buffer via MemoryPool
if (buffer == null && requestedBufferSize != -1) {
buffer = memoryPool.tryAllocate(requestedBufferSize);
if (buffer == null)
log.trace("Broker low on memory - could not allocate buffer of size {} for source {}", requestedBufferSize, source);
if (buffer == null) {
log.trace(
"Broker low on memory - could not allocate buffer of size {} for source {}",
requestedBufferSize,
source
);
return read;
}
}

// Step 3: Read payload
if (buffer != null) {
int bytesRead = channel.read(buffer);
if (bytesRead < 0)
throw new EOFException();

read += bytesRead;
}

Expand All @@ -124,31 +150,29 @@ public boolean memoryAllocated() {
return buffer != null;
}


@Override
public void close() throws IOException {
if (buffer != null && buffer != EMPTY_BUFFER) {
if (buffer != null) {
memoryPool.release(buffer);
buffer = null;
}
}

public ByteBuffer payload() {
return this.buffer;
return buffer;
}

public int bytesRead() {
if (buffer == null)
return size.position();
return buffer.position() + size.position();
return size.position() + buffer.position();
}

/**
* Returns the total size of the receive including payload and size buffer
* for use in metrics. This is consistent with {@link NetworkSend#size()}
* for use in metrics. This is consistent with {@link NetworkSend#size()}.
*/
public int size() {
return payload().limit() + size.limit();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.kafka.common.network;

import org.apache.kafka.common.memory.MemoryPool;
import org.apache.kafka.test.TestUtils;

import org.junit.jupiter.api.Test;
Expand All @@ -27,47 +28,90 @@
import java.nio.channels.ScatteringByteChannel;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class NetworkReceiveTest {

@Test
public void testBytesRead() throws IOException {
NetworkReceive receive = new NetworkReceive(128, "0");
assertEquals(0, receive.bytesRead());

ScatteringByteChannel channel = Mockito.mock(ScatteringByteChannel.class);
ArgumentCaptor<ByteBuffer> captor = ArgumentCaptor.forClass(ByteBuffer.class);

Mockito.when(channel.read(captor.capture()))
.thenAnswer(invocation -> {
ByteBuffer buf = captor.getValue();
if (buf.remaining() == 4) {
buf.putInt(128);
return 4;
}
int toWrite = Math.min(buf.remaining(), 64);
buf.put(TestUtils.randomBytes(toWrite));
return toWrite;
});

int totalRead = 0;
while (!receive.complete()) {
totalRead += receive.readFrom(channel);
}

ArgumentCaptor<ByteBuffer> bufferCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
Mockito.when(channel.read(bufferCaptor.capture())).thenAnswer(invocation -> {
bufferCaptor.getValue().putInt(128);
return 4;
}).thenReturn(0);

assertEquals(4, receive.readFrom(channel));
assertEquals(4, receive.bytesRead());
assertFalse(receive.complete());

Mockito.reset(channel);
Mockito.when(channel.read(bufferCaptor.capture())).thenAnswer(invocation -> {
bufferCaptor.getValue().put(TestUtils.randomBytes(64));
return 64;
});

assertEquals(64, receive.readFrom(channel));
assertEquals(68, receive.bytesRead());
assertFalse(receive.complete());

Mockito.reset(channel);
Mockito.when(channel.read(bufferCaptor.capture())).thenAnswer(invocation -> {
bufferCaptor.getValue().put(TestUtils.randomBytes(64));
return 64;
});

assertEquals(64, receive.readFrom(channel));
assertEquals(132, receive.bytesRead());
assertEquals(132, totalRead);
assertTrue(receive.complete());

receive.close();
}

@Test
public void testZeroSizePayload() throws IOException {
NetworkReceive receive = new NetworkReceive("zero");

ScatteringByteChannel channel = Mockito.mock(ScatteringByteChannel.class);
ArgumentCaptor<ByteBuffer> captor = ArgumentCaptor.forClass(ByteBuffer.class);

Mockito.when(channel.read(captor.capture()))
.thenAnswer(invocation -> {
ByteBuffer buf = captor.getValue();
if (buf.remaining() == 4) {
buf.putInt(0);
return 4;
}
return 0;
});

receive.readFrom(channel);

assertTrue(receive.complete());
assertEquals(0, receive.payload().remaining());

receive.close();
}

@Test
public void testWithMemoryPoolNone() throws IOException {
NetworkReceive receive =
new NetworkReceive(1024, "none", MemoryPool.NONE);

ScatteringByteChannel channel = Mockito.mock(ScatteringByteChannel.class);
ArgumentCaptor<ByteBuffer> captor = ArgumentCaptor.forClass(ByteBuffer.class);

Mockito.when(channel.read(captor.capture()))
.thenAnswer(invocation -> {
ByteBuffer buf = captor.getValue();
if (buf.remaining() == 4) {
buf.putInt(256);
return 4;
}
buf.put(TestUtils.randomBytes(buf.remaining()));
return buf.remaining();
});

while (!receive.complete()) {
receive.readFrom(channel);
}

assertTrue(receive.complete());
receive.close();
}
}
Loading