diff options
Diffstat (limited to 'src/main/java/com/orbekk/protobuf/RpcChannel.java')
-rw-r--r-- | src/main/java/com/orbekk/protobuf/RpcChannel.java | 57 |
1 files changed, 34 insertions, 23 deletions
diff --git a/src/main/java/com/orbekk/protobuf/RpcChannel.java b/src/main/java/com/orbekk/protobuf/RpcChannel.java index 56b54c2..94ab8a5 100644 --- a/src/main/java/com/orbekk/protobuf/RpcChannel.java +++ b/src/main/java/com/orbekk/protobuf/RpcChannel.java @@ -7,6 +7,8 @@ import java.net.UnknownHostException; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; @@ -28,19 +30,20 @@ public class RpcChannel extends Thread implements private Map<Long, RpcChannel.OngoingRequest> rpcs = Collections.synchronizedMap( new HashMap<Long, RpcChannel.OngoingRequest>()); + private BlockingQueue<Socket> sockets = new LinkedBlockingQueue<Socket>(); private static class OngoingRequest implements Closeable { long id; - RpcController controller; + Rpc rpc; RpcCallback<Message> done; Message responsePrototype; Map<Long, RpcChannel.OngoingRequest> rpcs; - public OngoingRequest(long id, RpcController controller, + public OngoingRequest(long id, Rpc rpc, RpcCallback<Message> done, Message responsePrototype, Map<Long, RpcChannel.OngoingRequest> rpcs) { this.id = id; - this.controller = controller; + this.rpc = rpc; this.done = done; this.responsePrototype = responsePrototype; this.rpcs = rpcs; @@ -67,10 +70,8 @@ public class RpcChannel extends Thread implements if (socket == null || socket.isClosed()) { try { logger.info("Creating new socket to " + host + ":" + port); - synchronized (this) { - socket = new Socket(host, port); - notify(); - } + socket = new Socket(host, port); + sockets.add(socket); } catch (UnknownHostException e) { return null; } catch (IOException e) { @@ -82,17 +83,18 @@ public class RpcChannel extends Thread implements return socket; } - private Rpc.Request createRequest(Descriptors.MethodDescriptor method, + private Data.Request createRequest(Descriptors.MethodDescriptor method, RpcController controller, Message requestMessage, Message responsePrototype, RpcCallback<Message> done) { long id = nextId.incrementAndGet(); - OngoingRequest ongoingRequest = new OngoingRequest(id, controller, + Rpc rpc = (Rpc)controller; + OngoingRequest ongoingRequest = new OngoingRequest(id, rpc, done, responsePrototype, rpcs); rpcs.put(id, ongoingRequest); - Rpc.Request request = Rpc.Request.newBuilder() + Data.Request request = Data.Request.newBuilder() .setRequestId(id) .setFullServiceName(method.getService().getFullName()) .setMethodName(method.getName()) @@ -102,12 +104,13 @@ public class RpcChannel extends Thread implements return request; } - private void finishRequest(Rpc.Response response) { + private void finishRequest(Data.Response response) { OngoingRequest ongoingRequest = rpcs.remove(response.getRequestId()); if (ongoingRequest != null) { try { Message responsePb = ongoingRequest.responsePrototype.toBuilder() .mergeFrom(response.getResponseProto()).build(); + ongoingRequest.rpc.readFrom(response); ongoingRequest.done.run(responsePb); } catch (InvalidProtocolBufferException e) { throw new AssertionError("Should fail here."); @@ -122,7 +125,7 @@ public class RpcChannel extends Thread implements Message responsePrototype, RpcCallback<Message> done) { try { - Rpc.Request request = createRequest(method, controller, + Data.Request request = createRequest(method, controller, requestMessage, responsePrototype, done); Socket socket = getSocket(); request.writeDelimitedTo(socket.getOutputStream()); @@ -135,13 +138,17 @@ public class RpcChannel extends Thread implements try { logger.info("Handling responses to socket " + socket); while (!socket.isClosed()) { - Rpc.Response response; - response = Rpc.Response.parseDelimitedFrom( + Data.Response response; + response = Data.Response.parseDelimitedFrom( socket.getInputStream()); finishRequest(response); } } catch (IOException e) { - // Breaks the loop. + if (!rpcs.isEmpty()) { + logger.log(Level.WARNING, "IO Error. Canceling " + + rpcs.size() + " requests.", e); + cancelAllRpcs(); + } } finally { if (socket != null && !socket.isClosed()) { try { @@ -153,17 +160,21 @@ public class RpcChannel extends Thread implements } } + private void cancelAllRpcs() { + synchronized (rpcs) { + for (OngoingRequest request : rpcs.values()) { + request.rpc.setFailed("connection closed"); + request.done.run(null); + } + rpcs.clear(); + } + } + public void run() { while (!Thread.interrupted()) { try { - synchronized (this) { - if (socket == null) { - wait(); - } - } - if (socket != null) { - handleResponses(socket); - } + Socket socket = sockets.take(); + handleResponses(socket); } catch (InterruptedException e) { // Interrupts handled by outer loop } |