Validate response types

This commit is contained in:
s809
2023-09-19 20:19:25 +05:00
parent fccd197cd5
commit a80bb082b3
6 changed files with 39 additions and 27 deletions
@@ -88,7 +88,7 @@ public class RemoteFullDataFileHandler extends GeneratedFullDataFileHandler
this.writeChunkDataToFile(new DhSectionPos(pos.getDetailLevel(), pos.getX(), pos.getZ()), data);
};
this.networkState.getClient().<FullDataChangeSummaryResponseMessage>sendRequest(new FullDataChangeSummaryRequestMessage(level.getLevelWrapper(), block))
this.networkState.getClient().sendRequest(new FullDataChangeSummaryRequestMessage(level.getLevelWrapper(), block), FullDataChangeSummaryResponseMessage.class)
.handle((response, throwable) ->
{
try
@@ -101,7 +101,7 @@ public class WorldRemoteGenerationQueue implements IWorldGenerationQueue, IDebug
return;
};
CompletableFuture<GenTaskPriorityResponseMessage> request = this.networkState.getClient().sendRequest(new GenTaskPriorityRequestMessage(posList));
CompletableFuture<GenTaskPriorityResponseMessage> request = this.networkState.getClient().sendRequest(new GenTaskPriorityRequestMessage(posList), GenTaskPriorityResponseMessage.class);
genTaskPriorityRequest = request;
request.handleAsync((response, throwable) -> {
try
@@ -163,7 +163,7 @@ public class WorldRemoteGenerationQueue implements IWorldGenerationQueue, IDebug
DhSectionPos sectionPos = mapEntry.getKey();
WorldGenQueueEntry entry = mapEntry.getValue();
CompletableFuture<FullDataSourceResponseMessage> request = this.networkState.getClient().sendRequest(new FullDataSourceRequestMessage(sectionPos));
CompletableFuture<FullDataSourceResponseMessage> request = this.networkState.getClient().sendRequest(new FullDataSourceRequestMessage(sectionPos), FullDataSourceResponseMessage.class);
entry.request = request;
request.handleAsync((response, throwable) ->
{
@@ -47,12 +47,12 @@ public class ClientNetworkState implements Closeable
{
LOGGER.info("Connected to server: "+helloMessage.getChannelContext().channel().remoteAddress());
this.getClient().<AckMessage>sendRequest(new PlayerUUIDMessage(playerUUID))
.thenCompose(ack -> this.getClient().<RemotePlayerConfigMessage>sendRequest(new RemotePlayerConfigMessage(new MultiplayerConfig()
this.getClient().sendRequest(new PlayerUUIDMessage(playerUUID), AckMessage.class)
.thenCompose(ack -> this.getClient().sendRequest(new RemotePlayerConfigMessage(new MultiplayerConfig()
{{
renderDistance = Config.Client.Advanced.Graphics.Quality.lodChunkRenderDistance.get();
fullDataRequestRateLimit = Config.Client.Advanced.Multiplayer.serverNetworkingRateLimit.get();
}})))
}}), RemotePlayerConfigMessage.class))
.thenAccept(msg -> {
this.config = msg.payload;
})
@@ -168,9 +168,9 @@ public class NetworkClient extends NetworkEventSource implements AutoCloseable
});
}
public final <TResponse extends FutureTrackableNetworkMessage> CompletableFuture<TResponse> sendRequest(FutureTrackableNetworkMessage msg)
public final <TResponse extends FutureTrackableNetworkMessage> CompletableFuture<TResponse> sendRequest(FutureTrackableNetworkMessage msg, Class<TResponse> responseClass)
{
return this.sendRequest(this.channel.pipeline().context(MessageHandler.class), msg);
return this.sendRequest(this.channel.pipeline().context(MessageHandler.class), msg, responseClass);
}
@Override
@@ -19,9 +19,6 @@
package com.seibel.distanthorizons.core.network;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.collect.Tables;
import com.seibel.distanthorizons.core.logging.DhLoggerBuilder;
import com.seibel.distanthorizons.core.network.messages.base.CancelMessage;
import com.seibel.distanthorizons.core.network.messages.base.CloseEvent;
@@ -33,6 +30,7 @@ import io.netty.channel.ChannelException;
import io.netty.channel.ChannelHandlerContext;
import org.apache.logging.log4j.Logger;
import java.io.InvalidClassException;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CancellationException;
@@ -45,7 +43,7 @@ public abstract class NetworkEventSource
{
private static final Logger LOGGER = DhLoggerBuilder.getLogger();
protected final ConcurrentMap<Class<? extends NetworkMessage>, Set<Consumer<NetworkMessage>>> handlers = new ConcurrentHashMap<>();
private final ConcurrentMap<ChannelHandlerContext, ConcurrentMap<Long, CompletableFuture<FutureTrackableNetworkMessage>>> pendingFutures = new ConcurrentHashMap<>();
private final ConcurrentMap<ChannelHandlerContext, ConcurrentMap<Long, FutureResponseData>> pendingFutures = new ConcurrentHashMap<>();
protected boolean hasHandler(Class<? extends NetworkMessage> handlerClass)
{
@@ -70,18 +68,20 @@ public abstract class NetworkEventSource
if (message instanceof FutureTrackableNetworkMessage)
{
FutureTrackableNetworkMessage trackableMessage = (FutureTrackableNetworkMessage)message;
ConcurrentMap<Long, CompletableFuture<FutureTrackableNetworkMessage>> subMap = pendingFutures.get(message.getChannelContext());
ConcurrentMap<Long, FutureResponseData> subMap = pendingFutures.get(message.getChannelContext());
if (subMap != null)
{
CompletableFuture<FutureTrackableNetworkMessage> future = subMap.get(trackableMessage.futureId);
if (future != null)
FutureResponseData responseData = subMap.get(trackableMessage.futureId);
if (responseData != null)
{
handled = true;
if (message instanceof ExceptionMessage)
future.completeExceptionally(((ExceptionMessage) message).exception);
responseData.future.completeExceptionally(((ExceptionMessage) message).exception);
else if (message.getClass() != responseData.responseClass)
responseData.future.completeExceptionally(new InvalidClassException("Response with invalid type: expected " + responseData.responseClass.getSimpleName() + ", got:" + message));
else
future.complete(trackableMessage);
responseData.future.complete(trackableMessage);
}
}
}
@@ -97,6 +97,7 @@ public abstract class NetworkEventSource
public <T extends NetworkMessage> void registerHandler(Class<T> handlerClass, Consumer<T> handlerImplementation)
{
//noinspection unchecked
this.handlers.computeIfAbsent(handlerClass, missingHandlerClass ->
{
// Will throw if the handler class is not found
@@ -114,7 +115,7 @@ public abstract class NetworkEventSource
}
protected <TResponse extends FutureTrackableNetworkMessage> CompletableFuture<TResponse> sendRequest(ChannelHandlerContext ctx, FutureTrackableNetworkMessage msg)
protected <TResponse extends FutureTrackableNetworkMessage> CompletableFuture<TResponse> sendRequest(ChannelHandlerContext ctx, FutureTrackableNetworkMessage msg, Class<TResponse> responseClass)
{
msg.setChannelContext(ctx);
@@ -122,7 +123,7 @@ public abstract class NetworkEventSource
responseFuture.handle((response, throwable) -> {
if (!(throwable instanceof ChannelException))
{
ConcurrentMap<Long, CompletableFuture<FutureTrackableNetworkMessage>> subMap = pendingFutures.get(ctx);
ConcurrentMap<Long, FutureResponseData> subMap = pendingFutures.get(ctx);
if (subMap != null)
subMap.remove(msg.futureId);
}
@@ -133,14 +134,13 @@ public abstract class NetworkEventSource
return null;
});
ConcurrentMap<Long, CompletableFuture<FutureTrackableNetworkMessage>> subMap = pendingFutures.get(ctx);
ConcurrentMap<Long, FutureResponseData> subMap = pendingFutures.get(ctx);
if (subMap == null) {
// Was deleted before adding
responseFuture.completeExceptionally(ctx.channel().closeFuture().cause());
return responseFuture;
}
//noinspection unchecked
subMap.put(msg.futureId, (CompletableFuture<FutureTrackableNetworkMessage>) responseFuture);
subMap.put(msg.futureId, new FutureResponseData(responseClass, responseFuture));
if (!pendingFutures.containsKey(ctx)) {
// Was deleted while adding
responseFuture.completeExceptionally(ctx.channel().closeFuture().cause());
@@ -158,11 +158,11 @@ public abstract class NetworkEventSource
protected final void completeAllFuturesExceptionally(ChannelHandlerContext ctx, Throwable cause)
{
ConcurrentMap<Long, CompletableFuture<FutureTrackableNetworkMessage>> map = pendingFutures.remove(ctx);
ConcurrentMap<Long, FutureResponseData> map = pendingFutures.remove(ctx);
if (map == null) return;
for (CompletableFuture<FutureTrackableNetworkMessage> future : map.values())
future.completeExceptionally(cause);
for (FutureResponseData responseData : map.values())
responseData.future.completeExceptionally(cause);
}
protected final void completeAllFuturesExceptionally(Throwable cause)
@@ -176,4 +176,16 @@ public abstract class NetworkEventSource
this.handlers.clear();
completeAllFuturesExceptionally(new ChannelException(this.getClass().getSimpleName()+" is closed."));
}
private static class FutureResponseData
{
public final Class<? extends FutureTrackableNetworkMessage> responseClass;
public final CompletableFuture<FutureTrackableNetworkMessage> future;
private <T extends FutureTrackableNetworkMessage> FutureResponseData(Class<T> responseClass, CompletableFuture<T> future) {
this.responseClass = responseClass;
//noinspection unchecked
this.future = (CompletableFuture<FutureTrackableNetworkMessage>) future;
}
}
}
@@ -124,9 +124,9 @@ public class NetworkServer extends NetworkEventSource implements AutoCloseable
}
@Override
public <TResponse extends FutureTrackableNetworkMessage> CompletableFuture<TResponse> sendRequest(ChannelHandlerContext ctx, FutureTrackableNetworkMessage msg)
public <TResponse extends FutureTrackableNetworkMessage> CompletableFuture<TResponse> sendRequest(ChannelHandlerContext ctx, FutureTrackableNetworkMessage msg, Class<TResponse> responseClass)
{
return super.sendRequest(ctx, msg);
return super.sendRequest(ctx, msg, responseClass);
}
@Override