// SslContext.cpp // Implements the cSslContext class that holds everything a single SSL context needs to function #include "Globals.h" #include "SslContext.h" #include "EntropyContext.h" #include "CtrDrbgContext.h" cSslContext::cSslContext(void) : m_IsValid(false), m_HasHandshaken(false) { } cSslContext::~cSslContext() { if (m_IsValid) { ssl_free(&m_Ssl); } } int cSslContext::Initialize(bool a_IsClient, const SharedPtr & a_CtrDrbg) { // Check double-initialization: if (m_IsValid) { LOGWARNING("SSL: Double initialization is not supported."); return POLARSSL_ERR_SSL_MALLOC_FAILED; // There is no return value well-suited for this, reuse this one. } // Set the CtrDrbg context, create a new one if needed: m_CtrDrbg = a_CtrDrbg; if (m_CtrDrbg.get() == NULL) { m_CtrDrbg.reset(new cCtrDrbgContext); m_CtrDrbg->Initialize("MCServer", 8); } // Initialize PolarSSL's structures: memset(&m_Ssl, 0, sizeof(m_Ssl)); int res = ssl_init(&m_Ssl); if (res != 0) { return res; } ssl_set_endpoint(&m_Ssl, a_IsClient ? SSL_IS_CLIENT : SSL_IS_SERVER); ssl_set_authmode(&m_Ssl, SSL_VERIFY_OPTIONAL); ssl_set_rng(&m_Ssl, ctr_drbg_random, &m_CtrDrbg->m_CtrDrbg); ssl_set_bio(&m_Ssl, ReceiveEncrypted, this, SendEncrypted, this); #ifdef _DEBUG ssl_set_dbg(&m_Ssl, &SSLDebugMessage, this); #endif m_IsValid = true; return 0; } void cSslContext::SetCACerts(const cX509CertPtr & a_CACert, const AString & a_ExpectedPeerName) { // Store the data in our internal buffers, to avoid losing the pointers later on // PolarSSL will need these after this call returns, and the caller may move / delete the data before that: m_ExpectedPeerName = a_ExpectedPeerName; m_CACerts = a_CACert; // Set the trusted CA root cert store: ssl_set_authmode(&m_Ssl, SSL_VERIFY_REQUIRED); ssl_set_ca_chain(&m_Ssl, m_CACerts->GetInternal(), NULL, m_ExpectedPeerName.empty() ? NULL : m_ExpectedPeerName.c_str()); } int cSslContext::WritePlain(const void * a_Data, size_t a_NumBytes) { ASSERT(m_IsValid); // Need to call Initialize() first if (!m_HasHandshaken) { int res = Handshake(); if (res != 0) { return res; } } return ssl_write(&m_Ssl, (const unsigned char *)a_Data, a_NumBytes); } int cSslContext::ReadPlain(void * a_Data, size_t a_MaxBytes) { ASSERT(m_IsValid); // Need to call Initialize() first if (!m_HasHandshaken) { int res = Handshake(); if (res != 0) { return res; } } return ssl_read(&m_Ssl, (unsigned char *)a_Data, a_MaxBytes); } int cSslContext::Handshake(void) { ASSERT(m_IsValid); // Need to call Initialize() first ASSERT(!m_HasHandshaken); // Must not call twice int res = ssl_handshake(&m_Ssl); if (res == 0) { m_HasHandshaken = true; } return res; } int cSslContext::NotifyClose(void) { return ssl_close_notify(&m_Ssl); } #ifdef _DEBUG void cSslContext::SSLDebugMessage(void * a_UserParam, int a_Level, const char * a_Text) { if (a_Level > 3) { // Don't want the trace messages return; } // Remove the terminating LF: size_t len = strlen(a_Text) - 1; while ((len > 0) && (a_Text[len] <= 32)) { len--; } AString Text(a_Text, len + 1); LOGD("SSL (%d): %s", a_Level, Text.c_str()); } #endif // _DEBUG