diff options
Diffstat (limited to '')
-rw-r--r-- | src/core/hle/service/ssl/ssl.cpp | 349 | ||||
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend.h | 44 | ||||
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend_none.cpp | 15 | ||||
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend_openssl.cpp | 342 | ||||
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend_schannel.cpp | 529 |
5 files changed, 1262 insertions, 17 deletions
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp index 2b99dd7ac..a3b54c7f0 100644 --- a/src/core/hle/service/ssl/ssl.cpp +++ b/src/core/hle/service/ssl/ssl.cpp @@ -1,10 +1,18 @@ // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later +#include "common/string_util.h" + +#include "core/core.h" #include "core/hle/service/ipc_helpers.h" #include "core/hle/service/server_manager.h" #include "core/hle/service/service.h" +#include "core/hle/service/sm/sm.h" +#include "core/hle/service/sockets/bsd.h" #include "core/hle/service/ssl/ssl.h" +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" namespace Service::SSL { @@ -20,6 +28,18 @@ enum class ContextOption : u32 { CrlImportDateCheckEnable = 1, }; +// This is nn::ssl::Connection::IoMode +enum class IoMode : u32 { + Blocking = 1, + NonBlocking = 2, +}; + +// This is nn::ssl::sf::OptionType +enum class OptionType : u32 { + DoNotCloseSocket = 0, + GetServerCertChain = 1, +}; + // This is nn::ssl::sf::SslVersion struct SslVersion { union { @@ -34,35 +54,42 @@ struct SslVersion { }; }; +struct SslContextSharedData { + u32 connection_count = 0; +}; + class ISslConnection final : public ServiceFramework<ISslConnection> { public: - explicit ISslConnection(Core::System& system_, SslVersion version) - : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { + explicit ISslConnection(Core::System& system_, SslVersion version, + std::shared_ptr<SslContextSharedData>& shared_data, + std::unique_ptr<SSLConnectionBackend>&& backend) + : ServiceFramework{system_, "ISslConnection"}, ssl_version{version}, + shared_data_{shared_data}, backend_{std::move(backend)} { // clang-format off static const FunctionInfo functions[] = { - {0, nullptr, "SetSocketDescriptor"}, - {1, nullptr, "SetHostName"}, - {2, nullptr, "SetVerifyOption"}, - {3, nullptr, "SetIoMode"}, + {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"}, + {1, &ISslConnection::SetHostName, "SetHostName"}, + {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"}, + {3, &ISslConnection::SetIoMode, "SetIoMode"}, {4, nullptr, "GetSocketDescriptor"}, {5, nullptr, "GetHostName"}, {6, nullptr, "GetVerifyOption"}, {7, nullptr, "GetIoMode"}, - {8, nullptr, "DoHandshake"}, - {9, nullptr, "DoHandshakeGetServerCert"}, - {10, nullptr, "Read"}, - {11, nullptr, "Write"}, - {12, nullptr, "Pending"}, + {8, &ISslConnection::DoHandshake, "DoHandshake"}, + {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"}, + {10, &ISslConnection::Read, "Read"}, + {11, &ISslConnection::Write, "Write"}, + {12, &ISslConnection::Pending, "Pending"}, {13, nullptr, "Peek"}, {14, nullptr, "Poll"}, {15, nullptr, "GetVerifyCertError"}, {16, nullptr, "GetNeededServerCertBufferSize"}, - {17, nullptr, "SetSessionCacheMode"}, + {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"}, {18, nullptr, "GetSessionCacheMode"}, {19, nullptr, "FlushSessionCache"}, {20, nullptr, "SetRenegotiationMode"}, {21, nullptr, "GetRenegotiationMode"}, - {22, nullptr, "SetOption"}, + {22, &ISslConnection::SetOption, "SetOption"}, {23, nullptr, "GetOption"}, {24, nullptr, "GetVerifyCertErrors"}, {25, nullptr, "GetCipherInfo"}, @@ -80,21 +107,295 @@ public: // clang-format on RegisterHandlers(functions); + + shared_data->connection_count++; + } + + ~ISslConnection() { + shared_data_->connection_count--; + if (fd_to_close_.has_value()) { + s32 fd = *fd_to_close_; + if (!do_not_close_socket_) { + LOG_ERROR(Service_SSL, + "do_not_close_socket was changed after setting socket; is this right?"); + } else { + auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); + if (bsd) { + auto err = bsd->CloseImpl(fd); + if (err != Service::Sockets::Errno::SUCCESS) { + LOG_ERROR(Service_SSL, "failed to close duplicated socket: {}", err); + } + } + } + } } private: SslVersion ssl_version; + std::shared_ptr<SslContextSharedData> shared_data_; + std::unique_ptr<SSLConnectionBackend> backend_; + std::optional<int> fd_to_close_; + bool do_not_close_socket_ = false; + bool get_server_cert_chain_ = false; + std::shared_ptr<Network::SocketBase> socket_; + bool did_set_host_name_ = false; + bool did_handshake_ = false; + + ResultVal<s32> SetSocketDescriptorImpl(s32 fd) { + LOG_DEBUG(Service_SSL, "called, fd={}", fd); + ASSERT(!did_handshake_); + auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); + ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; }); + s32 ret_fd; + // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor + if (do_not_close_socket_) { + auto res = bsd->DuplicateSocketImpl(fd); + if (!res.has_value()) { + LOG_ERROR(Service_SSL, "failed to duplicate socket"); + return ResultInvalidSocket; + } + fd = *res; + fd_to_close_ = fd; + ret_fd = fd; + } else { + ret_fd = -1; + } + std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd); + if (!sock.has_value()) { + LOG_ERROR(Service_SSL, "invalid socket fd {}", fd); + return ResultInvalidSocket; + } + socket_ = std::move(*sock); + backend_->SetSocket(socket_); + return ret_fd; + } + + Result SetHostNameImpl(const std::string& hostname) { + LOG_DEBUG(Service_SSL, "SetHostNameImpl({})", hostname); + ASSERT(!did_handshake_); + Result res = backend_->SetHostName(hostname); + if (res == ResultSuccess) { + did_set_host_name_ = true; + } + return res; + } + + Result SetVerifyOptionImpl(u32 option) { + ASSERT(!did_handshake_); + LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option); + return ResultSuccess; + } + + Result SetIOModeImpl(u32 _mode) { + auto mode = static_cast<IoMode>(_mode); + ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking); + ASSERT_OR_EXECUTE(socket_, { return ResultNoSocket; }); + + bool non_block = mode == IoMode::NonBlocking; + Network::Errno e = socket_->SetNonBlock(non_block); + if (e != Network::Errno::SUCCESS) { + LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block); + } + return ResultSuccess; + } + + Result SetSessionCacheModeImpl(u32 mode) { + ASSERT(!did_handshake_); + LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode); + return ResultSuccess; + } + + Result DoHandshakeImpl() { + ASSERT_OR_EXECUTE(!did_handshake_ && socket_, { return ResultNoSocket; }); + ASSERT_OR_EXECUTE_MSG( + did_set_host_name_, { return ResultInternalError; }, + "Expected SetHostName before DoHandshake"); + Result res = backend_->DoHandshake(); + did_handshake_ = res.IsSuccess(); + return res; + } + + std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) { + struct Header { + u64 magic; + u32 count; + u32 pad; + }; + struct EntryHeader { + u32 size; + u32 offset; + }; + if (!get_server_cert_chain_) { + // Just return the first one, unencoded. + ASSERT_OR_EXECUTE_MSG( + !certs.empty(), { return {}; }, "Should be at least one server cert"); + return certs[0]; + } + std::vector<u8> ret; + Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0}; + ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1)); + size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader); + for (auto& cert : certs) { + EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)}; + data_offset += cert.size(); + ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header), + reinterpret_cast<u8*>(&entry_header + 1)); + } + for (auto& cert : certs) { + ret.insert(ret.end(), cert.begin(), cert.end()); + } + return ret; + } + + ResultVal<std::vector<u8>> ReadImpl(size_t size) { + ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; }); + std::vector<u8> res(size); + ResultVal<size_t> actual = backend_->Read(res); + if (actual.Failed()) { + return actual.Code(); + } + res.resize(*actual); + return res; + } + + ResultVal<size_t> WriteImpl(std::span<const u8> data) { + ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; }); + return backend_->Write(data); + } + + ResultVal<s32> PendingImpl() { + LOG_WARNING(Service_SSL, "(STUBBED) called."); + return 0; + } + + void SetSocketDescriptor(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const s32 fd = rp.Pop<s32>(); + const ResultVal<s32> res = SetSocketDescriptorImpl(fd); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push<s32>(res.ValueOr(-1)); + } + + void SetHostName(HLERequestContext& ctx) { + const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer()); + const Result res = SetHostNameImpl(hostname); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetVerifyOption(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 option = rp.Pop<u32>(); + const Result res = SetVerifyOptionImpl(option); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetIoMode(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 mode = rp.Pop<u32>(); + const Result res = SetIOModeImpl(mode); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void DoHandshake(HLERequestContext& ctx) { + const Result res = DoHandshakeImpl(); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void DoHandshakeGetServerCert(HLERequestContext& ctx) { + Result res = DoHandshakeImpl(); + u32 certs_count = 0; + u32 certs_size = 0; + if (res == ResultSuccess) { + auto certs = backend_->GetServerCerts(); + if (certs.Succeeded()) { + std::vector<u8> certs_buf = SerializeServerCerts(*certs); + ctx.WriteBuffer(certs_buf); + certs_count = static_cast<u32>(certs->size()); + certs_size = static_cast<u32>(certs_buf.size()); + } + } + IPC::ResponseBuilder rb{ctx, 4}; + rb.Push(res); + rb.Push(certs_size); + rb.Push(certs_count); + } + + void Read(HLERequestContext& ctx) { + const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize()); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + if (res.Succeeded()) { + rb.Push(static_cast<u32>(res->size())); + ctx.WriteBuffer(*res); + } else { + rb.Push(static_cast<u32>(0)); + } + } + + void Write(HLERequestContext& ctx) { + const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer()); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push(static_cast<u32>(res.ValueOr(0))); + } + + void Pending(HLERequestContext& ctx) { + const ResultVal<s32> res = PendingImpl(); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push<s32>(res.ValueOr(0)); + } + + void SetSessionCacheMode(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 mode = rp.Pop<u32>(); + const Result res = SetSessionCacheModeImpl(mode); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetOption(HLERequestContext& ctx) { + struct Parameters { + OptionType option; + s32 value; + }; + static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size"); + + IPC::RequestParser rp{ctx}; + const auto parameters = rp.PopRaw<Parameters>(); + + switch (parameters.option) { + case OptionType::DoNotCloseSocket: + do_not_close_socket_ = static_cast<bool>(parameters.value); + break; + case OptionType::GetServerCertChain: + get_server_cert_chain_ = static_cast<bool>(parameters.value); + break; + default: + LOG_WARNING(Service_SSL, "unrecognized option={}, value={}", parameters.option, + parameters.value); + } + + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(ResultSuccess); + } }; class ISslContext final : public ServiceFramework<ISslContext> { public: explicit ISslContext(Core::System& system_, SslVersion version) - : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { + : ServiceFramework{system_, "ISslContext"}, ssl_version{version}, + shared_data_{std::make_shared<SslContextSharedData>()} { static const FunctionInfo functions[] = { {0, &ISslContext::SetOption, "SetOption"}, {1, nullptr, "GetOption"}, {2, &ISslContext::CreateConnection, "CreateConnection"}, - {3, nullptr, "GetConnectionCount"}, + {3, &ISslContext::GetConnectionCount, "GetConnectionCount"}, {4, &ISslContext::ImportServerPki, "ImportServerPki"}, {5, &ISslContext::ImportClientPki, "ImportClientPki"}, {6, nullptr, "RemoveServerPki"}, @@ -111,6 +412,7 @@ public: private: SslVersion ssl_version; + std::shared_ptr<SslContextSharedData> shared_data_; void SetOption(HLERequestContext& ctx) { struct Parameters { @@ -130,11 +432,24 @@ private: } void CreateConnection(HLERequestContext& ctx) { - LOG_WARNING(Service_SSL, "(STUBBED) called"); + LOG_WARNING(Service_SSL, "called"); + + auto backend_res = CreateSSLConnectionBackend(); IPC::ResponseBuilder rb{ctx, 2, 0, 1}; + rb.Push(backend_res.Code()); + if (backend_res.Succeeded()) { + rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data_, + std::move(*backend_res)); + } + } + + void GetConnectionCount(HLERequestContext& ctx) { + LOG_WARNING(Service_SSL, "connection_count={}", shared_data_->connection_count); + + IPC::ResponseBuilder rb{ctx, 3}; rb.Push(ResultSuccess); - rb.PushIpcInterface<ISslConnection>(system, ssl_version); + rb.Push(shared_data_->connection_count); } void ImportServerPki(HLERequestContext& ctx) { diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h new file mode 100644 index 000000000..624e07d41 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend.h @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#pragma once + +#include "core/hle/result.h" + +#include "common/common_types.h" + +#include <memory> +#include <span> +#include <string> +#include <vector> + +namespace Network { +class SocketBase; +} + +namespace Service::SSL { + +constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103}; +constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106}; +constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205}; +constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up + +constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204}; +// ^ ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake, +// with no way in the latter case to distinguish whether the client should poll +// for read or write. The one official client I've seen handles this by always +// polling for read (with a timeout). + +class SSLConnectionBackend { +public: + virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0; + virtual Result SetHostName(const std::string& hostname) = 0; + virtual Result DoHandshake() = 0; + virtual ResultVal<size_t> Read(std::span<u8> data) = 0; + virtual ResultVal<size_t> Write(std::span<const u8> data) = 0; + virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend(); + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp new file mode 100644 index 000000000..eb01561e2 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_none.cpp @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include "core/hle/service/ssl/ssl_backend.h" + +#include "common/logging/log.h" + +namespace Service::SSL { + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + LOG_ERROR(Service_SSL, "No SSL backend on this platform"); + return ResultInternalError; +} + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp new file mode 100644 index 000000000..cf9b904ac --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp @@ -0,0 +1,342 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +#include "common/fs/file.h" +#include "common/hex_util.h" +#include "common/string_util.h" + +#include <mutex> + +#include <openssl/bio.h> +#include <openssl/err.h> +#include <openssl/ssl.h> +#include <openssl/x509.h> + +using namespace Common::FS; + +namespace Service::SSL { + +// Import OpenSSL's `SSL` type into the namespace. This is needed because the +// namespace is also named `SSL`. +using ::SSL; + +namespace { + +std::once_flag one_time_init_flag; +bool one_time_init_success = false; + +SSL_CTX* ssl_ctx; +IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment +BIO_METHOD* bio_meth; + +Result CheckOpenSSLErrors(); +void OneTimeInit(); +void OneTimeInitLogFile(); +bool OneTimeInitBIO(); + +} // namespace + +class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend { +public: + Result Init() { + std::call_once(one_time_init_flag, OneTimeInit); + + if (!one_time_init_success) { + LOG_ERROR(Service_SSL, + "Can't create SSL connection because OpenSSL one-time initialization failed"); + return ResultInternalError; + } + + ssl_ = SSL_new(ssl_ctx); + if (!ssl_) { + LOG_ERROR(Service_SSL, "SSL_new failed"); + return CheckOpenSSLErrors(); + } + + SSL_set_connect_state(ssl_); + + bio_ = BIO_new(bio_meth); + if (!bio_) { + LOG_ERROR(Service_SSL, "BIO_new failed"); + return CheckOpenSSLErrors(); + } + + BIO_set_data(bio_, this); + BIO_set_init(bio_, 1); + SSL_set_bio(ssl_, bio_, bio_); + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr<Network::SocketBase> socket) override { + socket_ = socket; + } + + Result SetHostName(const std::string& hostname) override { + if (!SSL_set1_host(ssl_, hostname.c_str())) { // hostname for verification + LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname); + return CheckOpenSSLErrors(); + } + if (!SSL_set_tlsext_host_name(ssl_, hostname.c_str())) { // hostname for SNI + LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname); + return CheckOpenSSLErrors(); + } + return ResultSuccess; + } + + Result DoHandshake() override { + SSL_set_verify_result(ssl_, X509_V_OK); + int ret = SSL_do_handshake(ssl_); + long verify_result = SSL_get_verify_result(ssl_); + if (verify_result != X509_V_OK) { + LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}", + X509_verify_cert_error_string(verify_result)); + return CheckOpenSSLErrors(); + } + if (ret <= 0) { + int ssl_err = SSL_get_error(ssl_, ret); + if (ssl_err == SSL_ERROR_ZERO_RETURN || + (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_)) { + LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); + return ResultInternalError; + } + } + return HandleReturn("SSL_do_handshake", 0, ret).Code(); + } + + ResultVal<size_t> Read(std::span<u8> data) override { + size_t actual; + int ret = SSL_read_ex(ssl_, data.data(), data.size(), &actual); + return HandleReturn("SSL_read_ex", actual, ret); + } + + ResultVal<size_t> Write(std::span<const u8> data) override { + size_t actual; + int ret = SSL_write_ex(ssl_, data.data(), data.size(), &actual); + return HandleReturn("SSL_write_ex", actual, ret); + } + + ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) { + int ssl_err = SSL_get_error(ssl_, ret); + CheckOpenSSLErrors(); + switch (ssl_err) { + case SSL_ERROR_NONE: + return actual; + case SSL_ERROR_ZERO_RETURN: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what); + // DoHandshake special-cases this, but for Read and Write: + return size_t(0); + case SSL_ERROR_WANT_READ: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what); + return ResultWouldBlock; + case SSL_ERROR_WANT_WRITE: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what); + return ResultWouldBlock; + default: + if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_) { + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what); + return size_t(0); + } + LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err); + return ResultInternalError; + } + } + + ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { + STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_); + if (!chain) { + LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr"); + return ResultInternalError; + } + std::vector<std::vector<u8>> ret; + int count = sk_X509_num(chain); + ASSERT(count >= 0); + for (int i = 0; i < count; i++) { + X509* x509 = sk_X509_value(chain, i); + ASSERT_OR_EXECUTE(x509 != nullptr, { continue; }); + unsigned char* buf = nullptr; + int len = i2d_X509(x509, &buf); + ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; }); + ret.emplace_back(buf, buf + len); + OPENSSL_free(buf); + } + return ret; + } + + ~SSLConnectionBackendOpenSSL() { + // these are null-tolerant: + SSL_free(ssl_); + BIO_free(bio_); + } + + static void KeyLogCallback(const SSL* ssl, const char* line) { + std::string str(line); + str.push_back('\n'); + // Do this in a single WriteString for atomicity if multiple instances + // are running on different threads (though that can't currently + // happen). + if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) { + LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE"); + } + LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line); + } + + static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) { + auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); + ASSERT_OR_EXECUTE_MSG( + self->socket_, { return 0; }, "OpenSSL asked to send but we have no socket"); + BIO_clear_retry_flags(bio); + auto [actual, err] = self->socket_->Send({reinterpret_cast<const u8*>(buf), len}, 0); + switch (err) { + case Network::Errno::SUCCESS: + *actual_p = actual; + return 1; + case Network::Errno::AGAIN: + BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY); + return 0; + default: + LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); + return -1; + } + } + + static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) { + auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); + ASSERT_OR_EXECUTE_MSG( + self->socket_, { return 0; }, "OpenSSL asked to recv but we have no socket"); + BIO_clear_retry_flags(bio); + auto [actual, err] = self->socket_->Recv(0, {reinterpret_cast<u8*>(buf), len}); + switch (err) { + case Network::Errno::SUCCESS: + *actual_p = actual; + if (actual == 0) { + self->got_read_eof_ = true; + } + return actual ? 1 : 0; + case Network::Errno::AGAIN: + BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY); + return 0; + default: + LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); + return -1; + } + } + + static long CtrlCallback(BIO* bio, int cmd, long larg, void* parg) { + switch (cmd) { + case BIO_CTRL_FLUSH: + // Nothing to flush. + return 1; + case BIO_CTRL_PUSH: + case BIO_CTRL_POP: + case BIO_CTRL_GET_KTLS_SEND: + case BIO_CTRL_GET_KTLS_RECV: + // We don't support these operations, but don't bother logging them + // as they're nothing unusual. + return 0; + default: + LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, larg, parg); + return 0; + } + } + + SSL* ssl_ = nullptr; + BIO* bio_ = nullptr; + bool got_read_eof_ = false; + + std::shared_ptr<Network::SocketBase> socket_; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + auto conn = std::make_unique<SSLConnectionBackendOpenSSL>(); + Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +namespace { + +Result CheckOpenSSLErrors() { + unsigned long rc; + const char* file; + int line; + const char* func; + const char* data; + int flags; + while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) { + std::string msg; + msg.resize(1024, '\0'); + ERR_error_string_n(rc, msg.data(), msg.size()); + msg.resize(strlen(msg.data()), '\0'); + if (flags & ERR_TXT_STRING) { + msg.append(" | "); + msg.append(data); + } + Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error, + Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}", + msg); + } + return ResultInternalError; +} + +void OneTimeInit() { + ssl_ctx = SSL_CTX_new(TLS_client_method()); + if (!ssl_ctx) { + LOG_ERROR(Service_SSL, "SSL_CTX_new failed"); + CheckOpenSSLErrors(); + return; + } + + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr); + + if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) { + LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed"); + CheckOpenSSLErrors(); + return; + } + + OneTimeInitLogFile(); + + if (!OneTimeInitBIO()) { + return; + } + + one_time_init_success = true; +} + +void OneTimeInitLogFile() { + const char* logfile = getenv("SSLKEYLOGFILE"); + if (logfile) { + key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile, + FileShareFlag::ShareWriteOnly); + if (key_log_file.IsOpen()) { + SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback); + } else { + LOG_CRITICAL(Service_SSL, + "SSLKEYLOGFILE was set but file could not be opened; not logging keys!"); + } + } +} + +bool OneTimeInitBIO() { + bio_meth = + BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL"); + if (!bio_meth || + !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) || + !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) || + !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) { + LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD"); + return false; + } + return true; +} + +} // namespace + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp new file mode 100644 index 000000000..0a326b536 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp @@ -0,0 +1,529 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +#include "common/error.h" +#include "common/fs/file.h" +#include "common/hex_util.h" +#include "common/string_util.h" + +#include <mutex> + +#define SECURITY_WIN32 +#include <Security.h> +#include <schnlsp.h> + +namespace { + +std::once_flag one_time_init_flag; +bool one_time_init_success = false; + +SCHANNEL_CRED schannel_cred{ + .dwVersion = SCHANNEL_CRED_VERSION, + .dwFlags = SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols + SCH_CRED_AUTO_CRED_VALIDATION | // validate certs + SCH_CRED_NO_DEFAULT_CREDS, // don't automatically present a client certificate + // ^ I'm assuming that nobody would want to connect Yuzu to a + // service that requires some OS-provided corporate client + // certificate, and presenting one to some arbitrary server + // might be a privacy concern? Who knows, though. +}; + +CredHandle cred_handle; + +static void OneTimeInit() { + SECURITY_STATUS ret = + AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND, + nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr); + if (ret != SEC_E_OK) { + // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString. + LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}", + Common::NativeErrorToString(ret)); + return; + } + + one_time_init_success = true; +} + +} // namespace + +namespace Service::SSL { + +class SSLConnectionBackendSchannel final : public SSLConnectionBackend { +public: + Result Init() { + std::call_once(one_time_init_flag, OneTimeInit); + + if (!one_time_init_success) { + LOG_ERROR( + Service_SSL, + "Can't create SSL connection because Schannel one-time initialization failed"); + return ResultInternalError; + } + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr<Network::SocketBase> socket) override { + socket_ = socket; + } + + Result SetHostName(const std::string& hostname) override { + hostname_ = hostname; + return ResultSuccess; + } + + Result DoHandshake() override { + while (1) { + Result r; + switch (handshake_state_) { + case HandshakeState::Initial: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || + (r = CallInitializeSecurityContext()) != ResultSuccess) { + return r; + } + // CallInitializeSecurityContext updated `handshake_state_`. + continue; + case HandshakeState::ContinueNeeded: + case HandshakeState::IncompleteMessage: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || + (r = FillCiphertextReadBuf()) != ResultSuccess) { + return r; + } + if (ciphertext_read_buf_.empty()) { + LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); + return ResultInternalError; + } + if ((r = CallInitializeSecurityContext()) != ResultSuccess) { + return r; + } + // CallInitializeSecurityContext updated `handshake_state_`. + continue; + case HandshakeState::DoneAfterFlush: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) { + return r; + } + handshake_state_ = HandshakeState::Connected; + return ResultSuccess; + case HandshakeState::Connected: + LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook"); + return ResultInternalError; + case HandshakeState::Error: + return ResultInternalError; + } + } + } + + Result FillCiphertextReadBuf() { + size_t fill_size = read_buf_fill_size_ ? read_buf_fill_size_ : 4096; + read_buf_fill_size_ = 0; + // This unnecessarily zeroes the buffer; oh well. + size_t offset = ciphertext_read_buf_.size(); + ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; }); + ciphertext_read_buf_.resize(offset + fill_size, 0); + auto read_span = std::span(ciphertext_read_buf_).subspan(offset, fill_size); + auto [actual, err] = socket_->Recv(0, read_span); + switch (err) { + case Network::Errno::SUCCESS: + ASSERT(static_cast<size_t>(actual) <= fill_size); + ciphertext_read_buf_.resize(offset + actual); + return ResultSuccess; + case Network::Errno::AGAIN: + ciphertext_read_buf_.resize(offset); + return ResultWouldBlock; + default: + ciphertext_read_buf_.resize(offset); + LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); + return ResultInternalError; + } + } + + // Returns success if the write buffer has been completely emptied. + Result FlushCiphertextWriteBuf() { + while (!ciphertext_write_buf_.empty()) { + auto [actual, err] = socket_->Send(ciphertext_write_buf_, 0); + switch (err) { + case Network::Errno::SUCCESS: + ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf_.size()); + ciphertext_write_buf_.erase(ciphertext_write_buf_.begin(), + ciphertext_write_buf_.begin() + actual); + break; + case Network::Errno::AGAIN: + return ResultWouldBlock; + default: + LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); + return ResultInternalError; + } + } + return ResultSuccess; + } + + Result CallInitializeSecurityContext() { + unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY | + ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM | + ISC_REQ_USE_SUPPLIED_CREDS; + unsigned long attr; + // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel + std::array<SecBuffer, 2> input_buffers{{ + // only used if `initial_call_done` + { + // [0] + .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()), + .BufferType = SECBUFFER_TOKEN, + .pvBuffer = ciphertext_read_buf_.data(), + }, + { + // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is + // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the + // whole buffer wasn't used) + .BufferType = SECBUFFER_EMPTY, + }, + }}; + std::array<SecBuffer, 2> output_buffers{{ + { + .BufferType = SECBUFFER_TOKEN, + }, // [0] + { + .BufferType = SECBUFFER_ALERT, + }, // [1] + }}; + SecBufferDesc input_desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(input_buffers.size()), + .pBuffers = input_buffers.data(), + }; + SecBufferDesc output_desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(output_buffers.size()), + .pBuffers = output_buffers.data(), + }; + ASSERT_OR_EXECUTE_MSG( + input_buffers[0].cbBuffer == ciphertext_read_buf_.size(), + { return ResultInternalError; }, "read buffer too large"); + + bool initial_call_done = handshake_state_ != HandshakeState::Initial; + if (initial_call_done) { + LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext", + ciphertext_read_buf_.size()); + } + + SECURITY_STATUS ret = + InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt_ : nullptr, + // Caller ensured we have set a hostname: + const_cast<char*>(hostname_.value().c_str()), req, + 0, // Reserved1 + 0, // TargetDataRep not used with Schannel + initial_call_done ? &input_desc : nullptr, + 0, // Reserved2 + initial_call_done ? nullptr : &ctxt_, &output_desc, &attr, + nullptr); // ptsExpiry + + if (output_buffers[0].pvBuffer) { + std::span span(static_cast<u8*>(output_buffers[0].pvBuffer), + output_buffers[0].cbBuffer); + ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), span.begin(), span.end()); + FreeContextBuffer(output_buffers[0].pvBuffer); + } + + if (output_buffers[1].pvBuffer) { + std::span span(static_cast<u8*>(output_buffers[1].pvBuffer), + output_buffers[1].cbBuffer); + // The documentation doesn't explain what format this data is in. + LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(), + Common::HexToString(span)); + } + + switch (ret) { + case SEC_I_CONTINUE_NEEDED: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED"); + if (input_buffers[1].BufferType == SECBUFFER_EXTRA) { + LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer); + ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf_.size()); + ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(), + ciphertext_read_buf_.end() - input_buffers[1].cbBuffer); + } else { + ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY); + ciphertext_read_buf_.clear(); + } + handshake_state_ = HandshakeState::ContinueNeeded; + return ResultSuccess; + case SEC_E_INCOMPLETE_MESSAGE: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE"); + ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING); + read_buf_fill_size_ = input_buffers[1].cbBuffer; + handshake_state_ = HandshakeState::IncompleteMessage; + return ResultSuccess; + case SEC_E_OK: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK"); + ciphertext_read_buf_.clear(); + handshake_state_ = HandshakeState::DoneAfterFlush; + return GrabStreamSizes(); + default: + LOG_ERROR(Service_SSL, + "InitializeSecurityContext failed (probably certificate/protocol issue): {}", + Common::NativeErrorToString(ret)); + handshake_state_ = HandshakeState::Error; + return ResultInternalError; + } + } + + Result GrabStreamSizes() { + SECURITY_STATUS ret = + QueryContextAttributes(&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}", + Common::NativeErrorToString(ret)); + handshake_state_ = HandshakeState::Error; + return ResultInternalError; + } + return ResultSuccess; + } + + ResultVal<size_t> Read(std::span<u8> data) override { + if (handshake_state_ != HandshakeState::Connected) { + LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); + return ResultInternalError; + } + if (data.size() == 0 || got_read_eof_) { + return size_t(0); + } + while (1) { + if (!cleartext_read_buf_.empty()) { + size_t read_size = std::min(cleartext_read_buf_.size(), data.size()); + std::memcpy(data.data(), cleartext_read_buf_.data(), read_size); + cleartext_read_buf_.erase(cleartext_read_buf_.begin(), + cleartext_read_buf_.begin() + read_size); + return read_size; + } + if (!ciphertext_read_buf_.empty()) { + std::array<SecBuffer, 5> buffers{{ + { + .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()), + .BufferType = SECBUFFER_DATA, + .pvBuffer = ciphertext_read_buf_.data(), + }, + { + .BufferType = SECBUFFER_EMPTY, + }, + { + .BufferType = SECBUFFER_EMPTY, + }, + { + .BufferType = SECBUFFER_EMPTY, + }, + }}; + ASSERT_OR_EXECUTE_MSG( + buffers[0].cbBuffer == ciphertext_read_buf_.size(), + { return ResultInternalError; }, "read buffer too large"); + SecBufferDesc desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(buffers.size()), + .pBuffers = buffers.data(), + }; + SECURITY_STATUS ret = + DecryptMessage(&ctxt_, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); + switch (ret) { + case SEC_E_OK: + ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER, + { return ResultInternalError; }); + ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA, + { return ResultInternalError; }); + ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER, + { return ResultInternalError; }); + cleartext_read_buf_.assign(static_cast<u8*>(buffers[1].pvBuffer), + static_cast<u8*>(buffers[1].pvBuffer) + + buffers[1].cbBuffer); + if (buffers[3].BufferType == SECBUFFER_EXTRA) { + ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf_.size()); + ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(), + ciphertext_read_buf_.end() - + buffers[3].cbBuffer); + } else { + ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY); + ciphertext_read_buf_.clear(); + } + continue; + case SEC_E_INCOMPLETE_MESSAGE: + break; + case SEC_I_CONTEXT_EXPIRED: + // Server hung up by sending close_notify. + got_read_eof_ = true; + return size_t(0); + default: + LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", + Common::NativeErrorToString(ret)); + return ResultInternalError; + } + } + Result r = FillCiphertextReadBuf(); + if (r != ResultSuccess) { + return r; + } + if (ciphertext_read_buf_.empty()) { + got_read_eof_ = true; + return size_t(0); + } + } + } + + ResultVal<size_t> Write(std::span<const u8> data) override { + if (handshake_state_ != HandshakeState::Connected) { + LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); + return ResultInternalError; + } + if (data.size() == 0) { + return size_t(0); + } + data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes_.cbMaximumMessage)); + if (!cleartext_write_buf_.empty()) { + // Already in the middle of a write. It wouldn't make sense to not + // finish sending the entire buffer since TLS has + // header/MAC/padding/etc. + if (data.size() != cleartext_write_buf_.size() || + std::memcmp(data.data(), cleartext_write_buf_.data(), data.size())) { + LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); + return ResultInternalError; + } + return WriteAlreadyEncryptedData(); + } else { + cleartext_write_buf_.assign(data.begin(), data.end()); + } + + std::vector<u8> header_buf(stream_sizes_.cbHeader, 0); + std::vector<u8> tmp_data_buf = cleartext_write_buf_; + std::vector<u8> trailer_buf(stream_sizes_.cbTrailer, 0); + + std::array<SecBuffer, 3> buffers{{ + { + .cbBuffer = stream_sizes_.cbHeader, + .BufferType = SECBUFFER_STREAM_HEADER, + .pvBuffer = header_buf.data(), + }, + { + .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()), + .BufferType = SECBUFFER_DATA, + .pvBuffer = tmp_data_buf.data(), + }, + { + .cbBuffer = stream_sizes_.cbTrailer, + .BufferType = SECBUFFER_STREAM_TRAILER, + .pvBuffer = trailer_buf.data(), + }, + }}; + ASSERT_OR_EXECUTE_MSG( + buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; }, + "temp buffer too large"); + SecBufferDesc desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(buffers.size()), + .pBuffers = buffers.data(), + }; + + SECURITY_STATUS ret = EncryptMessage(&ctxt_, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret)); + return ResultInternalError; + } + ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), header_buf.begin(), + header_buf.end()); + ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), tmp_data_buf.begin(), + tmp_data_buf.end()); + ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), trailer_buf.begin(), + trailer_buf.end()); + return WriteAlreadyEncryptedData(); + } + + ResultVal<size_t> WriteAlreadyEncryptedData() { + Result r = FlushCiphertextWriteBuf(); + if (r != ResultSuccess) { + return r; + } + // write buf is empty + size_t cleartext_bytes_written = cleartext_write_buf_.size(); + cleartext_write_buf_.clear(); + return cleartext_bytes_written; + } + + ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { + PCCERT_CONTEXT returned_cert = nullptr; + SECURITY_STATUS ret = + QueryContextAttributes(&ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, + "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}", + Common::NativeErrorToString(ret)); + return ResultInternalError; + } + PCCERT_CONTEXT some_cert = nullptr; + std::vector<std::vector<u8>> certs; + while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) { + certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded), + static_cast<u8*>(some_cert->pbCertEncoded) + + some_cert->cbCertEncoded); + } + std::reverse(certs.begin(), + certs.end()); // Windows returns certs in reverse order from what we want + CertFreeCertificateContext(returned_cert); + return certs; + } + + ~SSLConnectionBackendSchannel() { + if (handshake_state_ != HandshakeState::Initial) { + DeleteSecurityContext(&ctxt_); + } + } + + enum class HandshakeState { + // Haven't called anything yet. + Initial, + // `SEC_I_CONTINUE_NEEDED` was returned by + // `InitializeSecurityContext`; must finish sending data (if any) in + // the write buffer, then read at least one byte before calling + // `InitializeSecurityContext` again. + ContinueNeeded, + // `SEC_E_INCOMPLETE_MESSAGE` was returned by + // `InitializeSecurityContext`; hopefully the write buffer is empty; + // must read at least one byte before calling + // `InitializeSecurityContext` again. + IncompleteMessage, + // `SEC_E_OK` was returned by `InitializeSecurityContext`; must + // finish sending data in the write buffer before having `DoHandshake` + // report success. + DoneAfterFlush, + // We finished the above and are now connected. At this point, writing + // and reading are separate 'state machines' represented by the + // nonemptiness of the ciphertext and cleartext read and write buffers. + Connected, + // Another error was returned and we shouldn't allow initialization + // to continue. + Error, + } handshake_state_ = HandshakeState::Initial; + + CtxtHandle ctxt_; + SecPkgContext_StreamSizes stream_sizes_; + + std::shared_ptr<Network::SocketBase> socket_; + std::optional<std::string> hostname_; + + std::vector<u8> ciphertext_read_buf_; + std::vector<u8> ciphertext_write_buf_; + std::vector<u8> cleartext_read_buf_; + std::vector<u8> cleartext_write_buf_; + + bool got_read_eof_ = false; + size_t read_buf_fill_size_ = 0; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + auto conn = std::make_unique<SSLConnectionBackendSchannel>(); + Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +} // namespace Service::SSL |