Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,7 @@
<include>src/test/java/com/rabbitmq/client/test/server/Permissions.java</include>
<include>src/test/java/com/rabbitmq/client/test/PublishWithByteBufferTest.java</include>
<include>src/test/java/com/rabbitmq/client/test/ByteBufferPublishTest.java</include>
<include>src/test/java/com/rabbitmq/client/test/InboundFrameMax.java</include>
</includes>
<googleJavaFormat>
<version>${google-java-format.version}</version>
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/rabbitmq/client/ConnectionFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() throws IO
this.nettyConf.enqueuingTimeout,
connectionTimeout,
socketConf,
maxInboundMessageBodySize,
AMQP.FRAME_MIN_SIZE,
this.automaticRecovery,
recoveryCondition);
}
Expand All @@ -1097,7 +1097,7 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() throws IO
isSSL(),
this.shutdownExecutor,
sslContextFactory,
this.maxInboundMessageBodySize);
AMQP.FRAME_MIN_SIZE);
}
}

Expand Down
7 changes: 2 additions & 5 deletions src/main/java/com/rabbitmq/client/impl/AMQConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -435,11 +435,8 @@ public void start()

// Inbound payload limit: the smaller of frame_max (less framing
// overhead) and the configured message body cap.
if (frameMax > 0) {
_frameHandler.setMaxInboundFramePayloadSize(
Math.min(this.maxInboundMessageBodySize,
frameMax - AMQCommand.EMPTY_FRAME_SIZE + 1));
}
_frameHandler.setFrameMax(
Math.min(this.maxInboundMessageBodySize, frameMax));

int negotiatedHeartbeat =
negotiatedMaxValue(this.requestedHeartbeat,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ public abstract class AbstractFrameHandlerFactory implements FrameHandlerFactory
protected final int connectionTimeout;
protected final SocketConfigurator configurator;
protected final boolean ssl;
protected final int maxInboundMessageBodySize;
protected final int frameMax;

protected AbstractFrameHandlerFactory(int connectionTimeout, SocketConfigurator configurator,
boolean ssl, int maxInboundMessageBodySize) {
boolean ssl, int frameMax) {
this.connectionTimeout = connectionTimeout;
this.configurator = configurator;
this.ssl = ssl;
this.maxInboundMessageBodySize = maxInboundMessageBodySize;
this.frameMax = frameMax;
}
}
17 changes: 6 additions & 11 deletions src/main/java/com/rabbitmq/client/impl/Frame.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.function.IntSupplier;

import static java.lang.String.format;

