diff options
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/java/com/orbekk/protobuf/NewRpcChannel.java | 196 | 
1 files changed, 196 insertions, 0 deletions
| diff --git a/src/main/java/com/orbekk/protobuf/NewRpcChannel.java b/src/main/java/com/orbekk/protobuf/NewRpcChannel.java new file mode 100644 index 0000000..1e009fa --- /dev/null +++ b/src/main/java/com/orbekk/protobuf/NewRpcChannel.java @@ -0,0 +1,196 @@ +package com.orbekk.protobuf; + +import java.io.IOException; +import java.net.Socket; +import java.net.UnknownHostException; +import java.util.Map; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; + +import com.google.protobuf.Descriptors.MethodDescriptor; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import com.google.protobuf.RpcCallback; +import com.google.protobuf.RpcController; + +public class NewRpcChannel implements com.google.protobuf.RpcChannel { +    public static int NUM_CONCURRENT_REQUESTS = 50; +    private static final Logger logger = +            Logger.getLogger(RpcChannel.class.getName()); +    private final String host; +    private final int port; +    private final AtomicLong nextId = new AtomicLong(0); +    private final ExecutorService responseHandlerPool = +            Executors.newSingleThreadExecutor(); +    private final BlockingQueue<Data.Request> requestQueue = +            new ArrayBlockingQueue(NUM_CONCURRENT_REQUESTS); +    private volatile Socket socket = null; +    private final ConcurrentHashMap<Long, RequestMetadata> ongoingRequests = +            new ConcurrentHashMap<Long, RequestMetadata>(); +     +    class RequestMetadata { +        public final long id; +        public final Rpc rpc; +        public final RpcCallback<Message> done; +        public final Message responsePrototype; +         +        public RequestMetadata(long id, Rpc rpc, RpcCallback<Message> done, +                Message responsePrototype) { +            this.id = id; +            this.rpc = rpc; +            this.done = done; +            this.responsePrototype = responsePrototype; +        } +    } +     +    class ResponseHandler implements Runnable { +        private final Data.Response response; +         +        public ResponseHandler(Data.Response response) { +            this.response = response; +        } +         +        @Override public void run() { +            handleResponse(response); +        } +    } +     +    class OutgoingHandler implements Runnable { +        private final Socket socket; +        private final BlockingQueue<Data.Request> requests; + +        public OutgoingHandler(Socket socket, +                BlockingQueue<Data.Request> requests) { +            this.socket = socket; +            this.requests = requests; +        } +         +        @Override public void run() { +            for (;;) { +                try { +                    Data.Request request = requests.take(); +                    request.writeDelimitedTo(socket.getOutputStream()); +                } catch (InterruptedException e) { +                    tryCloseSocket(socket); +                    return; +                } catch (IOException e) { +                    tryCloseSocket(socket); +                    return; +                } +            } +        } + + +    } +     +    class IncomingHandler implements Runnable { +        private Socket socket; +        private ExecutorService responseHandlerPool; +         +        public IncomingHandler(Socket socket, +                ExecutorService responseHandlerPool) { +            this.socket = socket; +            this.responseHandlerPool = responseHandlerPool; +        } +         +        @Override public void run() { +            for (;;) { +                try { +                    Data.Response response = Data.Response +                            .parseDelimitedFrom(socket.getInputStream()); +                    responseHandlerPool.execute(new ResponseHandler(response)); +                } catch (IOException e) { +                    responseHandlerPool.shutdown(); +                    tryCloseSocket(socket); +                    return; +                } +            } +        } +    } +     +    public static NewRpcChannel create(String host, int port) +            throws UnknownHostException, IOException { +        NewRpcChannel channel = new NewRpcChannel(host, port); +        channel.start(); +        return channel; +    } +     +    NewRpcChannel(String host, int port) { +        this.host = host; +        this.port = port; +    } +     +    public void start() throws UnknownHostException, IOException { +        socket = new Socket(host, port); +        OutgoingHandler outgoing = new OutgoingHandler(socket, +                requestQueue); +        IncomingHandler incoming = new IncomingHandler(socket, +                responseHandlerPool); +         +        new Thread(outgoing, "RequestSender: " + host + ":" + port).start(); +        new Thread(incoming, "RequestReceiver: " + host + ":" + port).start(); +    } + +    private void tryCloseSocket(Socket socket) { +        try { +            socket.close(); +        } catch (IOException e1) { +            logger.log(Level.WARNING, +                    "Unable to close socket " + socket, +                    e1); +        } +    } + +    @Override +    public void callMethod(MethodDescriptor method, +            RpcController rpc, Message requestMessage, +            Message responsePrototype, +            RpcCallback<Message> done) { +        long id = nextId.incrementAndGet(); +        Rpc rpc_ = (Rpc) rpc; +        RequestMetadata request_ = new RequestMetadata(id, rpc_, done, +                responsePrototype); +        ongoingRequests.put(id, request_); +         +        Data.Request requestData = Data.Request.newBuilder() +                .setRequestId(id) +                .setFullServiceName(method.getService().getFullName()) +                .setMethodName(method.getName()) +                .setRequestProto(requestMessage.toByteString()) +                .build(); +         +        try { +            requestQueue.put(requestData); +        } catch (InterruptedException e) { +            cancelRequest(request_, "channel closed"); +        } +    } +     +    private void cancelRequest(RequestMetadata request, String reason) { +        throw new IllegalStateException("Not implemented"); +    } +     +    private void handleResponse(Data.Response response) { +        RequestMetadata request = +                ongoingRequests.remove(response.getRequestId()); +        if (request == null) { +            logger.info("Unknown request. Possible timeout?" + response); +            return; +        } +        try { +            Message responsePb = request.responsePrototype.toBuilder() +                    .mergeFrom(response.getResponseProto()).build(); +            request.rpc.readFrom(response); +            request.done.run(responsePb); +        } catch (InvalidProtocolBufferException e) { +            cancelRequest(request, "invalid response from server"); +        } +    } +} + | 
