diff options
Diffstat (limited to 'src/core/hle/service/ssl')
-rw-r--r-- | src/core/hle/service/ssl/ssl.cpp | 353 | ||||
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend.h | 45 | ||||
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend_none.cpp | 16 | ||||
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend_openssl.cpp | 351 | ||||
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend_schannel.cpp | 544 | ||||
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend_securetransport.cpp | 222 |
6 files changed, 1514 insertions, 17 deletions
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp index 2b99dd7ac..9c96f9763 100644 --- a/src/core/hle/service/ssl/ssl.cpp +++ b/src/core/hle/service/ssl/ssl.cpp @@ -1,10 +1,18 @@ // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later +#include "common/string_util.h" + +#include "core/core.h" #include "core/hle/service/ipc_helpers.h" #include "core/hle/service/server_manager.h" #include "core/hle/service/service.h" +#include "core/hle/service/sm/sm.h" +#include "core/hle/service/sockets/bsd.h" #include "core/hle/service/ssl/ssl.h" +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" namespace Service::SSL { @@ -20,6 +28,18 @@ enum class ContextOption : u32 { CrlImportDateCheckEnable = 1, }; +// This is nn::ssl::Connection::IoMode +enum class IoMode : u32 { + Blocking = 1, + NonBlocking = 2, +}; + +// This is nn::ssl::sf::OptionType +enum class OptionType : u32 { + DoNotCloseSocket = 0, + GetServerCertChain = 1, +}; + // This is nn::ssl::sf::SslVersion struct SslVersion { union { @@ -34,35 +54,42 @@ struct SslVersion { }; }; +struct SslContextSharedData { + u32 connection_count = 0; +}; + class ISslConnection final : public ServiceFramework<ISslConnection> { public: - explicit ISslConnection(Core::System& system_, SslVersion version) - : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { + explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in, + std::shared_ptr<SslContextSharedData>& shared_data_in, + std::unique_ptr<SSLConnectionBackend>&& backend_in) + : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in}, + shared_data{shared_data_in}, backend{std::move(backend_in)} { // clang-format off static const FunctionInfo functions[] = { - {0, nullptr, "SetSocketDescriptor"}, - {1, nullptr, "SetHostName"}, - {2, nullptr, "SetVerifyOption"}, - {3, nullptr, "SetIoMode"}, + {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"}, + {1, &ISslConnection::SetHostName, "SetHostName"}, + {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"}, + {3, &ISslConnection::SetIoMode, "SetIoMode"}, {4, nullptr, "GetSocketDescriptor"}, {5, nullptr, "GetHostName"}, {6, nullptr, "GetVerifyOption"}, {7, nullptr, "GetIoMode"}, - {8, nullptr, "DoHandshake"}, - {9, nullptr, "DoHandshakeGetServerCert"}, - {10, nullptr, "Read"}, - {11, nullptr, "Write"}, - {12, nullptr, "Pending"}, + {8, &ISslConnection::DoHandshake, "DoHandshake"}, + {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"}, + {10, &ISslConnection::Read, "Read"}, + {11, &ISslConnection::Write, "Write"}, + {12, &ISslConnection::Pending, "Pending"}, {13, nullptr, "Peek"}, {14, nullptr, "Poll"}, {15, nullptr, "GetVerifyCertError"}, {16, nullptr, "GetNeededServerCertBufferSize"}, - {17, nullptr, "SetSessionCacheMode"}, + {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"}, {18, nullptr, "GetSessionCacheMode"}, {19, nullptr, "FlushSessionCache"}, {20, nullptr, "SetRenegotiationMode"}, {21, nullptr, "GetRenegotiationMode"}, - {22, nullptr, "SetOption"}, + {22, &ISslConnection::SetOption, "SetOption"}, {23, nullptr, "GetOption"}, {24, nullptr, "GetVerifyCertErrors"}, {25, nullptr, "GetCipherInfo"}, @@ -80,21 +107,299 @@ public: // clang-format on RegisterHandlers(functions); + + shared_data->connection_count++; + } + + ~ISslConnection() { + shared_data->connection_count--; + if (fd_to_close.has_value()) { + const s32 fd = *fd_to_close; + if (!do_not_close_socket) { + LOG_ERROR(Service_SSL, + "do_not_close_socket was changed after setting socket; is this right?"); + } else { + auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); + if (bsd) { + auto err = bsd->CloseImpl(fd); + if (err != Service::Sockets::Errno::SUCCESS) { + LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err); + } + } + } + } } private: SslVersion ssl_version; + std::shared_ptr<SslContextSharedData> shared_data; + std::unique_ptr<SSLConnectionBackend> backend; + std::optional<int> fd_to_close; + bool do_not_close_socket = false; + bool get_server_cert_chain = false; + std::shared_ptr<Network::SocketBase> socket; + bool did_set_host_name = false; + bool did_handshake = false; + + ResultVal<s32> SetSocketDescriptorImpl(s32 fd) { + LOG_DEBUG(Service_SSL, "called, fd={}", fd); + ASSERT(!did_handshake); + auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); + ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; }); + s32 ret_fd; + // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor + if (do_not_close_socket) { + auto res = bsd->DuplicateSocketImpl(fd); + if (!res.has_value()) { + LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd); + return ResultInvalidSocket; + } + fd = *res; + fd_to_close = fd; + ret_fd = fd; + } else { + ret_fd = -1; + } + std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd); + if (!sock.has_value()) { + LOG_ERROR(Service_SSL, "invalid socket fd {}", fd); + return ResultInvalidSocket; + } + socket = std::move(*sock); + backend->SetSocket(socket); + return ret_fd; + } + + Result SetHostNameImpl(const std::string& hostname) { + LOG_DEBUG(Service_SSL, "called. hostname={}", hostname); + ASSERT(!did_handshake); + Result res = backend->SetHostName(hostname); + if (res == ResultSuccess) { + did_set_host_name = true; + } + return res; + } + + Result SetVerifyOptionImpl(u32 option) { + ASSERT(!did_handshake); + LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option); + return ResultSuccess; + } + + Result SetIoModeImpl(u32 input_mode) { + auto mode = static_cast<IoMode>(input_mode); + ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking); + ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; }); + + const bool non_block = mode == IoMode::NonBlocking; + const Network::Errno error = socket->SetNonBlock(non_block); + if (error != Network::Errno::SUCCESS) { + LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block); + } + return ResultSuccess; + } + + Result SetSessionCacheModeImpl(u32 mode) { + ASSERT(!did_handshake); + LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode); + return ResultSuccess; + } + + Result DoHandshakeImpl() { + ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; }); + ASSERT_OR_EXECUTE_MSG( + did_set_host_name, { return ResultInternalError; }, + "Expected SetHostName before DoHandshake"); + Result res = backend->DoHandshake(); + did_handshake = res.IsSuccess(); + return res; + } + + std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) { + struct Header { + u64 magic; + u32 count; + u32 pad; + }; + struct EntryHeader { + u32 size; + u32 offset; + }; + if (!get_server_cert_chain) { + // Just return the first one, unencoded. + ASSERT_OR_EXECUTE_MSG( + !certs.empty(), { return {}; }, "Should be at least one server cert"); + return certs[0]; + } + std::vector<u8> ret; + Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0}; + ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1)); + size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader); + for (auto& cert : certs) { + EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)}; + data_offset += cert.size(); + ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header), + reinterpret_cast<u8*>(&entry_header + 1)); + } + for (auto& cert : certs) { + ret.insert(ret.end(), cert.begin(), cert.end()); + } + return ret; + } + + ResultVal<std::vector<u8>> ReadImpl(size_t size) { + ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); + std::vector<u8> res(size); + ResultVal<size_t> actual = backend->Read(res); + if (actual.Failed()) { + return actual.Code(); + } + res.resize(*actual); + return res; + } + + ResultVal<size_t> WriteImpl(std::span<const u8> data) { + ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); + return backend->Write(data); + } + + ResultVal<s32> PendingImpl() { + LOG_WARNING(Service_SSL, "(STUBBED) called."); + return 0; + } + + void SetSocketDescriptor(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const s32 fd = rp.Pop<s32>(); + const ResultVal<s32> res = SetSocketDescriptorImpl(fd); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push<s32>(res.ValueOr(-1)); + } + + void SetHostName(HLERequestContext& ctx) { + const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer()); + const Result res = SetHostNameImpl(hostname); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetVerifyOption(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 option = rp.Pop<u32>(); + const Result res = SetVerifyOptionImpl(option); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetIoMode(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 mode = rp.Pop<u32>(); + const Result res = SetIoModeImpl(mode); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void DoHandshake(HLERequestContext& ctx) { + const Result res = DoHandshakeImpl(); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void DoHandshakeGetServerCert(HLERequestContext& ctx) { + struct OutputParameters { + u32 certs_size; + u32 certs_count; + }; + static_assert(sizeof(OutputParameters) == 0x8); + + const Result res = DoHandshakeImpl(); + OutputParameters out{}; + if (res == ResultSuccess) { + auto certs = backend->GetServerCerts(); + if (certs.Succeeded()) { + const std::vector<u8> certs_buf = SerializeServerCerts(*certs); + ctx.WriteBuffer(certs_buf); + out.certs_count = static_cast<u32>(certs->size()); + out.certs_size = static_cast<u32>(certs_buf.size()); + } + } + IPC::ResponseBuilder rb{ctx, 4}; + rb.Push(res); + rb.PushRaw(out); + } + + void Read(HLERequestContext& ctx) { + const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize()); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + if (res.Succeeded()) { + rb.Push(static_cast<u32>(res->size())); + ctx.WriteBuffer(*res); + } else { + rb.Push(static_cast<u32>(0)); + } + } + + void Write(HLERequestContext& ctx) { + const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer()); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push(static_cast<u32>(res.ValueOr(0))); + } + + void Pending(HLERequestContext& ctx) { + const ResultVal<s32> res = PendingImpl(); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push<s32>(res.ValueOr(0)); + } + + void SetSessionCacheMode(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 mode = rp.Pop<u32>(); + const Result res = SetSessionCacheModeImpl(mode); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetOption(HLERequestContext& ctx) { + struct Parameters { + OptionType option; + s32 value; + }; + static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size"); + + IPC::RequestParser rp{ctx}; + const auto parameters = rp.PopRaw<Parameters>(); + + switch (parameters.option) { + case OptionType::DoNotCloseSocket: + do_not_close_socket = static_cast<bool>(parameters.value); + break; + case OptionType::GetServerCertChain: + get_server_cert_chain = static_cast<bool>(parameters.value); + break; + default: + LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option, + parameters.value); + } + + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(ResultSuccess); + } }; class ISslContext final : public ServiceFramework<ISslContext> { public: explicit ISslContext(Core::System& system_, SslVersion version) - : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { + : ServiceFramework{system_, "ISslContext"}, ssl_version{version}, + shared_data{std::make_shared<SslContextSharedData>()} { static const FunctionInfo functions[] = { {0, &ISslContext::SetOption, "SetOption"}, {1, nullptr, "GetOption"}, {2, &ISslContext::CreateConnection, "CreateConnection"}, - {3, nullptr, "GetConnectionCount"}, + {3, &ISslContext::GetConnectionCount, "GetConnectionCount"}, {4, &ISslContext::ImportServerPki, "ImportServerPki"}, {5, &ISslContext::ImportClientPki, "ImportClientPki"}, {6, nullptr, "RemoveServerPki"}, @@ -111,6 +416,7 @@ public: private: SslVersion ssl_version; + std::shared_ptr<SslContextSharedData> shared_data; void SetOption(HLERequestContext& ctx) { struct Parameters { @@ -130,11 +436,24 @@ private: } void CreateConnection(HLERequestContext& ctx) { - LOG_WARNING(Service_SSL, "(STUBBED) called"); + LOG_WARNING(Service_SSL, "called"); + + auto backend_res = CreateSSLConnectionBackend(); IPC::ResponseBuilder rb{ctx, 2, 0, 1}; + rb.Push(backend_res.Code()); + if (backend_res.Succeeded()) { + rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data, + std::move(*backend_res)); + } + } + + void GetConnectionCount(HLERequestContext& ctx) { + LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count); + + IPC::ResponseBuilder rb{ctx, 3}; rb.Push(ResultSuccess); - rb.PushIpcInterface<ISslConnection>(system, ssl_version); + rb.Push(shared_data->connection_count); } void ImportServerPki(HLERequestContext& ctx) { diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h new file mode 100644 index 000000000..409f4367c --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend.h @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#pragma once + +#include <memory> +#include <span> +#include <string> +#include <vector> + +#include "common/common_types.h" + +#include "core/hle/result.h" + +namespace Network { +class SocketBase; +} + +namespace Service::SSL { + +constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103}; +constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106}; +constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205}; +constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up + +// ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake, +// with no way in the latter case to distinguish whether the client should poll +// for read or write. The one official client I've seen handles this by always +// polling for read (with a timeout). +constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204}; + +class SSLConnectionBackend { +public: + virtual ~SSLConnectionBackend() {} + virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0; + virtual Result SetHostName(const std::string& hostname) = 0; + virtual Result DoHandshake() = 0; + virtual ResultVal<size_t> Read(std::span<u8> data) = 0; + virtual ResultVal<size_t> Write(std::span<const u8> data) = 0; + virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend(); + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp new file mode 100644 index 000000000..2f4f23c42 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_none.cpp @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include "common/logging/log.h" + +#include "core/hle/service/ssl/ssl_backend.h" + +namespace Service::SSL { + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + LOG_ERROR(Service_SSL, + "Can't create SSL connection because no SSL backend is available on this platform"); + return ResultInternalError; +} + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp new file mode 100644 index 000000000..6ca869dbf --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp @@ -0,0 +1,351 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include <mutex> + +#include <openssl/bio.h> +#include <openssl/err.h> +#include <openssl/ssl.h> +#include <openssl/x509.h> + +#include "common/fs/file.h" +#include "common/hex_util.h" +#include "common/string_util.h" + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +using namespace Common::FS; + +namespace Service::SSL { + +// Import OpenSSL's `SSL` type into the namespace. This is needed because the +// namespace is also named `SSL`. +using ::SSL; + +namespace { + +std::once_flag one_time_init_flag; +bool one_time_init_success = false; + +SSL_CTX* ssl_ctx; +IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment +BIO_METHOD* bio_meth; + +Result CheckOpenSSLErrors(); +void OneTimeInit(); +void OneTimeInitLogFile(); +bool OneTimeInitBIO(); + +} // namespace + +class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend { +public: + Result Init() { + std::call_once(one_time_init_flag, OneTimeInit); + + if (!one_time_init_success) { + LOG_ERROR(Service_SSL, + "Can't create SSL connection because OpenSSL one-time initialization failed"); + return ResultInternalError; + } + + ssl = SSL_new(ssl_ctx); + if (!ssl) { + LOG_ERROR(Service_SSL, "SSL_new failed"); + return CheckOpenSSLErrors(); + } + + SSL_set_connect_state(ssl); + + bio = BIO_new(bio_meth); + if (!bio) { + LOG_ERROR(Service_SSL, "BIO_new failed"); + return CheckOpenSSLErrors(); + } + + BIO_set_data(bio, this); + BIO_set_init(bio, 1); + SSL_set_bio(ssl, bio, bio); + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { + socket = std::move(socket_in); + } + + Result SetHostName(const std::string& hostname) override { + if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification + LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname); + return CheckOpenSSLErrors(); + } + if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI + LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname); + return CheckOpenSSLErrors(); + } + return ResultSuccess; + } + + Result DoHandshake() override { + SSL_set_verify_result(ssl, X509_V_OK); + const int ret = SSL_do_handshake(ssl); + const long verify_result = SSL_get_verify_result(ssl); + if (verify_result != X509_V_OK) { + LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}", + X509_verify_cert_error_string(verify_result)); + return CheckOpenSSLErrors(); + } + if (ret <= 0) { + const int ssl_err = SSL_get_error(ssl, ret); + if (ssl_err == SSL_ERROR_ZERO_RETURN || + (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) { + LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); + return ResultInternalError; + } + } + return HandleReturn("SSL_do_handshake", 0, ret).Code(); + } + + ResultVal<size_t> Read(std::span<u8> data) override { + size_t actual; + const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual); + return HandleReturn("SSL_read_ex", actual, ret); + } + + ResultVal<size_t> Write(std::span<const u8> data) override { + size_t actual; + const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual); + return HandleReturn("SSL_write_ex", actual, ret); + } + + ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) { + const int ssl_err = SSL_get_error(ssl, ret); + CheckOpenSSLErrors(); + switch (ssl_err) { + case SSL_ERROR_NONE: + return actual; + case SSL_ERROR_ZERO_RETURN: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what); + // DoHandshake special-cases this, but for Read and Write: + return size_t(0); + case SSL_ERROR_WANT_READ: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what); + return ResultWouldBlock; + case SSL_ERROR_WANT_WRITE: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what); + return ResultWouldBlock; + default: + if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) { + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what); + return size_t(0); + } + LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err); + return ResultInternalError; + } + } + + ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { + STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); + if (!chain) { + LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr"); + return ResultInternalError; + } + std::vector<std::vector<u8>> ret; + int count = sk_X509_num(chain); + ASSERT(count >= 0); + for (int i = 0; i < count; i++) { + X509* x509 = sk_X509_value(chain, i); + ASSERT_OR_EXECUTE(x509 != nullptr, { continue; }); + unsigned char* buf = nullptr; + int len = i2d_X509(x509, &buf); + ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; }); + ret.emplace_back(buf, buf + len); + OPENSSL_free(buf); + } + return ret; + } + + ~SSLConnectionBackendOpenSSL() { + // these are null-tolerant: + SSL_free(ssl); + BIO_free(bio); + } + + static void KeyLogCallback(const SSL* ssl, const char* line) { + std::string str(line); + str.push_back('\n'); + // Do this in a single WriteString for atomicity if multiple instances + // are running on different threads (though that can't currently + // happen). + if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) { + LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE"); + } + LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line); + } + + static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) { + auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); + ASSERT_OR_EXECUTE_MSG( + self->socket, { return 0; }, "OpenSSL asked to send but we have no socket"); + BIO_clear_retry_flags(bio); + auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0); + switch (err) { + case Network::Errno::SUCCESS: + *actual_p = actual; + return 1; + case Network::Errno::AGAIN: + BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY); + return 0; + default: + LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); + return -1; + } + } + + static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) { + auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); + ASSERT_OR_EXECUTE_MSG( + self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket"); + BIO_clear_retry_flags(bio); + auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len}); + switch (err) { + case Network::Errno::SUCCESS: + *actual_p = actual; + if (actual == 0) { + self->got_read_eof = true; + } + return actual ? 1 : 0; + case Network::Errno::AGAIN: + BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY); + return 0; + default: + LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); + return -1; + } + } + + static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) { + switch (cmd) { + case BIO_CTRL_FLUSH: + // Nothing to flush. + return 1; + case BIO_CTRL_PUSH: + case BIO_CTRL_POP: +#ifdef BIO_CTRL_GET_KTLS_SEND + case BIO_CTRL_GET_KTLS_SEND: + case BIO_CTRL_GET_KTLS_RECV: +#endif + // We don't support these operations, but don't bother logging them + // as they're nothing unusual. + return 0; + default: + LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg); + return 0; + } + } + + SSL* ssl = nullptr; + BIO* bio = nullptr; + bool got_read_eof = false; + + std::shared_ptr<Network::SocketBase> socket; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + auto conn = std::make_unique<SSLConnectionBackendOpenSSL>(); + const Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +namespace { + +Result CheckOpenSSLErrors() { + unsigned long rc; + const char* file; + int line; + const char* func; + const char* data; + int flags; +#if OPENSSL_VERSION_NUMBER >= 0x30000000L + while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) +#else + // Can't get function names from OpenSSL on this version, so use mine: + func = __func__; + while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags))) +#endif + { + std::string msg; + msg.resize(1024, '\0'); + ERR_error_string_n(rc, msg.data(), msg.size()); + msg.resize(strlen(msg.data()), '\0'); + if (flags & ERR_TXT_STRING) { + msg.append(" | "); + msg.append(data); + } + Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error, + Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}", + msg); + } + return ResultInternalError; +} + +void OneTimeInit() { + ssl_ctx = SSL_CTX_new(TLS_client_method()); + if (!ssl_ctx) { + LOG_ERROR(Service_SSL, "SSL_CTX_new failed"); + CheckOpenSSLErrors(); + return; + } + + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr); + + if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) { + LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed"); + CheckOpenSSLErrors(); + return; + } + + OneTimeInitLogFile(); + + if (!OneTimeInitBIO()) { + return; + } + + one_time_init_success = true; +} + +void OneTimeInitLogFile() { + const char* logfile = getenv("SSLKEYLOGFILE"); + if (logfile) { + key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile, + FileShareFlag::ShareWriteOnly); + if (key_log_file.IsOpen()) { + SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback); + } else { + LOG_CRITICAL(Service_SSL, + "SSLKEYLOGFILE was set but file could not be opened; not logging keys!"); + } + } +} + +bool OneTimeInitBIO() { + bio_meth = + BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL"); + if (!bio_meth || + !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) || + !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) || + !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) { + LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD"); + return false; + } + return true; +} + +} // namespace + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp new file mode 100644 index 000000000..d8074339a --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp @@ -0,0 +1,544 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include <mutex> + +#include "common/error.h" +#include "common/fs/file.h" +#include "common/hex_util.h" +#include "common/string_util.h" + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +namespace { + +// These includes are inside the namespace to avoid a conflict on MinGW where +// the headers define an enum containing Network and Service as enumerators +// (which clash with the correspondingly named namespaces). +#define SECURITY_WIN32 +#include <schnlsp.h> +#include <security.h> +#include <wincrypt.h> + +std::once_flag one_time_init_flag; +bool one_time_init_success = false; + +SCHANNEL_CRED schannel_cred{}; +CredHandle cred_handle; + +static void OneTimeInit() { + schannel_cred.dwVersion = SCHANNEL_CRED_VERSION; + schannel_cred.dwFlags = + SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols + SCH_CRED_AUTO_CRED_VALIDATION | // validate certs + SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate + // ^ I'm assuming that nobody would want to connect Yuzu to a + // service that requires some OS-provided corporate client + // certificate, and presenting one to some arbitrary server + // might be a privacy concern? Who knows, though. + + const SECURITY_STATUS ret = + AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND, + nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr); + if (ret != SEC_E_OK) { + // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString. + LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}", + Common::NativeErrorToString(ret)); + return; + } + + if (getenv("SSLKEYLOGFILE")) { + LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting " + "keys; not logging keys!"); + // Not fatal. + } + + one_time_init_success = true; +} + +} // namespace + +namespace Service::SSL { + +class SSLConnectionBackendSchannel final : public SSLConnectionBackend { +public: + Result Init() { + std::call_once(one_time_init_flag, OneTimeInit); + + if (!one_time_init_success) { + LOG_ERROR( + Service_SSL, + "Can't create SSL connection because Schannel one-time initialization failed"); + return ResultInternalError; + } + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { + socket = std::move(socket_in); + } + + Result SetHostName(const std::string& hostname_in) override { + hostname = hostname_in; + return ResultSuccess; + } + + Result DoHandshake() override { + while (1) { + Result r; + switch (handshake_state) { + case HandshakeState::Initial: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || + (r = CallInitializeSecurityContext()) != ResultSuccess) { + return r; + } + // CallInitializeSecurityContext updated `handshake_state`. + continue; + case HandshakeState::ContinueNeeded: + case HandshakeState::IncompleteMessage: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || + (r = FillCiphertextReadBuf()) != ResultSuccess) { + return r; + } + if (ciphertext_read_buf.empty()) { + LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); + return ResultInternalError; + } + if ((r = CallInitializeSecurityContext()) != ResultSuccess) { + return r; + } + // CallInitializeSecurityContext updated `handshake_state`. + continue; + case HandshakeState::DoneAfterFlush: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) { + return r; + } + handshake_state = HandshakeState::Connected; + return ResultSuccess; + case HandshakeState::Connected: + LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook"); + return ResultInternalError; + case HandshakeState::Error: + return ResultInternalError; + } + } + } + + Result FillCiphertextReadBuf() { + const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096; + read_buf_fill_size = 0; + // This unnecessarily zeroes the buffer; oh well. + const size_t offset = ciphertext_read_buf.size(); + ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; }); + ciphertext_read_buf.resize(offset + fill_size, 0); + const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size); + const auto [actual, err] = socket->Recv(0, read_span); + switch (err) { + case Network::Errno::SUCCESS: + ASSERT(static_cast<size_t>(actual) <= fill_size); + ciphertext_read_buf.resize(offset + actual); + return ResultSuccess; + case Network::Errno::AGAIN: + ciphertext_read_buf.resize(offset); + return ResultWouldBlock; + default: + ciphertext_read_buf.resize(offset); + LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); + return ResultInternalError; + } + } + + // Returns success if the write buffer has been completely emptied. + Result FlushCiphertextWriteBuf() { + while (!ciphertext_write_buf.empty()) { + const auto [actual, err] = socket->Send(ciphertext_write_buf, 0); + switch (err) { + case Network::Errno::SUCCESS: + ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size()); + ciphertext_write_buf.erase(ciphertext_write_buf.begin(), + ciphertext_write_buf.begin() + actual); + break; + case Network::Errno::AGAIN: + return ResultWouldBlock; + default: + LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); + return ResultInternalError; + } + } + return ResultSuccess; + } + + Result CallInitializeSecurityContext() { + const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | + ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT | + ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM | + ISC_REQ_USE_SUPPLIED_CREDS; + unsigned long attr; + // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel + std::array<SecBuffer, 2> input_buffers{{ + // only used if `initial_call_done` + { + // [0] + .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), + .BufferType = SECBUFFER_TOKEN, + .pvBuffer = ciphertext_read_buf.data(), + }, + { + // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is + // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the + // whole buffer wasn't used) + .cbBuffer = 0, + .BufferType = SECBUFFER_EMPTY, + .pvBuffer = nullptr, + }, + }}; + std::array<SecBuffer, 2> output_buffers{{ + { + .cbBuffer = 0, + .BufferType = SECBUFFER_TOKEN, + .pvBuffer = nullptr, + }, // [0] + { + .cbBuffer = 0, + .BufferType = SECBUFFER_ALERT, + .pvBuffer = nullptr, + }, // [1] + }}; + SecBufferDesc input_desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(input_buffers.size()), + .pBuffers = input_buffers.data(), + }; + SecBufferDesc output_desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(output_buffers.size()), + .pBuffers = output_buffers.data(), + }; + ASSERT_OR_EXECUTE_MSG( + input_buffers[0].cbBuffer == ciphertext_read_buf.size(), + { return ResultInternalError; }, "read buffer too large"); + + bool initial_call_done = handshake_state != HandshakeState::Initial; + if (initial_call_done) { + LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext", + ciphertext_read_buf.size()); + } + + const SECURITY_STATUS ret = + InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr, + // Caller ensured we have set a hostname: + const_cast<char*>(hostname.value().c_str()), req, + 0, // Reserved1 + 0, // TargetDataRep not used with Schannel + initial_call_done ? &input_desc : nullptr, + 0, // Reserved2 + initial_call_done ? nullptr : &ctxt, &output_desc, &attr, + nullptr); // ptsExpiry + + if (output_buffers[0].pvBuffer) { + const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer), + output_buffers[0].cbBuffer); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end()); + FreeContextBuffer(output_buffers[0].pvBuffer); + } + + if (output_buffers[1].pvBuffer) { + const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer), + output_buffers[1].cbBuffer); + // The documentation doesn't explain what format this data is in. + LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(), + Common::HexToString(span)); + } + + switch (ret) { + case SEC_I_CONTINUE_NEEDED: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED"); + if (input_buffers[1].BufferType == SECBUFFER_EXTRA) { + LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer); + ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size()); + ciphertext_read_buf.erase(ciphertext_read_buf.begin(), + ciphertext_read_buf.end() - input_buffers[1].cbBuffer); + } else { + ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY); + ciphertext_read_buf.clear(); + } + handshake_state = HandshakeState::ContinueNeeded; + return ResultSuccess; + case SEC_E_INCOMPLETE_MESSAGE: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE"); + ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING); + read_buf_fill_size = input_buffers[1].cbBuffer; + handshake_state = HandshakeState::IncompleteMessage; + return ResultSuccess; + case SEC_E_OK: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK"); + ciphertext_read_buf.clear(); + handshake_state = HandshakeState::DoneAfterFlush; + return GrabStreamSizes(); + default: + LOG_ERROR(Service_SSL, + "InitializeSecurityContext failed (probably certificate/protocol issue): {}", + Common::NativeErrorToString(ret)); + handshake_state = HandshakeState::Error; + return ResultInternalError; + } + } + + Result GrabStreamSizes() { + const SECURITY_STATUS ret = + QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}", + Common::NativeErrorToString(ret)); + handshake_state = HandshakeState::Error; + return ResultInternalError; + } + return ResultSuccess; + } + + ResultVal<size_t> Read(std::span<u8> data) override { + if (handshake_state != HandshakeState::Connected) { + LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); + return ResultInternalError; + } + if (data.size() == 0 || got_read_eof) { + return size_t(0); + } + while (1) { + if (!cleartext_read_buf.empty()) { + const size_t read_size = std::min(cleartext_read_buf.size(), data.size()); + std::memcpy(data.data(), cleartext_read_buf.data(), read_size); + cleartext_read_buf.erase(cleartext_read_buf.begin(), + cleartext_read_buf.begin() + read_size); + return read_size; + } + if (!ciphertext_read_buf.empty()) { + SecBuffer empty{ + .cbBuffer = 0, + .BufferType = SECBUFFER_EMPTY, + .pvBuffer = nullptr, + }; + std::array<SecBuffer, 5> buffers{{ + { + .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), + .BufferType = SECBUFFER_DATA, + .pvBuffer = ciphertext_read_buf.data(), + }, + empty, + empty, + empty, + }}; + ASSERT_OR_EXECUTE_MSG( + buffers[0].cbBuffer == ciphertext_read_buf.size(), + { return ResultInternalError; }, "read buffer too large"); + SecBufferDesc desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(buffers.size()), + .pBuffers = buffers.data(), + }; + SECURITY_STATUS ret = + DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); + switch (ret) { + case SEC_E_OK: + ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER, + { return ResultInternalError; }); + ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA, + { return ResultInternalError; }); + ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER, + { return ResultInternalError; }); + cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer), + static_cast<u8*>(buffers[1].pvBuffer) + + buffers[1].cbBuffer); + if (buffers[3].BufferType == SECBUFFER_EXTRA) { + ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size()); + ciphertext_read_buf.erase(ciphertext_read_buf.begin(), + ciphertext_read_buf.end() - buffers[3].cbBuffer); + } else { + ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY); + ciphertext_read_buf.clear(); + } + continue; + case SEC_E_INCOMPLETE_MESSAGE: + break; + case SEC_I_CONTEXT_EXPIRED: + // Server hung up by sending close_notify. + got_read_eof = true; + return size_t(0); + default: + LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", + Common::NativeErrorToString(ret)); + return ResultInternalError; + } + } + const Result r = FillCiphertextReadBuf(); + if (r != ResultSuccess) { + return r; + } + if (ciphertext_read_buf.empty()) { + got_read_eof = true; + return size_t(0); + } + } + } + + ResultVal<size_t> Write(std::span<const u8> data) override { + if (handshake_state != HandshakeState::Connected) { + LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); + return ResultInternalError; + } + if (data.size() == 0) { + return size_t(0); + } + data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage)); + if (!cleartext_write_buf.empty()) { + // Already in the middle of a write. It wouldn't make sense to not + // finish sending the entire buffer since TLS has + // header/MAC/padding/etc. + if (data.size() != cleartext_write_buf.size() || + std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) { + LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); + return ResultInternalError; + } + return WriteAlreadyEncryptedData(); + } else { + cleartext_write_buf.assign(data.begin(), data.end()); + } + + std::vector<u8> header_buf(stream_sizes.cbHeader, 0); + std::vector<u8> tmp_data_buf = cleartext_write_buf; + std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0); + + std::array<SecBuffer, 3> buffers{{ + { + .cbBuffer = stream_sizes.cbHeader, + .BufferType = SECBUFFER_STREAM_HEADER, + .pvBuffer = header_buf.data(), + }, + { + .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()), + .BufferType = SECBUFFER_DATA, + .pvBuffer = tmp_data_buf.data(), + }, + { + .cbBuffer = stream_sizes.cbTrailer, + .BufferType = SECBUFFER_STREAM_TRAILER, + .pvBuffer = trailer_buf.data(), + }, + }}; + ASSERT_OR_EXECUTE_MSG( + buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; }, + "temp buffer too large"); + SecBufferDesc desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(buffers.size()), + .pBuffers = buffers.data(), + }; + + const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret)); + return ResultInternalError; + } + ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(), + header_buf.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(), + tmp_data_buf.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(), + trailer_buf.end()); + return WriteAlreadyEncryptedData(); + } + + ResultVal<size_t> WriteAlreadyEncryptedData() { + const Result r = FlushCiphertextWriteBuf(); + if (r != ResultSuccess) { + return r; + } + // write buf is empty + const size_t cleartext_bytes_written = cleartext_write_buf.size(); + cleartext_write_buf.clear(); + return cleartext_bytes_written; + } + + ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { + PCCERT_CONTEXT returned_cert = nullptr; + const SECURITY_STATUS ret = + QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, + "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}", + Common::NativeErrorToString(ret)); + return ResultInternalError; + } + PCCERT_CONTEXT some_cert = nullptr; + std::vector<std::vector<u8>> certs; + while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) { + certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded), + static_cast<u8*>(some_cert->pbCertEncoded) + + some_cert->cbCertEncoded); + } + std::reverse(certs.begin(), + certs.end()); // Windows returns certs in reverse order from what we want + CertFreeCertificateContext(returned_cert); + return certs; + } + + ~SSLConnectionBackendSchannel() { + if (handshake_state != HandshakeState::Initial) { + DeleteSecurityContext(&ctxt); + } + } + + enum class HandshakeState { + // Haven't called anything yet. + Initial, + // `SEC_I_CONTINUE_NEEDED` was returned by + // `InitializeSecurityContext`; must finish sending data (if any) in + // the write buffer, then read at least one byte before calling + // `InitializeSecurityContext` again. + ContinueNeeded, + // `SEC_E_INCOMPLETE_MESSAGE` was returned by + // `InitializeSecurityContext`; hopefully the write buffer is empty; + // must read at least one byte before calling + // `InitializeSecurityContext` again. + IncompleteMessage, + // `SEC_E_OK` was returned by `InitializeSecurityContext`; must + // finish sending data in the write buffer before having `DoHandshake` + // report success. + DoneAfterFlush, + // We finished the above and are now connected. At this point, writing + // and reading are separate 'state machines' represented by the + // nonemptiness of the ciphertext and cleartext read and write buffers. + Connected, + // Another error was returned and we shouldn't allow initialization + // to continue. + Error, + } handshake_state = HandshakeState::Initial; + + CtxtHandle ctxt; + SecPkgContext_StreamSizes stream_sizes; + + std::shared_ptr<Network::SocketBase> socket; + std::optional<std::string> hostname; + + std::vector<u8> ciphertext_read_buf; + std::vector<u8> ciphertext_write_buf; + std::vector<u8> cleartext_read_buf; + std::vector<u8> cleartext_write_buf; + + bool got_read_eof = false; + size_t read_buf_fill_size = 0; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + auto conn = std::make_unique<SSLConnectionBackendSchannel>(); + const Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp new file mode 100644 index 000000000..b3083cbad --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp @@ -0,0 +1,222 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include <mutex> + +// SecureTransport has been deprecated in its entirety in favor of +// Network.framework, but that does not allow layering TLS on top of an +// arbitrary socket. +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#include <Security/SecureTransport.h> +#pragma GCC diagnostic pop +#endif + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +namespace { + +template <typename T> +struct CFReleaser { + T ptr; + + YUZU_NON_COPYABLE(CFReleaser); + constexpr CFReleaser() : ptr(nullptr) {} + constexpr CFReleaser(T ptr) : ptr(ptr) {} + constexpr operator T() { + return ptr; + } + ~CFReleaser() { + if (ptr) { + CFRelease(ptr); + } + } +}; + +std::string CFStringToString(CFStringRef cfstr) { + CFReleaser<CFDataRef> cfdata( + CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0)); + ASSERT_OR_EXECUTE(cfdata, { return "???"; }); + return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)), + CFDataGetLength(cfdata)); +} + +std::string OSStatusToString(OSStatus status) { + CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr)); + if (!cfstr) { + return "[unknown error]"; + } + return CFStringToString(cfstr); +} + +} // namespace + +namespace Service::SSL { + +class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend { +public: + Result Init() { + static std::once_flag once_flag; + std::call_once(once_flag, []() { + if (getenv("SSLKEYLOGFILE")) { + LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not " + "support exporting keys; not logging keys!"); + // Not fatal. + } + }); + + context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType); + if (!context) { + LOG_ERROR(Service_SSL, "SSLCreateContext failed"); + return ResultInternalError; + } + + OSStatus status; + if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) || + (status = SSLSetConnection(context, this))) { + LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}", + OSStatusToString(status)); + return ResultInternalError; + } + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override { + socket = std::move(in_socket); + } + + Result SetHostName(const std::string& hostname) override { + OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size()); + if (status) { + LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status)); + return ResultInternalError; + } + return ResultSuccess; + } + + Result DoHandshake() override { + OSStatus status = SSLHandshake(context); + return HandleReturn("SSLHandshake", 0, status).Code(); + } + + ResultVal<size_t> Read(std::span<u8> data) override { + size_t actual; + OSStatus status = SSLRead(context, data.data(), data.size(), &actual); + ; + return HandleReturn("SSLRead", actual, status); + } + + ResultVal<size_t> Write(std::span<const u8> data) override { + size_t actual; + OSStatus status = SSLWrite(context, data.data(), data.size(), &actual); + ; + return HandleReturn("SSLWrite", actual, status); + } + + ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) { + switch (status) { + case 0: + return actual; + case errSSLWouldBlock: + return ResultWouldBlock; + default: { + std::string reason; + if (got_read_eof) { + reason = "server hung up"; + } else { + reason = OSStatusToString(status); + } + LOG_ERROR(Service_SSL, "{} failed: {}", what, reason); + return ResultInternalError; + } + } + } + + ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { + CFReleaser<SecTrustRef> trust; + OSStatus status = SSLCopyPeerTrust(context, &trust.ptr); + if (status) { + LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status)); + return ResultInternalError; + } + std::vector<std::vector<u8>> ret; + for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) { + SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i); + CFReleaser<CFDataRef> data(SecCertificateCopyData(cert)); + ASSERT_OR_EXECUTE(data, { return ResultInternalError; }); + const u8* ptr = CFDataGetBytePtr(data); + ret.emplace_back(ptr, ptr + CFDataGetLength(data)); + } + return ret; + } + + static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) { + return ReadOrWriteCallback(connection, data, dataLength, true); + } + + static OSStatus WriteCallback(SSLConnectionRef connection, const void* data, + size_t* dataLength) { + return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false); + } + + static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength, + bool is_read) { + auto self = + static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection)); + ASSERT_OR_EXECUTE_MSG( + self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket", + is_read ? "read" : "write"); + + // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are + // expected to read/write the full requested dataLength or return an + // error, so we have to add a loop ourselves. + size_t requested_len = *dataLength; + size_t offset = 0; + while (offset < requested_len) { + std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset); + auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0); + LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset, + actual, cur.size(), static_cast<s32>(err)); + switch (err) { + case Network::Errno::SUCCESS: + offset += actual; + if (actual == 0) { + ASSERT(is_read); + self->got_read_eof = true; + return errSecEndOfData; + } + break; + case Network::Errno::AGAIN: + *dataLength = offset; + return errSSLWouldBlock; + default: + LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}", + is_read ? "recv" : "send", err); + return errSecIO; + } + } + ASSERT(offset == requested_len); + return 0; + } + +private: + CFReleaser<SSLContextRef> context = nullptr; + bool got_read_eof = false; + + std::shared_ptr<Network::SocketBase> socket; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + auto conn = std::make_unique<SSLConnectionBackendSecureTransport>(); + const Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +} // namespace Service::SSL |