From e9265b1d0129e56b9b665019697055bd63229fde Mon Sep 17 00:00:00 2001 From: Tiger Wang Date: Sat, 10 Jul 2021 21:04:49 +0100 Subject: Symmetry in MultiVersionProtocol to fix a crash (#5261) + Add HandleOutgoingData, which tests for m_Protocol before calling DataPrepared. * Change std::function to bool + if/else to handle incoming data; it's almost certainly faster. * Fixes #5260 --- src/ClientHandle.cpp | 4 +- src/Protocol/Protocol.h | 2 +- src/Protocol/ProtocolRecognizer.cpp | 73 +++++++++++++++++++++++++------------ src/Protocol/ProtocolRecognizer.h | 21 ++++++----- src/Protocol/Protocol_1_8.cpp | 2 +- src/Protocol/Protocol_1_8.h | 2 +- 6 files changed, 67 insertions(+), 37 deletions(-) diff --git a/src/ClientHandle.cpp b/src/ClientHandle.cpp index 47bbe691a..9dc112011 100644 --- a/src/ClientHandle.cpp +++ b/src/ClientHandle.cpp @@ -241,7 +241,7 @@ void cClientHandle::ProcessProtocolOut() // to prevent it being reset between the null check and the Send: if (auto Link = m_Link; Link != nullptr) { - m_Protocol->DataPrepared(OutgoingData); + m_Protocol.HandleOutgoingData(OutgoingData); Link->Send(OutgoingData.data(), OutgoingData.size()); } } @@ -3308,7 +3308,7 @@ void cClientHandle::ProcessProtocolIn(void) try { - m_Protocol.HandleIncomingData(*this, std::move(IncomingData)); + m_Protocol.HandleIncomingData(*this, IncomingData); } catch (const std::exception & Oops) { diff --git a/src/Protocol/Protocol.h b/src/Protocol/Protocol.h index de420308b..f306abc1c 100644 --- a/src/Protocol/Protocol.h +++ b/src/Protocol/Protocol.h @@ -348,7 +348,7 @@ public: /** Called by cClientHandle to process data, when the client sends some. The protocol uses the provided buffers for storage and processing, and must have exclusive access to them. */ - virtual void DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer && a_Data) = 0; + virtual void DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer & a_Data) = 0; /** Called by cClientHandle to finalise a buffer of prepared data before they are sent to the client. Descendants may for example, encrypt the data if needed. diff --git a/src/Protocol/ProtocolRecognizer.cpp b/src/Protocol/ProtocolRecognizer.cpp index 181998337..ba179c1bf 100644 --- a/src/Protocol/ProtocolRecognizer.cpp +++ b/src/Protocol/ProtocolRecognizer.cpp @@ -38,8 +38,8 @@ struct TriedToJoinWithUnsupportedProtocolException : public std::runtime_error cMultiVersionProtocol::cMultiVersionProtocol() : - HandleIncomingData(std::bind(&cMultiVersionProtocol::HandleIncomingDataInRecognitionStage, this, std::placeholders::_1, std::placeholders::_2)), - m_Buffer(32 KiB) + m_Buffer(32 KiB), + m_WaitingForData(true) { } @@ -76,7 +76,7 @@ AString cMultiVersionProtocol::GetVersionTextFromInt(cProtocol::Version a_Protoc -void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, ContiguousByteBuffer && a_Data) +void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, ContiguousByteBuffer & a_Data) { // NOTE: If a new protocol is added or an old one is removed, adjust MCS_CLIENT_VERSIONS and MCS_PROTOCOL_VERSIONS macros in the header file @@ -115,26 +115,14 @@ void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle & 3. Exception: the data sent were garbage, the client handle deals with it by disconnecting */ m_Protocol = TryRecognizeLengthedProtocol(a_Client); - if (m_Protocol == nullptr) - { - // Got a server list ping for an unrecognised version, - // switch into responding to unknown protocols mode: - HandleIncomingData = [this](cClientHandle & a_Clyent, ContiguousByteBuffer && a_In) - { - HandleIncomingDataInOldPingResponseStage(a_Clyent, a_In); - }; - } - else - { - // The protocol recogniser succesfully identified, switch mode: - HandleIncomingData = [this](cClientHandle &, ContiguousByteBuffer && a_In) - { - m_Protocol->DataReceived(m_Buffer, std::move(a_In)); - }; - } + // Version recognised. Cause HandleIncomingData to stop calling us to handle data: + m_WaitingForData = false; // Explicitly process any remaining data (already written to m_Buffer) with the new handler: - HandleIncomingData(a_Client, {}); + { + ContiguousByteBuffer Empty; + HandleIncomingData(a_Client, Empty); + } } @@ -170,12 +158,10 @@ void cMultiVersionProtocol::HandleIncomingDataInOldPingResponseStage(cClientHand if ((PacketID == 0x00) && (PacketLen == 1)) // Request packet { HandlePacketStatusRequest(a_Client, OutPacketBuffer); - SendPacket(a_Client, OutPacketBuffer); } else if ((PacketID == 0x01) && (PacketLen == 9)) // Ping packet { HandlePacketStatusPing(a_Client, OutPacketBuffer); - SendPacket(a_Client, OutPacketBuffer); } else { @@ -191,6 +177,43 @@ void cMultiVersionProtocol::HandleIncomingDataInOldPingResponseStage(cClientHand +void cMultiVersionProtocol::HandleIncomingData(cClientHandle & a_Client, ContiguousByteBuffer & a_Data) +{ + if (m_WaitingForData) + { + HandleIncomingDataInRecognitionStage(a_Client, a_Data); + } + else if (m_Protocol == nullptr) + { + // Got a Handshake for an unrecognised version, process future data accordingly: + HandleIncomingDataInOldPingResponseStage(a_Client, a_Data); + } + else + { + // The protocol recogniser succesfully identified a supported version, direct data to that protocol: + m_Protocol->DataReceived(m_Buffer, a_Data); + } +} + + + + + +void cMultiVersionProtocol::HandleOutgoingData(ContiguousByteBuffer & a_Data) +{ + // Normally only the protocol sends data, so outgoing data are only present when m_Protocol != nullptr. + // However, for unrecognised protocols we send data too, and that's when m_Protocol == nullptr. Check to avoid crashing (GH #5260). + + if (m_Protocol != nullptr) + { + m_Protocol->DataPrepared(a_Data); + } +} + + + + + void cMultiVersionProtocol::SendDisconnect(cClientHandle & a_Client, const AString & a_Reason) { if (m_Protocol != nullptr) @@ -382,6 +405,8 @@ void cMultiVersionProtocol::HandlePacketStatusRequest(cClientHandle & a_Client, VERIFY(a_Out.WriteVarInt32(GetPacketID(cProtocol::ePacketType::pktStatusResponse))); VERIFY(a_Out.WriteVarUTF8String(Response)); + + SendPacket(a_Client, a_Out); } @@ -398,4 +423,6 @@ void cMultiVersionProtocol::HandlePacketStatusPing(cClientHandle & a_Client, cBy VERIFY(a_Out.WriteVarInt32(GetPacketID(cProtocol::ePacketType::pktPingResponse))); VERIFY(a_Out.WriteBEInt64(Timestamp)); + + SendPacket(a_Client, a_Out); } diff --git a/src/Protocol/ProtocolRecognizer.h b/src/Protocol/ProtocolRecognizer.h index 03b379f17..0a923e78f 100644 --- a/src/Protocol/ProtocolRecognizer.h +++ b/src/Protocol/ProtocolRecognizer.h @@ -19,9 +19,6 @@ protocol version instance and redirects everything to it. */ class cMultiVersionProtocol { - // Work around the style checker complaining about && in template. - using OwnedContiguousByteBuffer = ContiguousByteBuffer &&; - public: cMultiVersionProtocol(); @@ -41,8 +38,12 @@ public: return m_Protocol; } - /** The function that's responsible for processing incoming protocol data. */ - std::function HandleIncomingData; + /** Directs incoming protocol data along the correct pathway, depending on the state of the version recognition process. + The protocol modifies the provided buffer in-place. */ + void HandleIncomingData(cClientHandle & a_Client, ContiguousByteBuffer & a_Data); + + /** Allows the protocol (if any) to do a final pass on outgiong data, possibly modifying the provided buffer in-place. */ + void HandleOutgoingData(ContiguousByteBuffer & a_Data); /** Sends a disconnect to the client as a result of a recognition error. This function can be used to disconnect before any protocol has been recognised. */ @@ -53,7 +54,7 @@ private: /** Handles data reception in a newly-created client handle that doesn't yet have a known protocol. a_Data contains a view of data that were just received. Tries to recognize a protocol, populate m_Protocol, and transitions to another mode depending on success. */ - void HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, ContiguousByteBuffer && a_Data); + void HandleIncomingDataInRecognitionStage(cClientHandle & a_Client, ContiguousByteBuffer & a_Data); /** Handles and responds to unsupported clients sending pings. */ void HandleIncomingDataInOldPingResponseStage(cClientHandle & a_Client, ContiguousByteBufferView a_Data); @@ -75,11 +76,13 @@ private: /* Ping handler for unrecognised versions. */ void HandlePacketStatusPing(cClientHandle & a_Client, cByteBuffer & a_Out); + /** Buffer for received protocol data. */ + cByteBuffer m_Buffer; + /** The actual protocol implementation. Created when recognition of the client version succeeds with a version we support. */ std::unique_ptr m_Protocol; - /** Buffer for received protocol data. */ - cByteBuffer m_Buffer; - + /** If we're still waiting for data required for version recognition to arrive. */ + bool m_WaitingForData; } ; diff --git a/src/Protocol/Protocol_1_8.cpp b/src/Protocol/Protocol_1_8.cpp index 205a899c1..02b76ccae 100644 --- a/src/Protocol/Protocol_1_8.cpp +++ b/src/Protocol/Protocol_1_8.cpp @@ -174,7 +174,7 @@ cProtocol_1_8_0::cProtocol_1_8_0(cClientHandle * a_Client, const AString & a_Ser -void cProtocol_1_8_0::DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer && a_Data) +void cProtocol_1_8_0::DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer & a_Data) { if (m_IsEncrypted) { diff --git a/src/Protocol/Protocol_1_8.h b/src/Protocol/Protocol_1_8.h index 838435ad0..704725bee 100644 --- a/src/Protocol/Protocol_1_8.h +++ b/src/Protocol/Protocol_1_8.h @@ -36,7 +36,7 @@ public: cProtocol_1_8_0(cClientHandle * a_Client, const AString & a_ServerAddress, State a_State); - virtual void DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer && a_Data) override; + virtual void DataReceived(cByteBuffer & a_Buffer, ContiguousByteBuffer & a_Data) override; virtual void DataPrepared(ContiguousByteBuffer & a_Data) override; // Sending stuff to clients (alphabetically sorted): -- cgit v1.2.3