/**
Expand Down Expand Up @@ -106,7 +108,7 @@ public static Frame fromBodyFragment(int channelNumber, ByteBuffer body, int off
*
* @return a new Frame if we read a frame successfully, otherwise null
*/
public static Frame readFrom(DataInputStream is, int maxPayloadSize) throws IOException {
public static Frame readFrom(DataInputStream is, IntSupplier payloadLimit) throws IOException {
int type;
int channel;

Expand All @@ -130,16 +132,9 @@ public static Frame readFrom(DataInputStream is, int maxPayloadSize) throws IOEx
}

channel = is.readUnsignedShort();
int payloadSize = is.readInt();
if (payloadSize < 0 || payloadSize >= maxPayloadSize) {
throw new MalformedFrameException(format(
"Frame body size is invalid (%d), maximum configured size is %d. " +
"See ConnectionFactory#setMaxInboundMessageBodySize " +
"if you need to increase the limit.",
payloadSize, maxPayloadSize
));
}
byte[] payload = new byte[payloadSize];
int frameSize = is.readInt();
Utils.enforceFrameMax(frameSize, payloadLimit.getAsInt());
byte[] payload = new byte[frameSize];
is.readFully(payload);

int frameEndMarker = is.readUnsignedByte();
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/rabbitmq/client/impl/FrameHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ default void finishConnectionNegotiation() {
}

/** Cap inbound frame payloads, applied once frame_max is negotiated. */
default void setMaxInboundFramePayloadSize(int maxPayloadSize) {
default void setFrameMax(int frameMax) {

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
// info@rabbitmq.com.
package com.rabbitmq.client.impl;

import static java.lang.String.format;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;

Expand All @@ -32,6 +31,7 @@
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
Expand Down Expand Up @@ -89,10 +89,10 @@ public NettyFrameHandlerFactory(
Duration enqueuingTimeout,
int connectionTimeout,
SocketConfigurator configurator,
int maxInboundMessageBodySize,
int frameMax,
boolean automaticRecovery,
Predicate<ShutdownSignalException> recoveryCondition) {
super(connectionTimeout, configurator, sslContextFactory != null, maxInboundMessageBodySize);
super(connectionTimeout, configurator, sslContextFactory != null, frameMax);
this.eventLoopGroup = eventLoopGroup;
this.sslContextFactory = sslContextFactory == null ? connName -> null : sslContextFactory;
this.channelCustomizer = channelCustomizer == null ? Utils.noOpConsumer() : channelCustomizer;
Expand Down Expand Up @@ -151,7 +151,7 @@ private static void closeNettyState(Channel channel, EventLoopGroup eventLoopGro
public FrameHandler create(Address addr, String connectionName) throws IOException {
SslContext sslContext = this.sslContextFactory.apply(connectionName);
return new NettyFrameHandler(
this.maxInboundMessageBodySize,
this.frameMax,
addr,
sslContext,
this.eventLoopGroup,
Expand Down Expand Up @@ -191,7 +191,7 @@ private static final class NettyFrameHandler implements FrameHandler {
private final int zeroCopyThreshold = 1024;

private NettyFrameHandler(
int maxInboundMessageBodySize,
int frameMax,
Address addr,
SslContext sslContext,
EventLoopGroup elg,
Expand Down Expand Up @@ -232,13 +232,7 @@ private NettyFrameHandler(
b.option(ChannelOption.ALLOCATOR, Utils.byteBufAllocator());
}

// type + channel + payload size + payload + frame end marker
int maxFrameLength = 1 + 2 + 4 + maxInboundMessageBodySize + 1;
int lengthFieldOffset = 3;
int lengthFieldLength = 4;
int lengthAdjustement = 1;
AmqpHandler amqpHandler =
new AmqpHandler(maxInboundMessageBodySize, this::close, willRecover);
AmqpHandler amqpHandler = new AmqpHandler(frameMax, this::close, willRecover);
int port = ConnectionFactory.portOrDefault(addr.getPort(), sslContext != null);
b.handler(
new ChannelInitializer<SocketChannel>() {
Expand All @@ -251,15 +245,7 @@ public void initChannel(SocketChannel ch) {
FlushConsolidationHandler.DEFAULT_EXPLICIT_FLUSH_AFTER_FLUSHES, true));
ch.pipeline()
.addLast(HANDLER_PROTOCOL_VERSION_MISMATCH, new ProtocolVersionMismatchHandler());
ch.pipeline()
.addLast(
HANDLER_FRAME_DECODER,
new LengthFieldBasedFrameDecoder(
maxFrameLength,
lengthFieldOffset,
lengthFieldLength,
lengthAdjustement,
0));
ch.pipeline().addLast(HANDLER_FRAME_DECODER, createFrameDecoder(frameMax));
ch.pipeline().addLast(AmqpHandler.class.getSimpleName(), amqpHandler);
if (sslContext != null) {
SslHandler sslHandler = sslContext.newHandler(ch.alloc(), addr.getHost(), port);
Expand Down Expand Up @@ -352,8 +338,11 @@ public void finishConnectionNegotiation() {
}

@Override
public void setMaxInboundFramePayloadSize(int maxPayloadSize) {
this.handler.maxPayloadSize = maxPayloadSize;
public void setFrameMax(int frameMax) {
this.channel
.pipeline()
.replace(HANDLER_FRAME_DECODER, HANDLER_FRAME_DECODER, createFrameDecoder(frameMax));
this.handler.setFrameMax(frameMax);
}

@Override
Expand Down Expand Up @@ -507,11 +496,19 @@ InetSocketAddress maybeInetSocketAddress(SocketAddress socketAddress) {
return null;
}
}

private ChannelHandler createFrameDecoder(int frameMax) {
int lengthFieldOffset = 3;
int lengthFieldLength = 4;
int lengthAdjustement = 1; // frame-end byte
return new LengthFieldBasedFrameDecoder(
frameMax, lengthFieldOffset, lengthFieldLength, lengthAdjustement, 0);
}
}

private static class AmqpHandler extends ChannelInboundHandlerAdapter {

private volatile int maxPayloadSize;
private volatile int framePayloadLimit;
private final Runnable closeSequence;
private final Predicate<ShutdownSignalException> willRecover;
private volatile AMQConnection connection;
Expand All @@ -524,15 +521,20 @@ private static class AmqpHandler extends ChannelInboundHandlerAdapter {
private final String id;

private AmqpHandler(
int maxPayloadSize,
Runnable closeSequence,
Predicate<ShutdownSignalException> willRecover) {
this.maxPayloadSize = maxPayloadSize;
int frameMax, Runnable closeSequence, Predicate<ShutdownSignalException> willRecover) {
this.setFrameMax(frameMax);
this.closeSequence = closeSequence;
this.willRecover = willRecover;
this.id = "amqp-handler-" + SEQUENCE.getAndIncrement();
}

private void setFrameMax(int frameMax) {
if (frameMax > 0 && frameMax < AMQP.FRAME_MIN_SIZE) {
frameMax = AMQP.FRAME_MIN_SIZE;
}
this.framePayloadLimit = Utils.framePayloadLimit(frameMax);
}

@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
this.ch = ctx.channel();
Expand All @@ -545,17 +547,10 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
try {
int type = m.readUnsignedByte();
int channel = m.readUnsignedShort();
int payloadSize = m.readInt();
if (payloadSize < 0 || payloadSize >= maxPayloadSize) {
throw new MalformedFrameException(
format(
"Frame body size is invalid (%d), maximum configured size is %d. "
+ "See ConnectionFactory#setMaxInboundMessageBodySize "
+ "if you need to increase the limit.",
payloadSize, maxPayloadSize));
}
int frameSize = m.readInt();
Utils.enforceFrameMax(frameSize, this.framePayloadLimit);

byte[] payload = new byte[payloadSize];
byte[] payload = new byte[frameSize];
m.readBytes(payload);

int frameEndMarker = m.readUnsignedByte();
Expand Down
19 changes: 13 additions & 6 deletions src/main/java/com/rabbitmq/client/impl/SocketFrameHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.IntSupplier;

/**
* A socket-based frame handler.
Expand All @@ -57,7 +58,8 @@ public class SocketFrameHandler implements FrameHandler {
private final DataOutputStream _outputStream;
private final Lock _outputStreamLock = new ReentrantLock();

private volatile int maxInboundMessageBodySize;
private volatile int framePayloadLimit;
private final IntSupplier payloadLimitSupplier;

/** Time to linger before closing the socket forcefully. */
public static final int SOCKET_CLOSING_TIMEOUT = 1;
Expand All @@ -73,13 +75,14 @@ public SocketFrameHandler(Socket socket) throws IOException {
* @param socket the socket to use
*/
public SocketFrameHandler(Socket socket, ExecutorService shutdownExecutor,
int maxInboundMessageBodySize) throws IOException {
int frameMax) throws IOException {
_socket = socket;
_shutdownExecutor = shutdownExecutor;
this.maxInboundMessageBodySize = maxInboundMessageBodySize;
this.setFrameMax(frameMax);

_inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream()));
_outputStream = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()));
this.payloadLimitSupplier = () -> this.framePayloadLimit;
}

@Override
Expand Down Expand Up @@ -194,15 +197,19 @@ public void initialize(AMQConnection connection) {
}

@Override
public void setMaxInboundFramePayloadSize(int maxPayloadSize) {
this.maxInboundMessageBodySize = maxPayloadSize;
public void setFrameMax(int frameMax) {
this.framePayloadLimit = Utils.framePayloadLimit(frameMax);
}

@Override
public Frame readFrame() throws IOException {
_inputStreamLock.lock();
try {
return Frame.readFrom(_inputStream, this.maxInboundMessageBodySize);
// we need to check frameMax against the latest value, hence the supplier
// otherwise we can start waiting for a new frame with the current limit
// and the frameMax is changed while we wait, so the next frame is checked against
// a stale value
return Frame.readFrom(_inputStream, this.payloadLimitSupplier);
} finally {
_inputStreamLock.unlock();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ protected Socket createSocket(String connectionName) throws IOException {

public FrameHandler create(Socket sock) throws IOException
{
return new SocketFrameHandler(sock, this.shutdownExecutor, this.maxInboundMessageBodySize);
return new SocketFrameHandler(sock, this.shutdownExecutor, this.frameMax);
}

private static void quietTrySocketClose(Socket socket) {
Expand Down
29 changes: 29 additions & 0 deletions src/main/java/com/rabbitmq/client/impl/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@

package com.rabbitmq.client.impl;

import com.rabbitmq.client.AMQP;
import com.rabbitmq.client.MalformedFrameException;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.MultiThreadIoEventLoopGroup;
import io.netty.channel.nio.NioIoHandler;

import java.util.function.Consumer;

import static java.lang.String.format;

final class Utils {

@SuppressWarnings("rawtypes")
Expand Down Expand Up @@ -69,4 +73,29 @@ static ByteBufAllocator byteBufAllocator() {
static <T> Consumer<T> noOpConsumer() {
return (Consumer<T>) NO_OP_CONSUMER;
}

static int framePayloadLimit(int frameMax) {
if (frameMax <= 0) {
return Integer.MAX_VALUE;
} else if (frameMax < AMQP.FRAME_MIN_SIZE) {
return AMQP.FRAME_MIN_SIZE - AMQCommand.EMPTY_FRAME_SIZE;
} else {
return frameMax - AMQCommand.EMPTY_FRAME_SIZE;
}
}

static void enforceFrameMax(int framePayloadSize, int framePayloadLimit) throws MalformedFrameException {
if (framePayloadSize < 0 || framePayloadSize > framePayloadLimit) {
throw new MalformedFrameException(
format(
"Frame size is invalid (%d), maximum configured size is %d. "
+ "See ConnectionFactory#setMaxInboundMessageBodySize "
+ "if you need to increase the limit.",
frameSizeFromPayloadSize(framePayloadSize), frameSizeFromPayloadSize(framePayloadLimit)));
}
}

private static int frameSizeFromPayloadSize(int limit) {
return limit + AMQCommand.EMPTY_FRAME_SIZE;
}
}
Loading