summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--Tools/ProtoProxy/Connection.cpp4
-rw-r--r--src/ClientHandle.cpp6
-rw-r--r--src/ClientHandle.h2
-rw-r--r--src/Protocol/Protocol.h6
-rw-r--r--src/Protocol/ProtocolRecognizer.cpp14
-rw-r--r--src/Protocol/ProtocolRecognizer.h11
-rw-r--r--src/Protocol/Protocol_1_8.cpp259
-rw-r--r--src/Protocol/Protocol_1_8.h11
-rw-r--r--src/mbedTLS++/AesCfb128Decryptor.cpp48
-rw-r--r--src/mbedTLS++/AesCfb128Decryptor.h19
10 files changed, 206 insertions, 174 deletions
diff --git a/Tools/ProtoProxy/Connection.cpp b/Tools/ProtoProxy/Connection.cpp
index f9b732142..ba4614382 100644
--- a/Tools/ProtoProxy/Connection.cpp
+++ b/Tools/ProtoProxy/Connection.cpp
@@ -376,13 +376,13 @@ bool cConnection::RelayFromServer(void)
}
case csEncryptedUnderstood:
{
- m_ServerDecryptor.ProcessData(reinterpret_cast<std::byte *>(Buffer), reinterpret_cast<const Byte *>(Buffer), static_cast<size_t>(res));
+ m_ServerDecryptor.ProcessData(reinterpret_cast<std::byte *>(Buffer), static_cast<size_t>(res));
DataLog(Buffer, static_cast<size_t>(res), "Decrypted %d bytes from the SERVER", res);
return DecodeServersPackets(Buffer, res);
}
case csEncryptedUnknown:
{
- m_ServerDecryptor.ProcessData(reinterpret_cast<std::byte *>(Buffer), reinterpret_cast<const Byte *>(Buffer), static_cast<size_t>(res));
+ m_ServerDecryptor.ProcessData(reinterpret_cast<std::byte *>(Buffer), static_cast<size_t>(res));
DataLog(Buffer, static_cast<size_t>(res), "Decrypted %d bytes from the SERVER", res);
return CLIENTSEND({ reinterpret_cast<const std::byte *>(Buffer), static_cast<size_t>(res) });
}
diff --git a/src/ClientHandle.cpp b/src/ClientHandle.cpp
index 683a243fe..6a6ea951f 100644
--- a/src/ClientHandle.cpp
+++ b/src/ClientHandle.cpp
@@ -3299,7 +3299,7 @@ bool cClientHandle::SetState(eState a_NewState)
void cClientHandle::ProcessProtocolIn(void)
{
// Process received network data:
- AString IncomingData;
+ decltype(m_IncomingData) IncomingData;
{
cCSLock Lock(m_CSIncomingData);
std::swap(IncomingData, m_IncomingData);
@@ -3312,7 +3312,7 @@ void cClientHandle::ProcessProtocolIn(void)
try
{
- m_Protocol.HandleIncomingData(*this, IncomingData);
+ m_Protocol.HandleIncomingData(*this, std::move(IncomingData));
}
catch (const std::exception & Oops)
{
@@ -3340,7 +3340,7 @@ void cClientHandle::OnReceivedData(const char * a_Data, size_t a_Length)
// Queue the incoming data to be processed in the tick thread:
cCSLock Lock(m_CSIncomingData);
- m_IncomingData.append(a_Data, a_Length);
+ m_IncomingData.append(reinterpret_cast<const std::byte *>(a_Data), a_Length);
}
diff --git a/src/ClientHandle.h b/src/ClientHandle.h
index 2e1e09a06..9f1113669 100644
--- a/src/ClientHandle.h
+++ b/src/ClientHandle.h
@@ -445,7 +445,7 @@ private:
/** Queue for the incoming data received on the link until it is processed in Tick().
Protected by m_CSIncomingData. */
- AString m_IncomingData;
+ ContiguousByteBuffer m_IncomingData;
/** Protects m_OutgoingData against multithreaded access. */
cCriticalSection m_CSOutgoingData;
diff --git a/src/Protocol/Protocol.h b/src/Protocol/Protocol.h
index 41c461e76..743a73aba 100644
--- a/src/Protocol/Protocol.h
+++ b/src/Protocol/Protocol.h
@@ -48,6 +48,7 @@ typedef unsigned char Byte;
class cProtocol
{
public:
+
cProtocol(cClientHandle * a_Client) :
m_Client(a_Client),
m_OutPacketBuffer(64 KiB),
@@ -354,8 +355,9 @@ public:
Game = 3,
};
- /** Called when client sends some data */
- virtual void DataReceived(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size) = 0;
+ /** Called to process them, when client sends some data.
+ The protocol uses the provided buffers for storage and processing, and must have exclusive access to them. */
+ virtual void DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer && a_Data) = 0;
// Sending stuff to clients (alphabetically sorted):
virtual void SendAttachEntity (const cEntity & a_Entity, const cEntity & a_Vehicle) = 0;
diff --git a/src/Protocol/ProtocolRecognizer.cpp b/src/Protocol/ProtocolRecognizer.cpp
index 3af4f9654..181998337 100644
--- a/src/Protocol/ProtocolRecognizer.cpp
+++ b/src/Protocol/ProtocolRecognizer.cpp
@@ -76,7 +76,7 @@ AString cMultiVersionProtocol::GetVersionTextFromInt(cProtocol::Version a_Protoc
-void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, std::string_view a_Data)
+void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, ContiguousByteBuffer && a_Data)
{
// NOTE: If a new protocol is added or an old one is removed, adjust MCS_CLIENT_VERSIONS and MCS_PROTOCOL_VERSIONS macros in the header file
@@ -113,13 +113,13 @@ void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle &
1. m_Protocol != nullptr: the protocol is supported and we have a handler
2. m_Protocol == nullptr: the protocol is unsupported, handling is a special case done by ourselves
3. Exception: the data sent were garbage, the client handle deals with it by disconnecting */
- m_Protocol = TryRecognizeLengthedProtocol(a_Client, a_Data);
+ m_Protocol = TryRecognizeLengthedProtocol(a_Client);
if (m_Protocol == nullptr)
{
// Got a server list ping for an unrecognised version,
// switch into responding to unknown protocols mode:
- HandleIncomingData = [this](cClientHandle & a_Clyent, const std::string_view a_In)
+ HandleIncomingData = [this](cClientHandle & a_Clyent, ContiguousByteBuffer && a_In)
{
HandleIncomingDataInOldPingResponseStage(a_Clyent, a_In);
};
@@ -127,9 +127,9 @@ void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle &
else
{
// The protocol recogniser succesfully identified, switch mode:
- HandleIncomingData = [this](cClientHandle &, const std::string_view a_In)
+ HandleIncomingData = [this](cClientHandle &, ContiguousByteBuffer && a_In)
{
- m_Protocol->DataReceived(m_Buffer, a_In.data(), a_In.size());
+ m_Protocol->DataReceived(m_Buffer, std::move(a_In));
};
}
@@ -141,7 +141,7 @@ void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle &
-void cMultiVersionProtocol::HandleIncomingDataInOldPingResponseStage(cClientHandle & a_Client, const std::string_view a_Data)
+void cMultiVersionProtocol::HandleIncomingDataInOldPingResponseStage(cClientHandle & a_Client, const ContiguousByteBufferView a_Data)
{
if (!m_Buffer.Write(a_Data.data(), a_Data.size()))
{
@@ -215,7 +215,7 @@ void cMultiVersionProtocol::SendDisconnect(cClientHandle & a_Client, const AStri
-std::unique_ptr<cProtocol> cMultiVersionProtocol::TryRecognizeLengthedProtocol(cClientHandle & a_Client, const std::string_view a_Data)
+std::unique_ptr<cProtocol> cMultiVersionProtocol::TryRecognizeLengthedProtocol(cClientHandle & a_Client)
{
UInt32 PacketType;
UInt32 ProtocolVersion;
diff --git a/src/Protocol/ProtocolRecognizer.h b/src/Protocol/ProtocolRecognizer.h
index 56d5645c0..03b379f17 100644
--- a/src/Protocol/ProtocolRecognizer.h
+++ b/src/Protocol/ProtocolRecognizer.h
@@ -19,6 +19,9 @@
protocol version instance and redirects everything to it. */
class cMultiVersionProtocol
{
+ // Work around the style checker complaining about && in template.
+ using OwnedContiguousByteBuffer = ContiguousByteBuffer &&;
+
public:
cMultiVersionProtocol();
@@ -39,7 +42,7 @@ public:
}
/** The function that's responsible for processing incoming protocol data. */
- std::function<void(cClientHandle &, std::string_view)> HandleIncomingData;
+ std::function<void(cClientHandle &, OwnedContiguousByteBuffer)> HandleIncomingData;
/** Sends a disconnect to the client as a result of a recognition error.
This function can be used to disconnect before any protocol has been recognised. */
@@ -50,14 +53,14 @@ private:
/** Handles data reception in a newly-created client handle that doesn't yet have a known protocol.
a_Data contains a view of data that were just received.
Tries to recognize a protocol, populate m_Protocol, and transitions to another mode depending on success. */
- void HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, std::string_view a_Data);
+ void HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, ContiguousByteBuffer && a_Data);
/** Handles and responds to unsupported clients sending pings. */
- void HandleIncomingDataInOldPingResponseStage(cClientHandle & a_Client, std::string_view a_Data);
+ void HandleIncomingDataInOldPingResponseStage(cClientHandle & a_Client, ContiguousByteBufferView a_Data);
/** Tries to recognize a protocol in the lengthed family (1.7+), based on m_Buffer.
Returns a cProtocol_XXX instance if recognized. */
- std::unique_ptr<cProtocol> TryRecognizeLengthedProtocol(cClientHandle & a_Client, std::string_view a_Data);
+ std::unique_ptr<cProtocol> TryRecognizeLengthedProtocol(cClientHandle & a_Client);
/** Sends one packet inside a cByteBuffer.
This is used only when handling an outdated server ping. */
diff --git a/src/Protocol/Protocol_1_8.cpp b/src/Protocol/Protocol_1_8.cpp
index e3bb7e5a3..8e7b74614 100644
--- a/src/Protocol/Protocol_1_8.cpp
+++ b/src/Protocol/Protocol_1_8.cpp
@@ -173,31 +173,14 @@ cProtocol_1_8_0::cProtocol_1_8_0(cClientHandle * a_Client, const AString & a_Ser
-void cProtocol_1_8_0::DataReceived(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size)
+void cProtocol_1_8_0::DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer && a_Data)
{
if (m_IsEncrypted)
{
- // An artefact of the protocol recogniser, will be removed when decryption done in-place:
- if (a_Size == 0)
- {
- AddReceivedData(a_Buffer, nullptr, 0);
- return;
- }
-
- std::byte Decrypted[512];
- while (a_Size > 0)
- {
- size_t NumBytes = (a_Size > sizeof(Decrypted)) ? sizeof(Decrypted) : a_Size;
- m_Decryptor.ProcessData(Decrypted, reinterpret_cast<const Byte *>(a_Data), NumBytes);
- AddReceivedData(a_Buffer, reinterpret_cast<const char *>(Decrypted), NumBytes);
- a_Size -= NumBytes;
- a_Data += NumBytes;
- }
- }
- else
- {
- AddReceivedData(a_Buffer, a_Data, a_Size);
+ m_Decryptor.ProcessData(a_Data.data(), a_Data.size());
}
+
+ AddReceivedData(a_Buffer, a_Data);
}
@@ -2002,123 +1985,6 @@ UInt32 cProtocol_1_8_0::GetProtocolMobType(const eMonsterType a_MobType)
-void cProtocol_1_8_0::AddReceivedData(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size)
-{
- // Write the incoming data into the comm log file:
- if (g_ShouldLogCommIn && m_CommLogFile.IsOpen())
- {
- if (a_Buffer.GetReadableSpace() > 0)
- {
- ContiguousByteBuffer AllData;
- size_t OldReadableSpace = a_Buffer.GetReadableSpace();
- a_Buffer.ReadAll(AllData);
- a_Buffer.ResetRead();
- a_Buffer.SkipRead(a_Buffer.GetReadableSpace() - OldReadableSpace);
- ASSERT(a_Buffer.GetReadableSpace() == OldReadableSpace);
- AString Hex;
- CreateHexDump(Hex, AllData.data(), AllData.size(), 16);
- m_CommLogFile.Printf("Incoming data, %zu (0x%zx) unparsed bytes already present in buffer:\n%s\n",
- AllData.size(), AllData.size(), Hex.c_str()
- );
- }
- AString Hex;
- CreateHexDump(Hex, a_Data, a_Size, 16);
- m_CommLogFile.Printf("Incoming data: %u (0x%x) bytes: \n%s\n",
- static_cast<unsigned>(a_Size), static_cast<unsigned>(a_Size), Hex.c_str()
- );
- m_CommLogFile.Flush();
- }
-
- if (!a_Buffer.Write(a_Data, a_Size))
- {
- // Too much data in the incoming queue, report to caller:
- m_Client->PacketBufferFull();
- return;
- }
-
- // Handle all complete packets:
- for (;;)
- {
- UInt32 PacketLen;
- if (!a_Buffer.ReadVarInt(PacketLen))
- {
- // Not enough data
- a_Buffer.ResetRead();
- break;
- }
- if (!a_Buffer.CanReadBytes(PacketLen))
- {
- // The full packet hasn't been received yet
- a_Buffer.ResetRead();
- break;
- }
-
- // Check packet for compression:
- if (m_State == 3)
- {
- UInt32 NumBytesRead = static_cast<UInt32>(a_Buffer.GetReadableSpace());
-
- UInt32 UncompressedSize;
- if (!a_Buffer.ReadVarInt(UncompressedSize))
- {
- m_Client->Kick("Compression packet incomplete");
- return;
- }
-
- NumBytesRead -= static_cast<UInt32>(a_Buffer.GetReadableSpace()); // How many bytes has the UncompressedSize taken up?
- ASSERT(PacketLen > NumBytesRead);
- PacketLen -= NumBytesRead;
-
- if (UncompressedSize > 0)
- {
- // Decompress the data:
- m_Extractor.ReadFrom(a_Buffer, PacketLen);
- a_Buffer.CommitRead();
-
- const auto UncompressedData = m_Extractor.Extract(UncompressedSize);
- const auto Uncompressed = UncompressedData.GetView();
- cByteBuffer bb(Uncompressed.size());
-
- // Compression was used, move the uncompressed data:
- VERIFY(bb.Write(Uncompressed.data(), Uncompressed.size()));
-
- HandlePacket(bb);
- continue;
- }
- }
-
- // Move the packet payload to a separate cByteBuffer, bb:
- cByteBuffer bb(PacketLen);
-
- // No compression was used, move directly:
- VERIFY(a_Buffer.ReadToByteBuffer(bb, static_cast<size_t>(PacketLen)));
- a_Buffer.CommitRead();
-
- HandlePacket(bb);
- } // for (ever)
-
- // Log any leftover bytes into the logfile:
- if (g_ShouldLogCommIn && (a_Buffer.GetReadableSpace() > 0) && m_CommLogFile.IsOpen())
- {
- ContiguousByteBuffer AllData;
- size_t OldReadableSpace = a_Buffer.GetReadableSpace();
- a_Buffer.ReadAll(AllData);
- a_Buffer.ResetRead();
- a_Buffer.SkipRead(a_Buffer.GetReadableSpace() - OldReadableSpace);
- ASSERT(a_Buffer.GetReadableSpace() == OldReadableSpace);
- AString Hex;
- CreateHexDump(Hex, AllData.data(), AllData.size(), 16);
- m_CommLogFile.Printf("There are %zu (0x%zx) bytes of non-parse-able data left in the buffer:\n%s",
- a_Buffer.GetReadableSpace(), a_Buffer.GetReadableSpace(), Hex.c_str()
- );
- m_CommLogFile.Flush();
- }
-}
-
-
-
-
-
UInt32 cProtocol_1_8_0::GetPacketID(ePacketType a_PacketType)
{
switch (a_PacketType)
@@ -3961,6 +3827,123 @@ void cProtocol_1_8_0::WriteEntityProperties(cPacketizer & a_Pkt, const cEntity &
+void cProtocol_1_8_0::AddReceivedData(cByteBuffer & a_Buffer, const ContiguousByteBufferView a_Data)
+{
+ // Write the incoming data into the comm log file:
+ if (g_ShouldLogCommIn && m_CommLogFile.IsOpen())
+ {
+ if (a_Buffer.GetReadableSpace() > 0)
+ {
+ ContiguousByteBuffer AllData;
+ size_t OldReadableSpace = a_Buffer.GetReadableSpace();
+ a_Buffer.ReadAll(AllData);
+ a_Buffer.ResetRead();
+ a_Buffer.SkipRead(a_Buffer.GetReadableSpace() - OldReadableSpace);
+ ASSERT(a_Buffer.GetReadableSpace() == OldReadableSpace);
+ AString Hex;
+ CreateHexDump(Hex, AllData.data(), AllData.size(), 16);
+ m_CommLogFile.Printf("Incoming data, %zu (0x%zx) unparsed bytes already present in buffer:\n%s\n",
+ AllData.size(), AllData.size(), Hex.c_str()
+ );
+ }
+ AString Hex;
+ CreateHexDump(Hex, a_Data.data(), a_Data.size(), 16);
+ m_CommLogFile.Printf("Incoming data: %zu (0x%zx) bytes: \n%s\n",
+ a_Data.size(), a_Data.size(), Hex.c_str()
+ );
+ m_CommLogFile.Flush();
+ }
+
+ if (!a_Buffer.Write(a_Data.data(), a_Data.size()))
+ {
+ // Too much data in the incoming queue, report to caller:
+ m_Client->PacketBufferFull();
+ return;
+ }
+
+ // Handle all complete packets:
+ for (;;)
+ {
+ UInt32 PacketLen;
+ if (!a_Buffer.ReadVarInt(PacketLen))
+ {
+ // Not enough data
+ a_Buffer.ResetRead();
+ break;
+ }
+ if (!a_Buffer.CanReadBytes(PacketLen))
+ {
+ // The full packet hasn't been received yet
+ a_Buffer.ResetRead();
+ break;
+ }
+
+ // Check packet for compression:
+ if (m_State == 3)
+ {
+ UInt32 NumBytesRead = static_cast<UInt32>(a_Buffer.GetReadableSpace());
+
+ UInt32 UncompressedSize;
+ if (!a_Buffer.ReadVarInt(UncompressedSize))
+ {
+ m_Client->Kick("Compression packet incomplete");
+ return;
+ }
+
+ NumBytesRead -= static_cast<UInt32>(a_Buffer.GetReadableSpace()); // How many bytes has the UncompressedSize taken up?
+ ASSERT(PacketLen > NumBytesRead);
+ PacketLen -= NumBytesRead;
+
+ if (UncompressedSize > 0)
+ {
+ // Decompress the data:
+ m_Extractor.ReadFrom(a_Buffer, PacketLen);
+ a_Buffer.CommitRead();
+
+ const auto UncompressedData = m_Extractor.Extract(UncompressedSize);
+ const auto Uncompressed = UncompressedData.GetView();
+ cByteBuffer bb(Uncompressed.size());
+
+ // Compression was used, move the uncompressed data:
+ VERIFY(bb.Write(Uncompressed.data(), Uncompressed.size()));
+
+ HandlePacket(bb);
+ continue;
+ }
+ }
+
+ // Move the packet payload to a separate cByteBuffer, bb:
+ cByteBuffer bb(PacketLen);
+
+ // No compression was used, move directly:
+ VERIFY(a_Buffer.ReadToByteBuffer(bb, static_cast<size_t>(PacketLen)));
+ a_Buffer.CommitRead();
+
+ HandlePacket(bb);
+ } // for (ever)
+
+ // Log any leftover bytes into the logfile:
+ if (g_ShouldLogCommIn && (a_Buffer.GetReadableSpace() > 0) && m_CommLogFile.IsOpen())
+ {
+ ContiguousByteBuffer AllData;
+ size_t OldReadableSpace = a_Buffer.GetReadableSpace();
+ a_Buffer.ReadAll(AllData);
+ a_Buffer.ResetRead();
+ a_Buffer.SkipRead(a_Buffer.GetReadableSpace() - OldReadableSpace);
+ ASSERT(a_Buffer.GetReadableSpace() == OldReadableSpace);
+ AString Hex;
+ CreateHexDump(Hex, AllData.data(), AllData.size(), 16);
+ m_CommLogFile.Printf("There are %zu (0x%zx) bytes of non-parse-able data left in the buffer:\n%s",
+ a_Buffer.GetReadableSpace(), a_Buffer.GetReadableSpace(), Hex.c_str()
+ );
+ m_CommLogFile.Flush();
+ }
+}
+
+
+
+
+
void cProtocol_1_8_0::HandlePacket(cByteBuffer & a_Buffer)
{
UInt32 PacketType;
diff --git a/src/Protocol/Protocol_1_8.h b/src/Protocol/Protocol_1_8.h
index e2aadf147..29bc7420c 100644
--- a/src/Protocol/Protocol_1_8.h
+++ b/src/Protocol/Protocol_1_8.h
@@ -36,8 +36,9 @@ public:
cProtocol_1_8_0(cClientHandle * a_Client, const AString & a_ServerAddress, State a_State);
- /** Called when client sends some data: */
- virtual void DataReceived(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size) override;
+ /** Called to process them, when client sends some data.
+ The protocol uses the provided buffers for storage and processing, and must have exclusive access to them. */
+ virtual void DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer && a_Data) override;
/** Sending stuff to clients (alphabetically sorted): */
virtual void SendAttachEntity (const cEntity & a_Entity, const cEntity & a_Vehicle) override;
@@ -146,9 +147,6 @@ protected:
/** State of the protocol. */
State m_State;
- /** Adds the received (unencrypted) data to m_ReceivedData, parses complete packets */
- virtual void AddReceivedData(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size);
-
/** Nobody inherits 1.8, so it doesn't use this method */
virtual UInt32 GetPacketID(ePacketType a_Packet) override;
@@ -257,6 +255,9 @@ private:
/** The logfile where the comm is logged, when g_ShouldLogComm is true */
cFile m_CommLogFile;
+ /** Adds the received (unencrypted) data to m_ReceivedData, parses complete packets */
+ void AddReceivedData(cByteBuffer & a_Buffer, ContiguousByteBufferView a_Data);
+
/** Handle a complete packet stored in the given buffer. */
void HandlePacket(cByteBuffer & a_Buffer);
diff --git a/src/mbedTLS++/AesCfb128Decryptor.cpp b/src/mbedTLS++/AesCfb128Decryptor.cpp
index 523e06161..6243a3ded 100644
--- a/src/mbedTLS++/AesCfb128Decryptor.cpp
+++ b/src/mbedTLS++/AesCfb128Decryptor.cpp
@@ -10,10 +10,17 @@
-cAesCfb128Decryptor::cAesCfb128Decryptor(void):
+cAesCfb128Decryptor::cAesCfb128Decryptor(void) :
m_IsValid(false)
{
+#ifdef _WIN32
+ if (!CryptAcquireContext(&m_Aes, nullptr, nullptr, PROV_RSA_AES, CRYPT_VERIFYCONTEXT))
+ {
+ throw std::system_error(GetLastError(), std::system_category());
+ }
+#else
mbedtls_aes_init(&m_Aes);
+#endif
}
@@ -22,8 +29,12 @@ cAesCfb128Decryptor::cAesCfb128Decryptor(void):
cAesCfb128Decryptor::~cAesCfb128Decryptor()
{
- // Clear the leftover in-memory data, so that they can't be accessed by a backdoor
+ // Clear the leftover in-memory data, so that they can't be accessed by a backdoor:
+#ifdef _WIN32
+ CryptReleaseContext(m_Aes, 0);
+#else
mbedtls_aes_free(&m_Aes);
+#endif
}
@@ -34,8 +45,27 @@ void cAesCfb128Decryptor::Init(const Byte a_Key[16], const Byte a_IV[16])
{
ASSERT(!IsValid()); // Cannot Init twice
- memcpy(m_IV, a_IV, 16);
+#ifdef _WIN32
+ struct Key
+ {
+ PUBLICKEYSTRUC Header;
+ DWORD Length;
+ Byte Key[16];
+ } Key;
+
+ const DWORD Mode = CRYPT_MODE_CFB;
+ Key.Header = { PLAINTEXTKEYBLOB, CUR_BLOB_VERSION, 0, CALG_AES_128 };
+ Key.Length = 16;
+ std::copy_n(a_Key, 16, Key.Key);
+
+ CryptImportKey(m_Aes, reinterpret_cast<const BYTE *>(&Key), sizeof(Key), 0, 0, &m_Key);
+ CryptSetKeyParam(m_Key, KP_MODE, reinterpret_cast<const BYTE *>(&Mode), 0);
+ CryptSetKeyParam(m_Key, KP_IV, a_IV, 0);
+#else
+ std::copy_n(a_IV, 16, m_IV);
mbedtls_aes_setkey_enc(&m_Aes, a_Key, 128);
+#endif
+
m_IsValid = true;
}
@@ -43,8 +73,16 @@ void cAesCfb128Decryptor::Init(const Byte a_Key[16], const Byte a_IV[16])
-void cAesCfb128Decryptor::ProcessData(std::byte * a_DecryptedOut, const Byte * a_EncryptedIn, size_t a_Length)
+void cAesCfb128Decryptor::ProcessData(std::byte * const a_EncryptedIn, const size_t a_Length)
{
ASSERT(IsValid()); // Must Init() first
- mbedtls_aes_crypt_cfb8(&m_Aes, MBEDTLS_AES_DECRYPT, a_Length, m_IV, a_EncryptedIn, reinterpret_cast<unsigned char *>(a_DecryptedOut));
+
+#ifdef _WIN32
+ ASSERT(a_Length <= std::numeric_limits<DWORD>::max());
+
+ DWORD Length = static_cast<DWORD>(a_Length);
+ CryptDecrypt(m_Key, 0, FALSE, 0, reinterpret_cast<BYTE *>(a_EncryptedIn), &Length);
+#else
+ mbedtls_aes_crypt_cfb8(&m_Aes, MBEDTLS_AES_DECRYPT, a_Length, m_IV, reinterpret_cast<unsigned char *>(a_EncryptedIn), reinterpret_cast<unsigned char *>(a_EncryptedIn));
+#endif
}
diff --git a/src/mbedTLS++/AesCfb128Decryptor.h b/src/mbedTLS++/AesCfb128Decryptor.h
index 601699998..a2c9d6a05 100644
--- a/src/mbedTLS++/AesCfb128Decryptor.h
+++ b/src/mbedTLS++/AesCfb128Decryptor.h
@@ -9,7 +9,11 @@
#pragma once
+#ifdef _WIN32
+#include <wincrypt.h>
+#else
#include "mbedtls/aes.h"
+#endif
@@ -26,14 +30,20 @@ public:
/** Initializes the decryptor with the specified Key / IV */
void Init(const Byte a_Key[16], const Byte a_IV[16]);
- /** Decrypts a_Length bytes of the encrypted data; produces a_Length output bytes */
- void ProcessData(std::byte * a_DecryptedOut, const Byte * a_EncryptedIn, size_t a_Length);
+ /** Decrypts a_Length bytes of the encrypted data in-place; produces a_Length output bytes */
+ void ProcessData(std::byte * a_EncryptedIn, size_t a_Length);
/** Returns true if the object has been initialized with the Key / IV */
bool IsValid(void) const { return m_IsValid; }
protected:
+
+#ifdef _WIN32
+ HCRYPTPROV m_Aes;
+ HCRYPTKEY m_Key;
+#else
mbedtls_aes_context m_Aes;
+#endif
/** The InitialVector, used by the CFB mode decryption */
Byte m_IV[16];
@@ -41,8 +51,3 @@ protected:
/** Indicates whether the object has been initialized with the Key / IV */
bool m_IsValid;
} ;
-
-
-
-
-