diff options
Diffstat (limited to 'src/main/java/com/orbekk/protobuf/RpcChannel.java')
-rw-r--r-- | src/main/java/com/orbekk/protobuf/RpcChannel.java | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/src/main/java/com/orbekk/protobuf/RpcChannel.java b/src/main/java/com/orbekk/protobuf/RpcChannel.java new file mode 100644 index 0000000..56b54c2 --- /dev/null +++ b/src/main/java/com/orbekk/protobuf/RpcChannel.java @@ -0,0 +1,182 @@ +package com.orbekk.protobuf; + +import java.io.Closeable; +import java.io.IOException; +import java.net.Socket; +import java.net.UnknownHostException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import com.google.protobuf.RpcCallback; +import com.google.protobuf.RpcController; + +public class RpcChannel extends Thread implements + com.google.protobuf.RpcChannel { + static final Logger logger = + Logger.getLogger(RpcChannel.class.getName()); + private String host; + private int port; + private volatile Socket socket = null; + private AtomicLong nextId = new AtomicLong(0); + private Map<Long, RpcChannel.OngoingRequest> rpcs = + Collections.synchronizedMap( + new HashMap<Long, RpcChannel.OngoingRequest>()); + + private static class OngoingRequest implements Closeable { + long id; + RpcController controller; + RpcCallback<Message> done; + Message responsePrototype; + Map<Long, RpcChannel.OngoingRequest> rpcs; + + public OngoingRequest(long id, RpcController controller, + RpcCallback<Message> done, Message responsePrototype, + Map<Long, RpcChannel.OngoingRequest> rpcs) { + this.id = id; + this.controller = controller; + this.done = done; + this.responsePrototype = responsePrototype; + this.rpcs = rpcs; + } + + @Override + public void close() throws IOException { + throw new AssertionError("Not implemented"); + } + } + + public static RpcChannel create(String host, int port) { + RpcChannel channel = new RpcChannel(host, port); + channel.start(); + return channel; + } + + private RpcChannel(String host, int port) { + this.host = host; + this.port = port; + } + + private Socket getSocket() { + if (socket == null || socket.isClosed()) { + try { + logger.info("Creating new socket to " + host + ":" + port); + synchronized (this) { + socket = new Socket(host, port); + notify(); + } + } catch (UnknownHostException e) { + return null; + } catch (IOException e) { + logger.log(Level.WARNING, + "Could not establish connection.", e); + return null; + } + } + return socket; + } + + private Rpc.Request createRequest(Descriptors.MethodDescriptor method, + RpcController controller, + Message requestMessage, + Message responsePrototype, + RpcCallback<Message> done) { + long id = nextId.incrementAndGet(); + OngoingRequest ongoingRequest = new OngoingRequest(id, controller, + done, responsePrototype, rpcs); + rpcs.put(id, ongoingRequest); + + Rpc.Request request = Rpc.Request.newBuilder() + .setRequestId(id) + .setFullServiceName(method.getService().getFullName()) + .setMethodName(method.getName()) + .setRequestProto(requestMessage.toByteString()) + .build(); + + return request; + } + + private void finishRequest(Rpc.Response response) { + OngoingRequest ongoingRequest = rpcs.remove(response.getRequestId()); + if (ongoingRequest != null) { + try { + Message responsePb = ongoingRequest.responsePrototype.toBuilder() + .mergeFrom(response.getResponseProto()).build(); + ongoingRequest.done.run(responsePb); + } catch (InvalidProtocolBufferException e) { + throw new AssertionError("Should fail here."); + } + } + } + + @Override public void callMethod( + Descriptors.MethodDescriptor method, + RpcController controller, + Message requestMessage, + Message responsePrototype, + RpcCallback<Message> done) { + try { + Rpc.Request request = createRequest(method, controller, + requestMessage, responsePrototype, done); + Socket socket = getSocket(); + request.writeDelimitedTo(socket.getOutputStream()); + } catch (IOException e) { + throw new AssertionError("Should return error."); + } + } + + private void handleResponses(Socket socket) { + try { + logger.info("Handling responses to socket " + socket); + while (!socket.isClosed()) { + Rpc.Response response; + response = Rpc.Response.parseDelimitedFrom( + socket.getInputStream()); + finishRequest(response); + } + } catch (IOException e) { + // Breaks the loop. + } finally { + if (socket != null && !socket.isClosed()) { + try { + socket.close(); + } catch (IOException e) { + // Socket is closed. + } + } + } + } + + public void run() { + while (!Thread.interrupted()) { + try { + synchronized (this) { + if (socket == null) { + wait(); + } + } + if (socket != null) { + handleResponses(socket); + } + } catch (InterruptedException e) { + // Interrupts handled by outer loop + } + } + } + + public void close() { + if (socket != null) { + try { + socket.close(); + } catch (IOException e) { + logger.info("Error closing socket."); + } + } + } +}
\ No newline at end of file |