diff options
Diffstat (limited to '')
-rw-r--r-- | src/core/hle/kernel/k_server_session.cpp | 165 |
1 files changed, 158 insertions, 7 deletions
diff --git a/src/core/hle/kernel/k_server_session.cpp b/src/core/hle/kernel/k_server_session.cpp index c66aff501..c64ceb530 100644 --- a/src/core/hle/kernel/k_server_session.cpp +++ b/src/core/hle/kernel/k_server_session.cpp @@ -20,12 +20,132 @@ #include "core/hle/kernel/k_thread.h" #include "core/hle/kernel/k_thread_queue.h" #include "core/hle/kernel/kernel.h" +#include "core/hle/kernel/message_buffer.h" #include "core/hle/service/hle_ipc.h" #include "core/hle/service/ipc_helpers.h" #include "core/memory.h" namespace Kernel { +namespace { + +template <bool MoveHandleAllowed> +Result ProcessMessageSpecialData(KProcess& dst_process, KProcess& src_process, KThread& src_thread, + MessageBuffer& dst_msg, const MessageBuffer& src_msg, + MessageBuffer::SpecialHeader& src_special_header) { + // Copy the special header to the destination. + s32 offset = dst_msg.Set(src_special_header); + + // Copy the process ID. + if (src_special_header.GetHasProcessId()) { + offset = dst_msg.SetProcessId(offset, src_process.GetProcessId()); + } + + // Prepare to process handles. + auto& dst_handle_table = dst_process.GetHandleTable(); + auto& src_handle_table = src_process.GetHandleTable(); + Result result = ResultSuccess; + + // Process copy handles. + for (auto i = 0; i < src_special_header.GetCopyHandleCount(); ++i) { + // Get the handles. + const Handle src_handle = src_msg.GetHandle(offset); + Handle dst_handle = Svc::InvalidHandle; + + // If we're in a success state, try to move the handle to the new table. + if (R_SUCCEEDED(result) && src_handle != Svc::InvalidHandle) { + KScopedAutoObject obj = + src_handle_table.GetObjectForIpc(src_handle, std::addressof(src_thread)); + if (obj.IsNotNull()) { + Result add_result = + dst_handle_table.Add(std::addressof(dst_handle), obj.GetPointerUnsafe()); + if (R_FAILED(add_result)) { + result = add_result; + dst_handle = Svc::InvalidHandle; + } + } else { + result = ResultInvalidHandle; + } + } + + // Set the handle. + offset = dst_msg.SetHandle(offset, dst_handle); + } + + // Process move handles. + if constexpr (MoveHandleAllowed) { + for (auto i = 0; i < src_special_header.GetMoveHandleCount(); ++i) { + // Get the handles. + const Handle src_handle = src_msg.GetHandle(offset); + Handle dst_handle = Svc::InvalidHandle; + + // Whether or not we've succeeded, we need to remove the handles from the source table. + if (src_handle != Svc::InvalidHandle) { + if (R_SUCCEEDED(result)) { + KScopedAutoObject obj = + src_handle_table.GetObjectForIpcWithoutPseudoHandle(src_handle); + if (obj.IsNotNull()) { + Result add_result = dst_handle_table.Add(std::addressof(dst_handle), + obj.GetPointerUnsafe()); + + src_handle_table.Remove(src_handle); + + if (R_FAILED(add_result)) { + result = add_result; + dst_handle = Svc::InvalidHandle; + } + } else { + result = ResultInvalidHandle; + } + } else { + src_handle_table.Remove(src_handle); + } + } + + // Set the handle. + offset = dst_msg.SetHandle(offset, dst_handle); + } + } + + R_RETURN(result); +} + +void CleanupSpecialData(KProcess& dst_process, u32* dst_msg_ptr, size_t dst_buffer_size) { + // Parse the message. + const MessageBuffer dst_msg(dst_msg_ptr, dst_buffer_size); + const MessageBuffer::MessageHeader dst_header(dst_msg); + const MessageBuffer::SpecialHeader dst_special_header(dst_msg, dst_header); + + // Check that the size is big enough. + if (MessageBuffer::GetMessageBufferSize(dst_header, dst_special_header) > dst_buffer_size) { + return; + } + + // Set the special header. + int offset = dst_msg.Set(dst_special_header); + + // Clear the process id, if needed. + if (dst_special_header.GetHasProcessId()) { + offset = dst_msg.SetProcessId(offset, 0); + } + + // Clear handles, as relevant. + auto& dst_handle_table = dst_process.GetHandleTable(); + for (auto i = 0; + i < (dst_special_header.GetCopyHandleCount() + dst_special_header.GetMoveHandleCount()); + ++i) { + const Handle handle = dst_msg.GetHandle(offset); + + if (handle != Svc::InvalidHandle) { + dst_handle_table.Remove(handle); + } + + offset = dst_msg.SetHandle(offset, Svc::InvalidHandle); + } +} + +} // namespace + using ThreadQueueImplForKServerSessionRequest = KThreadQueue; KServerSession::KServerSession(KernelCore& kernel) @@ -223,12 +343,27 @@ Result KServerSession::SendReply(bool is_hle) { // the reply has already been written in this case. } else { Core::Memory::Memory& memory{client_thread->GetOwnerProcess()->GetMemory()}; - KThread* server_thread{GetCurrentThreadPointer(m_kernel)}; + KThread* server_thread = GetCurrentThreadPointer(m_kernel); + KProcess& src_process = *client_thread->GetOwnerProcess(); + KProcess& dst_process = *server_thread->GetOwnerProcess(); UNIMPLEMENTED_IF(server_thread->GetOwnerProcess() != client_thread->GetOwnerProcess()); - auto* src_msg_buffer = memory.GetPointer(server_thread->GetTlsAddress()); - auto* dst_msg_buffer = memory.GetPointer(client_message); + auto* src_msg_buffer = memory.GetPointer<u32>(server_thread->GetTlsAddress()); + auto* dst_msg_buffer = memory.GetPointer<u32>(client_message); std::memcpy(dst_msg_buffer, src_msg_buffer, client_buffer_size); + + // Translate special header ad-hoc. + MessageBuffer src_msg(src_msg_buffer, client_buffer_size); + MessageBuffer::MessageHeader src_header(src_msg); + MessageBuffer::SpecialHeader src_special_header(src_msg, src_header); + if (src_header.GetHasSpecialHeader()) { + MessageBuffer dst_msg(dst_msg_buffer, client_buffer_size); + result = ProcessMessageSpecialData<true>(dst_process, src_process, *server_thread, + dst_msg, src_msg, src_special_header); + if (R_FAILED(result)) { + CleanupSpecialData(dst_process, dst_msg_buffer, client_buffer_size); + } + } } } else { result = ResultSessionClosed; @@ -330,12 +465,28 @@ Result KServerSession::ReceiveRequest(std::shared_ptr<Service::HLERequestContext ->PopulateFromIncomingCommandBuffer(client_thread->GetOwnerProcess()->GetHandleTable(), cmd_buf); } else { - KThread* server_thread{GetCurrentThreadPointer(m_kernel)}; - UNIMPLEMENTED_IF(server_thread->GetOwnerProcess() != client_thread->GetOwnerProcess()); + KThread* server_thread = GetCurrentThreadPointer(m_kernel); + KProcess& src_process = *client_thread->GetOwnerProcess(); + KProcess& dst_process = *server_thread->GetOwnerProcess(); + UNIMPLEMENTED_IF(client_thread->GetOwnerProcess() != server_thread->GetOwnerProcess()); - auto* src_msg_buffer = memory.GetPointer(client_message); - auto* dst_msg_buffer = memory.GetPointer(server_thread->GetTlsAddress()); + auto* src_msg_buffer = memory.GetPointer<u32>(client_message); + auto* dst_msg_buffer = memory.GetPointer<u32>(server_thread->GetTlsAddress()); std::memcpy(dst_msg_buffer, src_msg_buffer, client_buffer_size); + + // Translate special header ad-hoc. + // TODO: fix this mess + MessageBuffer src_msg(src_msg_buffer, client_buffer_size); + MessageBuffer::MessageHeader src_header(src_msg); + MessageBuffer::SpecialHeader src_special_header(src_msg, src_header); + if (src_header.GetHasSpecialHeader()) { + MessageBuffer dst_msg(dst_msg_buffer, client_buffer_size); + Result res = ProcessMessageSpecialData<false>(dst_process, src_process, *client_thread, + dst_msg, src_msg, src_special_header); + if (R_FAILED(res)) { + CleanupSpecialData(dst_process, dst_msg_buffer, client_buffer_size); + } + } } // We succeeded. |