summaryrefslogtreecommitdiffstats
path: root/src/core/hle/kernel/k_server_session.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/hle/kernel/k_server_session.cpp')
-rw-r--r--src/core/hle/kernel/k_server_session.cpp165
1 files changed, 158 insertions, 7 deletions
diff --git a/src/core/hle/kernel/k_server_session.cpp b/src/core/hle/kernel/k_server_session.cpp
index c66aff501..c64ceb530 100644
--- a/src/core/hle/kernel/k_server_session.cpp
+++ b/src/core/hle/kernel/k_server_session.cpp
@@ -20,12 +20,132 @@
#include "core/hle/kernel/k_thread.h"
#include "core/hle/kernel/k_thread_queue.h"
#include "core/hle/kernel/kernel.h"
+#include "core/hle/kernel/message_buffer.h"
#include "core/hle/service/hle_ipc.h"
#include "core/hle/service/ipc_helpers.h"
#include "core/memory.h"
namespace Kernel {
+namespace {
+
+template <bool MoveHandleAllowed>
+Result ProcessMessageSpecialData(KProcess& dst_process, KProcess& src_process, KThread& src_thread,
+ MessageBuffer& dst_msg, const MessageBuffer& src_msg,
+ MessageBuffer::SpecialHeader& src_special_header) {
+ // Copy the special header to the destination.
+ s32 offset = dst_msg.Set(src_special_header);
+
+ // Copy the process ID.
+ if (src_special_header.GetHasProcessId()) {
+ offset = dst_msg.SetProcessId(offset, src_process.GetProcessId());
+ }
+
+ // Prepare to process handles.
+ auto& dst_handle_table = dst_process.GetHandleTable();
+ auto& src_handle_table = src_process.GetHandleTable();
+ Result result = ResultSuccess;
+
+ // Process copy handles.
+ for (auto i = 0; i < src_special_header.GetCopyHandleCount(); ++i) {
+ // Get the handles.
+ const Handle src_handle = src_msg.GetHandle(offset);
+ Handle dst_handle = Svc::InvalidHandle;
+
+ // If we're in a success state, try to move the handle to the new table.
+ if (R_SUCCEEDED(result) && src_handle != Svc::InvalidHandle) {
+ KScopedAutoObject obj =
+ src_handle_table.GetObjectForIpc(src_handle, std::addressof(src_thread));
+ if (obj.IsNotNull()) {
+ Result add_result =
+ dst_handle_table.Add(std::addressof(dst_handle), obj.GetPointerUnsafe());
+ if (R_FAILED(add_result)) {
+ result = add_result;
+ dst_handle = Svc::InvalidHandle;
+ }
+ } else {
+ result = ResultInvalidHandle;
+ }
+ }
+
+ // Set the handle.
+ offset = dst_msg.SetHandle(offset, dst_handle);
+ }
+
+ // Process move handles.
+ if constexpr (MoveHandleAllowed) {
+ for (auto i = 0; i < src_special_header.GetMoveHandleCount(); ++i) {
+ // Get the handles.
+ const Handle src_handle = src_msg.GetHandle(offset);
+ Handle dst_handle = Svc::InvalidHandle;
+
+ // Whether or not we've succeeded, we need to remove the handles from the source table.
+ if (src_handle != Svc::InvalidHandle) {
+ if (R_SUCCEEDED(result)) {
+ KScopedAutoObject obj =
+ src_handle_table.GetObjectForIpcWithoutPseudoHandle(src_handle);
+ if (obj.IsNotNull()) {
+ Result add_result = dst_handle_table.Add(std::addressof(dst_handle),
+ obj.GetPointerUnsafe());
+
+ src_handle_table.Remove(src_handle);
+
+ if (R_FAILED(add_result)) {
+ result = add_result;
+ dst_handle = Svc::InvalidHandle;
+ }
+ } else {
+ result = ResultInvalidHandle;
+ }
+ } else {
+ src_handle_table.Remove(src_handle);
+ }
+ }
+
+ // Set the handle.
+ offset = dst_msg.SetHandle(offset, dst_handle);
+ }
+ }
+
+ R_RETURN(result);
+}
+
+void CleanupSpecialData(KProcess& dst_process, u32* dst_msg_ptr, size_t dst_buffer_size) {
+ // Parse the message.
+ const MessageBuffer dst_msg(dst_msg_ptr, dst_buffer_size);
+ const MessageBuffer::MessageHeader dst_header(dst_msg);
+ const MessageBuffer::SpecialHeader dst_special_header(dst_msg, dst_header);
+
+ // Check that the size is big enough.
+ if (MessageBuffer::GetMessageBufferSize(dst_header, dst_special_header) > dst_buffer_size) {
+ return;
+ }
+
+ // Set the special header.
+ int offset = dst_msg.Set(dst_special_header);
+
+ // Clear the process id, if needed.
+ if (dst_special_header.GetHasProcessId()) {
+ offset = dst_msg.SetProcessId(offset, 0);
+ }
+
+ // Clear handles, as relevant.
+ auto& dst_handle_table = dst_process.GetHandleTable();
+ for (auto i = 0;
+ i < (dst_special_header.GetCopyHandleCount() + dst_special_header.GetMoveHandleCount());
+ ++i) {
+ const Handle handle = dst_msg.GetHandle(offset);
+
+ if (handle != Svc::InvalidHandle) {
+ dst_handle_table.Remove(handle);
+ }
+
+ offset = dst_msg.SetHandle(offset, Svc::InvalidHandle);
+ }
+}
+
+} // namespace
+
using ThreadQueueImplForKServerSessionRequest = KThreadQueue;
KServerSession::KServerSession(KernelCore& kernel)
@@ -223,12 +343,27 @@ Result KServerSession::SendReply(bool is_hle) {
// the reply has already been written in this case.
} else {
Core::Memory::Memory& memory{client_thread->GetOwnerProcess()->GetMemory()};
- KThread* server_thread{GetCurrentThreadPointer(m_kernel)};
+ KThread* server_thread = GetCurrentThreadPointer(m_kernel);
+ KProcess& src_process = *client_thread->GetOwnerProcess();
+ KProcess& dst_process = *server_thread->GetOwnerProcess();
UNIMPLEMENTED_IF(server_thread->GetOwnerProcess() != client_thread->GetOwnerProcess());
- auto* src_msg_buffer = memory.GetPointer(server_thread->GetTlsAddress());
- auto* dst_msg_buffer = memory.GetPointer(client_message);
+ auto* src_msg_buffer = memory.GetPointer<u32>(server_thread->GetTlsAddress());
+ auto* dst_msg_buffer = memory.GetPointer<u32>(client_message);
std::memcpy(dst_msg_buffer, src_msg_buffer, client_buffer_size);
+
+ // Translate special header ad-hoc.
+ MessageBuffer src_msg(src_msg_buffer, client_buffer_size);
+ MessageBuffer::MessageHeader src_header(src_msg);
+ MessageBuffer::SpecialHeader src_special_header(src_msg, src_header);
+ if (src_header.GetHasSpecialHeader()) {
+ MessageBuffer dst_msg(dst_msg_buffer, client_buffer_size);
+ result = ProcessMessageSpecialData<true>(dst_process, src_process, *server_thread,
+ dst_msg, src_msg, src_special_header);
+ if (R_FAILED(result)) {
+ CleanupSpecialData(dst_process, dst_msg_buffer, client_buffer_size);
+ }
+ }
}
} else {
result = ResultSessionClosed;
@@ -330,12 +465,28 @@ Result KServerSession::ReceiveRequest(std::shared_ptr<Service::HLERequestContext
->PopulateFromIncomingCommandBuffer(client_thread->GetOwnerProcess()->GetHandleTable(),
cmd_buf);
} else {
- KThread* server_thread{GetCurrentThreadPointer(m_kernel)};
- UNIMPLEMENTED_IF(server_thread->GetOwnerProcess() != client_thread->GetOwnerProcess());
+ KThread* server_thread = GetCurrentThreadPointer(m_kernel);
+ KProcess& src_process = *client_thread->GetOwnerProcess();
+ KProcess& dst_process = *server_thread->GetOwnerProcess();
+ UNIMPLEMENTED_IF(client_thread->GetOwnerProcess() != server_thread->GetOwnerProcess());
- auto* src_msg_buffer = memory.GetPointer(client_message);
- auto* dst_msg_buffer = memory.GetPointer(server_thread->GetTlsAddress());
+ auto* src_msg_buffer = memory.GetPointer<u32>(client_message);
+ auto* dst_msg_buffer = memory.GetPointer<u32>(server_thread->GetTlsAddress());
std::memcpy(dst_msg_buffer, src_msg_buffer, client_buffer_size);
+
+ // Translate special header ad-hoc.
+ // TODO: fix this mess
+ MessageBuffer src_msg(src_msg_buffer, client_buffer_size);
+ MessageBuffer::MessageHeader src_header(src_msg);
+ MessageBuffer::SpecialHeader src_special_header(src_msg, src_header);
+ if (src_header.GetHasSpecialHeader()) {
+ MessageBuffer dst_msg(dst_msg_buffer, client_buffer_size);
+ Result res = ProcessMessageSpecialData<false>(dst_process, src_process, *client_thread,
+ dst_msg, src_msg, src_special_header);
+ if (R_FAILED(res)) {
+ CleanupSpecialData(dst_process, dst_msg_buffer, client_buffer_size);
+ }
+ }
}
// We succeeded.