summaryrefslogtreecommitdiffstats
path: root/src/core/hle/service/ssl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/hle/service/ssl')
-rw-r--r--src/core/hle/service/ssl/ssl.cpp353
-rw-r--r--src/core/hle/service/ssl/ssl_backend.h45
-rw-r--r--src/core/hle/service/ssl/ssl_backend_none.cpp16
-rw-r--r--src/core/hle/service/ssl/ssl_backend_openssl.cpp351
-rw-r--r--src/core/hle/service/ssl/ssl_backend_schannel.cpp544
-rw-r--r--src/core/hle/service/ssl/ssl_backend_securetransport.cpp222
6 files changed, 1514 insertions, 17 deletions
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp
index 2b99dd7ac..9c96f9763 100644
--- a/src/core/hle/service/ssl/ssl.cpp
+++ b/src/core/hle/service/ssl/ssl.cpp
@@ -1,10 +1,18 @@
// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
+#include "common/string_util.h"
+
+#include "core/core.h"
#include "core/hle/service/ipc_helpers.h"
#include "core/hle/service/server_manager.h"
#include "core/hle/service/service.h"
+#include "core/hle/service/sm/sm.h"
+#include "core/hle/service/sockets/bsd.h"
#include "core/hle/service/ssl/ssl.h"
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
namespace Service::SSL {
@@ -20,6 +28,18 @@ enum class ContextOption : u32 {
CrlImportDateCheckEnable = 1,
};
+// This is nn::ssl::Connection::IoMode
+enum class IoMode : u32 {
+ Blocking = 1,
+ NonBlocking = 2,
+};
+
+// This is nn::ssl::sf::OptionType
+enum class OptionType : u32 {
+ DoNotCloseSocket = 0,
+ GetServerCertChain = 1,
+};
+
// This is nn::ssl::sf::SslVersion
struct SslVersion {
union {
@@ -34,35 +54,42 @@ struct SslVersion {
};
};
+struct SslContextSharedData {
+ u32 connection_count = 0;
+};
+
class ISslConnection final : public ServiceFramework<ISslConnection> {
public:
- explicit ISslConnection(Core::System& system_, SslVersion version)
- : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} {
+ explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in,
+ std::shared_ptr<SslContextSharedData>& shared_data_in,
+ std::unique_ptr<SSLConnectionBackend>&& backend_in)
+ : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in},
+ shared_data{shared_data_in}, backend{std::move(backend_in)} {
// clang-format off
static const FunctionInfo functions[] = {
- {0, nullptr, "SetSocketDescriptor"},
- {1, nullptr, "SetHostName"},
- {2, nullptr, "SetVerifyOption"},
- {3, nullptr, "SetIoMode"},
+ {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
+ {1, &ISslConnection::SetHostName, "SetHostName"},
+ {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"},
+ {3, &ISslConnection::SetIoMode, "SetIoMode"},
{4, nullptr, "GetSocketDescriptor"},
{5, nullptr, "GetHostName"},
{6, nullptr, "GetVerifyOption"},
{7, nullptr, "GetIoMode"},
- {8, nullptr, "DoHandshake"},
- {9, nullptr, "DoHandshakeGetServerCert"},
- {10, nullptr, "Read"},
- {11, nullptr, "Write"},
- {12, nullptr, "Pending"},
+ {8, &ISslConnection::DoHandshake, "DoHandshake"},
+ {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"},
+ {10, &ISslConnection::Read, "Read"},
+ {11, &ISslConnection::Write, "Write"},
+ {12, &ISslConnection::Pending, "Pending"},
{13, nullptr, "Peek"},
{14, nullptr, "Poll"},
{15, nullptr, "GetVerifyCertError"},
{16, nullptr, "GetNeededServerCertBufferSize"},
- {17, nullptr, "SetSessionCacheMode"},
+ {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"},
{18, nullptr, "GetSessionCacheMode"},
{19, nullptr, "FlushSessionCache"},
{20, nullptr, "SetRenegotiationMode"},
{21, nullptr, "GetRenegotiationMode"},
- {22, nullptr, "SetOption"},
+ {22, &ISslConnection::SetOption, "SetOption"},
{23, nullptr, "GetOption"},
{24, nullptr, "GetVerifyCertErrors"},
{25, nullptr, "GetCipherInfo"},
@@ -80,21 +107,299 @@ public:
// clang-format on
RegisterHandlers(functions);
+
+ shared_data->connection_count++;
+ }
+
+ ~ISslConnection() {
+ shared_data->connection_count--;
+ if (fd_to_close.has_value()) {
+ const s32 fd = *fd_to_close;
+ if (!do_not_close_socket) {
+ LOG_ERROR(Service_SSL,
+ "do_not_close_socket was changed after setting socket; is this right?");
+ } else {
+ auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
+ if (bsd) {
+ auto err = bsd->CloseImpl(fd);
+ if (err != Service::Sockets::Errno::SUCCESS) {
+ LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err);
+ }
+ }
+ }
+ }
}
private:
SslVersion ssl_version;
+ std::shared_ptr<SslContextSharedData> shared_data;
+ std::unique_ptr<SSLConnectionBackend> backend;
+ std::optional<int> fd_to_close;
+ bool do_not_close_socket = false;
+ bool get_server_cert_chain = false;
+ std::shared_ptr<Network::SocketBase> socket;
+ bool did_set_host_name = false;
+ bool did_handshake = false;
+
+ ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
+ LOG_DEBUG(Service_SSL, "called, fd={}", fd);
+ ASSERT(!did_handshake);
+ auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
+ ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
+ s32 ret_fd;
+ // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
+ if (do_not_close_socket) {
+ auto res = bsd->DuplicateSocketImpl(fd);
+ if (!res.has_value()) {
+ LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd);
+ return ResultInvalidSocket;
+ }
+ fd = *res;
+ fd_to_close = fd;
+ ret_fd = fd;
+ } else {
+ ret_fd = -1;
+ }
+ std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
+ if (!sock.has_value()) {
+ LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
+ return ResultInvalidSocket;
+ }
+ socket = std::move(*sock);
+ backend->SetSocket(socket);
+ return ret_fd;
+ }
+
+ Result SetHostNameImpl(const std::string& hostname) {
+ LOG_DEBUG(Service_SSL, "called. hostname={}", hostname);
+ ASSERT(!did_handshake);
+ Result res = backend->SetHostName(hostname);
+ if (res == ResultSuccess) {
+ did_set_host_name = true;
+ }
+ return res;
+ }
+
+ Result SetVerifyOptionImpl(u32 option) {
+ ASSERT(!did_handshake);
+ LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
+ return ResultSuccess;
+ }
+
+ Result SetIoModeImpl(u32 input_mode) {
+ auto mode = static_cast<IoMode>(input_mode);
+ ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
+ ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; });
+
+ const bool non_block = mode == IoMode::NonBlocking;
+ const Network::Errno error = socket->SetNonBlock(non_block);
+ if (error != Network::Errno::SUCCESS) {
+ LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
+ }
+ return ResultSuccess;
+ }
+
+ Result SetSessionCacheModeImpl(u32 mode) {
+ ASSERT(!did_handshake);
+ LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
+ return ResultSuccess;
+ }
+
+ Result DoHandshakeImpl() {
+ ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; });
+ ASSERT_OR_EXECUTE_MSG(
+ did_set_host_name, { return ResultInternalError; },
+ "Expected SetHostName before DoHandshake");
+ Result res = backend->DoHandshake();
+ did_handshake = res.IsSuccess();
+ return res;
+ }
+
+ std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) {
+ struct Header {
+ u64 magic;
+ u32 count;
+ u32 pad;
+ };
+ struct EntryHeader {
+ u32 size;
+ u32 offset;
+ };
+ if (!get_server_cert_chain) {
+ // Just return the first one, unencoded.
+ ASSERT_OR_EXECUTE_MSG(
+ !certs.empty(), { return {}; }, "Should be at least one server cert");
+ return certs[0];
+ }
+ std::vector<u8> ret;
+ Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0};
+ ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1));
+ size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader);
+ for (auto& cert : certs) {
+ EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)};
+ data_offset += cert.size();
+ ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header),
+ reinterpret_cast<u8*>(&entry_header + 1));
+ }
+ for (auto& cert : certs) {
+ ret.insert(ret.end(), cert.begin(), cert.end());
+ }
+ return ret;
+ }
+
+ ResultVal<std::vector<u8>> ReadImpl(size_t size) {
+ ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
+ std::vector<u8> res(size);
+ ResultVal<size_t> actual = backend->Read(res);
+ if (actual.Failed()) {
+ return actual.Code();
+ }
+ res.resize(*actual);
+ return res;
+ }
+
+ ResultVal<size_t> WriteImpl(std::span<const u8> data) {
+ ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
+ return backend->Write(data);
+ }
+
+ ResultVal<s32> PendingImpl() {
+ LOG_WARNING(Service_SSL, "(STUBBED) called.");
+ return 0;
+ }
+
+ void SetSocketDescriptor(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const s32 fd = rp.Pop<s32>();
+ const ResultVal<s32> res = SetSocketDescriptorImpl(fd);
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push<s32>(res.ValueOr(-1));
+ }
+
+ void SetHostName(HLERequestContext& ctx) {
+ const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer());
+ const Result res = SetHostNameImpl(hostname);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetVerifyOption(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 option = rp.Pop<u32>();
+ const Result res = SetVerifyOptionImpl(option);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetIoMode(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 mode = rp.Pop<u32>();
+ const Result res = SetIoModeImpl(mode);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void DoHandshake(HLERequestContext& ctx) {
+ const Result res = DoHandshakeImpl();
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void DoHandshakeGetServerCert(HLERequestContext& ctx) {
+ struct OutputParameters {
+ u32 certs_size;
+ u32 certs_count;
+ };
+ static_assert(sizeof(OutputParameters) == 0x8);
+
+ const Result res = DoHandshakeImpl();
+ OutputParameters out{};
+ if (res == ResultSuccess) {
+ auto certs = backend->GetServerCerts();
+ if (certs.Succeeded()) {
+ const std::vector<u8> certs_buf = SerializeServerCerts(*certs);
+ ctx.WriteBuffer(certs_buf);
+ out.certs_count = static_cast<u32>(certs->size());
+ out.certs_size = static_cast<u32>(certs_buf.size());
+ }
+ }
+ IPC::ResponseBuilder rb{ctx, 4};
+ rb.Push(res);
+ rb.PushRaw(out);
+ }
+
+ void Read(HLERequestContext& ctx) {
+ const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize());
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ if (res.Succeeded()) {
+ rb.Push(static_cast<u32>(res->size()));
+ ctx.WriteBuffer(*res);
+ } else {
+ rb.Push(static_cast<u32>(0));
+ }
+ }
+
+ void Write(HLERequestContext& ctx) {
+ const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer());
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push(static_cast<u32>(res.ValueOr(0)));
+ }
+
+ void Pending(HLERequestContext& ctx) {
+ const ResultVal<s32> res = PendingImpl();
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push<s32>(res.ValueOr(0));
+ }
+
+ void SetSessionCacheMode(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 mode = rp.Pop<u32>();
+ const Result res = SetSessionCacheModeImpl(mode);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetOption(HLERequestContext& ctx) {
+ struct Parameters {
+ OptionType option;
+ s32 value;
+ };
+ static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size");
+
+ IPC::RequestParser rp{ctx};
+ const auto parameters = rp.PopRaw<Parameters>();
+
+ switch (parameters.option) {
+ case OptionType::DoNotCloseSocket:
+ do_not_close_socket = static_cast<bool>(parameters.value);
+ break;
+ case OptionType::GetServerCertChain:
+ get_server_cert_chain = static_cast<bool>(parameters.value);
+ break;
+ default:
+ LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option,
+ parameters.value);
+ }
+
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(ResultSuccess);
+ }
};
class ISslContext final : public ServiceFramework<ISslContext> {
public:
explicit ISslContext(Core::System& system_, SslVersion version)
- : ServiceFramework{system_, "ISslContext"}, ssl_version{version} {
+ : ServiceFramework{system_, "ISslContext"}, ssl_version{version},
+ shared_data{std::make_shared<SslContextSharedData>()} {
static const FunctionInfo functions[] = {
{0, &ISslContext::SetOption, "SetOption"},
{1, nullptr, "GetOption"},
{2, &ISslContext::CreateConnection, "CreateConnection"},
- {3, nullptr, "GetConnectionCount"},
+ {3, &ISslContext::GetConnectionCount, "GetConnectionCount"},
{4, &ISslContext::ImportServerPki, "ImportServerPki"},
{5, &ISslContext::ImportClientPki, "ImportClientPki"},
{6, nullptr, "RemoveServerPki"},
@@ -111,6 +416,7 @@ public:
private:
SslVersion ssl_version;
+ std::shared_ptr<SslContextSharedData> shared_data;
void SetOption(HLERequestContext& ctx) {
struct Parameters {
@@ -130,11 +436,24 @@ private:
}
void CreateConnection(HLERequestContext& ctx) {
- LOG_WARNING(Service_SSL, "(STUBBED) called");
+ LOG_WARNING(Service_SSL, "called");
+
+ auto backend_res = CreateSSLConnectionBackend();
IPC::ResponseBuilder rb{ctx, 2, 0, 1};
+ rb.Push(backend_res.Code());
+ if (backend_res.Succeeded()) {
+ rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data,
+ std::move(*backend_res));
+ }
+ }
+
+ void GetConnectionCount(HLERequestContext& ctx) {
+ LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count);
+
+ IPC::ResponseBuilder rb{ctx, 3};
rb.Push(ResultSuccess);
- rb.PushIpcInterface<ISslConnection>(system, ssl_version);
+ rb.Push(shared_data->connection_count);
}
void ImportServerPki(HLERequestContext& ctx) {
diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h
new file mode 100644
index 000000000..409f4367c
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend.h
@@ -0,0 +1,45 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#pragma once
+
+#include <memory>
+#include <span>
+#include <string>
+#include <vector>
+
+#include "common/common_types.h"
+
+#include "core/hle/result.h"
+
+namespace Network {
+class SocketBase;
+}
+
+namespace Service::SSL {
+
+constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103};
+constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106};
+constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205};
+constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up
+
+// ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake,
+// with no way in the latter case to distinguish whether the client should poll
+// for read or write. The one official client I've seen handles this by always
+// polling for read (with a timeout).
+constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204};
+
+class SSLConnectionBackend {
+public:
+ virtual ~SSLConnectionBackend() {}
+ virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
+ virtual Result SetHostName(const std::string& hostname) = 0;
+ virtual Result DoHandshake() = 0;
+ virtual ResultVal<size_t> Read(std::span<u8> data) = 0;
+ virtual ResultVal<size_t> Write(std::span<const u8> data) = 0;
+ virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend();
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp
new file mode 100644
index 000000000..2f4f23c42
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_none.cpp
@@ -0,0 +1,16 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include "common/logging/log.h"
+
+#include "core/hle/service/ssl/ssl_backend.h"
+
+namespace Service::SSL {
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ LOG_ERROR(Service_SSL,
+ "Can't create SSL connection because no SSL backend is available on this platform");
+ return ResultInternalError;
+}
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
new file mode 100644
index 000000000..6ca869dbf
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
@@ -0,0 +1,351 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include <mutex>
+
+#include <openssl/bio.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+#include <openssl/x509.h>
+
+#include "common/fs/file.h"
+#include "common/hex_util.h"
+#include "common/string_util.h"
+
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
+
+using namespace Common::FS;
+
+namespace Service::SSL {
+
+// Import OpenSSL's `SSL` type into the namespace. This is needed because the
+// namespace is also named `SSL`.
+using ::SSL;
+
+namespace {
+
+std::once_flag one_time_init_flag;
+bool one_time_init_success = false;
+
+SSL_CTX* ssl_ctx;
+IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment
+BIO_METHOD* bio_meth;
+
+Result CheckOpenSSLErrors();
+void OneTimeInit();
+void OneTimeInitLogFile();
+bool OneTimeInitBIO();
+
+} // namespace
+
+class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend {
+public:
+ Result Init() {
+ std::call_once(one_time_init_flag, OneTimeInit);
+
+ if (!one_time_init_success) {
+ LOG_ERROR(Service_SSL,
+ "Can't create SSL connection because OpenSSL one-time initialization failed");
+ return ResultInternalError;
+ }
+
+ ssl = SSL_new(ssl_ctx);
+ if (!ssl) {
+ LOG_ERROR(Service_SSL, "SSL_new failed");
+ return CheckOpenSSLErrors();
+ }
+
+ SSL_set_connect_state(ssl);
+
+ bio = BIO_new(bio_meth);
+ if (!bio) {
+ LOG_ERROR(Service_SSL, "BIO_new failed");
+ return CheckOpenSSLErrors();
+ }
+
+ BIO_set_data(bio, this);
+ BIO_set_init(bio, 1);
+ SSL_set_bio(ssl, bio, bio);
+
+ return ResultSuccess;
+ }
+
+ void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
+ socket = std::move(socket_in);
+ }
+
+ Result SetHostName(const std::string& hostname) override {
+ if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification
+ LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
+ return CheckOpenSSLErrors();
+ }
+ if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI
+ LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
+ return CheckOpenSSLErrors();
+ }
+ return ResultSuccess;
+ }
+
+ Result DoHandshake() override {
+ SSL_set_verify_result(ssl, X509_V_OK);
+ const int ret = SSL_do_handshake(ssl);
+ const long verify_result = SSL_get_verify_result(ssl);
+ if (verify_result != X509_V_OK) {
+ LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
+ X509_verify_cert_error_string(verify_result));
+ return CheckOpenSSLErrors();
+ }
+ if (ret <= 0) {
+ const int ssl_err = SSL_get_error(ssl, ret);
+ if (ssl_err == SSL_ERROR_ZERO_RETURN ||
+ (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) {
+ LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
+ return ResultInternalError;
+ }
+ }
+ return HandleReturn("SSL_do_handshake", 0, ret).Code();
+ }
+
+ ResultVal<size_t> Read(std::span<u8> data) override {
+ size_t actual;
+ const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual);
+ return HandleReturn("SSL_read_ex", actual, ret);
+ }
+
+ ResultVal<size_t> Write(std::span<const u8> data) override {
+ size_t actual;
+ const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual);
+ return HandleReturn("SSL_write_ex", actual, ret);
+ }
+
+ ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
+ const int ssl_err = SSL_get_error(ssl, ret);
+ CheckOpenSSLErrors();
+ switch (ssl_err) {
+ case SSL_ERROR_NONE:
+ return actual;
+ case SSL_ERROR_ZERO_RETURN:
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
+ // DoHandshake special-cases this, but for Read and Write:
+ return size_t(0);
+ case SSL_ERROR_WANT_READ:
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
+ return ResultWouldBlock;
+ case SSL_ERROR_WANT_WRITE:
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
+ return ResultWouldBlock;
+ default:
+ if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) {
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
+ return size_t(0);
+ }
+ LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
+ return ResultInternalError;
+ }
+ }
+
+ ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
+ STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
+ if (!chain) {
+ LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
+ return ResultInternalError;
+ }
+ std::vector<std::vector<u8>> ret;
+ int count = sk_X509_num(chain);
+ ASSERT(count >= 0);
+ for (int i = 0; i < count; i++) {
+ X509* x509 = sk_X509_value(chain, i);
+ ASSERT_OR_EXECUTE(x509 != nullptr, { continue; });
+ unsigned char* buf = nullptr;
+ int len = i2d_X509(x509, &buf);
+ ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
+ ret.emplace_back(buf, buf + len);
+ OPENSSL_free(buf);
+ }
+ return ret;
+ }
+
+ ~SSLConnectionBackendOpenSSL() {
+ // these are null-tolerant:
+ SSL_free(ssl);
+ BIO_free(bio);
+ }
+
+ static void KeyLogCallback(const SSL* ssl, const char* line) {
+ std::string str(line);
+ str.push_back('\n');
+ // Do this in a single WriteString for atomicity if multiple instances
+ // are running on different threads (though that can't currently
+ // happen).
+ if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) {
+ LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE");
+ }
+ LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line);
+ }
+
+ static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
+ auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
+ ASSERT_OR_EXECUTE_MSG(
+ self->socket, { return 0; }, "OpenSSL asked to send but we have no socket");
+ BIO_clear_retry_flags(bio);
+ auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0);
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ *actual_p = actual;
+ return 1;
+ case Network::Errno::AGAIN:
+ BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY);
+ return 0;
+ default:
+ LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
+ return -1;
+ }
+ }
+
+ static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
+ auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
+ ASSERT_OR_EXECUTE_MSG(
+ self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket");
+ BIO_clear_retry_flags(bio);
+ auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len});
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ *actual_p = actual;
+ if (actual == 0) {
+ self->got_read_eof = true;
+ }
+ return actual ? 1 : 0;
+ case Network::Errno::AGAIN:
+ BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY);
+ return 0;
+ default:
+ LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
+ return -1;
+ }
+ }
+
+ static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) {
+ switch (cmd) {
+ case BIO_CTRL_FLUSH:
+ // Nothing to flush.
+ return 1;
+ case BIO_CTRL_PUSH:
+ case BIO_CTRL_POP:
+#ifdef BIO_CTRL_GET_KTLS_SEND
+ case BIO_CTRL_GET_KTLS_SEND:
+ case BIO_CTRL_GET_KTLS_RECV:
+#endif
+ // We don't support these operations, but don't bother logging them
+ // as they're nothing unusual.
+ return 0;
+ default:
+ LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg);
+ return 0;
+ }
+ }
+
+ SSL* ssl = nullptr;
+ BIO* bio = nullptr;
+ bool got_read_eof = false;
+
+ std::shared_ptr<Network::SocketBase> socket;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
+ const Result res = conn->Init();
+ if (res.IsFailure()) {
+ return res;
+ }
+ return conn;
+}
+
+namespace {
+
+Result CheckOpenSSLErrors() {
+ unsigned long rc;
+ const char* file;
+ int line;
+ const char* func;
+ const char* data;
+ int flags;
+#if OPENSSL_VERSION_NUMBER >= 0x30000000L
+ while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags)))
+#else
+ // Can't get function names from OpenSSL on this version, so use mine:
+ func = __func__;
+ while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags)))
+#endif
+ {
+ std::string msg;
+ msg.resize(1024, '\0');
+ ERR_error_string_n(rc, msg.data(), msg.size());
+ msg.resize(strlen(msg.data()), '\0');
+ if (flags & ERR_TXT_STRING) {
+ msg.append(" | ");
+ msg.append(data);
+ }
+ Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error,
+ Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}",
+ msg);
+ }
+ return ResultInternalError;
+}
+
+void OneTimeInit() {
+ ssl_ctx = SSL_CTX_new(TLS_client_method());
+ if (!ssl_ctx) {
+ LOG_ERROR(Service_SSL, "SSL_CTX_new failed");
+ CheckOpenSSLErrors();
+ return;
+ }
+
+ SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr);
+
+ if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) {
+ LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed");
+ CheckOpenSSLErrors();
+ return;
+ }
+
+ OneTimeInitLogFile();
+
+ if (!OneTimeInitBIO()) {
+ return;
+ }
+
+ one_time_init_success = true;
+}
+
+void OneTimeInitLogFile() {
+ const char* logfile = getenv("SSLKEYLOGFILE");
+ if (logfile) {
+ key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile,
+ FileShareFlag::ShareWriteOnly);
+ if (key_log_file.IsOpen()) {
+ SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback);
+ } else {
+ LOG_CRITICAL(Service_SSL,
+ "SSLKEYLOGFILE was set but file could not be opened; not logging keys!");
+ }
+ }
+}
+
+bool OneTimeInitBIO() {
+ bio_meth =
+ BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL");
+ if (!bio_meth ||
+ !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) ||
+ !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) ||
+ !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) {
+ LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD");
+ return false;
+ }
+ return true;
+}
+
+} // namespace
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
new file mode 100644
index 000000000..d8074339a
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
@@ -0,0 +1,544 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include <mutex>
+
+#include "common/error.h"
+#include "common/fs/file.h"
+#include "common/hex_util.h"
+#include "common/string_util.h"
+
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
+
+namespace {
+
+// These includes are inside the namespace to avoid a conflict on MinGW where
+// the headers define an enum containing Network and Service as enumerators
+// (which clash with the correspondingly named namespaces).
+#define SECURITY_WIN32
+#include <schnlsp.h>
+#include <security.h>
+#include <wincrypt.h>
+
+std::once_flag one_time_init_flag;
+bool one_time_init_success = false;
+
+SCHANNEL_CRED schannel_cred{};
+CredHandle cred_handle;
+
+static void OneTimeInit() {
+ schannel_cred.dwVersion = SCHANNEL_CRED_VERSION;
+ schannel_cred.dwFlags =
+ SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols
+ SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
+ SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate
+ // ^ I'm assuming that nobody would want to connect Yuzu to a
+ // service that requires some OS-provided corporate client
+ // certificate, and presenting one to some arbitrary server
+ // might be a privacy concern? Who knows, though.
+
+ const SECURITY_STATUS ret =
+ AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
+ nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
+ if (ret != SEC_E_OK) {
+ // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
+ LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}",
+ Common::NativeErrorToString(ret));
+ return;
+ }
+
+ if (getenv("SSLKEYLOGFILE")) {
+ LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting "
+ "keys; not logging keys!");
+ // Not fatal.
+ }
+
+ one_time_init_success = true;
+}
+
+} // namespace
+
+namespace Service::SSL {
+
+class SSLConnectionBackendSchannel final : public SSLConnectionBackend {
+public:
+ Result Init() {
+ std::call_once(one_time_init_flag, OneTimeInit);
+
+ if (!one_time_init_success) {
+ LOG_ERROR(
+ Service_SSL,
+ "Can't create SSL connection because Schannel one-time initialization failed");
+ return ResultInternalError;
+ }
+
+ return ResultSuccess;
+ }
+
+ void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
+ socket = std::move(socket_in);
+ }
+
+ Result SetHostName(const std::string& hostname_in) override {
+ hostname = hostname_in;
+ return ResultSuccess;
+ }
+
+ Result DoHandshake() override {
+ while (1) {
+ Result r;
+ switch (handshake_state) {
+ case HandshakeState::Initial:
+ if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
+ (r = CallInitializeSecurityContext()) != ResultSuccess) {
+ return r;
+ }
+ // CallInitializeSecurityContext updated `handshake_state`.
+ continue;
+ case HandshakeState::ContinueNeeded:
+ case HandshakeState::IncompleteMessage:
+ if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
+ (r = FillCiphertextReadBuf()) != ResultSuccess) {
+ return r;
+ }
+ if (ciphertext_read_buf.empty()) {
+ LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
+ return ResultInternalError;
+ }
+ if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
+ return r;
+ }
+ // CallInitializeSecurityContext updated `handshake_state`.
+ continue;
+ case HandshakeState::DoneAfterFlush:
+ if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
+ return r;
+ }
+ handshake_state = HandshakeState::Connected;
+ return ResultSuccess;
+ case HandshakeState::Connected:
+ LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
+ return ResultInternalError;
+ case HandshakeState::Error:
+ return ResultInternalError;
+ }
+ }
+ }
+
+ Result FillCiphertextReadBuf() {
+ const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096;
+ read_buf_fill_size = 0;
+ // This unnecessarily zeroes the buffer; oh well.
+ const size_t offset = ciphertext_read_buf.size();
+ ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
+ ciphertext_read_buf.resize(offset + fill_size, 0);
+ const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size);
+ const auto [actual, err] = socket->Recv(0, read_span);
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ ASSERT(static_cast<size_t>(actual) <= fill_size);
+ ciphertext_read_buf.resize(offset + actual);
+ return ResultSuccess;
+ case Network::Errno::AGAIN:
+ ciphertext_read_buf.resize(offset);
+ return ResultWouldBlock;
+ default:
+ ciphertext_read_buf.resize(offset);
+ LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
+ return ResultInternalError;
+ }
+ }
+
+ // Returns success if the write buffer has been completely emptied.
+ Result FlushCiphertextWriteBuf() {
+ while (!ciphertext_write_buf.empty()) {
+ const auto [actual, err] = socket->Send(ciphertext_write_buf, 0);
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size());
+ ciphertext_write_buf.erase(ciphertext_write_buf.begin(),
+ ciphertext_write_buf.begin() + actual);
+ break;
+ case Network::Errno::AGAIN:
+ return ResultWouldBlock;
+ default:
+ LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
+ return ResultInternalError;
+ }
+ }
+ return ResultSuccess;
+ }
+
+ Result CallInitializeSecurityContext() {
+ const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY |
+ ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
+ ISC_REQ_USE_SUPPLIED_CREDS;
+ unsigned long attr;
+ // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
+ std::array<SecBuffer, 2> input_buffers{{
+ // only used if `initial_call_done`
+ {
+ // [0]
+ .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
+ .BufferType = SECBUFFER_TOKEN,
+ .pvBuffer = ciphertext_read_buf.data(),
+ },
+ {
+ // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
+ // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
+ // whole buffer wasn't used)
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_EMPTY,
+ .pvBuffer = nullptr,
+ },
+ }};
+ std::array<SecBuffer, 2> output_buffers{{
+ {
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_TOKEN,
+ .pvBuffer = nullptr,
+ }, // [0]
+ {
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_ALERT,
+ .pvBuffer = nullptr,
+ }, // [1]
+ }};
+ SecBufferDesc input_desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(input_buffers.size()),
+ .pBuffers = input_buffers.data(),
+ };
+ SecBufferDesc output_desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(output_buffers.size()),
+ .pBuffers = output_buffers.data(),
+ };
+ ASSERT_OR_EXECUTE_MSG(
+ input_buffers[0].cbBuffer == ciphertext_read_buf.size(),
+ { return ResultInternalError; }, "read buffer too large");
+
+ bool initial_call_done = handshake_state != HandshakeState::Initial;
+ if (initial_call_done) {
+ LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
+ ciphertext_read_buf.size());
+ }
+
+ const SECURITY_STATUS ret =
+ InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr,
+ // Caller ensured we have set a hostname:
+ const_cast<char*>(hostname.value().c_str()), req,
+ 0, // Reserved1
+ 0, // TargetDataRep not used with Schannel
+ initial_call_done ? &input_desc : nullptr,
+ 0, // Reserved2
+ initial_call_done ? nullptr : &ctxt, &output_desc, &attr,
+ nullptr); // ptsExpiry
+
+ if (output_buffers[0].pvBuffer) {
+ const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
+ output_buffers[0].cbBuffer);
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end());
+ FreeContextBuffer(output_buffers[0].pvBuffer);
+ }
+
+ if (output_buffers[1].pvBuffer) {
+ const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
+ output_buffers[1].cbBuffer);
+ // The documentation doesn't explain what format this data is in.
+ LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
+ Common::HexToString(span));
+ }
+
+ switch (ret) {
+ case SEC_I_CONTINUE_NEEDED:
+ LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
+ if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
+ LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
+ ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size());
+ ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
+ ciphertext_read_buf.end() - input_buffers[1].cbBuffer);
+ } else {
+ ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
+ ciphertext_read_buf.clear();
+ }
+ handshake_state = HandshakeState::ContinueNeeded;
+ return ResultSuccess;
+ case SEC_E_INCOMPLETE_MESSAGE:
+ LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
+ ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
+ read_buf_fill_size = input_buffers[1].cbBuffer;
+ handshake_state = HandshakeState::IncompleteMessage;
+ return ResultSuccess;
+ case SEC_E_OK:
+ LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
+ ciphertext_read_buf.clear();
+ handshake_state = HandshakeState::DoneAfterFlush;
+ return GrabStreamSizes();
+ default:
+ LOG_ERROR(Service_SSL,
+ "InitializeSecurityContext failed (probably certificate/protocol issue): {}",
+ Common::NativeErrorToString(ret));
+ handshake_state = HandshakeState::Error;
+ return ResultInternalError;
+ }
+ }
+
+ Result GrabStreamSizes() {
+ const SECURITY_STATUS ret =
+ QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
+ if (ret != SEC_E_OK) {
+ LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
+ Common::NativeErrorToString(ret));
+ handshake_state = HandshakeState::Error;
+ return ResultInternalError;
+ }
+ return ResultSuccess;
+ }
+
+ ResultVal<size_t> Read(std::span<u8> data) override {
+ if (handshake_state != HandshakeState::Connected) {
+ LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
+ return ResultInternalError;
+ }
+ if (data.size() == 0 || got_read_eof) {
+ return size_t(0);
+ }
+ while (1) {
+ if (!cleartext_read_buf.empty()) {
+ const size_t read_size = std::min(cleartext_read_buf.size(), data.size());
+ std::memcpy(data.data(), cleartext_read_buf.data(), read_size);
+ cleartext_read_buf.erase(cleartext_read_buf.begin(),
+ cleartext_read_buf.begin() + read_size);
+ return read_size;
+ }
+ if (!ciphertext_read_buf.empty()) {
+ SecBuffer empty{
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_EMPTY,
+ .pvBuffer = nullptr,
+ };
+ std::array<SecBuffer, 5> buffers{{
+ {
+ .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
+ .BufferType = SECBUFFER_DATA,
+ .pvBuffer = ciphertext_read_buf.data(),
+ },
+ empty,
+ empty,
+ empty,
+ }};
+ ASSERT_OR_EXECUTE_MSG(
+ buffers[0].cbBuffer == ciphertext_read_buf.size(),
+ { return ResultInternalError; }, "read buffer too large");
+ SecBufferDesc desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(buffers.size()),
+ .pBuffers = buffers.data(),
+ };
+ SECURITY_STATUS ret =
+ DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
+ switch (ret) {
+ case SEC_E_OK:
+ ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
+ { return ResultInternalError; });
+ ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA,
+ { return ResultInternalError; });
+ ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
+ { return ResultInternalError; });
+ cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer),
+ static_cast<u8*>(buffers[1].pvBuffer) +
+ buffers[1].cbBuffer);
+ if (buffers[3].BufferType == SECBUFFER_EXTRA) {
+ ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size());
+ ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
+ ciphertext_read_buf.end() - buffers[3].cbBuffer);
+ } else {
+ ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
+ ciphertext_read_buf.clear();
+ }
+ continue;
+ case SEC_E_INCOMPLETE_MESSAGE:
+ break;
+ case SEC_I_CONTEXT_EXPIRED:
+ // Server hung up by sending close_notify.
+ got_read_eof = true;
+ return size_t(0);
+ default:
+ LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
+ Common::NativeErrorToString(ret));
+ return ResultInternalError;
+ }
+ }
+ const Result r = FillCiphertextReadBuf();
+ if (r != ResultSuccess) {
+ return r;
+ }
+ if (ciphertext_read_buf.empty()) {
+ got_read_eof = true;
+ return size_t(0);
+ }
+ }
+ }
+
+ ResultVal<size_t> Write(std::span<const u8> data) override {
+ if (handshake_state != HandshakeState::Connected) {
+ LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
+ return ResultInternalError;
+ }
+ if (data.size() == 0) {
+ return size_t(0);
+ }
+ data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage));
+ if (!cleartext_write_buf.empty()) {
+ // Already in the middle of a write. It wouldn't make sense to not
+ // finish sending the entire buffer since TLS has
+ // header/MAC/padding/etc.
+ if (data.size() != cleartext_write_buf.size() ||
+ std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) {
+ LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
+ return ResultInternalError;
+ }
+ return WriteAlreadyEncryptedData();
+ } else {
+ cleartext_write_buf.assign(data.begin(), data.end());
+ }
+
+ std::vector<u8> header_buf(stream_sizes.cbHeader, 0);
+ std::vector<u8> tmp_data_buf = cleartext_write_buf;
+ std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0);
+
+ std::array<SecBuffer, 3> buffers{{
+ {
+ .cbBuffer = stream_sizes.cbHeader,
+ .BufferType = SECBUFFER_STREAM_HEADER,
+ .pvBuffer = header_buf.data(),
+ },
+ {
+ .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()),
+ .BufferType = SECBUFFER_DATA,
+ .pvBuffer = tmp_data_buf.data(),
+ },
+ {
+ .cbBuffer = stream_sizes.cbTrailer,
+ .BufferType = SECBUFFER_STREAM_TRAILER,
+ .pvBuffer = trailer_buf.data(),
+ },
+ }};
+ ASSERT_OR_EXECUTE_MSG(
+ buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; },
+ "temp buffer too large");
+ SecBufferDesc desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(buffers.size()),
+ .pBuffers = buffers.data(),
+ };
+
+ const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
+ if (ret != SEC_E_OK) {
+ LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
+ return ResultInternalError;
+ }
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(),
+ header_buf.end());
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(),
+ tmp_data_buf.end());
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(),
+ trailer_buf.end());
+ return WriteAlreadyEncryptedData();
+ }
+
+ ResultVal<size_t> WriteAlreadyEncryptedData() {
+ const Result r = FlushCiphertextWriteBuf();
+ if (r != ResultSuccess) {
+ return r;
+ }
+ // write buf is empty
+ const size_t cleartext_bytes_written = cleartext_write_buf.size();
+ cleartext_write_buf.clear();
+ return cleartext_bytes_written;
+ }
+
+ ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
+ PCCERT_CONTEXT returned_cert = nullptr;
+ const SECURITY_STATUS ret =
+ QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
+ if (ret != SEC_E_OK) {
+ LOG_ERROR(Service_SSL,
+ "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
+ Common::NativeErrorToString(ret));
+ return ResultInternalError;
+ }
+ PCCERT_CONTEXT some_cert = nullptr;
+ std::vector<std::vector<u8>> certs;
+ while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
+ certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
+ static_cast<u8*>(some_cert->pbCertEncoded) +
+ some_cert->cbCertEncoded);
+ }
+ std::reverse(certs.begin(),
+ certs.end()); // Windows returns certs in reverse order from what we want
+ CertFreeCertificateContext(returned_cert);
+ return certs;
+ }
+
+ ~SSLConnectionBackendSchannel() {
+ if (handshake_state != HandshakeState::Initial) {
+ DeleteSecurityContext(&ctxt);
+ }
+ }
+
+ enum class HandshakeState {
+ // Haven't called anything yet.
+ Initial,
+ // `SEC_I_CONTINUE_NEEDED` was returned by
+ // `InitializeSecurityContext`; must finish sending data (if any) in
+ // the write buffer, then read at least one byte before calling
+ // `InitializeSecurityContext` again.
+ ContinueNeeded,
+ // `SEC_E_INCOMPLETE_MESSAGE` was returned by
+ // `InitializeSecurityContext`; hopefully the write buffer is empty;
+ // must read at least one byte before calling
+ // `InitializeSecurityContext` again.
+ IncompleteMessage,
+ // `SEC_E_OK` was returned by `InitializeSecurityContext`; must
+ // finish sending data in the write buffer before having `DoHandshake`
+ // report success.
+ DoneAfterFlush,
+ // We finished the above and are now connected. At this point, writing
+ // and reading are separate 'state machines' represented by the
+ // nonemptiness of the ciphertext and cleartext read and write buffers.
+ Connected,
+ // Another error was returned and we shouldn't allow initialization
+ // to continue.
+ Error,
+ } handshake_state = HandshakeState::Initial;
+
+ CtxtHandle ctxt;
+ SecPkgContext_StreamSizes stream_sizes;
+
+ std::shared_ptr<Network::SocketBase> socket;
+ std::optional<std::string> hostname;
+
+ std::vector<u8> ciphertext_read_buf;
+ std::vector<u8> ciphertext_write_buf;
+ std::vector<u8> cleartext_read_buf;
+ std::vector<u8> cleartext_write_buf;
+
+ bool got_read_eof = false;
+ size_t read_buf_fill_size = 0;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ auto conn = std::make_unique<SSLConnectionBackendSchannel>();
+ const Result res = conn->Init();
+ if (res.IsFailure()) {
+ return res;
+ }
+ return conn;
+}
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp
new file mode 100644
index 000000000..b3083cbad
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp
@@ -0,0 +1,222 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include <mutex>
+
+// SecureTransport has been deprecated in its entirety in favor of
+// Network.framework, but that does not allow layering TLS on top of an
+// arbitrary socket.
+#if defined(__GNUC__) || defined(__clang__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+#include <Security/SecureTransport.h>
+#pragma GCC diagnostic pop
+#endif
+
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
+
+namespace {
+
+template <typename T>
+struct CFReleaser {
+ T ptr;
+
+ YUZU_NON_COPYABLE(CFReleaser);
+ constexpr CFReleaser() : ptr(nullptr) {}
+ constexpr CFReleaser(T ptr) : ptr(ptr) {}
+ constexpr operator T() {
+ return ptr;
+ }
+ ~CFReleaser() {
+ if (ptr) {
+ CFRelease(ptr);
+ }
+ }
+};
+
+std::string CFStringToString(CFStringRef cfstr) {
+ CFReleaser<CFDataRef> cfdata(
+ CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0));
+ ASSERT_OR_EXECUTE(cfdata, { return "???"; });
+ return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)),
+ CFDataGetLength(cfdata));
+}
+
+std::string OSStatusToString(OSStatus status) {
+ CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr));
+ if (!cfstr) {
+ return "[unknown error]";
+ }
+ return CFStringToString(cfstr);
+}
+
+} // namespace
+
+namespace Service::SSL {
+
+class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend {
+public:
+ Result Init() {
+ static std::once_flag once_flag;
+ std::call_once(once_flag, []() {
+ if (getenv("SSLKEYLOGFILE")) {
+ LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not "
+ "support exporting keys; not logging keys!");
+ // Not fatal.
+ }
+ });
+
+ context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType);
+ if (!context) {
+ LOG_ERROR(Service_SSL, "SSLCreateContext failed");
+ return ResultInternalError;
+ }
+
+ OSStatus status;
+ if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) ||
+ (status = SSLSetConnection(context, this))) {
+ LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}",
+ OSStatusToString(status));
+ return ResultInternalError;
+ }
+
+ return ResultSuccess;
+ }
+
+ void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override {
+ socket = std::move(in_socket);
+ }
+
+ Result SetHostName(const std::string& hostname) override {
+ OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size());
+ if (status) {
+ LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status));
+ return ResultInternalError;
+ }
+ return ResultSuccess;
+ }
+
+ Result DoHandshake() override {
+ OSStatus status = SSLHandshake(context);
+ return HandleReturn("SSLHandshake", 0, status).Code();
+ }
+
+ ResultVal<size_t> Read(std::span<u8> data) override {
+ size_t actual;
+ OSStatus status = SSLRead(context, data.data(), data.size(), &actual);
+ ;
+ return HandleReturn("SSLRead", actual, status);
+ }
+
+ ResultVal<size_t> Write(std::span<const u8> data) override {
+ size_t actual;
+ OSStatus status = SSLWrite(context, data.data(), data.size(), &actual);
+ ;
+ return HandleReturn("SSLWrite", actual, status);
+ }
+
+ ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) {
+ switch (status) {
+ case 0:
+ return actual;
+ case errSSLWouldBlock:
+ return ResultWouldBlock;
+ default: {
+ std::string reason;
+ if (got_read_eof) {
+ reason = "server hung up";
+ } else {
+ reason = OSStatusToString(status);
+ }
+ LOG_ERROR(Service_SSL, "{} failed: {}", what, reason);
+ return ResultInternalError;
+ }
+ }
+ }
+
+ ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
+ CFReleaser<SecTrustRef> trust;
+ OSStatus status = SSLCopyPeerTrust(context, &trust.ptr);
+ if (status) {
+ LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status));
+ return ResultInternalError;
+ }
+ std::vector<std::vector<u8>> ret;
+ for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) {
+ SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i);
+ CFReleaser<CFDataRef> data(SecCertificateCopyData(cert));
+ ASSERT_OR_EXECUTE(data, { return ResultInternalError; });
+ const u8* ptr = CFDataGetBytePtr(data);
+ ret.emplace_back(ptr, ptr + CFDataGetLength(data));
+ }
+ return ret;
+ }
+
+ static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) {
+ return ReadOrWriteCallback(connection, data, dataLength, true);
+ }
+
+ static OSStatus WriteCallback(SSLConnectionRef connection, const void* data,
+ size_t* dataLength) {
+ return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false);
+ }
+
+ static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength,
+ bool is_read) {
+ auto self =
+ static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection));
+ ASSERT_OR_EXECUTE_MSG(
+ self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket",
+ is_read ? "read" : "write");
+
+ // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are
+ // expected to read/write the full requested dataLength or return an
+ // error, so we have to add a loop ourselves.
+ size_t requested_len = *dataLength;
+ size_t offset = 0;
+ while (offset < requested_len) {
+ std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset);
+ auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0);
+ LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset,
+ actual, cur.size(), static_cast<s32>(err));
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ offset += actual;
+ if (actual == 0) {
+ ASSERT(is_read);
+ self->got_read_eof = true;
+ return errSecEndOfData;
+ }
+ break;
+ case Network::Errno::AGAIN:
+ *dataLength = offset;
+ return errSSLWouldBlock;
+ default:
+ LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}",
+ is_read ? "recv" : "send", err);
+ return errSecIO;
+ }
+ }
+ ASSERT(offset == requested_len);
+ return 0;
+ }
+
+private:
+ CFReleaser<SSLContextRef> context = nullptr;
+ bool got_read_eof = false;
+
+ std::shared_ptr<Network::SocketBase> socket;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ auto conn = std::make_unique<SSLConnectionBackendSecureTransport>();
+ const Result res = conn->Init();
+ if (res.IsFailure()) {
+ return res;
+ }
+ return conn;
+}
+
+} // namespace Service::SSL