Skip to content
56 changes: 45 additions & 11 deletions src/main/java/io/termd/core/http/HttpTtyConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
import io.termd.core.function.BiConsumer;
import io.termd.core.function.Consumer;
import io.termd.core.io.BinaryDecoder;
import io.termd.core.io.BinaryEncoder;
import io.termd.core.tty.TtyConnectionSupport;
import io.termd.core.tty.TtyEvent;
import io.termd.core.tty.TtyEventDecoder;
import io.termd.core.tty.TtyOutputMode;
import io.termd.core.io.BufferBinaryEncoder;
import io.termd.core.tty.*;
import io.termd.core.util.Helper;
import io.termd.core.util.Vector;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.charset.Charset;
import java.util.Map;

Expand All @@ -52,6 +52,7 @@
*
* @author <a href="mailto:julien@julienviet.com">Julien Viet</a>
* @author <a href="mailto:matejonnet@gmail.com">Matej Lazar</a>
* @author gongdewei 2020/05/20
*/
public abstract class HttpTtyConnection extends TtyConnectionSupport {

Expand All @@ -62,10 +63,12 @@ public abstract class HttpTtyConnection extends TtyConnectionSupport {
private Consumer<Vector> sizeHandler;
private final TtyEventDecoder eventDecoder;
private final BinaryDecoder decoder;
private final Consumer<int[]> stdout;
private final Consumer<IntBuffer> stdout;
private final Consumer<int[]> stdoutWrapper;
private Consumer<Void> closeHandler;
private Consumer<String> termHandler;
private long lastAccessedTime = System.currentTimeMillis();
private final IntBuffer codePointBuf = IntBuffer.allocate(8192);

public HttpTtyConnection() {
this(Charset.forName("UTF-8"), DEFAULT_SIZE);
Expand All @@ -76,12 +79,38 @@ public HttpTtyConnection(Charset charset, Vector size) {
this.size = size;
this.eventDecoder = new TtyEventDecoder(3, 26, 4);
this.decoder = new BinaryDecoder(512, charset, eventDecoder);
this.stdout = new TtyOutputMode(new BinaryEncoder(charset, new Consumer<byte[]>() {
this.stdout = new BufferTtyOutputMode(new BufferBinaryEncoder(charset, new Consumer<ByteBuffer>() {
@Override
public void accept(byte[] bytes) {
write(bytes);
public void accept(ByteBuffer data) {
write(data.array(), data.position(), data.remaining());
}
}));
this.stdoutWrapper = new Consumer<int[]>() {
@Override
public void accept(int[] data) {
stdout.accept(IntBuffer.wrap(data));
}
};
}

@Override
public TtyConnection write(String s) {
synchronized (this) {
int count = Helper.codePointCount(s);
IntBuffer buffer = null;
if (count <= codePointBuf.capacity()) {
buffer = codePointBuf;
} else {
buffer = IntBuffer.allocate(count);
}

buffer.clear();
Helper.toCodePoints(s, buffer);
buffer.flip();

stdout.accept(buffer);
return this;
}
}

@Override
Expand All @@ -104,7 +133,11 @@ public String terminalType() {
return "vt100";
}

protected abstract void write(byte[] buffer);
protected void write(byte[] buffer) {
this.write(buffer, 0, buffer.length);
}

protected abstract void write(byte[] buffer, int offset, int length);

/**
* Special case to handle tty events.
Expand Down Expand Up @@ -195,7 +228,8 @@ public void setStdinHandler(Consumer<int[]> handler) {
}

public Consumer<int[]> stdoutHandler() {
return stdout;
//TODO replace with Consumer<IntBuffer>
return stdoutWrapper;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ protected void initChannel(SocketChannel ch) throws Exception {

pipeline.addLast(httpRequestHandler);
pipeline.addLast(new WebSocketServerProtocolHandler("/ws"));
pipeline.addLast(new TtyWebSocketFrameHandler(group, handler));
pipeline.addLast(new TtyWebSocketFrameHandler(group, handler, HttpRequestHandler.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package io.termd.core.http.netty;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.group.ChannelGroup;
Expand All @@ -26,6 +27,7 @@
import io.termd.core.function.Consumer;
import io.termd.core.http.HttpTtyConnection;
import io.termd.core.tty.TtyConnection;
import io.termd.core.util.ByteBufPool;

import java.util.concurrent.TimeUnit;

Expand All @@ -38,10 +40,21 @@ public class TtyWebSocketFrameHandler extends SimpleChannelInboundHandler<TextWe
private final Consumer<TtyConnection> handler;
private ChannelHandlerContext context;
private HttpTtyConnection conn;
private Class removingHandlerClass;
private final ByteBufPool byteBufPool;

public TtyWebSocketFrameHandler(ChannelGroup group, Consumer<TtyConnection> handler) {
/**
* Create TtyWebSocketFrameHandler
*
* @param group
* @param handler tty connection handler
* @param removingHandlerClass removing specify handler class after protocol upgrade
*/
public TtyWebSocketFrameHandler(ChannelGroup group, Consumer<TtyConnection> handler, Class removingHandlerClass) {
this.group = group;
this.handler = handler;
this.removingHandlerClass = removingHandlerClass;
this.byteBufPool = new ByteBufPool();
}

@Override
Expand All @@ -53,16 +66,50 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt == WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
ctx.pipeline().remove(HttpRequestHandler.class);
if (removingHandlerClass != null) {
ctx.pipeline().remove(removingHandlerClass);
}
group.add(ctx.channel());
conn = new HttpTtyConnection() {
@Override
protected void write(byte[] buffer) {
ByteBuf byteBuf = Unpooled.buffer();
byteBuf.writeBytes(buffer);
if (context != null) {
context.writeAndFlush(new TextWebSocketFrame(byteBuf));
protected void write(byte[] buffer, int offset, int length) {

int start = offset;
int remain = length;
while (remain > 0) {
if (context == null) {
break;
}

//ByteBuf byteBuf = PooledByteBufAllocator.DEFAULT.buffer(remain<=32?32: (remain<=64?64: byteBufSize));
final ByteBuf byteBuf = byteBufPool.get(50, TimeUnit.MILLISECONDS);
boolean done = false;
int size = 0;

try {
//write segment
size = Math.min(remain, byteBuf.writableBytes());
byteBuf.writeBytes(buffer, start, size);
if (context != null) {
context.writeAndFlush(new TextWebSocketFrame(byteBuf)).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
byteBufPool.put(byteBuf);
}
});
done = true;
}
} finally {
if (!done) {
//discard
byteBufPool.discard(byteBuf);
}
}

start += size;
remain -= size;
}

}

@Override
Expand All @@ -84,6 +131,7 @@ public void close() {
if (context != null) {
context.close();
}
byteBufPool.release();
}
};
handler.accept(conn);
Expand Down
127 changes: 127 additions & 0 deletions src/main/java/io/termd/core/io/BufferBinaryEncoder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright 2015 Julien Viet
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.termd.core.io;

import io.termd.core.function.Consumer;

import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.IntBuffer;
import java.nio.charset.Charset;
import java.nio.charset.CharsetEncoder;
import java.nio.charset.CodingErrorAction;

/**
* @author <a href="mailto:julien@julienviet.com">Julien Viet</a>
* @author gongdewei 2020/05/19
*/
public class BufferBinaryEncoder implements Consumer<IntBuffer> {

private CharsetEncoder charsetEncoder;
private volatile Charset charset;
private final Consumer<ByteBuffer> onByte;
private final char[] charsBuf = new char[2];
private int capacity = 8192;
private ByteBuffer cachedByteBuffer;
private CharBuffer cachedCharBuffer;

public BufferBinaryEncoder(Charset charset, Consumer<ByteBuffer> onByte) {
this.setCharset(charset);
this.onByte = onByte;
}

/**
* Set a new charset on the encoder.
*
* @param charset the new charset
*/
public void setCharset(Charset charset) {
this.charset = charset;
charsetEncoder = charset.newEncoder()
.onMalformedInput(CodingErrorAction.REPLACE)
.onUnmappableCharacter(CodingErrorAction.REPLACE);

//check buffer
ensureBuffer();
}

@Override
public void accept(IntBuffer codePoints) {
int[] array = codePoints.array();
int offset = codePoints.position();
int limit = codePoints.limit();

int capacity = 0;
for (int i = offset; i < limit; i++) {
capacity += Character.charCount(array[i]);
}

//charsetEncoder/charsBuf/cachedBuffer are not thread-safe
synchronized (this) {
//convert code points to chars
CharBuffer charBuffer = getCharBuffer(capacity);
for (int i = offset; i < limit; i++) {
int size = Character.toChars(array[i], charsBuf, 0);
charBuffer.put(charsBuf, 0, size);
}
charBuffer.flip();

//encode chars to bytes
ByteBuffer byteBuffer = getByteBuffer(getByteCapacity(capacity));
charsetEncoder.encode(charBuffer, byteBuffer, true);
byteBuffer.flip();

onByte.accept(byteBuffer);
}
}

private CharBuffer getCharBuffer(int capacity) {
CharBuffer charBuffer = null;
if (capacity <= cachedCharBuffer.capacity()) {
charBuffer = cachedCharBuffer;
charBuffer.clear();
} else {
charBuffer = CharBuffer.allocate(capacity);
}
return charBuffer;
}

private ByteBuffer getByteBuffer(int capacity) {
ByteBuffer byteBuffer = null;
if (capacity <= cachedByteBuffer.capacity()) {
byteBuffer = cachedByteBuffer;
byteBuffer.clear();
} else {
byteBuffer = ByteBuffer.allocate(capacity);
}
return byteBuffer;
}

private void ensureBuffer() {
if (cachedCharBuffer ==null || cachedCharBuffer.limit() != capacity) {
cachedCharBuffer = CharBuffer.allocate(capacity);
}
int byteBufCapacity = getByteCapacity(capacity);
if (cachedByteBuffer ==null || cachedByteBuffer.limit() != byteBufCapacity) {
cachedByteBuffer = ByteBuffer.allocate(byteBufCapacity);
}
}

private int getByteCapacity(int capacity) {
return (int) (capacity *charsetEncoder.averageBytesPerChar());
}
}
Loading