summaryrefslogtreecommitdiff
path: root/src/main/java/com/orbekk/protobuf/RpcChannel.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/com/orbekk/protobuf/RpcChannel.java')
-rw-r--r--src/main/java/com/orbekk/protobuf/RpcChannel.java57
1 files changed, 34 insertions, 23 deletions
diff --git a/src/main/java/com/orbekk/protobuf/RpcChannel.java b/src/main/java/com/orbekk/protobuf/RpcChannel.java
index 56b54c2..94ab8a5 100644
--- a/src/main/java/com/orbekk/protobuf/RpcChannel.java
+++ b/src/main/java/com/orbekk/protobuf/RpcChannel.java
@@ -7,6 +7,8 @@ import java.net.UnknownHostException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -28,19 +30,20 @@ public class RpcChannel extends Thread implements
private Map<Long, RpcChannel.OngoingRequest> rpcs =
Collections.synchronizedMap(
new HashMap<Long, RpcChannel.OngoingRequest>());
+ private BlockingQueue<Socket> sockets = new LinkedBlockingQueue<Socket>();
private static class OngoingRequest implements Closeable {
long id;
- RpcController controller;
+ Rpc rpc;
RpcCallback<Message> done;
Message responsePrototype;
Map<Long, RpcChannel.OngoingRequest> rpcs;
- public OngoingRequest(long id, RpcController controller,
+ public OngoingRequest(long id, Rpc rpc,
RpcCallback<Message> done, Message responsePrototype,
Map<Long, RpcChannel.OngoingRequest> rpcs) {
this.id = id;
- this.controller = controller;
+ this.rpc = rpc;
this.done = done;
this.responsePrototype = responsePrototype;
this.rpcs = rpcs;
@@ -67,10 +70,8 @@ public class RpcChannel extends Thread implements
if (socket == null || socket.isClosed()) {
try {
logger.info("Creating new socket to " + host + ":" + port);
- synchronized (this) {
- socket = new Socket(host, port);
- notify();
- }
+ socket = new Socket(host, port);
+ sockets.add(socket);
} catch (UnknownHostException e) {
return null;
} catch (IOException e) {
@@ -82,17 +83,18 @@ public class RpcChannel extends Thread implements
return socket;
}
- private Rpc.Request createRequest(Descriptors.MethodDescriptor method,
+ private Data.Request createRequest(Descriptors.MethodDescriptor method,
RpcController controller,
Message requestMessage,
Message responsePrototype,
RpcCallback<Message> done) {
long id = nextId.incrementAndGet();
- OngoingRequest ongoingRequest = new OngoingRequest(id, controller,
+ Rpc rpc = (Rpc)controller;
+ OngoingRequest ongoingRequest = new OngoingRequest(id, rpc,
done, responsePrototype, rpcs);
rpcs.put(id, ongoingRequest);
- Rpc.Request request = Rpc.Request.newBuilder()
+ Data.Request request = Data.Request.newBuilder()
.setRequestId(id)
.setFullServiceName(method.getService().getFullName())
.setMethodName(method.getName())
@@ -102,12 +104,13 @@ public class RpcChannel extends Thread implements
return request;
}
- private void finishRequest(Rpc.Response response) {
+ private void finishRequest(Data.Response response) {
OngoingRequest ongoingRequest = rpcs.remove(response.getRequestId());
if (ongoingRequest != null) {
try {
Message responsePb = ongoingRequest.responsePrototype.toBuilder()
.mergeFrom(response.getResponseProto()).build();
+ ongoingRequest.rpc.readFrom(response);
ongoingRequest.done.run(responsePb);
} catch (InvalidProtocolBufferException e) {
throw new AssertionError("Should fail here.");
@@ -122,7 +125,7 @@ public class RpcChannel extends Thread implements
Message responsePrototype,
RpcCallback<Message> done) {
try {
- Rpc.Request request = createRequest(method, controller,
+ Data.Request request = createRequest(method, controller,
requestMessage, responsePrototype, done);
Socket socket = getSocket();
request.writeDelimitedTo(socket.getOutputStream());
@@ -135,13 +138,17 @@ public class RpcChannel extends Thread implements
try {
logger.info("Handling responses to socket " + socket);
while (!socket.isClosed()) {
- Rpc.Response response;
- response = Rpc.Response.parseDelimitedFrom(
+ Data.Response response;
+ response = Data.Response.parseDelimitedFrom(
socket.getInputStream());
finishRequest(response);
}
} catch (IOException e) {
- // Breaks the loop.
+ if (!rpcs.isEmpty()) {
+ logger.log(Level.WARNING, "IO Error. Canceling " +
+ rpcs.size() + " requests.", e);
+ cancelAllRpcs();
+ }
} finally {
if (socket != null && !socket.isClosed()) {
try {
@@ -153,17 +160,21 @@ public class RpcChannel extends Thread implements
}
}
+ private void cancelAllRpcs() {
+ synchronized (rpcs) {
+ for (OngoingRequest request : rpcs.values()) {
+ request.rpc.setFailed("connection closed");
+ request.done.run(null);
+ }
+ rpcs.clear();
+ }
+ }
+
public void run() {
while (!Thread.interrupted()) {
try {
- synchronized (this) {
- if (socket == null) {
- wait();
- }
- }
- if (socket != null) {
- handleResponses(socket);
- }
+ Socket socket = sockets.take();
+ handleResponses(socket);
} catch (InterruptedException e) {
// Interrupts handled by outer loop
}