From 0b1df4e0aeb98145541a243f2e2d0dbbebc05278 Mon Sep 17 00:00:00 2001 From: Bensong Liu <bensl@microsoft.com> Date: Wed, 29 Jul 2020 16:49:05 +0800 Subject: [PATCH] PlainInbound PlainOutbound done --- src/protocols/base.hpp | 4 +- src/protocols/plain.hpp | 112 +++++++++++++++++++++++++++++++++++++++- src/utils.hpp | 5 +- 3 files changed, 114 insertions(+), 7 deletions(-) diff --git a/src/protocols/base.hpp b/src/protocols/base.hpp index 09a9bed..2a88f1a 100644 --- a/src/protocols/base.hpp +++ b/src/protocols/base.hpp @@ -38,7 +38,7 @@ namespace Protocols { virtual void listenForever(BaseInbound *previousHop) = 0; // Inbound.listenForever MUST initialize this field. - sockfd_t ipcPipe = -1; + volatile sockfd_t ipcPipe = -1; }; struct BaseInbound : rlib::noncopyable { @@ -58,7 +58,7 @@ namespace Protocols { virtual void listenForever(BaseOutbound *nextHop) = 0; // Inbound.listenForever MUST initialize this field. - sockfd_t ipcPipe = -1; + volatile sockfd_t ipcPipe = -1; }; // TODO: PIPE only works on linux epoll. The windows epoll only works on SOCKET. diff --git a/src/protocols/plain.hpp b/src/protocols/plain.hpp index 2bcc7bc..71194bc 100644 --- a/src/protocols/plain.hpp +++ b/src/protocols/plain.hpp @@ -7,7 +7,31 @@ #include <utils.hpp> #include <common.hpp> +#if RLIB_OS_ID == OS_LINUX +#include <linux/sched.h> +#endif + namespace Protocols { + template <typename ClientIdT, typename ServerIdT> + struct InjectiveConnectionMapping { + std::unordered_map<ClientIdT, ServerIdT> client2server; + std::unordered_map<ServerIdT, ClientIdT> server2client; + void add(const ClientIdT& clientId, const ServerIdT& serverId) { + client2server[clientId] = serverId; + server2client[serverId] = clientId; + } + void del(const ClientIdT& clientId) { + const auto& serverId = client2server[clientId]; + server2client.erase(serverId); + client2server.erase(clientId); + } + std::enable_if_t<! std::is_same_v<ClientIdT, ServerIdT>, void> del(const ServerIdT& serverId) { + const auto& clientId = server2client[serverId]; + client2server.erase(clientId); + server2client.erase(serverId); + } + }; + class PlainInbound : public BaseInbound { public: using BaseInbound::BaseInbound; @@ -26,6 +50,7 @@ namespace Protocols { virtual void listenForever(BaseOutbound* nextHop) override { std::tie(this->ipcPipe, nextHop->ipcPipe) = mk_tcp_pipe(); + // ----------------------- Initialization / Setup ------------------------------ auto listenFd = rlib::quick_listen(listenAddr, listenPort, true); rlib_defer([&] {rlib::sockIO::close_ex(listenFd);}); @@ -45,7 +70,7 @@ namespace Protocols { auto targetClientId = rlib::sockIO::recv_msg(activeFd); auto msg = rlib::sockIO::recv_msg(activeFd); - auto clientAddr = ConnectionMapping::parseClientId(targetClientId); + auto clientAddr = ClientIdUtils::parseClientId(targetClientId); auto status = sendto(udpSenderSocket, msg.data(), msg.size(), 0, &clientAddr.addr, clientAddr.len); dynamic_assert(status != -1, "sendto failed"); } @@ -54,7 +79,7 @@ namespace Protocols { auto msgLength = recvfrom(activeFd, msgBuffer.data(), msgBuffer.size(), 0, &clientAddr.addr, &clientAddr.len); dynamic_assert(msgLength != -1, "recvfrom failed"); - forwardMessageToOutbound(msgBuffer.substr(msgLength), ConnectionMapping::makeClientId(clientAddr)); + forwardMessageToOutbound(msgBuffer.substr(msgLength), ClientIdUtils::makeClientId(clientAddr)); } }; @@ -81,9 +106,92 @@ namespace Protocols { public: using BaseOutbound::BaseOutbound; virtual void loadConfig(string config) override { + auto ar = rlib::string(config).split('@'); // Also works for ipv6. + if (ar.size() != 3) + throw std::invalid_argument("Wrong parameter string for protocol 'plain'. Example: plain@fe00:1e10:ce95:1@10809"); + serverAddr = ar[1]; + serverPort = ar[2].as<uint16_t>(); + } + // InboundThread calls this function. Check the mapping between senderId and serverConn, wake up listenThread, and deliver the msg. + virtual void forwardMessageToInbound(string binaryMessage, string senderId) override { + rlib::sockIO::send_msg(ipcPipe, senderId); + rlib::sockIO::send_msg(ipcPipe, binaryMessage); + } + + // Listen the PIPE. handleMessage will wake up this thread from epoll. + // Also listen the connection fileDescriptors. + virtual void listenForever(BaseInbound* previousHop) override { + + // ----------------------- Initialization / Setup ------------------------------ + auto epollFd = epoll_create1(0); + dynamic_assert((int)epollFd != -1, "epoll_create1 failed"); + while (ipcPipe == -1) { + ; // Sleep until InboundThread initializes the pipe. +#ifdef cond_resched + cond_resched(); +#endif + } + epoll_add_fd(epollFd, ipcPipe); + + // ----------------------- Process an event ------------------------------ + std::string msgBuffer(DGRAM_BUFFER_SIZE, '\0'); + auto onEvent = [&](auto activeFd) { + if (activeFd == ipcPipe) { + // Inbound gave me a message to forward! Send it. + auto targetClientId = rlib::sockIO::recv_msg(activeFd); + auto msg = rlib::sockIO::recv_msg(activeFd); + + auto iter = connectionMap.client2server.find(targetClientId); + if (iter != connectionMap.client2server.end()) { + // Map contains ClientId. + rlib::sockIO::quick_send(iter->second, msg); // udp + } + else { + // This clientId is new. I don't know how to listen many sockets for response, so I just issue `connect` just like TCP does. + auto connFd = rlib::quick_connect(serverAddr, serverPort, true); + epoll_add_fd(epollFd, connFd); + connectionMap.add(targetClientId, connFd); + rlib::sockIO::quick_send(connFd, msg); // udp + } + } + else { + // Message from some connFd. Read and forward it. + auto status = recv(activeFd, msgBuffer.data(), msgBuffer.size(), 0); + dynamic_assert(status != -1, "recv failed"); + if (status == 0) { + // TODO: close the socket, and notify Inbound to destory data structures. + epoll_del_fd(epollFd, activeFd); + connectionMap.del(activeFd); + rlib::sockIO::close_ex(activeFd); + } + + dynamic_assert(connectionMap.server2client.count(activeFd) > 0, "connectionMap MUST contain server connfd. "); + + forwardMessageToInbound(msgBuffer.substr(0, status), connectionMap.server2client.at(activeFd)); + } + }; + + // ----------------------- listener main loop ------------------------------ + epoll_event events[EPOLL_MAX_EVENTS]; + rlog.info("PlainOutbound to {}:{} is up, listening for request ...", serverAddr, serverPort); + while (true) { + auto nfds = epoll_wait(epollFd, events, EPOLL_MAX_EVENTS, -1); + dynamic_assert(nfds != -1, "epoll_wait failed"); + + for (auto cter = 0; cter < nfds; ++cter) { + onEvent(events[cter].data.fd); + } + } } + private: + string serverAddr; + uint16_t serverPort; + + + InjectiveConnectionMapping<string, sockfd_t> connectionMap; + }; } diff --git a/src/utils.hpp b/src/utils.hpp index f124673..812b5ad 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -26,9 +26,7 @@ struct SockAddr { socklen_t len; }; -struct ConnectionMapping { - std::unordered_map<string, fd_t> client2server; - std::unordered_multimap<fd_t, string> server2client; +struct ClientIdUtils { static string makeClientId(const SockAddr &osStruct) { // ClientId is a binary string. static_assert(sizeof(osStruct) == sizeof(SockAddr), "error: programming error detected."); @@ -93,5 +91,6 @@ inline auto mk_tcp_pipe() { } while(false) + #endif -- GitLab