From 86f2f82d2a81bac4739d135959b08c011d007303 Mon Sep 17 00:00:00 2001 From: Mattes D Date: Sat, 24 Jan 2015 23:17:13 +0100 Subject: BlockingSslClientSocket: Migrated to cNetwork API. --- src/PolarSSL++/BlockingSslClientSocket.cpp | 176 +++++++++++++++++++++++++++-- src/PolarSSL++/BlockingSslClientSocket.h | 35 +++++- src/PolarSSL++/SslContext.cpp | 2 +- 3 files changed, 198 insertions(+), 15 deletions(-) diff --git a/src/PolarSSL++/BlockingSslClientSocket.cpp b/src/PolarSSL++/BlockingSslClientSocket.cpp index 59e1281ac..821125b31 100644 --- a/src/PolarSSL++/BlockingSslClientSocket.cpp +++ b/src/PolarSSL++/BlockingSslClientSocket.cpp @@ -10,6 +10,80 @@ +//////////////////////////////////////////////////////////////////////////////// +// cBlockingSslClientSocketConnectCallbacks: + +class cBlockingSslClientSocketConnectCallbacks: + public cNetwork::cConnectCallbacks +{ + /** The socket object that is using this instance of the callbacks. */ + cBlockingSslClientSocket & m_Socket; + + virtual void OnConnected(cTCPLink & a_Link) override + { + m_Socket.OnConnected(); + } + + virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override + { + m_Socket.OnConnectError(a_ErrorMsg); + } + +public: + cBlockingSslClientSocketConnectCallbacks(cBlockingSslClientSocket & a_Socket): + m_Socket(a_Socket) + { + } +}; + + + + + +//////////////////////////////////////////////////////////////////////////////// +// cBlockingSslClientSocketLinkCallbacks: + +class cBlockingSslClientSocketLinkCallbacks: + public cTCPLink::cCallbacks +{ + cBlockingSslClientSocket & m_Socket; + + virtual void OnLinkCreated(cTCPLinkPtr a_Link) override + { + m_Socket.SetLink(a_Link); + } + + + virtual void OnReceivedData(const char * a_Data, size_t a_Length) + { + m_Socket.OnReceivedData(a_Data, a_Length); + } + + + virtual void OnRemoteClosed(void) + { + m_Socket.OnDisconnected(); + } + + + virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) + { + m_Socket.OnDisconnected(); + } +public: + cBlockingSslClientSocketLinkCallbacks(cBlockingSslClientSocket & a_Socket): + m_Socket(a_Socket) + { + } +}; + + + + + +//////////////////////////////////////////////////////////////////////////////// +// cBlockingSslClientSocket: + cBlockingSslClientSocket::cBlockingSslClientSocket(void) : m_Ssl(*this), m_IsConnected(false) @@ -32,10 +106,19 @@ bool cBlockingSslClientSocket::Connect(const AString & a_ServerName, UInt16 a_Po } // Connect the underlying socket: - m_Socket.CreateSocket(cSocket::IPv4); - if (!m_Socket.ConnectIPv4(a_ServerName.c_str(), a_Port)) + m_ServerName = a_ServerName; + if (!cNetwork::Connect(a_ServerName, a_Port, + std::make_shared(*this), + std::make_shared(*this)) + ) + { + return false; + } + + // Wait for the connection to succeed or fail: + m_Event.Wait(); + if (!m_IsConnected) { - Printf(m_LastErrorText, "Socket connect failed: %s", m_Socket.GetLastErrorString().c_str()); return false; } @@ -102,7 +185,7 @@ bool cBlockingSslClientSocket::Send(const void * a_Data, size_t a_NumBytes) ASSERT(m_IsConnected); // Keep sending the data until all of it is sent: - const char * Data = (const char *)a_Data; + const char * Data = reinterpret_cast(a_Data); size_t NumBytes = a_NumBytes; for (;;) { @@ -156,7 +239,8 @@ void cBlockingSslClientSocket::Disconnect(void) } m_Ssl.NotifyClose(); - m_Socket.CloseSocket(); + m_Socket->Close(); + m_Socket.reset(); m_IsConnected = false; } @@ -166,13 +250,25 @@ void cBlockingSslClientSocket::Disconnect(void) int cBlockingSslClientSocket::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) { - int res = m_Socket.Receive((char *)a_Buffer, a_NumBytes, 0); - if (res < 0) + // Wait for any incoming data, if there is none: + cCSLock Lock(m_CSIncomingData); + while (m_IsConnected && m_IncomingData.empty()) + { + cCSUnlock Unlock(Lock); + m_Event.Wait(); + } + + // If we got disconnected, report an error after processing all data: + if (!m_IsConnected && m_IncomingData.empty()) { - // PolarSSL's net routines distinguish between connection reset and general failure, we don't need to return POLARSSL_ERR_NET_RECV_FAILED; } - return res; + + // Copy the data from the incoming buffer into the specified space: + size_t NumToCopy = std::min(a_NumBytes, m_IncomingData.size()); + memcpy(a_Buffer, m_IncomingData.data(), NumToCopy); + m_IncomingData.erase(0, NumToCopy); + return static_cast(NumToCopy); } @@ -181,13 +277,69 @@ int cBlockingSslClientSocket::ReceiveEncrypted(unsigned char * a_Buffer, size_t int cBlockingSslClientSocket::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) { - int res = m_Socket.Send((const char *)a_Buffer, a_NumBytes); - if (res < 0) + cTCPLinkPtr Socket(m_Socket); // Make a copy so that multiple threads don't race on deleting the socket. + if (Socket == nullptr) + { + return POLARSSL_ERR_NET_SEND_FAILED; + } + if (!Socket->Send(a_Buffer, a_NumBytes)) { // PolarSSL's net routines distinguish between connection reset and general failure, we don't need to return POLARSSL_ERR_NET_SEND_FAILED; } - return res; + return static_cast(a_NumBytes); +} + + + + +void cBlockingSslClientSocket::OnConnected(void) +{ + m_IsConnected = true; + m_Event.Set(); +} + + + + + +void cBlockingSslClientSocket::OnConnectError(const AString & a_ErrorMsg) +{ + LOG("Cannot connect to %s: %s", m_ServerName.c_str(), a_ErrorMsg.c_str()); + m_Event.Set(); +} + + + + + +void cBlockingSslClientSocket::OnReceivedData(const char * a_Data, size_t a_Size) +{ + { + cCSLock Lock(m_CSIncomingData); + m_IncomingData.append(a_Data, a_Size); + } + m_Event.Set(); +} + + + + + +void cBlockingSslClientSocket::SetLink(cTCPLinkPtr a_Link) +{ + m_Socket = a_Link; +} + + + + + +void cBlockingSslClientSocket::OnDisconnected(void) +{ + m_Socket.reset(); + m_IsConnected = false; + m_Event.Set(); } diff --git a/src/PolarSSL++/BlockingSslClientSocket.h b/src/PolarSSL++/BlockingSslClientSocket.h index 7af897582..319e82bf2 100644 --- a/src/PolarSSL++/BlockingSslClientSocket.h +++ b/src/PolarSSL++/BlockingSslClientSocket.h @@ -9,8 +9,8 @@ #pragma once +#include "OSSupport/Network.h" #include "CallbackSslContext.h" -#include "../OSSupport/Socket.h" @@ -51,25 +51,56 @@ public: const AString & GetLastErrorText(void) const { return m_LastErrorText; } protected: + friend class cBlockingSslClientSocketConnectCallbacks; + friend class cBlockingSslClientSocketLinkCallbacks; + /** The SSL context used for the socket */ cCallbackSslContext m_Ssl; /** The underlying socket to the SSL server */ - cSocket m_Socket; + cTCPLinkPtr m_Socket; + + /** The object used to signal state changes in the socket (the cause of the blocking). */ + cEvent m_Event; /** The trusted CA root cert store, if we are to verify the cert strictly. Set by SetTrustedRootCertsFromString(). */ cX509CertPtr m_CACerts; /** The expected SSL peer's name, if we are to verify the cert strictly. Set by SetTrustedRootCertsFromString(). */ AString m_ExpectedPeerName; + + /** The hostname to which the socket is connecting (stored for error reporting). */ + AString m_ServerName; /** Text of the last error that has occurred. */ AString m_LastErrorText; /** Set to true if the connection established successfully. */ bool m_IsConnected; + + /** Protects m_IncomingData against multithreaded access. */ + cCriticalSection m_CSIncomingData; + + /** Buffer for the data incoming on the network socket. + Protected by m_CSIncomingData. */ + AString m_IncomingData; + /** Called when the connection is established successfully. */ + void OnConnected(void); + + /** Called when an error occurs while connecting the socket. */ + void OnConnectError(const AString & a_ErrorMsg); + + /** Called when there's incoming data from the socket. */ + void OnReceivedData(const char * a_Data, size_t a_Size); + + /** Called when the link for the connection is created. */ + void SetLink(cTCPLinkPtr a_Link); + + /** Called when the link is disconnected, either gracefully or by an error. */ + void OnDisconnected(void); + // cCallbackSslContext::cDataCallbacks overrides: virtual int ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) override; virtual int SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) override; diff --git a/src/PolarSSL++/SslContext.cpp b/src/PolarSSL++/SslContext.cpp index 902267f90..66dfefc65 100644 --- a/src/PolarSSL++/SslContext.cpp +++ b/src/PolarSSL++/SslContext.cpp @@ -70,7 +70,7 @@ int cSslContext::Initialize(bool a_IsClient, const SharedPtr & // so they're disabled until someone needs them ssl_set_dbg(&m_Ssl, &SSLDebugMessage, this); ssl_set_verify(&m_Ssl, &SSLVerifyCert, this); - */ + //*/ /* // Set ciphersuite to the easiest one to decode, so that the connection can be wireshark-decoded: -- cgit v1.2.3