diff options
author | Tiger Wang <ziwei.tiger@outlook.com> | 2021-03-08 17:37:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-03-08 17:37:36 +0100 |
commit | 01a4e696b3d2c973cdd1fb4345d747bd10e93ad9 (patch) | |
tree | 92996aef85fca3306b26535fa44feb812deb050b /src | |
parent | Some emplace_back replacements (#5149) (diff) | |
download | cuberite-01a4e696b3d2c973cdd1fb4345d747bd10e93ad9.tar cuberite-01a4e696b3d2c973cdd1fb4345d747bd10e93ad9.tar.gz cuberite-01a4e696b3d2c973cdd1fb4345d747bd10e93ad9.tar.bz2 cuberite-01a4e696b3d2c973cdd1fb4345d747bd10e93ad9.tar.lz cuberite-01a4e696b3d2c973cdd1fb4345d747bd10e93ad9.tar.xz cuberite-01a4e696b3d2c973cdd1fb4345d747bd10e93ad9.tar.zst cuberite-01a4e696b3d2c973cdd1fb4345d747bd10e93ad9.zip |
Diffstat (limited to 'src')
-rw-r--r-- | src/ClientHandle.cpp | 6 | ||||
-rw-r--r-- | src/ClientHandle.h | 2 | ||||
-rw-r--r-- | src/Protocol/Protocol.h | 6 | ||||
-rw-r--r-- | src/Protocol/ProtocolRecognizer.cpp | 14 | ||||
-rw-r--r-- | src/Protocol/ProtocolRecognizer.h | 11 | ||||
-rw-r--r-- | src/Protocol/Protocol_1_8.cpp | 259 | ||||
-rw-r--r-- | src/Protocol/Protocol_1_8.h | 11 | ||||
-rw-r--r-- | src/mbedTLS++/AesCfb128Decryptor.cpp | 48 | ||||
-rw-r--r-- | src/mbedTLS++/AesCfb128Decryptor.h | 19 |
9 files changed, 204 insertions, 172 deletions
diff --git a/src/ClientHandle.cpp b/src/ClientHandle.cpp index 683a243fe..6a6ea951f 100644 --- a/src/ClientHandle.cpp +++ b/src/ClientHandle.cpp @@ -3299,7 +3299,7 @@ bool cClientHandle::SetState(eState a_NewState) void cClientHandle::ProcessProtocolIn(void) { // Process received network data: - AString IncomingData; + decltype(m_IncomingData) IncomingData; { cCSLock Lock(m_CSIncomingData); std::swap(IncomingData, m_IncomingData); @@ -3312,7 +3312,7 @@ void cClientHandle::ProcessProtocolIn(void) try { - m_Protocol.HandleIncomingData(*this, IncomingData); + m_Protocol.HandleIncomingData(*this, std::move(IncomingData)); } catch (const std::exception & Oops) { @@ -3340,7 +3340,7 @@ void cClientHandle::OnReceivedData(const char * a_Data, size_t a_Length) // Queue the incoming data to be processed in the tick thread: cCSLock Lock(m_CSIncomingData); - m_IncomingData.append(a_Data, a_Length); + m_IncomingData.append(reinterpret_cast<const std::byte *>(a_Data), a_Length); } diff --git a/src/ClientHandle.h b/src/ClientHandle.h index 2e1e09a06..9f1113669 100644 --- a/src/ClientHandle.h +++ b/src/ClientHandle.h @@ -445,7 +445,7 @@ private: /** Queue for the incoming data received on the link until it is processed in Tick(). Protected by m_CSIncomingData. */ - AString m_IncomingData; + ContiguousByteBuffer m_IncomingData; /** Protects m_OutgoingData against multithreaded access. */ cCriticalSection m_CSOutgoingData; diff --git a/src/Protocol/Protocol.h b/src/Protocol/Protocol.h index 41c461e76..743a73aba 100644 --- a/src/Protocol/Protocol.h +++ b/src/Protocol/Protocol.h @@ -48,6 +48,7 @@ typedef unsigned char Byte; class cProtocol { public: + cProtocol(cClientHandle * a_Client) : m_Client(a_Client), m_OutPacketBuffer(64 KiB), @@ -354,8 +355,9 @@ public: Game = 3, }; - /** Called when client sends some data */ - virtual void DataReceived(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size) = 0; + /** Called to process them, when client sends some data. + The protocol uses the provided buffers for storage and processing, and must have exclusive access to them. */ + virtual void DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer && a_Data) = 0; // Sending stuff to clients (alphabetically sorted): virtual void SendAttachEntity (const cEntity & a_Entity, const cEntity & a_Vehicle) = 0; diff --git a/src/Protocol/ProtocolRecognizer.cpp b/src/Protocol/ProtocolRecognizer.cpp index 3af4f9654..181998337 100644 --- a/src/Protocol/ProtocolRecognizer.cpp +++ b/src/Protocol/ProtocolRecognizer.cpp @@ -76,7 +76,7 @@ AString cMultiVersionProtocol::GetVersionTextFromInt(cProtocol::Version a_Protoc -void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, std::string_view a_Data) +void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, ContiguousByteBuffer && a_Data) { // NOTE: If a new protocol is added or an old one is removed, adjust MCS_CLIENT_VERSIONS and MCS_PROTOCOL_VERSIONS macros in the header file @@ -113,13 +113,13 @@ void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle & 1. m_Protocol != nullptr: the protocol is supported and we have a handler 2. m_Protocol == nullptr: the protocol is unsupported, handling is a special case done by ourselves 3. Exception: the data sent were garbage, the client handle deals with it by disconnecting */ - m_Protocol = TryRecognizeLengthedProtocol(a_Client, a_Data); + m_Protocol = TryRecognizeLengthedProtocol(a_Client); if (m_Protocol == nullptr) { // Got a server list ping for an unrecognised version, // switch into responding to unknown protocols mode: - HandleIncomingData = [this](cClientHandle & a_Clyent, const std::string_view a_In) + HandleIncomingData = [this](cClientHandle & a_Clyent, ContiguousByteBuffer && a_In) { HandleIncomingDataInOldPingResponseStage(a_Clyent, a_In); }; @@ -127,9 +127,9 @@ void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle & else { // The protocol recogniser succesfully identified, switch mode: - HandleIncomingData = [this](cClientHandle &, const std::string_view a_In) + HandleIncomingData = [this](cClientHandle &, ContiguousByteBuffer && a_In) { - m_Protocol->DataReceived(m_Buffer, a_In.data(), a_In.size()); + m_Protocol->DataReceived(m_Buffer, std::move(a_In)); }; } @@ -141,7 +141,7 @@ void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle & -void cMultiVersionProtocol::HandleIncomingDataInOldPingResponseStage(cClientHandle & a_Client, const std::string_view a_Data) +void cMultiVersionProtocol::HandleIncomingDataInOldPingResponseStage(cClientHandle & a_Client, const ContiguousByteBufferView a_Data) { if (!m_Buffer.Write(a_Data.data(), a_Data.size())) { @@ -215,7 +215,7 @@ void cMultiVersionProtocol::SendDisconnect(cClientHandle & a_Client, const AStri -std::unique_ptr<cProtocol> cMultiVersionProtocol::TryRecognizeLengthedProtocol(cClientHandle & a_Client, const std::string_view a_Data) +std::unique_ptr<cProtocol> cMultiVersionProtocol::TryRecognizeLengthedProtocol(cClientHandle & a_Client) { UInt32 PacketType; UInt32 ProtocolVersion; diff --git a/src/Protocol/ProtocolRecognizer.h b/src/Protocol/ProtocolRecognizer.h index 56d5645c0..03b379f17 100644 --- a/src/Protocol/ProtocolRecognizer.h +++ b/src/Protocol/ProtocolRecognizer.h @@ -19,6 +19,9 @@ protocol version instance and redirects everything to it. */ class cMultiVersionProtocol { + // Work around the style checker complaining about && in template. + using OwnedContiguousByteBuffer = ContiguousByteBuffer &&; + public: cMultiVersionProtocol(); @@ -39,7 +42,7 @@ public: } /** The function that's responsible for processing incoming protocol data. */ - std::function<void(cClientHandle &, std::string_view)> HandleIncomingData; + std::function<void(cClientHandle &, OwnedContiguousByteBuffer)> HandleIncomingData; /** Sends a disconnect to the client as a result of a recognition error. This function can be used to disconnect before any protocol has been recognised. */ @@ -50,14 +53,14 @@ private: /** Handles data reception in a newly-created client handle that doesn't yet have a known protocol. a_Data contains a view of data that were just received. Tries to recognize a protocol, populate m_Protocol, and transitions to another mode depending on success. */ - void HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, std::string_view a_Data); + void HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, ContiguousByteBuffer && a_Data); /** Handles and responds to unsupported clients sending pings. */ - void HandleIncomingDataInOldPingResponseStage(cClientHandle & a_Client, std::string_view a_Data); + void HandleIncomingDataInOldPingResponseStage(cClientHandle & a_Client, ContiguousByteBufferView a_Data); /** Tries to recognize a protocol in the lengthed family (1.7+), based on m_Buffer. Returns a cProtocol_XXX instance if recognized. */ - std::unique_ptr<cProtocol> TryRecognizeLengthedProtocol(cClientHandle & a_Client, std::string_view a_Data); + std::unique_ptr<cProtocol> TryRecognizeLengthedProtocol(cClientHandle & a_Client); /** Sends one packet inside a cByteBuffer. This is used only when handling an outdated server ping. */ diff --git a/src/Protocol/Protocol_1_8.cpp b/src/Protocol/Protocol_1_8.cpp index e3bb7e5a3..8e7b74614 100644 --- a/src/Protocol/Protocol_1_8.cpp +++ b/src/Protocol/Protocol_1_8.cpp @@ -173,31 +173,14 @@ cProtocol_1_8_0::cProtocol_1_8_0(cClientHandle * a_Client, const AString & a_Ser -void cProtocol_1_8_0::DataReceived(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size) +void cProtocol_1_8_0::DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer && a_Data) { if (m_IsEncrypted) { - // An artefact of the protocol recogniser, will be removed when decryption done in-place: - if (a_Size == 0) - { - AddReceivedData(a_Buffer, nullptr, 0); - return; - } - - std::byte Decrypted[512]; - while (a_Size > 0) - { - size_t NumBytes = (a_Size > sizeof(Decrypted)) ? sizeof(Decrypted) : a_Size; - m_Decryptor.ProcessData(Decrypted, reinterpret_cast<const Byte *>(a_Data), NumBytes); - AddReceivedData(a_Buffer, reinterpret_cast<const char *>(Decrypted), NumBytes); - a_Size -= NumBytes; - a_Data += NumBytes; - } - } - else - { - AddReceivedData(a_Buffer, a_Data, a_Size); + m_Decryptor.ProcessData(a_Data.data(), a_Data.size()); } + + AddReceivedData(a_Buffer, a_Data); } @@ -2002,123 +1985,6 @@ UInt32 cProtocol_1_8_0::GetProtocolMobType(const eMonsterType a_MobType) -void cProtocol_1_8_0::AddReceivedData(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size) -{ - // Write the incoming data into the comm log file: - if (g_ShouldLogCommIn && m_CommLogFile.IsOpen()) - { - if (a_Buffer.GetReadableSpace() > 0) - { - ContiguousByteBuffer AllData; - size_t OldReadableSpace = a_Buffer.GetReadableSpace(); - a_Buffer.ReadAll(AllData); - a_Buffer.ResetRead(); - a_Buffer.SkipRead(a_Buffer.GetReadableSpace() - OldReadableSpace); - ASSERT(a_Buffer.GetReadableSpace() == OldReadableSpace); - AString Hex; - CreateHexDump(Hex, AllData.data(), AllData.size(), 16); - m_CommLogFile.Printf("Incoming data, %zu (0x%zx) unparsed bytes already present in buffer:\n%s\n", - AllData.size(), AllData.size(), Hex.c_str() - ); - } - AString Hex; - CreateHexDump(Hex, a_Data, a_Size, 16); - m_CommLogFile.Printf("Incoming data: %u (0x%x) bytes: \n%s\n", - static_cast<unsigned>(a_Size), static_cast<unsigned>(a_Size), Hex.c_str() - ); - m_CommLogFile.Flush(); - } - - if (!a_Buffer.Write(a_Data, a_Size)) - { - // Too much data in the incoming queue, report to caller: - m_Client->PacketBufferFull(); - return; - } - - // Handle all complete packets: - for (;;) - { - UInt32 PacketLen; - if (!a_Buffer.ReadVarInt(PacketLen)) - { - // Not enough data - a_Buffer.ResetRead(); - break; - } - if (!a_Buffer.CanReadBytes(PacketLen)) - { - // The full packet hasn't been received yet - a_Buffer.ResetRead(); - break; - } - - // Check packet for compression: - if (m_State == 3) - { - UInt32 NumBytesRead = static_cast<UInt32>(a_Buffer.GetReadableSpace()); - - UInt32 UncompressedSize; - if (!a_Buffer.ReadVarInt(UncompressedSize)) - { - m_Client->Kick("Compression packet incomplete"); - return; - } - - NumBytesRead -= static_cast<UInt32>(a_Buffer.GetReadableSpace()); // How many bytes has the UncompressedSize taken up? - ASSERT(PacketLen > NumBytesRead); - PacketLen -= NumBytesRead; - - if (UncompressedSize > 0) - { - // Decompress the data: - m_Extractor.ReadFrom(a_Buffer, PacketLen); - a_Buffer.CommitRead(); - - const auto UncompressedData = m_Extractor.Extract(UncompressedSize); - const auto Uncompressed = UncompressedData.GetView(); - cByteBuffer bb(Uncompressed.size()); - - // Compression was used, move the uncompressed data: - VERIFY(bb.Write(Uncompressed.data(), Uncompressed.size())); - - HandlePacket(bb); - continue; - } - } - - // Move the packet payload to a separate cByteBuffer, bb: - cByteBuffer bb(PacketLen); - - // No compression was used, move directly: - VERIFY(a_Buffer.ReadToByteBuffer(bb, static_cast<size_t>(PacketLen))); - a_Buffer.CommitRead(); - - HandlePacket(bb); - } // for (ever) - - // Log any leftover bytes into the logfile: - if (g_ShouldLogCommIn && (a_Buffer.GetReadableSpace() > 0) && m_CommLogFile.IsOpen()) - { - ContiguousByteBuffer AllData; - size_t OldReadableSpace = a_Buffer.GetReadableSpace(); - a_Buffer.ReadAll(AllData); - a_Buffer.ResetRead(); - a_Buffer.SkipRead(a_Buffer.GetReadableSpace() - OldReadableSpace); - ASSERT(a_Buffer.GetReadableSpace() == OldReadableSpace); - AString Hex; - CreateHexDump(Hex, AllData.data(), AllData.size(), 16); - m_CommLogFile.Printf("There are %zu (0x%zx) bytes of non-parse-able data left in the buffer:\n%s", - a_Buffer.GetReadableSpace(), a_Buffer.GetReadableSpace(), Hex.c_str() - ); - m_CommLogFile.Flush(); - } -} - - - - - UInt32 cProtocol_1_8_0::GetPacketID(ePacketType a_PacketType) { switch (a_PacketType) @@ -3961,6 +3827,123 @@ void cProtocol_1_8_0::WriteEntityProperties(cPacketizer & a_Pkt, const cEntity & +void cProtocol_1_8_0::AddReceivedData(cByteBuffer & a_Buffer, const ContiguousByteBufferView a_Data) +{ + // Write the incoming data into the comm log file: + if (g_ShouldLogCommIn && m_CommLogFile.IsOpen()) + { + if (a_Buffer.GetReadableSpace() > 0) + { + ContiguousByteBuffer AllData; + size_t OldReadableSpace = a_Buffer.GetReadableSpace(); + a_Buffer.ReadAll(AllData); + a_Buffer.ResetRead(); + a_Buffer.SkipRead(a_Buffer.GetReadableSpace() - OldReadableSpace); + ASSERT(a_Buffer.GetReadableSpace() == OldReadableSpace); + AString Hex; + CreateHexDump(Hex, AllData.data(), AllData.size(), 16); + m_CommLogFile.Printf("Incoming data, %zu (0x%zx) unparsed bytes already present in buffer:\n%s\n", + AllData.size(), AllData.size(), Hex.c_str() + ); + } + AString Hex; + CreateHexDump(Hex, a_Data.data(), a_Data.size(), 16); + m_CommLogFile.Printf("Incoming data: %zu (0x%zx) bytes: \n%s\n", + a_Data.size(), a_Data.size(), Hex.c_str() + ); + m_CommLogFile.Flush(); + } + + if (!a_Buffer.Write(a_Data.data(), a_Data.size())) + { + // Too much data in the incoming queue, report to caller: + m_Client->PacketBufferFull(); + return; + } + + // Handle all complete packets: + for (;;) + { + UInt32 PacketLen; + if (!a_Buffer.ReadVarInt(PacketLen)) + { + // Not enough data + a_Buffer.ResetRead(); + break; + } + if (!a_Buffer.CanReadBytes(PacketLen)) + { + // The full packet hasn't been received yet + a_Buffer.ResetRead(); + break; + } + + // Check packet for compression: + if (m_State == 3) + { + UInt32 NumBytesRead = static_cast<UInt32>(a_Buffer.GetReadableSpace()); + + UInt32 UncompressedSize; + if (!a_Buffer.ReadVarInt(UncompressedSize)) + { + m_Client->Kick("Compression packet incomplete"); + return; + } + + NumBytesRead -= static_cast<UInt32>(a_Buffer.GetReadableSpace()); // How many bytes has the UncompressedSize taken up? + ASSERT(PacketLen > NumBytesRead); + PacketLen -= NumBytesRead; + + if (UncompressedSize > 0) + { + // Decompress the data: + m_Extractor.ReadFrom(a_Buffer, PacketLen); + a_Buffer.CommitRead(); + + const auto UncompressedData = m_Extractor.Extract(UncompressedSize); + const auto Uncompressed = UncompressedData.GetView(); + cByteBuffer bb(Uncompressed.size()); + + // Compression was used, move the uncompressed data: + VERIFY(bb.Write(Uncompressed.data(), Uncompressed.size())); + + HandlePacket(bb); + continue; + } + } + + // Move the packet payload to a separate cByteBuffer, bb: + cByteBuffer bb(PacketLen); + + // No compression was used, move directly: + VERIFY(a_Buffer.ReadToByteBuffer(bb, static_cast<size_t>(PacketLen))); + a_Buffer.CommitRead(); + + HandlePacket(bb); + } // for (ever) + + // Log any leftover bytes into the logfile: + if (g_ShouldLogCommIn && (a_Buffer.GetReadableSpace() > 0) && m_CommLogFile.IsOpen()) + { + ContiguousByteBuffer AllData; + size_t OldReadableSpace = a_Buffer.GetReadableSpace(); + a_Buffer.ReadAll(AllData); + a_Buffer.ResetRead(); + a_Buffer.SkipRead(a_Buffer.GetReadableSpace() - OldReadableSpace); + ASSERT(a_Buffer.GetReadableSpace() == OldReadableSpace); + AString Hex; + CreateHexDump(Hex, AllData.data(), AllData.size(), 16); + m_CommLogFile.Printf("There are %zu (0x%zx) bytes of non-parse-able data left in the buffer:\n%s", + a_Buffer.GetReadableSpace(), a_Buffer.GetReadableSpace(), Hex.c_str() + ); + m_CommLogFile.Flush(); + } +} + + + + + void cProtocol_1_8_0::HandlePacket(cByteBuffer & a_Buffer) { UInt32 PacketType; diff --git a/src/Protocol/Protocol_1_8.h b/src/Protocol/Protocol_1_8.h index e2aadf147..29bc7420c 100644 --- a/src/Protocol/Protocol_1_8.h +++ b/src/Protocol/Protocol_1_8.h @@ -36,8 +36,9 @@ public: cProtocol_1_8_0(cClientHandle * a_Client, const AString & a_ServerAddress, State a_State); - /** Called when client sends some data: */ - virtual void DataReceived(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size) override; + /** Called to process them, when client sends some data. + The protocol uses the provided buffers for storage and processing, and must have exclusive access to them. */ + virtual void DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer && a_Data) override; /** Sending stuff to clients (alphabetically sorted): */ virtual void SendAttachEntity (const cEntity & a_Entity, const cEntity & a_Vehicle) override; @@ -146,9 +147,6 @@ protected: /** State of the protocol. */ State m_State; - /** Adds the received (unencrypted) data to m_ReceivedData, parses complete packets */ - virtual void AddReceivedData(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size); - /** Nobody inherits 1.8, so it doesn't use this method */ virtual UInt32 GetPacketID(ePacketType a_Packet) override; @@ -257,6 +255,9 @@ private: /** The logfile where the comm is logged, when g_ShouldLogComm is true */ cFile m_CommLogFile; + /** Adds the received (unencrypted) data to m_ReceivedData, parses complete packets */ + void AddReceivedData(cByteBuffer & a_Buffer, ContiguousByteBufferView a_Data); + /** Handle a complete packet stored in the given buffer. */ void HandlePacket(cByteBuffer & a_Buffer); diff --git a/src/mbedTLS++/AesCfb128Decryptor.cpp b/src/mbedTLS++/AesCfb128Decryptor.cpp index 523e06161..6243a3ded 100644 --- a/src/mbedTLS++/AesCfb128Decryptor.cpp +++ b/src/mbedTLS++/AesCfb128Decryptor.cpp @@ -10,10 +10,17 @@ -cAesCfb128Decryptor::cAesCfb128Decryptor(void): +cAesCfb128Decryptor::cAesCfb128Decryptor(void) : m_IsValid(false) { +#ifdef _WIN32 + if (!CryptAcquireContext(&m_Aes, nullptr, nullptr, PROV_RSA_AES, CRYPT_VERIFYCONTEXT)) + { + throw std::system_error(GetLastError(), std::system_category()); + } +#else mbedtls_aes_init(&m_Aes); +#endif } @@ -22,8 +29,12 @@ cAesCfb128Decryptor::cAesCfb128Decryptor(void): cAesCfb128Decryptor::~cAesCfb128Decryptor() { - // Clear the leftover in-memory data, so that they can't be accessed by a backdoor + // Clear the leftover in-memory data, so that they can't be accessed by a backdoor: +#ifdef _WIN32 + CryptReleaseContext(m_Aes, 0); +#else mbedtls_aes_free(&m_Aes); +#endif } @@ -34,8 +45,27 @@ void cAesCfb128Decryptor::Init(const Byte a_Key[16], const Byte a_IV[16]) { ASSERT(!IsValid()); // Cannot Init twice - memcpy(m_IV, a_IV, 16); +#ifdef _WIN32 + struct Key + { + PUBLICKEYSTRUC Header; + DWORD Length; + Byte Key[16]; + } Key; + + const DWORD Mode = CRYPT_MODE_CFB; + Key.Header = { PLAINTEXTKEYBLOB, CUR_BLOB_VERSION, 0, CALG_AES_128 }; + Key.Length = 16; + std::copy_n(a_Key, 16, Key.Key); + + CryptImportKey(m_Aes, reinterpret_cast<const BYTE *>(&Key), sizeof(Key), 0, 0, &m_Key); + CryptSetKeyParam(m_Key, KP_MODE, reinterpret_cast<const BYTE *>(&Mode), 0); + CryptSetKeyParam(m_Key, KP_IV, a_IV, 0); +#else + std::copy_n(a_IV, 16, m_IV); mbedtls_aes_setkey_enc(&m_Aes, a_Key, 128); +#endif + m_IsValid = true; } @@ -43,8 +73,16 @@ void cAesCfb128Decryptor::Init(const Byte a_Key[16], const Byte a_IV[16]) -void cAesCfb128Decryptor::ProcessData(std::byte * a_DecryptedOut, const Byte * a_EncryptedIn, size_t a_Length) +void cAesCfb128Decryptor::ProcessData(std::byte * const a_EncryptedIn, const size_t a_Length) { ASSERT(IsValid()); // Must Init() first - mbedtls_aes_crypt_cfb8(&m_Aes, MBEDTLS_AES_DECRYPT, a_Length, m_IV, a_EncryptedIn, reinterpret_cast<unsigned char *>(a_DecryptedOut)); + +#ifdef _WIN32 + ASSERT(a_Length <= std::numeric_limits<DWORD>::max()); + + DWORD Length = static_cast<DWORD>(a_Length); + CryptDecrypt(m_Key, 0, FALSE, 0, reinterpret_cast<BYTE *>(a_EncryptedIn), &Length); +#else + mbedtls_aes_crypt_cfb8(&m_Aes, MBEDTLS_AES_DECRYPT, a_Length, m_IV, reinterpret_cast<unsigned char *>(a_EncryptedIn), reinterpret_cast<unsigned char *>(a_EncryptedIn)); +#endif } diff --git a/src/mbedTLS++/AesCfb128Decryptor.h b/src/mbedTLS++/AesCfb128Decryptor.h index 601699998..a2c9d6a05 100644 --- a/src/mbedTLS++/AesCfb128Decryptor.h +++ b/src/mbedTLS++/AesCfb128Decryptor.h @@ -9,7 +9,11 @@ #pragma once +#ifdef _WIN32 +#include <wincrypt.h> +#else #include "mbedtls/aes.h" +#endif @@ -26,14 +30,20 @@ public: /** Initializes the decryptor with the specified Key / IV */ void Init(const Byte a_Key[16], const Byte a_IV[16]); - /** Decrypts a_Length bytes of the encrypted data; produces a_Length output bytes */ - void ProcessData(std::byte * a_DecryptedOut, const Byte * a_EncryptedIn, size_t a_Length); + /** Decrypts a_Length bytes of the encrypted data in-place; produces a_Length output bytes */ + void ProcessData(std::byte * a_EncryptedIn, size_t a_Length); /** Returns true if the object has been initialized with the Key / IV */ bool IsValid(void) const { return m_IsValid; } protected: + +#ifdef _WIN32 + HCRYPTPROV m_Aes; + HCRYPTKEY m_Key; +#else mbedtls_aes_context m_Aes; +#endif /** The InitialVector, used by the CFB mode decryption */ Byte m_IV[16]; @@ -41,8 +51,3 @@ protected: /** Indicates whether the object has been initialized with the Key / IV */ bool m_IsValid; } ; - - - - - |