From 8a76a78427c43b943721d53309e931f31d135122 Mon Sep 17 00:00:00 2001 From: Bensong Liu <bensl@microsoft.com> Date: Thu, 30 Jul 2020 14:54:13 +0800 Subject: [PATCH] finished two filters --- src/filters/aes_encryption.hpp | 73 +++ src/filters/base.hpp | 59 +++ src/filters/quic_obfs.hpp | 0 src/filters/wechat_video_obfs.hpp | 0 src/filters/xor_encryption.hpp | 42 ++ src/forwarder.hpp | 29 +- src/lib/plusaes.hpp | 766 ++++++++++++++++++++++++++++++ src/main.cc | 19 +- src/protocols/base.hpp | 9 +- src/protocols/plain.hpp | 14 +- src/test/filter_test.cc | 20 + src/utils.hpp | 20 +- 12 files changed, 1031 insertions(+), 20 deletions(-) create mode 100644 src/filters/aes_encryption.hpp create mode 100644 src/filters/base.hpp create mode 100644 src/filters/quic_obfs.hpp create mode 100644 src/filters/wechat_video_obfs.hpp create mode 100644 src/filters/xor_encryption.hpp create mode 100644 src/lib/plusaes.hpp create mode 100644 src/test/filter_test.cc diff --git a/src/filters/aes_encryption.hpp b/src/filters/aes_encryption.hpp new file mode 100644 index 0000000..b79516b --- /dev/null +++ b/src/filters/aes_encryption.hpp @@ -0,0 +1,73 @@ +#ifndef UDP_FWD_FILTER_AES_ +#define UDP_FWD_FILTER_AES_ 1 + +#include "filters/base.hpp" +#include "lib/plusaes.hpp" +#include <rlib/string.hpp> +#include <unordered_map> +#include "utils.hpp" + +namespace Filters { + enum class AESFilterMode { + CBC, + ECB, + CTR + }; + + template <size_t BlockSize = 128, AESFilterMode mode = AESFilterMode::CBC> + class AESFilter : public BaseFilter { + public: + virtual void loadConfig(string config) override { + auto ar = rlib::string(config).split('@'); // Also works for ipv6. + if (ar.size() != 2) + throw std::invalid_argument("Wrong parameter string for filter 'aes'. Example: aes@MyPassword"); + auto tmpKey = pskToKey<16>(ar[1]); + char aesKey[17] = {0}; + static_assert(BlockSize == 128, "TODO: Change key size on compilation time to support more AES algo. "); + std::copy(tmpKey.begin(), tmpKey.end(), std::begin(aesKey)); + key = plusaes::key_from_string(&aesKey); + } + + static_assert(mode == AESFilterMode::CBC || mode == AESFilterMode::ECB, "Only supporting CBC and ECB"); + + // Encrypt + virtual string convertForward(string datagram) override { + const unsigned long encrypted_size = plusaes::get_padded_encrypted_size(datagram.size()); + std::string result(encrypted_size, '\0'); + if constexpr (mode == AESFilterMode::CBC) { + plusaes::encrypt_cbc((unsigned char *)datagram.data(), datagram.size(), key.data(), key.size(), &iv, (unsigned char *)result.data(), result.size(), true); + } + if constexpr (mode == AESFilterMode::ECB) { + plusaes::encrypt_ecb((unsigned char *)datagram.data(), datagram.size(), key.data(), key.size(), (unsigned char *)result.data(), result.size(), true); + } + return result; + } + + // Decrypt + virtual string convertBackward(string datagram) override { + std::string result(datagram.size(), '\0'); + unsigned long padded_size = 0; + if constexpr (mode == AESFilterMode::CBC) { + plusaes::decrypt_cbc((unsigned char *)datagram.data(), datagram.size(), key.data(), key.size(), &iv, (unsigned char *)result.data(), result.size(), &padded_size); + } + if constexpr (mode == AESFilterMode::ECB) { + plusaes::decrypt_ecb((unsigned char *)datagram.data(), datagram.size(), key.data(), key.size(), (unsigned char *)result.data(), result.size(), &padded_size); + } + + return result.substr(0, datagram.size() - padded_size); + } + + private: + std::vector<unsigned char> key; + + static_assert(BlockSize == 128, "TODO: Change IV size on compilation time to support more AES algo. "); + unsigned char iv[16] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + }; + + }; +} + +#endif + diff --git a/src/filters/base.hpp b/src/filters/base.hpp new file mode 100644 index 0000000..00bb003 --- /dev/null +++ b/src/filters/base.hpp @@ -0,0 +1,59 @@ +#ifndef UDP_FORWARDER_FILT_BASE_HPP_ +#define UDP_FORWARDER_FILT_BASE_HPP_ 1 + +#include <rlib/class_decorator.hpp> +#include <string> +#include <list> +#include <cstring> +using std::string; + +/* + UDP Forwarder + |--------------------------------------| +---> Inbound ====|FILTER|===> Outbound ---> NextHop + | <===|FILTER|==== | + |--------------------------------------| +*/ + +namespace Filters { + struct BaseFilter : rlib::noncopyable { + virtual ~BaseFilter() = default; + + // Init data structures. + virtual void loadConfig(string config) {} + + // Usually the encrypt/encode/obfs function. + virtual string convertForward(string binaryDatagram) = 0; + + // Usually the decrypt/decode/de-obfs function. + virtual string convertBackward(string binaryDatagram) = 0; + }; + + struct ChainedFilters : public BaseFilter { + ChainedFilters(const std::list<Filters::BaseFilter*>& chainedFilters) + : chainedFilters(chainedFilters) {} + + // Usually the encrypt/encode/obfs function. + virtual string convertForward(string binaryDatagram) override { + for (auto* filterPtr : chainedFilters) { + binaryDatagram = filterPtr->convertForward(binaryDatagram); + } + return binaryDatagram; + } + + // Usually the decrypt/decode/de-obfs function. + virtual string convertBackward(string binaryDatagram) { + for (auto iter = chainedFilters.rbegin(); iter != chainedFilters.rend(); ++iter) { + binaryDatagram = (*iter)->convertForward(binaryDatagram); + } + return binaryDatagram; + } + + const std::list<Filters::BaseFilter*>& chainedFilters; + }; +} + +#endif + + + diff --git a/src/filters/quic_obfs.hpp b/src/filters/quic_obfs.hpp new file mode 100644 index 0000000..e69de29 diff --git a/src/filters/wechat_video_obfs.hpp b/src/filters/wechat_video_obfs.hpp new file mode 100644 index 0000000..e69de29 diff --git a/src/filters/xor_encryption.hpp b/src/filters/xor_encryption.hpp new file mode 100644 index 0000000..639acdd --- /dev/null +++ b/src/filters/xor_encryption.hpp @@ -0,0 +1,42 @@ +#ifndef UDP_FWD_FILTER_XOR_ +#define UDP_FWD_FILTER_XOR_ 1 + +#include "filters/base.hpp" +#include <unordered_map> +#include <rlib/string.hpp> +#include "utils.hpp" + +namespace Filters { + template <size_t BlockSize = 32> + class XorFilter : public BaseFilter { + public: + virtual void loadConfig(string config) override { + auto ar = rlib::string(config).split('@'); // Also works for ipv6. + if (ar.size() != 2) + throw std::invalid_argument("Wrong parameter string for protocol 'plain'. Example: XOR@MyPassword"); + key = pskToKey<BlockSize>(ar[1]); + } + + // Encrypt + virtual string convertForward(string datagram) override { + auto curr_key_digit = 0; + for (auto offset = 0; offset < datagram.size(); ++offset) { + datagram[0] ^= key[curr_key_digit++]; + } + return datagram; + } + + // Decrypt + virtual string convertBackward(string datagram) override { + return convertForward(datagram); + } + + private: + string key; + + + }; +} + +#endif + diff --git a/src/forwarder.hpp b/src/forwarder.hpp index 5fa8094..c408b37 100644 --- a/src/forwarder.hpp +++ b/src/forwarder.hpp @@ -7,13 +7,17 @@ #include "protocols/base.hpp" #include "protocols/plain.hpp" +#include "filters/base.hpp" +#include "filters/aes_encryption.hpp" +#include "filters/xor_encryption.hpp" + using std::string; class Forwarder { public: - Forwarder(const rlib::string &inboundConfig, const rlib::string &outboundConfig) { + Forwarder(const rlib::string &inboundConfig, const rlib::string &outboundConfig, const std::list<rlib::string> &filterConfigs) { if (inboundConfig.starts_with("plain")) ptrInbound = new Protocols::PlainInbound; else if (inboundConfig.starts_with("misc")) @@ -29,22 +33,41 @@ public: else throw std::invalid_argument("Unknown protocol in outboundConfig " + outboundConfig); ptrOutbound->loadConfig(outboundConfig); + + + std::list<Filters::BaseFilter*> chainedFilters; + for (auto &&filterConfig : filterConfigs) { + Filters::BaseFilter *newFilter = nullptr; + if (filterConfig.starts_with("aes")) + newFilter = new Filters::AESFilter(); + else if (filterConfig.starts_with("xor")) + newFilter = new Filters::XorFilter(); // these filters were not deleted. just a note. + else + throw std::invalid_argument("Unknown filter in filterConfig item: " + filterConfig); + + newFilter->loadConfig(filterConfig); + chainedFilters.push_back(newFilter); + } + + ptrFilter = new Filters::ChainedFilters(chainedFilters); } ~Forwarder() { if (ptrInbound) delete ptrInbound; if (ptrOutbound) delete ptrOutbound; + if (ptrFilter) delete ptrFilter; } [[noreturn]] void runForever() { - std::thread([this] {ptrInbound->listenForever(ptrOutbound);}).detach(); - ptrOutbound->listenForever(ptrInbound); // Blocks + std::thread([this] {ptrInbound->listenForever(ptrOutbound, ptrFilter);}).detach(); + ptrOutbound->listenForever(ptrInbound, ptrFilter); // Blocks } private: Protocols::BaseInbound *ptrInbound; Protocols::BaseOutbound *ptrOutbound; + Filters::BaseFilter *ptrFilter; }; #endif diff --git a/src/lib/plusaes.hpp b/src/lib/plusaes.hpp new file mode 100644 index 0000000..5b2be1c --- /dev/null +++ b/src/lib/plusaes.hpp @@ -0,0 +1,766 @@ +// Copyright (C) 2015 kkAyataka +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#ifndef PLUSAES_HPP__ +#define PLUSAES_HPP__ + +#include <stdexcept> +#include <vector> + +/** Version number of plusaes. + * 0x01020304 -> 1.2.3.4 */ +#define PLUSAES_VERSION 0x00090100 + +/** AES cipher APIs */ +namespace plusaes { +namespace detail { + +const int kWordSize = 4; +typedef unsigned int Word; + +const int kBlockSize = 4; +/** @private */ +typedef struct { + Word w[4]; + Word & operator[](const int index) { + return w[index]; + } + const Word & operator[](const int index) const { + return w[index]; + } +} State; + +const int kStateSize = 16; // Word * BlockSize +typedef State RoundKey; +typedef std::vector<RoundKey> RoundKeys; + +inline void add_round_key(const RoundKey &key, State &state) { + for (int i = 0; i < kBlockSize; ++i) { + state[i] ^= key[i]; + } +} + +const unsigned char kSbox[] = { + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16 +}; + +const unsigned char kInvSbox[] = { + 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb, + 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb, + 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e, + 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25, + 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92, + 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84, + 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, + 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, + 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, + 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e, + 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, + 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, + 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, + 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, + 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, + 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d +}; + +inline Word sub_word(const Word w) { + return kSbox[(w >> 0) & 0xFF] << 0 | + kSbox[(w >> 8) & 0xFF] << 8 | + kSbox[(w >> 16) & 0xFF] << 16 | + kSbox[(w >> 24) & 0xFF] << 24; +} + +inline Word inv_sub_word(const Word w) { + return kInvSbox[(w >> 0) & 0xFF] << 0 | + kInvSbox[(w >> 8) & 0xFF] << 8 | + kInvSbox[(w >> 16) & 0xFF] << 16 | + kInvSbox[(w >> 24) & 0xFF] << 24; +} + +inline void sub_bytes(State &state) { + for (int i = 0; i < kBlockSize; ++i) { + state[i] = sub_word(state[i]); + } +} + +inline void inv_sub_bytes(State &state) { + for (int i = 0; i < kBlockSize; ++i) { + state[i] = inv_sub_word(state[i]); + } +} + +inline void shift_rows(State &state) { + const State ori = { state[0], state[1], state[2], state[3] }; + for (int r = 1; r < kWordSize; ++r) { + const Word m2 = 0xFF << (r * 8); + const Word m1 = ~m2; + for (int c = 0; c < kBlockSize; ++c) { + state[c] = (state[c] & m1) | (ori[(c + r) % kBlockSize] & m2); + } + } +} + +inline void inv_shift_rows(State &state) { + const State ori = { state[0], state[1], state[2], state[3] }; + for (int r = 1; r < kWordSize; ++r) { + const Word m2 = 0xFF << (r * 8); + const Word m1 = ~m2; + for (int c = 0; c < kBlockSize; ++c) { + state[c] = (state[c] & m1) | (ori[(c + kBlockSize - r) % kWordSize] & m2); + } + } +} + +inline unsigned char mul2(const unsigned char b) { + unsigned char m2 = b << 1; + if (b & 0x80) { + m2 ^= 0x011B; + } + + return m2; +} + +inline unsigned char mul(const unsigned char b, const unsigned char m) { + unsigned char v = 0; + unsigned char t = b; + for (int i = 0; i < 8; ++i) { // 8-bits + if ((m >> i) & 0x01) { + v ^= t; + } + + t = mul2(t); + } + + return v; +} + +inline void mix_columns(State &state) { + for (int i = 0; i < kBlockSize; ++i) { + const unsigned char v0_1 = (state[i] >> 0) & 0xFF; + const unsigned char v1_1 = (state[i] >> 8) & 0xFF; + const unsigned char v2_1 = (state[i] >> 16) & 0xFF; + const unsigned char v3_1 = (state[i] >> 24) & 0xFF; + + const unsigned char v0_2 = mul2(v0_1); + const unsigned char v1_2 = mul2(v1_1); + const unsigned char v2_2 = mul2(v2_1); + const unsigned char v3_2 = mul2(v3_1); + + const unsigned char v0_3 = v0_2 ^ v0_1; + const unsigned char v1_3 = v1_2 ^ v1_1; + const unsigned char v2_3 = v2_2 ^ v2_1; + const unsigned char v3_3 = v3_2 ^ v3_1; + + state[i] = + (v0_2 ^ v1_3 ^ v2_1 ^ v3_1) << 0 | + (v0_1 ^ v1_2 ^ v2_3 ^ v3_1) << 8 | + (v0_1 ^ v1_1 ^ v2_2 ^ v3_3) << 16 | + (v0_3 ^ v1_1 ^ v2_1 ^ v3_2) << 24; + } +} + +inline void inv_mix_columns(State &state) { + for (int i = 0; i < kBlockSize; ++i) { + const unsigned char v0 = (state[i] >> 0) & 0xFF; + const unsigned char v1 = (state[i] >> 8) & 0xFF; + const unsigned char v2 = (state[i] >> 16) & 0xFF; + const unsigned char v3 = (state[i] >> 24) & 0xFF; + + state[i] = + (mul(v0, 0x0E) ^ mul(v1, 0x0B) ^ mul(v2, 0x0D) ^ mul(v3, 0x09)) << 0 | + (mul(v0, 0x09) ^ mul(v1, 0x0E) ^ mul(v2, 0x0B) ^ mul(v3, 0x0D)) << 8 | + (mul(v0, 0x0D) ^ mul(v1, 0x09) ^ mul(v2, 0x0E) ^ mul(v3, 0x0B)) << 16 | + (mul(v0, 0x0B) ^ mul(v1, 0x0D) ^ mul(v2, 0x09) ^ mul(v3, 0x0E)) << 24; + } +} + +inline Word rot_word(const Word v) { + return ((v >> 8) & 0x00FFFFFF) | ((v & 0xFF) << 24); +} + +/** + * @private + * @throws std::invalid_argument + */ +inline unsigned int get_round_count(const int key_size) { + switch (key_size) { + case 16: + return 10; + case 24: + return 12; + case 32: + return 14; + default: + throw std::invalid_argument("Invalid key size"); + } +} + +/** + * @private + * @throws std::invalid_argument + */ +inline RoundKeys expand_key(const unsigned char *key, const int key_size) { + if (key_size != 16 && key_size != 24 && key_size != 32) { + throw std::invalid_argument("Invalid key size"); + } + + const Word rcon[] = { + 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, + 0x20, 0x40, 0x80, 0x1b, 0x36 + }; + + const int nb = kBlockSize; + const int nk = key_size / nb; + const int nr = get_round_count(key_size); + + std::vector<Word> w(nb * (nr + 1)); + for (int i = 0; i < nk; ++ i) { + memcpy(&w[i], key + (i * kWordSize), kWordSize); + } + + for (int i = nk; i < nb * (nr + 1); ++i) { + Word t = w[i - 1]; + if (i % nk == 0) { + t = sub_word(rot_word(t)) ^ rcon[i / nk]; + } + else if (nk > 6 && i % nk == 4) { + t = sub_word(t); + } + + w[i] = t ^ w[i - nk]; + } + + RoundKeys keys(nr + 1); + memcpy(&keys[0], &w[0], w.size() * kWordSize); + + return keys; +} + +inline void copy_bytes_to_state(const unsigned char data[16], State &state) { + memcpy(&state[0], data + 0, kWordSize); + memcpy(&state[1], data + 4, kWordSize); + memcpy(&state[2], data + 8, kWordSize); + memcpy(&state[3], data + 12, kWordSize); +} + +inline void copy_state_to_bytes(const State &state, unsigned char buf[16]) { + memcpy(buf + 0, &state[0], kWordSize); + memcpy(buf + 4, &state[1], kWordSize); + memcpy(buf + 8, &state[2], kWordSize); + memcpy(buf + 12, &state[3], kWordSize); +} + +inline void xor_data(unsigned char data[kStateSize], const unsigned char v[kStateSize]) { + for (int i = 0; i < kStateSize; ++i) { + data[i] ^= v[i]; + } +} + +/** increment counter (128-bit int) by 1 */ +inline void incr_counter(unsigned char counter[kStateSize]) { + unsigned n = kStateSize, c = 1; + do { + --n; + c += counter[n]; + counter[n] = c; + c >>= 8; + } while (n); +} + +inline void encrypt_state(const RoundKeys &rkeys, const unsigned char data[16], unsigned char encrypted[16]) { + State s; + copy_bytes_to_state(data, s); + + add_round_key(rkeys[0], s); + + for (unsigned int i = 1; i < rkeys.size() - 1; ++i) { + sub_bytes(s); + shift_rows(s); + mix_columns(s); + add_round_key(rkeys[i], s); + } + + sub_bytes(s); + shift_rows(s); + add_round_key(rkeys.back(), s); + + copy_state_to_bytes(s, encrypted); +} + +inline void decrypt_state(const RoundKeys &rkeys, const unsigned char data[16], unsigned char decrypted[16]) { + State s; + copy_bytes_to_state(data, s); + + add_round_key(rkeys.back(), s); + inv_shift_rows(s); + inv_sub_bytes(s); + + for (std::size_t i = rkeys.size() - 2; i > 0; --i) { + add_round_key(rkeys[i], s); + inv_mix_columns(s); + inv_shift_rows(s); + inv_sub_bytes(s); + } + + add_round_key(rkeys[0], s); + + copy_state_to_bytes(s, decrypted); +} + +template<int KeyLen> +std::vector<unsigned char> key_from_string(const char (*key_str)[KeyLen]) { + std::vector<unsigned char> key(KeyLen - 1); + memcpy(&key[0], *key_str, KeyLen - 1); + return key; +} + +inline bool is_valid_key_size(const unsigned long key_size) { + if (key_size != 16 && key_size != 24 && key_size != 32) { + return false; + } + else { + return true; + } +} + +} // namespace detail + +/** Version number of plusaes. */ +inline unsigned int version() { + return PLUSAES_VERSION; +} + +/** Create 128-bit key from string. */ +inline std::vector<unsigned char> key_from_string(const char (*key_str)[17]) { + return detail::key_from_string<17>(key_str); +} + +/** Create 192-bit key from string. */ +inline std::vector<unsigned char> key_from_string(const char (*key_str)[25]) { + return detail::key_from_string<25>(key_str); +} + +/** Create 256-bit key from string. */ +inline std::vector<unsigned char> key_from_string(const char (*key_str)[33]) { + return detail::key_from_string<33>(key_str); +} + +/** Calculates encrypted data size when padding is enabled. */ +inline unsigned long get_padded_encrypted_size(const unsigned long data_size) { + return data_size + detail::kStateSize - (data_size % detail::kStateSize); +} + +/** Error code */ +typedef enum { + kErrorOk = 0, + kErrorInvalidDataSize = 1, + kErrorInvalidKeySize, + kErrorInvalidBufferSize, + kErrorInvalidKey, + kErrorInvalidNonceSize +} Error; + +namespace detail { + +inline Error check_encrypt_cond( + const unsigned long data_size, + const unsigned long key_size, + const unsigned long encrypted_size, + const bool pads) { + // check data size + if (!pads && (data_size % kStateSize != 0)) { + return kErrorInvalidDataSize; + } + + // check key size + if (!detail::is_valid_key_size(key_size)) { + return kErrorInvalidKeySize; + } + + // check encrypted buffer size + if (pads) { + const unsigned long required_size = get_padded_encrypted_size(data_size); + if (encrypted_size < required_size) { + return kErrorInvalidBufferSize; + } + } + else { + if (encrypted_size < data_size) { + return kErrorInvalidBufferSize; + } + } + return kErrorOk; +} + +inline Error check_decrypt_cond( + const unsigned long data_size, + const unsigned long key_size, + const unsigned long decrypted_size, + const unsigned long * padded_size + ) { + // check data size + if (data_size % 16 != 0) { + return kErrorInvalidDataSize; + } + + // check key size + if (!detail::is_valid_key_size(key_size)) { + return kErrorInvalidKeySize; + } + + // check decrypted buffer size + if (!padded_size) { + if (decrypted_size < data_size) { + return kErrorInvalidBufferSize; + } + } + else { + if (decrypted_size < (data_size - kStateSize)) { + return kErrorInvalidBufferSize; + } + } + + return kErrorOk; +} + +inline bool check_padding(const unsigned long padding, const unsigned char data[kStateSize]) { + if (padding > kStateSize) { + return false; + } + + for (unsigned long i = 0; i < padding; ++i) { + if (data[kStateSize - 1 - i] != padding) { + return false; + } + } + + return true; +} + +} // namespace detail + +/** + * Encrypts data with ECB mode. + * @param [in] data Data. + * @param [in] data_size Data size. + * If the pads is false, data size must be multiple of 16. + * @param [in] key key bytes. The key length must be 16 (128-bit), 24 (192-bit) or 32 (256-bit). + * @param [in] key_size key size. + * @param [out] encrypted Encrypted data buffer. + * @param [in] encrypted_size Encrypted data buffer size. + * @param [in] pads If this value is true, encrypted data is padded by PKCS. + * Encrypted data size must be multiple of 16. + * If the pads is true, encrypted data is padded with PKCS. + * So the data is multiple of 16, encrypted data size needs additonal 16 bytes. + * @since 1.0.0 + */ +inline Error encrypt_ecb( + const unsigned char * data, + const unsigned long data_size, + const unsigned char * key, + const unsigned long key_size, + unsigned char *encrypted, + const unsigned long encrypted_size, + const bool pads + ) { + const Error e = detail::check_encrypt_cond(data_size, key_size, encrypted_size, pads); + if (e != kErrorOk) { + return e; + } + + const detail::RoundKeys rkeys = detail::expand_key(key, static_cast<int>(key_size)); + + const unsigned long bc = data_size / detail::kStateSize; + for (unsigned long i = 0; i < bc; ++i) { + detail::encrypt_state(rkeys, data + (i * detail::kStateSize), encrypted + (i * detail::kStateSize)); + } + + if (pads) { + const int rem = data_size % detail::kStateSize; + const char pad_v = detail::kStateSize - rem; + + std::vector<unsigned char> ib(detail::kStateSize, pad_v), ob(detail::kStateSize); + memcpy(&ib[0], data + data_size - rem, rem); + + detail::encrypt_state(rkeys, &ib[0], &ob[0]); + memcpy(encrypted + (data_size - rem), &ob[0], detail::kStateSize); + } + + return kErrorOk; +} + +/** + * Decrypts data with ECB mode. + * @param [in] data Data bytes. + * @param [in] data_size Data size. + * @param [in] key Key bytes. + * @param [in] key_size Key size. + * @param [out] decrypted Decrypted data buffer. + * @param [in] decrypted_size Decrypted data buffer size. + * @param [out] padded_size If this value is NULL, this function does not remove padding. + * If this value is not NULL, this function removes padding by PKCS + * and returns padded size using padded_size. + * @since 1.0.0 + */ +inline Error decrypt_ecb( + const unsigned char * data, + const unsigned long data_size, + const unsigned char * key, + const unsigned long key_size, + unsigned char * decrypted, + const unsigned long decrypted_size, + unsigned long * padded_size + ) { + const Error e = detail::check_decrypt_cond(data_size, key_size, decrypted_size, padded_size); + if (e != kErrorOk) { + return e; + } + + const detail::RoundKeys rkeys = detail::expand_key(key, static_cast<int>(key_size)); + + const unsigned long bc = data_size / detail::kStateSize - 1; + for (unsigned long i = 0; i < bc; ++i) { + detail::decrypt_state(rkeys, data + (i * detail::kStateSize), decrypted + (i * detail::kStateSize)); + } + + unsigned char last[detail::kStateSize] = {}; + detail::decrypt_state(rkeys, data + (bc * detail::kStateSize), last); + + if (padded_size) { + *padded_size = last[detail::kStateSize - 1]; + const unsigned long cs = detail::kStateSize - *padded_size; + + if (!detail::check_padding(*padded_size, last)) { + return kErrorInvalidKey; + } + else if (decrypted_size >= (bc * detail::kStateSize) + cs) { + memcpy(decrypted + (bc * detail::kStateSize), last, cs); + } + else { + return kErrorInvalidBufferSize; + } + } + else { + memcpy(decrypted + (bc * detail::kStateSize), last, sizeof(last)); + } + + return kErrorOk; +} + +/** + * Encrypt data with CBC mode. + * @param [in] data Data. + * @param [in] data_size Data size. + * If the pads is false, data size must be multiple of 16. + * @param [in] key key bytes. The key length must be 16 (128-bit), 24 (192-bit) or 32 (256-bit). + * @param [in] key_size key size. + * @param [in] iv Initialize vector. + * @param [out] encrypted Encrypted data buffer. + * @param [in] encrypted_size Encrypted data buffer size. + * @param [in] pads If this value is true, encrypted data is padded by PKCS. + * Encrypted data size must be multiple of 16. + * If the pads is true, encrypted data is padded with PKCS. + * So the data is multiple of 16, encrypted data size needs additonal 16 bytes. + * @since 1.0.0 + */ +inline Error encrypt_cbc( + const unsigned char * data, + const unsigned long data_size, + const unsigned char * key, + const unsigned long key_size, + const unsigned char (* iv)[16], + unsigned char * encrypted, + const unsigned long encrypted_size, + const bool pads + ) { + const Error e = detail::check_encrypt_cond(data_size, key_size, encrypted_size, pads); + if (e != kErrorOk) { + return e; + } + + const detail::RoundKeys rkeys = detail::expand_key(key, static_cast<int>(key_size)); + + unsigned char s[detail::kStateSize] = {}; // encrypting data + + // calculate padding value + const bool ge16 = (data_size >= detail::kStateSize); + const int rem = data_size % detail::kStateSize; + const unsigned char pad_v = detail::kStateSize - rem; + + // encrypt 1st state + if (ge16) { + memcpy(s, data, detail::kStateSize); + } + else { + memset(s, pad_v, detail::kStateSize); + memcpy(s, data, data_size); + } + if (iv) { + detail::xor_data(s, *iv); + } + detail::encrypt_state(rkeys, s, encrypted); + + // encrypt mid + const unsigned long bc = data_size / detail::kStateSize; + for (unsigned long i = 1; i < bc; ++i) { + const long offset = i * detail::kStateSize; + memcpy(s, data + offset, detail::kStateSize); + detail::xor_data(s, encrypted + offset - detail::kStateSize); + + detail::encrypt_state(rkeys, s, encrypted + offset); + } + + // enctypt last + if (pads && ge16) { + std::vector<unsigned char> ib(detail::kStateSize, pad_v), ob(detail::kStateSize); + memcpy(&ib[0], data + data_size - rem, rem); + + detail::xor_data(&ib[0], encrypted + (bc - 1) * detail::kStateSize); + + detail::encrypt_state(rkeys, &ib[0], &ob[0]); + memcpy(encrypted + (data_size - rem), &ob[0], detail::kStateSize); + } + + return kErrorOk; +} + +/** + * Decrypt data with CBC mode. + * @param [in] data Data bytes. + * @param [in] data_size Data size. + * @param [in] key Key bytes. + * @param [in] key_size Key size. + * @param [in] iv Initialize vector. + * @param [out] decrypted Decrypted data buffer. + * @param [in] decrypted_size Decrypted data buffer size. + * @param [out] padded_size If this value is NULL, this function does not remove padding. + * If this value is not NULL, this function removes padding by PKCS + * and returns padded size using padded_size. + * @since 1.0.0 + */ +inline Error decrypt_cbc( + const unsigned char * data, + const unsigned long data_size, + const unsigned char * key, + const unsigned long key_size, + const unsigned char (* iv)[16], + unsigned char * decrypted, + const unsigned long decrypted_size, + unsigned long * padded_size + ) { + const Error e = detail::check_decrypt_cond(data_size, key_size, decrypted_size, padded_size); + if (e != kErrorOk) { + return e; + } + + const detail::RoundKeys rkeys = detail::expand_key(key, static_cast<int>(key_size)); + + // decrypt 1st state + detail::decrypt_state(rkeys, data, decrypted); + if (iv) { + detail::xor_data(decrypted, *iv); + } + + // decrypt mid + const unsigned long bc = data_size / detail::kStateSize - 1; + for (unsigned long i = 1; i < bc; ++i) { + const long offset = i * detail::kStateSize; + detail::decrypt_state(rkeys, data + offset, decrypted + offset); + detail::xor_data(decrypted + offset, data + offset - detail::kStateSize); + } + + // decrypt last + unsigned char last[detail::kStateSize] = {}; + if (data_size > detail::kStateSize) { + detail::decrypt_state(rkeys, data + (bc * detail::kStateSize), last); + detail::xor_data(last, data + (bc * detail::kStateSize - detail::kStateSize)); + } + else { + memcpy(last, decrypted, data_size); + memset(decrypted, 0, decrypted_size); + } + + if (padded_size) { + *padded_size = last[detail::kStateSize - 1]; + const unsigned long cs = detail::kStateSize - *padded_size; + + if (!detail::check_padding(*padded_size, last)) { + return kErrorInvalidKey; + } + else if (decrypted_size >= (bc * detail::kStateSize) + cs) { + memcpy(decrypted + (bc * detail::kStateSize), last, cs); + } + else { + return kErrorInvalidBufferSize; + } + } + else { + memcpy(decrypted + (bc * detail::kStateSize), last, sizeof(last)); + } + + return kErrorOk; +} + +/** + * @note + * This is BETA API. I might change API in the future. + * + * Encrypts or decrypt data in-place with CTR mode. + * @param [in,out] data Data. + * @param [in,out] data_size Data size. + * @param [in] key key bytes. The key length must be 16 (128-bit), 24 (192-bit) or 32 (256-bit). + * @param [in] key_size key size. + * @param [in] nonce 16 bytes. + * @since 1.0.0 + */ +inline Error crypt_ctr( + unsigned char *data, + unsigned long data_size, + const unsigned char *key, + const unsigned long key_size, + const unsigned char *nonce, + const unsigned long nonce_size +) { + if (nonce_size > detail::kStateSize) return kErrorInvalidNonceSize; + if (!detail::is_valid_key_size(key_size)) return kErrorInvalidKeySize; + const detail::RoundKeys rkeys = detail::expand_key(key, static_cast<int>(key_size)); + + unsigned long pos = 0; + unsigned long blkpos = detail::kStateSize; + unsigned char blk[detail::kStateSize]; + unsigned char counter[detail::kStateSize] = {}; + memcpy(counter, nonce, nonce_size); + + while (pos < data_size) { + if (blkpos == detail::kStateSize) { + detail::encrypt_state(rkeys, counter, blk); + detail::incr_counter(counter); + blkpos = 0; + } + data[pos++] ^= blk[blkpos++]; + } + + return kErrorOk; +} + +} // namespace plusaes + +#endif // PLUSAES_HPP__ diff --git a/src/main.cc b/src/main.cc index 37da970..88680ea 100644 --- a/src/main.cc +++ b/src/main.cc @@ -21,10 +21,12 @@ using namespace std::chrono_literals; int real_main(int argc, char **argv) { rlib::opt_parser args(argc, argv); if(args.getBoolArg("--help", "-h")) { - rlog.info("Usage: {} -i $InboundConfig -o $OutboundConfig [--log=error/info/verbose/debug]"_rs.format(args.getSelf())); + rlog.info("Usage: {} -i $InboundConfig -o $OutboundConfig [--filter $filterConfig ...] [--log=error/info/verbose/debug]"_rs.format(args.getSelf())); rlog.info(" InboundConfig and OutboundConfig are in this format: "); - rlog.info(" '$method:$params', available methods: "); - rlog.info(" 'plain:$addr:$port', 'misc:$addr:$portRange:$psk'"); + rlog.info(" '$method@$params', available methods: "); + rlog.info(" 'plain@$addr@$port', 'misc@$addr@$portRange@$psk'"); + rlog.info("There could be multiple --filter, but they MUST be in correct order. "); + rlog.info("available filters: 'aes@$password' , 'xor@$password'"); return 0; } auto inboundConfig = args.getValueArg("-i"); @@ -42,7 +44,16 @@ int real_main(int argc, char **argv) { else throw std::runtime_error("Unknown log level: " + log_level); - Forwarder(inboundConfig, outboundConfig).runForever(); + std::list<rlib::string> filterConfigs; + while (true) { + if (auto filterConfig = args.getValueArg("--filter", false, ""); !filterConfig.empty()) { + filterConfigs.push_back(filterConfig); + } + else + break; + } + + Forwarder(inboundConfig, outboundConfig, filterConfigs).runForever(); return 0; } diff --git a/src/protocols/base.hpp b/src/protocols/base.hpp index d0eda69..5278790 100644 --- a/src/protocols/base.hpp +++ b/src/protocols/base.hpp @@ -1,6 +1,7 @@ #ifndef UDP_FORWARDER_PROT_BASE_HPP_ #define UDP_FORWARDER_PROT_BASE_HPP_ 1 +#include "filters/base.hpp" #include <rlib/class_decorator.hpp> #include <string> using std::string; @@ -25,14 +26,14 @@ namespace Protocols { virtual ~BaseOutbound() = default; // Init data structures. - virtual void loadConfig(string config) = 0; + virtual void loadConfig(string config) {} // InboundThread calls this function. Check the mapping between senderId and serverConn, wake up listenThread, and deliver the msg. virtual void forwardMessageToInbound(string binaryMessage, string senderId) = 0; // Listen the PIPE. handleMessage will wake up this thread from epoll. // Also listen the connection fileDescriptors. - virtual void listenForever(BaseInbound *previousHop) = 0; + virtual void listenForever(BaseInbound *previousHop, Filters::BaseFilter *filter) = 0; // Inbound.listenForever MUST initialize this field. volatile sockfd_t ipcPipe = -1; @@ -42,14 +43,14 @@ namespace Protocols { virtual ~BaseInbound() = default; // Init data structures. - virtual void loadConfig(string config) = 0; + virtual void loadConfig(string config) {} // OutboundThread calls this function. Wake up 'listenForever' thread, and send back a message. Outbound provides the senderId. virtual void forwardMessageToOutbound(string binaryMessage, string senderId) = 0; // Listen the addr:port in config, for inbound connection. // Also listen the accepted connection fileDescriptors, and listen the PIPE. - virtual void listenForever(BaseOutbound *nextHop) = 0; + virtual void listenForever(BaseOutbound *nextHop, Filters::BaseFilter *filter) = 0; // Inbound.listenForever MUST initialize this field. volatile sockfd_t ipcPipe = -1; diff --git a/src/protocols/plain.hpp b/src/protocols/plain.hpp index 56efb85..1018332 100644 --- a/src/protocols/plain.hpp +++ b/src/protocols/plain.hpp @@ -4,8 +4,8 @@ #include <protocols/base.hpp> #include <rlib/sys/sio.hpp> #include <rlib/string.hpp> -#include <utils.hpp> -#include <common.hpp> +#include "utils.hpp" +#include "common.hpp" #if RLIB_OS_ID == OS_LINUX #include <linux/sched.h> @@ -46,7 +46,7 @@ namespace Protocols { rlib::sockIO::send_msg(ipcPipe, senderId); rlib::sockIO::send_msg(ipcPipe, binaryMessage); } - virtual void listenForever(BaseOutbound* nextHop) override { + virtual void listenForever(BaseOutbound* nextHop, Filters::BaseFilter *filter) override { std::tie(this->ipcPipe, nextHop->ipcPipe) = mk_tcp_pipe(); // ----------------------- Initialization / Setup ------------------------------ @@ -59,8 +59,6 @@ namespace Protocols { epoll_add_fd(epollFd, ipcPipe); // ----------------------- Process an event ------------------------------ - auto udpSenderSocket = socket(AF_INET, SOCK_DGRAM, 0); - dynamic_assert(udpSenderSocket > 0, "socket create failed."); std::string msgBuffer(DGRAM_BUFFER_SIZE, '\0'); // WARN: If you want to modify this program to work for TCP, PLEASE use rlib::sockIO::recv instead of fixed buffer. auto onEvent = [&](auto activeFd) { @@ -80,7 +78,7 @@ namespace Protocols { auto msgLength = recvfrom(activeFd, msgBuffer.data(), msgBuffer.size(), 0, &clientAddr.addr, &clientAddr.len); dynamic_assert(msgLength != -1, "recvfrom failed"); - forwardMessageToOutbound(msgBuffer.substr(0, msgLength), ClientIdUtils::makeClientId(clientAddr)); + forwardMessageToOutbound(filter->convertForward(msgBuffer.substr(0, msgLength)), ClientIdUtils::makeClientId(clientAddr)); } }; @@ -120,7 +118,7 @@ namespace Protocols { // Listen the PIPE. handleMessage will wake up this thread from epoll. // Also listen the connection fileDescriptors. - virtual void listenForever(BaseInbound* previousHop) override { + virtual void listenForever(BaseInbound* previousHop, Filters::BaseFilter *filter) override { // ----------------------- Initialization / Setup ------------------------------ auto epollFd = epoll_create1(0); @@ -172,7 +170,7 @@ namespace Protocols { dynamic_assert(connectionMap.server2client.count(activeFd) > 0, "connectionMap MUST contain server connfd. "); - forwardMessageToInbound(msgBuffer.substr(0, msgLength), connectionMap.server2client.at(activeFd)); + forwardMessageToInbound(filter->convertBackward(msgBuffer.substr(0, msgLength)), connectionMap.server2client.at(activeFd)); } }; diff --git a/src/test/filter_test.cc b/src/test/filter_test.cc new file mode 100644 index 0000000..b2571e4 --- /dev/null +++ b/src/test/filter_test.cc @@ -0,0 +1,20 @@ + +#include "../filters/aes_encryption.hpp" +#include <rlib/stdio.hpp> + +int main() { + Filters::AESFilter f; + f.loadConfig("aes@hello world"); + + + std::string plain = "Hi, I'm asoinvdowaieviosandoiv dasf sda ifsdh ofisdaf oisdfoisanoids nfoisdoafnsdoif nsdo fisdnio |"; + auto r = f.convertForward(plain); + rlib::println(plain); + rlib::println(r); + auto p = f.convertBackward(r); + rlib::println(p); + + +} + + diff --git a/src/utils.hpp b/src/utils.hpp index bf3727c..89967f4 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -31,13 +31,15 @@ struct ClientIdUtils { // ClientId is a binary string. static_assert(sizeof(osStruct) == sizeof(SockAddr), "error: programming error detected."); string result(sizeof(osStruct), '\0'); - std::memcpy(result.data(), &osStruct, sizeof(osStruct)); + // cross-platform TODO: byte-order problem. + std::memcpy((void *)result.data(), &osStruct, sizeof(osStruct)); return result; } static SockAddr parseClientId(const string &clientId) { SockAddr result; if (clientId.size() != sizeof(result)) throw std::invalid_argument("parseClientId, invalid input binary string length."); + // cross-platform TODO: byte-order problem. std::memcpy(&result, clientId.data(), sizeof(result)); return result; } @@ -91,6 +93,22 @@ inline auto mk_tcp_pipe() { } while(false) +template <size_t result_length> +string pskToKey(string psk) { + // Convert user-provided variable-length psk to a 32-byte key. + using HashType = decltype(std::hash<std::string>{}.operator()(std::string())); // std::result_of_t<std::hash<std::string>::operator()>; + static_assert(result_length % sizeof(HashType) == 0, "pskToKey: result_length MUST be multiply of HashResultSize."); + + string result(result_length, '\0'); + auto hashObject = std::hash<std::string>{}; + for (auto i = 0; i < result.size(); i += sizeof(HashType)) { + auto hashResult = hashObject(psk); + // cross-platform TODO: byte-order problem. + std::memcpy(result.data() + i, &hashResult, sizeof(HashType)); + } + return result; +} + #endif -- GitLab