summaryrefslogtreecommitdiff
path: root/src/main/java/com/orbekk/protobuf/SimpleProtobufServer.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/com/orbekk/protobuf/SimpleProtobufServer.java')
-rw-r--r--src/main/java/com/orbekk/protobuf/SimpleProtobufServer.java91
1 files changed, 51 insertions, 40 deletions
diff --git a/src/main/java/com/orbekk/protobuf/SimpleProtobufServer.java b/src/main/java/com/orbekk/protobuf/SimpleProtobufServer.java
index 690da81..f597556 100644
--- a/src/main/java/com/orbekk/protobuf/SimpleProtobufServer.java
+++ b/src/main/java/com/orbekk/protobuf/SimpleProtobufServer.java
@@ -1,27 +1,32 @@
package com.orbekk.protobuf;
-import java.util.logging.Level;
-import java.util.logging.Logger;
-import java.io.OutputStream;
import java.io.IOException;
+import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
-import java.util.Scanner;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+import java.util.HashSet;
+
+import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
-import com.google.protobuf.Service;
-import com.google.protobuf.RpcController;
import com.google.protobuf.RpcCallback;
-import com.google.protobuf.Descriptors;
-import java.util.Map;
-import java.util.HashMap;
+import com.google.protobuf.Service;
public class SimpleProtobufServer extends Thread {
private static Logger logger = Logger.getLogger(
SimpleProtobufServer.class.getName());
private ServerSocket serverSocket;
+ private Set<Socket> activeClientSockets =
+ Collections.synchronizedSet(new HashSet<Socket>());
private Map<String, Service> registeredServices =
- new HashMap<String, Service>();
+ Collections.synchronizedMap(
+ new HashMap<String, Service>());
public static SimpleProtobufServer create(int port) {
try {
@@ -44,7 +49,7 @@ public class SimpleProtobufServer extends Thread {
return serverSocket.getLocalPort();
}
- public synchronized void registerService(Service service) {
+ public void registerService(Service service) {
String serviceName = service.getDescriptorForType().getFullName();
if (registeredServices.containsKey(serviceName)) {
logger.warning("Already registered service with this name.");
@@ -53,25 +58,31 @@ public class SimpleProtobufServer extends Thread {
registeredServices.put(serviceName, service);
}
- public void handleRequest(Rpc.Request request, OutputStream out)
+ public void handleRequest(Data.Request request, OutputStream out)
throws IOException {
- Service service = registeredServices.get(request.getFullServiceName());
- final Rpc.Response.Builder response = Rpc.Response.newBuilder();
+ final Service service = registeredServices.get(request.getFullServiceName());
+ Rpc rpc = new Rpc();
+ final Data.Response.Builder response = Data.Response.newBuilder();
response.setRequestId(request.getRequestId());
if (service == null) {
- response.setError(Rpc.Response.Error.UNKNOWN_SERVICE);
+ response.setError(Data.Response.RpcError.UNKNOWN_SERVICE);
response.build().writeDelimitedTo(out);
return;
}
- Descriptors.MethodDescriptor method = service.getDescriptorForType()
+ final Descriptors.MethodDescriptor method = service.getDescriptorForType()
.findMethodByName(request.getMethodName());
if (method == null) {
- response.setError(Rpc.Response.Error.UNKNOWN_METHOD);
+ response.setError(Data.Response.RpcError.UNKNOWN_METHOD);
response.build().writeDelimitedTo(out);
return;
}
RpcCallback<Message> doneCallback = new RpcCallback<Message>() {
@Override public void run(Message responseMessage) {
+ if (responseMessage == null) {
+ responseMessage = service
+ .getResponsePrototype(method)
+ .toBuilder().build();
+ }
response.setResponseProto(responseMessage.toByteString());
}
};
@@ -79,16 +90,18 @@ public class SimpleProtobufServer extends Thread {
.toBuilder()
.mergeFrom(request.getRequestProto())
.build();
- service.callMethod(method, null, requestMessage, doneCallback);
+ service.callMethod(method, rpc, requestMessage, doneCallback);
+ rpc.writeTo(response);
response.build().writeDelimitedTo(out);
}
private void handleConnection(final Socket connection) {
new Thread(new Runnable() {
@Override public void run() {
+ activeClientSockets.add(connection);
try {
while (true) {
- Rpc.Request r1 = Rpc.Request.parseDelimitedFrom(
+ Data.Request r1 = Data.Request.parseDelimitedFrom(
connection.getInputStream());
if (r1 == null) {
try {
@@ -106,10 +119,28 @@ public class SimpleProtobufServer extends Thread {
connection.close();
} catch (IOException e) {
}
+ activeClientSockets.remove(connection);
}
}
}).start();
}
+
+ @Override public void interrupt() {
+ super.interrupt();
+ for (Socket socket : activeClientSockets) {
+ try {
+ socket.close();
+ } catch (IOException e) {
+ logger.log(Level.WARNING, "Error closing socket.", e);
+ }
+ }
+
+ try {
+ serverSocket.close();
+ } catch (IOException e) {
+ logger.log(Level.WARNING, "Error closing socket.", e);
+ }
+ }
public void run() {
logger.info("Running server on port " + serverSocket.getLocalPort());
@@ -122,26 +153,6 @@ public class SimpleProtobufServer extends Thread {
e);
}
}
+ logger.info("Server exits.");
}
-
-// public static void main(String[] args) {
-// SimpleProtobufServer server = SimpleProtobufServer.create(10000);
-// Test.TestService testService = new Test.TestService() {
-// @Override public void run(RpcController controller,
-// Test.TestRequest request,
-// RpcCallback<Test.TestResponse> done) {
-// System.out.println("Hello from TestService!");
-// done.run(Test.TestResponse.newBuilder()
-// .setId("Hello from server.")
-// .build());
-// }
-// };
-// server.registerService(testService);
-// server.start();
-// try {
-// server.join();
-// } catch (InterruptedException e) {
-// System.out.println("Stopped.");
-// }
-// }
}