From a46bb05825e7fadcf54da5b9cf80f13bc4df9119 Mon Sep 17 00:00:00 2001 From: Bensong Liu <bensl@microsoft.com> Date: Tue, 28 Jul 2020 17:26:06 +0800 Subject: [PATCH] some minimal adjustment. no comment --- CMakeLists.txt | 14 ++++++++++++-- src/forwarder.hpp | 16 +++++++++++++++- src/lib/rlib/string.hpp | 20 ++++++++++++++++++++ src/main.cc | 4 ++-- src/protocols/plain.hpp | 4 ++++ 5 files changed, 53 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 22dfff1..c4bc7bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,8 +32,7 @@ endif() include_directories(./src) include_directories(./src/lib) -# TODO -set(SRC src/main.cc) +AUX_SOURCE_DIRECTORY(src SRC) add_executable(udp-forwarder ${SRC}) target_link_libraries(udp-forwarder Threads::Threads) @@ -43,3 +42,14 @@ target_link_libraries(udp-forwarder Threads::Threads) # set_target_properties(... PROPERTIES COMPILE_FLAGS -m32 LINK_FLAGS -m32 ) #endif(FOR_M32) +#macro(print_all_variables) +# message(STATUS "print_all_variables------------------------------------------{") +# get_cmake_property(_variableNames VARIABLES) +# foreach (_variableName ${_variableNames}) +# message(STATUS "${_variableName}=${${_variableName}}") +# endforeach() +# message(STATUS "print_all_variables------------------------------------------}") +#endmacro() +# +#print_all_variables() + diff --git a/src/forwarder.hpp b/src/forwarder.hpp index 91ba1e3..3698999 100644 --- a/src/forwarder.hpp +++ b/src/forwarder.hpp @@ -5,6 +5,9 @@ #include <rlib/sys/sio.hpp> #include "utils.hpp" +#include "protocols/base.hpp" +#include "protocols/plain.hpp" + using std::string; struct ConnectionMapping { @@ -18,12 +21,23 @@ struct ConnectionMapping { class Forwarder { public: - Forwarder(const std::string& inboundConfig, const std::string& outboundConfig) { + Forwarder(const rlib::string &inboundConfig, const rlib::string &outboundConfig) { + if (inboundConfig.starts_with("plain")) + ptrInbound = new Protocols::PlainInbound(inboundConfig); + else if (inboundConfig.starts_with("misc")) + ptrInbound = nullptr; // TODO + + if (outboundConfig.starts_with("plain")) + ptrOutbound = nullptr; // TODO + else if (outboundConfig.starts_with("misc")) + ptrOutbound = nullptr; // TODO } private: + Protocols::BaseInbound *ptrInbound; + Protocols::BaseOutbound *ptrOutbound; }; #endif diff --git a/src/lib/rlib/string.hpp b/src/lib/rlib/string.hpp index 9fe7968..c8789c1 100644 --- a/src/lib/rlib/string.hpp +++ b/src/lib/rlib/string.hpp @@ -374,6 +374,26 @@ namespace rlib { return *this; } + bool starts_with(const std::string &what) const { + if(size() < what.size()) return false; + + std::string::value_type diffBits = 0; + for(auto i = 0; i < what.size(); ++i) { + diffBits = diffBits | (what[i] ^ (*this)[i]); + } + return diffBits == 0; + } + bool ends_with(const std::string &what) const { + if(size() < what.size()) return false; + + std::string::value_type diffBits = 0; + auto offset = size() - what.size(); + for(auto i = 0; i < what.size(); ++i) { + diffBits = diffBits | (what[i] ^ (*this)[offset+i]); + } + return diffBits == 0; + } + template <typename... Args> string &format(Args... args) { return operator=(std::move(impl::format_string(*this, args ...))); diff --git a/src/main.cc b/src/main.cc index dd18404..65169f2 100644 --- a/src/main.cc +++ b/src/main.cc @@ -1,14 +1,14 @@ #include <rlib/stdio.hpp> #include <rlib/opt.hpp> #include <rlib/sys/os.hpp> +#include <thread> #include "common.hpp" rlib::logger rlog(std::cerr); using namespace rlib::literals; +using namespace std::chrono_literals; #if RLIB_OS_ID == OS_WINDOWS - #include <thread> - using namespace std::chrono_literals; #define windows_main main #else #define real_main main diff --git a/src/protocols/plain.hpp b/src/protocols/plain.hpp index b9caa8b..fd0a327 100644 --- a/src/protocols/plain.hpp +++ b/src/protocols/plain.hpp @@ -10,6 +10,7 @@ namespace Protocols { class PlainInbound : public BaseInbound { public: + using BaseInbound::BaseInbound; virtual loadConfig(string config) override { auto ar = rlib::string(config).split('@'); // Also works for ipv6. if (ar.size() != 3) @@ -28,6 +29,7 @@ namespace Protocols { if(epollFd == -1) throw std::runtime_error("Failed to create epoll fd."); epoll_add_fd(epollFd, listenFd); + epoll_add_fd(epollFd, ipcPipeInboundEnd); epoll_event events[MAX_EVENTS]; char buffer[DGRAM_BUFFER_SIZE]; @@ -48,6 +50,8 @@ namespace Protocols { class PlainOutbound : public BaseOutbound { + public: + using BaseOutbound::BaseOutbound; }; } -- GitLab