From 0b1df4e0aeb98145541a243f2e2d0dbbebc05278 Mon Sep 17 00:00:00 2001
From: Bensong Liu <bensl@microsoft.com>
Date: Wed, 29 Jul 2020 16:49:05 +0800
Subject: [PATCH] PlainInbound PlainOutbound done

---
 src/protocols/base.hpp  |   4 +-
 src/protocols/plain.hpp | 112 +++++++++++++++++++++++++++++++++++++++-
 src/utils.hpp           |   5 +-
 3 files changed, 114 insertions(+), 7 deletions(-)

diff --git a/src/protocols/base.hpp b/src/protocols/base.hpp
index 09a9bed..2a88f1a 100644
--- a/src/protocols/base.hpp
+++ b/src/protocols/base.hpp
@@ -38,7 +38,7 @@ namespace Protocols {
 		virtual void listenForever(BaseInbound *previousHop) = 0;
 
 		// Inbound.listenForever MUST initialize this field. 
-		sockfd_t ipcPipe = -1;
+		volatile sockfd_t ipcPipe = -1;
 	};
 
 	struct BaseInbound : rlib::noncopyable {
@@ -58,7 +58,7 @@ namespace Protocols {
 		virtual void listenForever(BaseOutbound *nextHop) = 0;
 
 		// Inbound.listenForever MUST initialize this field. 
-		sockfd_t ipcPipe = -1;
+		volatile sockfd_t ipcPipe = -1;
 	};
 
 	// TODO: PIPE only works on linux epoll. The windows epoll only works on SOCKET. 
diff --git a/src/protocols/plain.hpp b/src/protocols/plain.hpp
index 2bcc7bc..71194bc 100644
--- a/src/protocols/plain.hpp
+++ b/src/protocols/plain.hpp
@@ -7,7 +7,31 @@
 #include <utils.hpp>
 #include <common.hpp>
 
+#if RLIB_OS_ID == OS_LINUX
+#include <linux/sched.h>
+#endif
+
 namespace Protocols {
+	template <typename ClientIdT, typename ServerIdT>
+	struct InjectiveConnectionMapping {
+	    std::unordered_map<ClientIdT, ServerIdT> client2server;
+	    std::unordered_map<ServerIdT, ClientIdT> server2client;
+		void add(const ClientIdT& clientId, const ServerIdT& serverId) {
+			client2server[clientId] = serverId;
+			server2client[serverId] = clientId;
+		}
+		void del(const ClientIdT& clientId) {
+			const auto& serverId = client2server[clientId];
+			server2client.erase(serverId);
+			client2server.erase(clientId);
+		}
+		std::enable_if_t<! std::is_same_v<ClientIdT, ServerIdT>, void> del(const ServerIdT& serverId) {
+			const auto& clientId = server2client[serverId];
+			client2server.erase(clientId);
+			server2client.erase(serverId);
+		}
+	};
+
 	class PlainInbound : public BaseInbound {
 	public:
 		using BaseInbound::BaseInbound;
@@ -26,6 +50,7 @@ namespace Protocols {
 		virtual void listenForever(BaseOutbound* nextHop) override {
 			std::tie(this->ipcPipe, nextHop->ipcPipe) = mk_tcp_pipe();
 
+			// ----------------------- Initialization / Setup ------------------------------
 			auto listenFd = rlib::quick_listen(listenAddr, listenPort, true);
 			rlib_defer([&] {rlib::sockIO::close_ex(listenFd);});
 
@@ -45,7 +70,7 @@ namespace Protocols {
 					auto targetClientId = rlib::sockIO::recv_msg(activeFd);
 					auto msg = rlib::sockIO::recv_msg(activeFd);
 
-					auto clientAddr = ConnectionMapping::parseClientId(targetClientId);
+					auto clientAddr = ClientIdUtils::parseClientId(targetClientId);
 					auto status = sendto(udpSenderSocket, msg.data(), msg.size(), 0, &clientAddr.addr, clientAddr.len);
 					dynamic_assert(status != -1, "sendto failed");
 				}
@@ -54,7 +79,7 @@ namespace Protocols {
 					auto msgLength = recvfrom(activeFd, msgBuffer.data(), msgBuffer.size(), 0, &clientAddr.addr, &clientAddr.len);
 					dynamic_assert(msgLength != -1, "recvfrom failed");
 
-					forwardMessageToOutbound(msgBuffer.substr(msgLength), ConnectionMapping::makeClientId(clientAddr));
+					forwardMessageToOutbound(msgBuffer.substr(msgLength), ClientIdUtils::makeClientId(clientAddr));
 				}
 			};
 
@@ -81,9 +106,92 @@ namespace Protocols {
 	public:
 		using BaseOutbound::BaseOutbound;
 		virtual void loadConfig(string config) override {
+			auto ar = rlib::string(config).split('@'); // Also works for ipv6.
+			if (ar.size() != 3)
+				throw std::invalid_argument("Wrong parameter string for protocol 'plain'. Example:    plain@fe00:1e10:ce95:1@10809");
+			serverAddr = ar[1];
+			serverPort = ar[2].as<uint16_t>();
+		}
+		// InboundThread calls this function. Check the mapping between senderId and serverConn, wake up listenThread, and deliver the msg. 
+		virtual void forwardMessageToInbound(string binaryMessage, string senderId) override {
+			rlib::sockIO::send_msg(ipcPipe, senderId);
+			rlib::sockIO::send_msg(ipcPipe, binaryMessage);
+		}
+
+		// Listen the PIPE. handleMessage will wake up this thread from epoll.
+		// Also listen the connection fileDescriptors.
+		virtual void listenForever(BaseInbound* previousHop) override {
+
+			// ----------------------- Initialization / Setup ------------------------------
+			auto epollFd = epoll_create1(0);
+			dynamic_assert((int)epollFd != -1, "epoll_create1 failed");
 
+			while (ipcPipe == -1) {
+				; // Sleep until InboundThread initializes the pipe.
+#ifdef cond_resched
+				cond_resched();
+#endif
+			}
+			epoll_add_fd(epollFd, ipcPipe);
+
+			// ----------------------- Process an event ------------------------------
+			std::string msgBuffer(DGRAM_BUFFER_SIZE, '\0');
+			auto onEvent = [&](auto activeFd) {
+				if (activeFd == ipcPipe) {
+					// Inbound gave me a message to forward! Send it. 
+					auto targetClientId = rlib::sockIO::recv_msg(activeFd);
+					auto msg = rlib::sockIO::recv_msg(activeFd);
+
+					auto iter = connectionMap.client2server.find(targetClientId);
+					if (iter != connectionMap.client2server.end()) {
+						// Map contains ClientId.
+						rlib::sockIO::quick_send(iter->second, msg); // udp
+					}
+					else {
+						// This clientId is new. I don't know how to listen many sockets for response, so I just issue `connect` just like TCP does. 
+						auto connFd = rlib::quick_connect(serverAddr, serverPort, true);
+						epoll_add_fd(epollFd, connFd);
+						connectionMap.add(targetClientId, connFd);
+						rlib::sockIO::quick_send(connFd, msg); // udp
+					}
+				}
+				else {
+					// Message from some connFd. Read and forward it. 
+					auto status = recv(activeFd, msgBuffer.data(), msgBuffer.size(), 0);
+					dynamic_assert(status != -1, "recv failed");
+					if (status == 0) {
+						// TODO: close the socket, and notify Inbound to destory data structures.
+						epoll_del_fd(epollFd, activeFd);
+						connectionMap.del(activeFd);
+						rlib::sockIO::close_ex(activeFd);
+					}
+
+					dynamic_assert(connectionMap.server2client.count(activeFd) > 0, "connectionMap MUST contain server connfd. ");
+
+					forwardMessageToInbound(msgBuffer.substr(0, status), connectionMap.server2client.at(activeFd));
+				}
+			};
+
+			// ----------------------- listener main loop ------------------------------
+			epoll_event events[EPOLL_MAX_EVENTS];
+			rlog.info("PlainOutbound to {}:{} is up, listening for request ...", serverAddr, serverPort);
+			while (true) {
+				auto nfds = epoll_wait(epollFd, events, EPOLL_MAX_EVENTS, -1);
+				dynamic_assert(nfds != -1, "epoll_wait failed");
+				
+				for (auto cter = 0; cter < nfds; ++cter) {
+					onEvent(events[cter].data.fd);
+				}
+			}
 		}
 
+	private:
+		string serverAddr;
+		uint16_t serverPort;
+
+
+		InjectiveConnectionMapping<string, sockfd_t> connectionMap;
+
 	};
 }
 
diff --git a/src/utils.hpp b/src/utils.hpp
index f124673..812b5ad 100644
--- a/src/utils.hpp
+++ b/src/utils.hpp
@@ -26,9 +26,7 @@ struct SockAddr {
     socklen_t len;
 };
 
-struct ConnectionMapping {
-    std::unordered_map<string, fd_t> client2server;
-    std::unordered_multimap<fd_t, string> server2client;
+struct ClientIdUtils {
     static string makeClientId(const SockAddr &osStruct) {
         // ClientId is a binary string.
         static_assert(sizeof(osStruct) == sizeof(SockAddr), "error: programming error detected.");
@@ -93,5 +91,6 @@ inline auto mk_tcp_pipe() {
 } while(false)
 
 
+
 #endif
 
-- 
GitLab