diff options
224 files changed, 7046 insertions, 4891 deletions
diff --git a/.ci/scripts/windows/upload.ps1 b/.ci/scripts/windows/upload.ps1 index b9b8b4af8..62483607b 100644 --- a/.ci/scripts/windows/upload.ps1 +++ b/.ci/scripts/windows/upload.ps1 @@ -39,6 +39,7 @@ mkdir "artifacts" # Build a tar.xz for the source of the release Copy-Item .\license.txt -Destination $MSVC_SOURCE Copy-Item .\README.md -Destination $MSVC_SOURCE +Copy-Item .\CMakeLists.txt -Destination $MSVC_SOURCE Copy-Item .\src -Recurse -Destination $MSVC_SOURCE Copy-Item .\externals -Recurse -Destination $MSVC_SOURCE Copy-Item .\dist -Recurse -Destination $MSVC_SOURCE @@ -60,4 +61,4 @@ Get-ChildItem "$BUILD_DIR" -Recurse -Filter "QtWebEngineProcess*.exe" | Copy-Ite Get-ChildItem . -Filter "*.zip" | Copy-Item -destination "artifacts" Get-ChildItem . -Filter "*.7z" | Copy-Item -destination "artifacts" -Get-ChildItem . -Filter "*.tar.xz" | Copy-Item -destination "artifacts"
\ No newline at end of file +Get-ChildItem . -Filter "*.tar.xz" | Copy-Item -destination "artifacts" diff --git a/.ci/yuzu-patreon-step2.yml b/.ci/yuzu-patreon-step2.yml index 35c5fbe36..1b36f63e1 100644 --- a/.ci/yuzu-patreon-step2.yml +++ b/.ci/yuzu-patreon-step2.yml @@ -10,6 +10,7 @@ stages: jobs: - job: format displayName: 'clang' + continueOnError: true pool: vmImage: ubuntu-latest steps: diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b3b0d6d5..118572c03 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,7 +29,7 @@ option(ENABLE_VULKAN "Enables Vulkan backend" ON) option(USE_DISCORD_PRESENCE "Enables Discord Rich Presence" OFF) -if(NOT EXISTS ${PROJECT_SOURCE_DIR}/.git/hooks/pre-commit) +if(EXISTS ${PROJECT_SOURCE_DIR}/hooks/pre-commit AND NOT EXISTS ${PROJECT_SOURCE_DIR}/.git/hooks/pre-commit) message(STATUS "Copying pre-commit hook") file(COPY hooks/pre-commit DESTINATION ${PROJECT_SOURCE_DIR}/.git/hooks) @@ -49,7 +49,10 @@ function(check_submodules_present) endif() endforeach() endfunction() -check_submodules_present() + +if(EXISTS ${PROJECT_SOURCE_DIR}/.gitmodules) + check_submodules_present() +endif() configure_file(${PROJECT_SOURCE_DIR}/dist/compatibility_list/compatibility_list.qrc ${PROJECT_BINARY_DIR}/dist/compatibility_list/compatibility_list.qrc diff --git a/externals/httplib/httplib.h b/externals/httplib/httplib.h index dd9afe693..fa2edcc94 100644 --- a/externals/httplib/httplib.h +++ b/externals/httplib/httplib.h @@ -1,357 +1,768 @@ // // httplib.h // -// Copyright (c) 2017 Yuji Hirose. All rights reserved. +// Copyright (c) 2019 Yuji Hirose. All rights reserved. // MIT License // -#ifndef _CPPHTTPLIB_HTTPLIB_H_ -#define _CPPHTTPLIB_HTTPLIB_H_ +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +/* + * Configuration + */ +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH (std::numeric_limits<size_t>::max)() +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT 8 +#endif #ifdef _WIN32 #ifndef _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS -#endif +#endif //_CRT_SECURE_NO_WARNINGS + #ifndef _CRT_NONSTDC_NO_DEPRECATE #define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#ifdef _WIN64 +typedef __int64 ssize_t; +#else +typedef int ssize_t; #endif -#if defined(_MSC_VER) && _MSC_VER < 1900 +#if _MSC_VER < 1900 #define snprintf _snprintf_s #endif +#endif // _MSC_VER #ifndef S_ISREG -#define S_ISREG(m) (((m)&S_IFREG)==S_IFREG) -#endif +#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) +#endif // S_ISREG + #ifndef S_ISDIR -#define S_ISDIR(m) (((m)&S_IFDIR)==S_IFDIR) -#endif +#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX #include <io.h> #include <winsock2.h> #include <ws2tcpip.h> -#undef min -#undef max +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +#ifdef _MSC_VER +#pragma comment(lib, "ws2_32.lib") +#endif #ifndef strcasecmp #define strcasecmp _stricmp -#endif +#endif // strcasecmp typedef SOCKET socket_t; -#else -#include <pthread.h> -#include <unistd.h> -#include <netdb.h> +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) +#endif + +#else // not _WIN32 + +#include <arpa/inet.h> #include <cstring> +#include <netdb.h> #include <netinet/in.h> -#include <arpa/inet.h> +#ifdef CPPHTTPLIB_USE_POLL +#include <poll.h> +#endif +#include <pthread.h> #include <signal.h> -#include <sys/socket.h> #include <sys/select.h> +#include <sys/socket.h> +#include <unistd.h> typedef int socket_t; #define INVALID_SOCKET (-1) -#endif +#endif //_WIN32 +#include <assert.h> +#include <atomic> +#include <condition_variable> +#include <errno.h> +#include <fcntl.h> #include <fstream> #include <functional> +#include <list> #include <map> #include <memory> #include <mutex> +#include <random> #include <regex> #include <string> -#include <thread> #include <sys/stat.h> -#include <fcntl.h> -#include <assert.h> +#include <thread> #ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#include <openssl/err.h> #include <openssl/ssl.h> +#include <openssl/x509v3.h> + +// #if OPENSSL_VERSION_NUMBER < 0x1010100fL +// #error Sorry, OpenSSL versions prior to 1.1.1 are not supported +// #endif + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#include <openssl/crypto.h> +inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) { + return M_ASN1_STRING_data(asn1); +} +#endif #endif #ifdef CPPHTTPLIB_ZLIB_SUPPORT #include <zlib.h> #endif -/* - * Configuration - */ -#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 -#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 - -namespace httplib -{ +namespace httplib { namespace detail { struct ci { - bool operator() (const std::string & s1, const std::string & s2) const { - return std::lexicographical_compare( - s1.begin(), s1.end(), - s2.begin(), s2.end(), - [](char c1, char c2) { - return ::tolower(c1) < ::tolower(c2); - }); - } + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); + } }; } // namespace detail enum class HttpVersion { v1_0 = 0, v1_1 }; -typedef std::multimap<std::string, std::string, detail::ci> Headers; +typedef std::multimap<std::string, std::string, detail::ci> Headers; + +typedef std::multimap<std::string, std::string> Params; +typedef std::smatch Match; + +typedef std::function<void(const char *data, size_t data_len)> DataSink; -template<typename uint64_t, typename... Args> -std::pair<std::string, std::string> make_range_header(uint64_t value, Args... args); +typedef std::function<void()> Done; -typedef std::multimap<std::string, std::string> Params; -typedef std::smatch Match; -typedef std::function<void (uint64_t current, uint64_t total)> Progress; +typedef std::function<void(size_t offset, size_t length, DataSink sink, + Done done)> + ContentProvider; + +typedef std::function<bool(const char *data, size_t data_length, size_t offset, + uint64_t content_length)> + ContentReceiver; + +typedef std::function<bool(uint64_t current, uint64_t total)> Progress; + +struct Response; +typedef std::function<bool(const Response &response)> ResponseHandler; struct MultipartFile { - std::string filename; - std::string content_type; - size_t offset = 0; - size_t length = 0; + std::string filename; + std::string content_type; + size_t offset = 0; + size_t length = 0; }; typedef std::multimap<std::string, MultipartFile> MultipartFiles; +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +typedef std::vector<MultipartFormData> MultipartFormDataItems; + +typedef std::pair<ssize_t, ssize_t> Range; +typedef std::vector<Range> Ranges; + struct Request { - std::string version; - std::string method; - std::string target; - std::string path; - Headers headers; - std::string body; - Params params; - MultipartFiles files; - Match matches; - - Progress progress; - - bool has_header(const char* key) const; - std::string get_header_value(const char* key) const; - void set_header(const char* key, const char* val); - - bool has_param(const char* key) const; - std::string get_param_value(const char* key) const; - - bool has_file(const char* key) const; - MultipartFile get_file_value(const char* key) const; + std::string method; + std::string path; + Headers headers; + std::string body; + + // for server + std::string version; + std::string target; + Params params; + MultipartFiles files; + Ranges ranges; + Match matches; + + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; + Progress progress; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl; +#endif + + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + bool has_param(const char *key) const; + std::string get_param_value(const char *key, size_t id = 0) const; + size_t get_param_value_count(const char *key) const; + + bool has_file(const char *key) const; + MultipartFile get_file_value(const char *key) const; }; struct Response { - std::string version; - int status; - Headers headers; - std::string body; + std::string version; + int status; + Headers headers; + std::string body; - bool has_header(const char* key) const; - std::string get_header_value(const char* key) const; - void set_header(const char* key, const char* val); + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); - void set_redirect(const char* uri); - void set_content(const char* s, size_t n, const char* content_type); - void set_content(const std::string& s, const char* content_type); + void set_redirect(const char *uri); + void set_content(const char *s, size_t n, const char *content_type); + void set_content(const std::string &s, const char *content_type); - Response() : status(-1) {} + void set_content_provider( + size_t length, + std::function<void(size_t offset, size_t length, DataSink sink)> provider, + std::function<void()> resource_releaser = [] {}); + + void set_chunked_content_provider( + std::function<void(size_t offset, DataSink sink, Done done)> provider, + std::function<void()> resource_releaser = [] {}); + + Response() : status(-1), content_provider_resource_length(0) {} + + ~Response() { + if (content_provider_resource_releaser) { + content_provider_resource_releaser(); + } + } + + size_t content_provider_resource_length; + ContentProvider content_provider; + std::function<void()> content_provider_resource_releaser; }; class Stream { public: - virtual ~Stream() {} - virtual int read(char* ptr, size_t size) = 0; - virtual int write(const char* ptr, size_t size1) = 0; - virtual int write(const char* ptr) = 0; - virtual std::string get_remote_addr() = 0; - - template <typename ...Args> - void write_format(const char* fmt, const Args& ...args); + virtual ~Stream() {} + virtual int read(char *ptr, size_t size) = 0; + virtual int write(const char *ptr, size_t size1) = 0; + virtual int write(const char *ptr) = 0; + virtual int write(const std::string &s) = 0; + virtual std::string get_remote_addr() const = 0; + + template <typename... Args> + int write_format(const char *fmt, const Args &... args); }; class SocketStream : public Stream { public: - SocketStream(socket_t sock); - virtual ~SocketStream(); + SocketStream(socket_t sock); + virtual ~SocketStream(); + + virtual int read(char *ptr, size_t size); + virtual int write(const char *ptr, size_t size); + virtual int write(const char *ptr); + virtual int write(const std::string &s); + virtual std::string get_remote_addr() const; + +private: + socket_t sock_; +}; + +class BufferStream : public Stream { +public: + BufferStream() {} + virtual ~BufferStream() {} + + virtual int read(char *ptr, size_t size); + virtual int write(const char *ptr, size_t size); + virtual int write(const char *ptr); + virtual int write(const std::string &s); + virtual std::string get_remote_addr() const; + + const std::string &get_buffer() const; + +private: + std::string buffer; +}; + +class TaskQueue { +public: + TaskQueue() {} + virtual ~TaskQueue() {} + virtual void enqueue(std::function<void()> fn) = 0; + virtual void shutdown() = 0; +}; - virtual int read(char* ptr, size_t size); - virtual int write(const char* ptr, size_t size); - virtual int write(const char* ptr); - virtual std::string get_remote_addr(); +#if CPPHTTPLIB_THREAD_POOL_COUNT > 0 +class ThreadPool : public TaskQueue { +public: + ThreadPool(size_t n) : shutdown_(false) { + while (n) { + auto t = std::make_shared<std::thread>(worker(*this)); + threads_.push_back(t); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + virtual ~ThreadPool() {} + + virtual void enqueue(std::function<void()> fn) override { + std::unique_lock<std::mutex> lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); + } + + virtual void shutdown() override { + // Stop all worker threads... + { + std::unique_lock<std::mutex> lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto t : threads_) { + t->join(); + } + } private: - socket_t sock_; + struct worker { + worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function<void()> fn; + { + std::unique_lock<std::mutex> lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast<bool>(fn)); + fn(); + } + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector<std::shared_ptr<std::thread>> threads_; + std::list<std::function<void()>> jobs_; + + bool shutdown_; + + std::condition_variable cond_; + std::mutex mutex_; +}; +#else +class Threads : public TaskQueue { +public: + Threads() : running_threads_(0) {} + virtual ~Threads() {} + + virtual void enqueue(std::function<void()> fn) override { + std::thread([=]() { + { + std::lock_guard<std::mutex> guard(running_threads_mutex_); + running_threads_++; + } + + fn(); + + { + std::lock_guard<std::mutex> guard(running_threads_mutex_); + running_threads_--; + } + }).detach(); + } + + virtual void shutdown() override { + for (;;) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + std::lock_guard<std::mutex> guard(running_threads_mutex_); + if (!running_threads_) { break; } + } + } + +private: + std::mutex running_threads_mutex_; + int running_threads_; }; +#endif class Server { public: - typedef std::function<void (const Request&, Response&)> Handler; - typedef std::function<void (const Request&, const Response&)> Logger; + typedef std::function<void(const Request &, Response &)> Handler; + typedef std::function<void(const Request &, const Response &)> Logger; + + Server(); - Server(); + virtual ~Server(); - virtual ~Server(); + virtual bool is_valid() const; - virtual bool is_valid() const; + Server &Get(const char *pattern, Handler handler); + Server &Post(const char *pattern, Handler handler); - Server& Get(const char* pattern, Handler handler); - Server& Post(const char* pattern, Handler handler); + Server &Put(const char *pattern, Handler handler); + Server &Patch(const char *pattern, Handler handler); + Server &Delete(const char *pattern, Handler handler); + Server &Options(const char *pattern, Handler handler); - Server& Put(const char* pattern, Handler handler); - Server& Delete(const char* pattern, Handler handler); - Server& Options(const char* pattern, Handler handler); + bool set_base_dir(const char *path); + void set_file_request_handler(Handler handler); - bool set_base_dir(const char* path); + void set_error_handler(Handler handler); + void set_logger(Logger logger); - void set_error_handler(Handler handler); - void set_logger(Logger logger); + void set_keep_alive_max_count(size_t count); + void set_payload_max_length(size_t length); - void set_keep_alive_max_count(size_t count); + int bind_to_any_port(const char *host, int socket_flags = 0); + bool listen_after_bind(); - int bind_to_any_port(const char* host, int socket_flags = 0); - bool listen_after_bind(); + bool listen(const char *host, int port, int socket_flags = 0); - bool listen(const char* host, int port, int socket_flags = 0); + bool is_running() const; + void stop(); - bool is_running() const; - void stop(); + std::function<TaskQueue *(void)> new_task_queue; protected: - bool process_request(Stream& strm, bool last_connection, bool& connection_close); + bool process_request(Stream &strm, bool last_connection, + bool &connection_close, + std::function<void(Request &)> setup_request); - size_t keep_alive_max_count_; + size_t keep_alive_max_count_; + size_t payload_max_length_; private: - typedef std::vector<std::pair<std::regex, Handler>> Handlers; - - socket_t create_server_socket(const char* host, int port, int socket_flags) const; - int bind_internal(const char* host, int port, int socket_flags); - bool listen_internal(); - - bool routing(Request& req, Response& res); - bool handle_file_request(Request& req, Response& res); - bool dispatch_request(Request& req, Response& res, Handlers& handlers); - - bool parse_request_line(const char* s, Request& req); - void write_response(Stream& strm, bool last_connection, const Request& req, Response& res); - - virtual bool read_and_close_socket(socket_t sock); - - bool is_running_; - socket_t svr_sock_; - std::string base_dir_; - Handlers get_handlers_; - Handlers post_handlers_; - Handlers put_handlers_; - Handlers delete_handlers_; - Handlers options_handlers_; - Handler error_handler_; - Logger logger_; - - // TODO: Use thread pool... - std::mutex running_threads_mutex_; - int running_threads_; + typedef std::vector<std::pair<std::regex, Handler>> Handlers; + + socket_t create_server_socket(const char *host, int port, + int socket_flags) const; + int bind_internal(const char *host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res); + bool handle_file_request(Request &req, Response &res); + bool dispatch_request(Request &req, Response &res, Handlers &handlers); + + bool parse_request_line(const char *s, Request &req); + bool write_response(Stream &strm, bool last_connection, const Request &req, + Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic<bool> is_running_; + std::atomic<socket_t> svr_sock_; + std::string base_dir_; + Handler file_request_handler_; + Handlers get_handlers_; + Handlers post_handlers_; + Handlers put_handlers_; + Handlers patch_handlers_; + Handlers delete_handlers_; + Handlers options_handlers_; + Handler error_handler_; + Logger logger_; }; class Client { public: - Client( - const char* host, - int port = 80, - size_t timeout_sec = 300); + Client(const char *host, int port = 80, time_t timeout_sec = 300); + + virtual ~Client(); + + virtual bool is_valid() const; + + std::shared_ptr<Response> Get(const char *path); + + std::shared_ptr<Response> Get(const char *path, const Headers &headers); + + std::shared_ptr<Response> Get(const char *path, Progress progress); + + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + Progress progress); - virtual ~Client(); + std::shared_ptr<Response> Get(const char *path, + ContentReceiver content_receiver); - virtual bool is_valid() const; + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); - std::shared_ptr<Response> Get(const char* path, Progress progress = nullptr); - std::shared_ptr<Response> Get(const char* path, const Headers& headers, Progress progress = nullptr); + std::shared_ptr<Response> + Get(const char *path, ContentReceiver content_receiver, Progress progress); - std::shared_ptr<Response> Head(const char* path); - std::shared_ptr<Response> Head(const char* path, const Headers& headers); + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, + Progress progress); - std::shared_ptr<Response> Post(const char* path, const std::string& body, const char* content_type); - std::shared_ptr<Response> Post(const char* path, const Headers& headers, const std::string& body, const char* content_type); + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); - std::shared_ptr<Response> Post(const char* path, const Params& params); - std::shared_ptr<Response> Post(const char* path, const Headers& headers, const Params& params); + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress); - std::shared_ptr<Response> Put(const char* path, const std::string& body, const char* content_type); - std::shared_ptr<Response> Put(const char* path, const Headers& headers, const std::string& body, const char* content_type); + std::shared_ptr<Response> Head(const char *path); - std::shared_ptr<Response> Delete(const char* path); - std::shared_ptr<Response> Delete(const char* path, const Headers& headers); + std::shared_ptr<Response> Head(const char *path, const Headers &headers); - std::shared_ptr<Response> Options(const char* path); - std::shared_ptr<Response> Options(const char* path, const Headers& headers); + std::shared_ptr<Response> Post(const char *path, const std::string &body, + const char *content_type); - bool send(Request& req, Response& res); + std::shared_ptr<Response> Post(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Post(const char *path, const Params ¶ms); + + std::shared_ptr<Response> Post(const char *path, const Headers &headers, + const Params ¶ms); + + std::shared_ptr<Response> Post(const char *path, + const MultipartFormDataItems &items); + + std::shared_ptr<Response> Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items); + + std::shared_ptr<Response> Put(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Put(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Patch(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Patch(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Delete(const char *path); + + std::shared_ptr<Response> Delete(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Delete(const char *path, const Headers &headers); + + std::shared_ptr<Response> Delete(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Options(const char *path); + + std::shared_ptr<Response> Options(const char *path, const Headers &headers); + + bool send(const Request &req, Response &res); + + bool send(const std::vector<Request> &requests, + std::vector<Response> &responses); + + void set_keep_alive_max_count(size_t count); + + void follow_location(bool on); protected: - bool process_request(Stream& strm, Request& req, Response& res, bool& connection_close); + bool process_request(Stream &strm, const Request &req, Response &res, + bool last_connection, bool &connection_close); - const std::string host_; - const int port_; - size_t timeout_sec_; - const std::string host_and_port_; + const std::string host_; + const int port_; + time_t timeout_sec_; + const std::string host_and_port_; + size_t keep_alive_max_count_; + size_t follow_location_; private: - socket_t create_client_socket() const; - bool read_response_line(Stream& strm, Response& res); - void write_request(Stream& strm, Request& req); - - virtual bool read_and_close_socket(socket_t sock, Request& req, Response& res); + socket_t create_client_socket() const; + bool read_response_line(Stream &strm, Response &res); + void write_request(Stream &strm, const Request &req, bool last_connection); + bool redirect(const Request &req, Response &res); + + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function<bool(Stream &strm, bool last_connection, + bool &connection_close)> + callback); + + virtual bool is_ssl() const; }; +inline void Get(std::vector<Request> &requests, const char *path, + const Headers &headers) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + requests.emplace_back(std::move(req)); +} + +inline void Get(std::vector<Request> &requests, const char *path) { + Get(requests, path, Headers()); +} + +inline void Post(std::vector<Request> &requests, const char *path, + const Headers &headers, const std::string &body, + const char *content_type) { + Request req; + req.method = "POST"; + req.path = path; + req.headers = headers; + req.headers.emplace("Content-Type", content_type); + req.body = body; + requests.emplace_back(std::move(req)); +} + +inline void Post(std::vector<Request> &requests, const char *path, + const std::string &body, const char *content_type) { + Post(requests, path, Headers(), body, content_type); +} + #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLSocketStream : public Stream { public: - SSLSocketStream(socket_t sock, SSL* ssl); - virtual ~SSLSocketStream(); + SSLSocketStream(socket_t sock, SSL *ssl); + virtual ~SSLSocketStream(); - virtual int read(char* ptr, size_t size); - virtual int write(const char* ptr, size_t size); - virtual int write(const char* ptr); - virtual std::string get_remote_addr(); + virtual int read(char *ptr, size_t size); + virtual int write(const char *ptr, size_t size); + virtual int write(const char *ptr); + virtual int write(const std::string &s); + virtual std::string get_remote_addr() const; private: - socket_t sock_; - SSL* ssl_; + socket_t sock_; + SSL *ssl_; }; class SSLServer : public Server { public: - SSLServer( - const char* cert_path, const char* private_key_path); + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr); - virtual ~SSLServer(); + virtual ~SSLServer(); - virtual bool is_valid() const; + virtual bool is_valid() const; private: - virtual bool read_and_close_socket(socket_t sock); + virtual bool process_and_close_socket(socket_t sock); - SSL_CTX* ctx_; - std::mutex ctx_mutex_; + SSL_CTX *ctx_; + std::mutex ctx_mutex_; }; class SSLClient : public Client { public: - SSLClient( - const char* host, - int port = 80, - size_t timeout_sec = 300); + SSLClient(const char *host, int port = 443, time_t timeout_sec = 300, + const char *client_cert_path = nullptr, + const char *client_key_path = nullptr); - virtual ~SSLClient(); + virtual ~SSLClient(); - virtual bool is_valid() const; + virtual bool is_valid() const; -private: - virtual bool read_and_close_socket(socket_t sock, Request& req, Response& res); + void set_ca_cert_path(const char *ca_ceert_file_path, + const char *ca_cert_dir_path = nullptr); + void enable_server_certificate_verification(bool enabled); - SSL_CTX* ctx_; - std::mutex ctx_mutex_; + long get_openssl_verify_result() const; + + SSL_CTX* ssl_context() const noexcept; + +private: + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function<bool(Stream &strm, bool last_connection, + bool &connection_close)> + callback); + virtual bool is_ssl() const; + + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::vector<std::string> host_components_; + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + bool server_certificate_verification_ = false; + long verify_result_ = 0; }; #endif @@ -360,913 +771,1237 @@ private: */ namespace detail { -template <class Fn> -void split(const char* b, const char* e, char d, Fn fn) -{ - int i = 0; - int beg = 0; +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} - while (e ? (b + i != e) : (b[i] != '\0')) { - if (b[i] == d) { - fn(&b[beg], &b[i]); - beg = i + 1; - } - i++; +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, + int &val) { + if (i >= s.size()) { return false; } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + int v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + const char *charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = (code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = (0xC0 | ((code >> 6) & 0x1F)); + buff[1] = (0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = (0xE0 | ((code >> 12) & 0xF)); + buff[1] = (0x80 | ((code >> 6) & 0x3F)); + buff[2] = (0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = (0xE0 | ((code >> 12) & 0xF)); + buff[1] = (0x80 | ((code >> 6) & 0x3F)); + buff[2] = (0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = (0xF0 | ((code >> 18) & 0x7)); + buff[1] = (0x80 | ((code >> 12) & 0x3F)); + buff[2] = (0x80 | ((code >> 6) & 0x3F)); + buff[3] = (0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + int val = 0; + int valb = -6; + + for (uint8_t c : in) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; } + } + + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} - if (i) { - fn(&b[beg], &b[i]); +inline bool is_file(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); +} + +inline bool is_dir(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + i++; } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline void read_file(const std::string &path, std::string &out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast<size_t>(size)); + fs.read(&out[0], size); +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + auto pat = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, pat)) { return m[1].str(); } + return std::string(); +} + +template <class Fn> void split(const char *b, const char *e, char d, Fn fn) { + int i = 0; + int beg = 0; + + while (e ? (b + i != e) : (b[i] != '\0')) { + if (b[i] == d) { + fn(&b[beg], &b[i]); + beg = i + 1; + } + i++; + } + + if (i) { fn(&b[beg], &b[i]); } } // NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` // to store data. The call can set memory on stack for performance. class stream_line_reader { public: - stream_line_reader(Stream& strm, char* fixed_buffer, size_t fixed_buffer_size) - : strm_(strm) - , fixed_buffer_(fixed_buffer) - , fixed_buffer_size_(fixed_buffer_size) { + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} + + const char *ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); } + } - const char* ptr() const { - if (glowable_buffer_.empty()) { - return fixed_buffer_; - } else { - return glowable_buffer_.data(); - } + size_t size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); } + } - bool getline() { - fixed_buffer_used_size_ = 0; - glowable_buffer_.clear(); - - for (size_t i = 0; ; i++) { - char byte; - auto n = strm_.read(&byte, 1); - - if (n < 0) { - return false; - } else if (n == 0) { - if (i == 0) { - return false; - } else { - break; - } - } + bool getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); - append(byte); + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); - if (byte == '\n') { - break; - } + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; } + } - return true; - } + append(byte); -private: - void append(char c) { - if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { - fixed_buffer_[fixed_buffer_used_size_++] = c; - fixed_buffer_[fixed_buffer_used_size_] = '\0'; - } else { - if (glowable_buffer_.empty()) { - assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); - glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); - } - glowable_buffer_ += c; - } + if (byte == '\n') { break; } } - Stream& strm_; - char* fixed_buffer_; - const size_t fixed_buffer_size_; - size_t fixed_buffer_used_size_; - std::string glowable_buffer_; + return true; + } + +private: + void append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } + } + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_; + std::string glowable_buffer_; }; -inline int close_socket(socket_t sock) -{ +inline int close_socket(socket_t sock) { #ifdef _WIN32 - return closesocket(sock); + return closesocket(sock); #else - return close(sock); + return close(sock); #endif } -inline int select_read(socket_t sock, size_t sec, size_t usec) -{ - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); +inline int select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; - timeval tv; - tv.tv_sec = sec; - tv.tv_usec = usec; + auto timeout = static_cast<int>(sec * 1000 + usec / 1000); - return select(sock + 1, &fds, NULL, NULL, &tv); -} + return poll(&pfd_read, 1, timeout); +#else + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); -inline bool wait_until_socket_is_ready(socket_t sock, size_t sec, size_t usec) -{ - fd_set fdsr; - FD_ZERO(&fdsr); - FD_SET(sock, &fdsr); + timeval tv; + tv.tv_sec = static_cast<long>(sec); + tv.tv_usec = static_cast<long>(usec); - auto fdsw = fdsr; - auto fdse = fdsr; + return select(static_cast<int>(sock + 1), &fds, nullptr, nullptr, &tv); +#endif +} - timeval tv; - tv.tv_sec = sec; - tv.tv_usec = usec; +inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; - if (select(sock + 1, &fdsr, &fdsw, &fdse, &tv) < 0) { - return false; - } else if (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw)) { - int error = 0; - socklen_t len = sizeof(error); - if (getsockopt(sock, SOL_SOCKET, SO_ERROR, (char*)&error, &len) < 0 || error) { - return false; - } - } else { - return false; - } + auto timeout = static_cast<int>(sec * 1000 + usec / 1000); - return true; + if (poll(&pfd_read, 1, timeout) > 0 && + pfd_read.revents & (POLLIN | POLLOUT)) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) >= 0 && + !error; + } + return false; +#else + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast<long>(sec); + tv.tv_usec = static_cast<long>(usec); + + if (select(static_cast<int>(sock + 1), &fdsr, &fdsw, &fdse, &tv) > 0 && + (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, (char *)&error, &len) >= 0 && + !error; + } + return false; +#endif } template <typename T> -inline bool read_and_close_socket(socket_t sock, size_t keep_alive_max_count, T callback) -{ - bool ret = false; - - if (keep_alive_max_count > 0) { - auto count = keep_alive_max_count; - while (count > 0 && - detail::select_read(sock, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { - SocketStream strm(sock); - auto last_connection = count == 1; - auto connection_close = false; - - ret = callback(strm, last_connection, connection_close); - if (!ret || connection_close) { - break; - } +inline bool process_and_close_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, T callback) { + assert(keep_alive_max_count > 0); - count--; - } - } else { - SocketStream strm(sock); - auto dummy_connection_close = false; - ret = callback(strm, true, dummy_connection_close); + bool ret = false; + + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SocketStream strm(sock); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; } + } else { + SocketStream strm(sock); + auto dummy_connection_close = false; + ret = callback(strm, true, dummy_connection_close); + } - close_socket(sock); - return ret; + close_socket(sock); + return ret; } -inline int shutdown_socket(socket_t sock) -{ +inline int shutdown_socket(socket_t sock) { #ifdef _WIN32 - return shutdown(sock, SD_BOTH); + return shutdown(sock, SD_BOTH); #else - return shutdown(sock, SHUT_RDWR); + return shutdown(sock, SHUT_RDWR); #endif } template <typename Fn> -socket_t create_socket(const char* host, int port, Fn fn, int socket_flags = 0) -{ +socket_t create_socket(const char *host, int port, Fn fn, + int socket_flags = 0) { #ifdef _WIN32 #define SO_SYNCHRONOUS_NONALERT 0x20 #define SO_OPENTYPE 0x7008 - int opt = SO_SYNCHRONOUS_NONALERT; - setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char*)&opt, sizeof(opt)); + int opt = SO_SYNCHRONOUS_NONALERT; + setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char *)&opt, + sizeof(opt)); #endif - // Get address info - struct addrinfo hints; - struct addrinfo *result; + // Get address info + struct addrinfo hints; + struct addrinfo *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = socket_flags; - hints.ai_protocol = 0; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = socket_flags; + hints.ai_protocol = 0; - auto service = std::to_string(port); + auto service = std::to_string(port); - if (getaddrinfo(host, service.c_str(), &hints, &result)) { - return INVALID_SOCKET; - } + if (getaddrinfo(host, service.c_str(), &hints, &result)) { + return INVALID_SOCKET; + } - for (auto rp = result; rp; rp = rp->ai_next) { - // Create a socket - auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - if (sock == INVALID_SOCKET) { - continue; - } + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, + nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + if (sock == INVALID_SOCKET) { continue; } - // Make 'reuse address' option available - int yes = 1; - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char*)&yes, sizeof(yes)); +#ifndef _WIN32 + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } +#endif - // bind or connect - if (fn(sock, *rp)) { - freeaddrinfo(result); - return sock; - } + // Make 'reuse address' option available + int yes = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&yes), sizeof(yes)); +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast<char*>(&yes), sizeof(yes)); +#endif - close_socket(sock); + // bind or connect + if (fn(sock, *rp)) { + freeaddrinfo(result); + return sock; } - freeaddrinfo(result); - return INVALID_SOCKET; + close_socket(sock); + } + + freeaddrinfo(result); + return INVALID_SOCKET; } -inline void set_nonblocking(socket_t sock, bool nonblocking) -{ +inline void set_nonblocking(socket_t sock, bool nonblocking) { #ifdef _WIN32 - auto flags = nonblocking ? 1UL : 0UL; - ioctlsocket(sock, FIONBIO, &flags); + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); #else - auto flags = fcntl(sock, F_GETFL, 0); - fcntl(sock, F_SETFL, nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); #endif } -inline bool is_connection_error() -{ +inline bool is_connection_error() { #ifdef _WIN32 - return WSAGetLastError() != WSAEWOULDBLOCK; + return WSAGetLastError() != WSAEWOULDBLOCK; #else - return errno != EINPROGRESS; + return errno != EINPROGRESS; #endif } inline std::string get_remote_addr(socket_t sock) { - struct sockaddr_storage addr; - socklen_t len = sizeof(addr); + struct sockaddr_storage addr; + socklen_t len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast<struct sockaddr *>(&addr), &len)) { + char ipstr[NI_MAXHOST]; + + if (!getnameinfo(reinterpret_cast<struct sockaddr *>(&addr), len, ipstr, sizeof(ipstr), + nullptr, 0, NI_NUMERICHOST)) { + return ipstr; + } + } + + return std::string(); +} + +inline const char *find_content_type(const std::string &path) { + auto ext = file_extension(path); + if (ext == "txt") { + return "text/plain"; + } else if (ext == "html") { + return "text/html"; + } else if (ext == "css") { + return "text/css"; + } else if (ext == "jpeg" || ext == "jpg") { + return "image/jpg"; + } else if (ext == "png") { + return "image/png"; + } else if (ext == "gif") { + return "image/gif"; + } else if (ext == "svg") { + return "image/svg+xml"; + } else if (ext == "ico") { + return "image/x-icon"; + } else if (ext == "json") { + return "application/json"; + } else if (ext == "pdf") { + return "application/pdf"; + } else if (ext == "js") { + return "application/javascript"; + } else if (ext == "xml") { + return "application/xml"; + } else if (ext == "xhtml") { + return "application/xhtml+xml"; + } + return nullptr; +} + +inline const char *status_message(int status) { + switch (status) { + case 200: return "OK"; + case 206: return "Partial Content"; + case 301: return "Moved Permanently"; + case 302: return "Found"; + case 303: return "See Other"; + case 304: return "Not Modified"; + case 400: return "Bad Request"; + case 403: return "Forbidden"; + case 404: return "Not Found"; + case 413: return "Payload Too Large"; + case 414: return "Request-URI Too Long"; + case 415: return "Unsupported Media Type"; + case 416: return "Range Not Satisfiable"; + + default: + case 500: return "Internal Server Error"; + } +} - if (!getpeername(sock, (struct sockaddr*)&addr, &len)) { - char ipstr[NI_MAXHOST]; +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline bool can_compress(const std::string &content_type) { + return !content_type.find("text/") || content_type == "image/svg+xml" || + content_type == "application/javascript" || + content_type == "application/json" || + content_type == "application/xml" || + content_type == "application/xhtml+xml"; +} - if (!getnameinfo((struct sockaddr*)&addr, len, - ipstr, sizeof(ipstr), nullptr, 0, NI_NUMERICHOST)) { - return ipstr; - } - } +inline bool compress(std::string &content) { + z_stream strm; + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; - return std::string(); -} + auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY); + if (ret != Z_OK) { return false; } -inline bool is_file(const std::string& path) -{ - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); -} + strm.avail_in = content.size(); + strm.next_in = const_cast<Bytef*>(reinterpret_cast<const Bytef*>(content.data())); + + std::string compressed; + + const auto bufsiz = 16384; + char buff[bufsiz]; + do { + strm.avail_out = bufsiz; + strm.next_out = reinterpret_cast<Bytef*>(buff); + ret = deflate(&strm, Z_FINISH); + assert(ret != Z_STREAM_ERROR); + compressed.append(buff, bufsiz - strm.avail_out); + } while (strm.avail_out == 0); + + assert(ret == Z_STREAM_END); + assert(strm.avail_in == 0); -inline bool is_dir(const std::string& path) -{ - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); + content.swap(compressed); + + deflateEnd(&strm); + return true; } -inline bool is_valid_path(const std::string& path) { - size_t level = 0; - size_t i = 0; +class decompressor { +public: + decompressor() { + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; - } + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 16 specifies + // that the stream to decompress will be formatted with a gzip wrapper. + is_valid_ = inflateInit2(&strm, 16 + 15) == Z_OK; + } - while (i < path.size()) { - // Read component - auto beg = i; - while (i < path.size() && path[i] != '/') { - i++; - } + ~decompressor() { inflateEnd(&strm); } - auto len = i - beg; - assert(len > 0); + bool is_valid() const { return is_valid_; } - if (!path.compare(beg, len, ".")) { - ; - } else if (!path.compare(beg, len, "..")) { - if (level == 0) { - return false; - } - level--; - } else { - level++; - } + template <typename T> + bool decompress(const char *data, size_t data_length, T callback) { + int ret = Z_OK; - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; - } - } + strm.avail_in = data_length; + strm.next_in = const_cast<Bytef*>(reinterpret_cast<const Bytef *>(data)); - return true; -} + const auto bufsiz = 16384; + char buff[bufsiz]; + do { + strm.avail_out = bufsiz; + strm.next_out = reinterpret_cast<Bytef*>(buff); + + ret = inflate(&strm, Z_NO_FLUSH); + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm); return false; + } + + if (!callback(buff, bufsiz - strm.avail_out)) { return false; } + } while (strm.avail_out == 0); -inline void read_file(const std::string& path, std::string& out) -{ - std::ifstream fs(path, std::ios_base::binary); - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - out.resize(static_cast<size_t>(size)); - fs.read(&out[0], size); -} + return ret == Z_STREAM_END; + } -inline std::string file_extension(const std::string& path) -{ - std::smatch m; - auto pat = std::regex("\\.([a-zA-Z0-9]+)$"); - if (std::regex_search(path, m, pat)) { - return m[1].str(); - } - return std::string(); -} - -inline const char* find_content_type(const std::string& path) -{ - auto ext = file_extension(path); - if (ext == "txt") { - return "text/plain"; - } else if (ext == "html") { - return "text/html"; - } else if (ext == "css") { - return "text/css"; - } else if (ext == "jpeg" || ext == "jpg") { - return "image/jpg"; - } else if (ext == "png") { - return "image/png"; - } else if (ext == "gif") { - return "image/gif"; - } else if (ext == "svg") { - return "image/svg+xml"; - } else if (ext == "ico") { - return "image/x-icon"; - } else if (ext == "json") { - return "application/json"; - } else if (ext == "pdf") { - return "application/pdf"; - } else if (ext == "js") { - return "application/javascript"; - } else if (ext == "xml") { - return "application/xml"; - } else if (ext == "xhtml") { - return "application/xhtml+xml"; - } - return nullptr; -} - -inline const char* status_message(int status) -{ - switch (status) { - case 200: return "OK"; - case 301: return "Moved Permanently"; - case 302: return "Found"; - case 303: return "See Other"; - case 304: return "Not Modified"; - case 400: return "Bad Request"; - case 403: return "Forbidden"; - case 404: return "Not Found"; - case 415: return "Unsupported Media Type"; - default: - case 500: return "Internal Server Error"; - } +private: + bool is_valid_; + z_stream strm; +}; +#endif + +inline bool has_header(const Headers &headers, const char *key) { + return headers.find(key) != headers.end(); } -inline const char* get_header_value(const Headers& headers, const char* key, const char* def) -{ - auto it = headers.find(key); - if (it != headers.end()) { - return it->second.c_str(); - } - return def; +inline const char *get_header_value(const Headers &headers, const char *key, + size_t id = 0, const char *def = nullptr) { + auto it = headers.find(key); + std::advance(it, id); + if (it != headers.end()) { return it->second.c_str(); } + return def; } -inline int get_header_value_int(const Headers& headers, const char* key, int def) -{ - auto it = headers.find(key); - if (it != headers.end()) { - return std::stoi(it->second); - } - return def; +inline uint64_t get_header_value_uint64(const Headers &headers, const char *key, + int def = 0) { + auto it = headers.find(key); + if (it != headers.end()) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; } -inline bool read_headers(Stream& strm, Headers& headers) -{ - static std::regex re(R"((.+?):\s*(.+?)\s*\r\n)"); +inline bool read_headers(Stream &strm, Headers &headers) { + static std::regex re(R"((.+?):\s*(.+?)\s*\r\n)"); - const auto bufsiz = 2048; - char buf[bufsiz]; + const auto bufsiz = 2048; + char buf[bufsiz]; - stream_line_reader reader(strm, buf, bufsiz); + stream_line_reader reader(strm, buf, bufsiz); - for (;;) { - if (!reader.getline()) { - return false; - } - if (!strcmp(reader.ptr(), "\r\n")) { - break; - } - std::cmatch m; - if (std::regex_match(reader.ptr(), m, re)) { - auto key = std::string(m[1]); - auto val = std::string(m[2]); - headers.emplace(key, val); - } + for (;;) { + if (!reader.getline()) { return false; } + if (!strcmp(reader.ptr(), "\r\n")) { break; } + std::cmatch m; + if (std::regex_match(reader.ptr(), m, re)) { + auto key = std::string(m[1]); + auto val = std::string(m[2]); + headers.emplace(key, val); } + } - return true; + return true; } -inline bool read_content_with_length(Stream& strm, std::string& out, size_t len, Progress progress) -{ - out.assign(len, 0); - size_t r = 0; - while (r < len){ - auto n = strm.read(&out[r], len - r); - if (n <= 0) { - return false; - } +typedef std::function<bool(const char *data, size_t data_length)> + ContentReceiverCore; - r += n; +inline bool read_content_with_length(Stream &strm, uint64_t len, + Progress progress, + ContentReceiverCore out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; - if (progress) { - progress(r, len); - } - } + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast<size_t>(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } - return true; -} + if (!out(buf, n)) { return false; } -inline bool read_content_without_length(Stream& strm, std::string& out) -{ - for (;;) { - char byte; - auto n = strm.read(&byte, 1); - if (n < 0) { - return false; - } else if (n == 0) { - return true; - } - out += byte; + r += n; + + if (progress) { + if (!progress(r, len)) { return false; } } + } - return true; + return true; } -inline bool read_content_chunked(Stream& strm, std::string& out) -{ - const auto bufsiz = 16; - char buf[bufsiz]; - - stream_line_reader reader(strm, buf, bufsiz); +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast<size_t>(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += n; + } +} - if (!reader.getline()) { - return false; +inline bool read_content_without_length(Stream &strm, ContentReceiverCore out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n < 0) { + return false; + } else if (n == 0) { + return true; } + if (!out(buf, n)) { return false; } + } - auto chunk_len = std::stoi(reader.ptr(), 0, 16); + return true; +} - while (chunk_len > 0){ - std::string chunk; - if (!read_content_with_length(strm, chunk, chunk_len, nullptr)) { - return false; - } +inline bool read_content_chunked(Stream &strm, ContentReceiverCore out) { + const auto bufsiz = 16; + char buf[bufsiz]; - if (!reader.getline()) { - return false; - } + stream_line_reader reader(strm, buf, bufsiz); - if (strcmp(reader.ptr(), "\r\n")) { - break; - } + if (!reader.getline()) { return false; } - out += chunk; + auto chunk_len = std::stoi(reader.ptr(), 0, 16); - if (!reader.getline()) { - return false; - } - - chunk_len = std::stoi(reader.ptr(), 0, 16); + while (chunk_len > 0) { + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; } - if (chunk_len == 0) { - // Reader terminator after chunks - if (!reader.getline() || strcmp(reader.ptr(), "\r\n")) - return false; - } + if (!reader.getline()) { return false; } - return true; -} + if (strcmp(reader.ptr(), "\r\n")) { break; } -template <typename T> -bool read_content(Stream& strm, T& x, Progress progress = Progress()) -{ - auto len = get_header_value_int(x.headers, "Content-Length", 0); + if (!reader.getline()) { return false; } - if (len) { - return read_content_with_length(strm, x.body, len, progress); - } else { - const auto& encoding = get_header_value(x.headers, "Transfer-Encoding", ""); + chunk_len = std::stoi(reader.ptr(), 0, 16); + } - if (!strcasecmp(encoding, "chunked")) { - return read_content_chunked(strm, x.body); - } else { - return read_content_without_length(strm, x.body); - } - } + if (chunk_len == 0) { + // Reader terminator after chunks + if (!reader.getline() || strcmp(reader.ptr(), "\r\n")) return false; + } - return true; + return true; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), + "chunked"); } template <typename T> -inline void write_headers(Stream& strm, const T& info) -{ - for (const auto& x: info.headers) { - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - } - strm.write("\r\n"); -} - -inline std::string encode_url(const std::string& s) -{ - std::string result; - - for (auto i = 0; s[i]; i++) { - switch (s[i]) { - case ' ': result += "+"; break; - case '\'': result += "%27"; break; - case ',': result += "%2C"; break; - case ':': result += "%3A"; break; - case ';': result += "%3B"; break; - default: - if (s[i] < 0) { - result += '%'; - char hex[4]; - size_t len = snprintf(hex, sizeof(hex) - 1, "%02X", (unsigned char)s[i]); - assert(len == 2); - result.append(hex, len); - } else { - result += s[i]; - } - break; - } - } +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, + Progress progress, ContentReceiverCore receiver) { - return result; -} + ContentReceiverCore out = [&](const char *buf, size_t n) { + return receiver(buf, n); + }; -inline bool is_hex(char c, int& v) -{ - if (0x20 <= c && isdigit(c)) { - v = c - '0'; - return true; - } else if ('A' <= c && c <= 'F') { - v = c - 'A' + 10; - return true; - } else if ('a' <= c && c <= 'f') { - v = c - 'a' + 10; - return true; - } +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + detail::decompressor decompressor; + + if (!decompressor.is_valid()) { + status = 500; return false; -} + } + + if (x.get_header_value("Content-Encoding") == "gzip") { + out = [&](const char *buf, size_t n) { + return decompressor.decompress( + buf, n, [&](const char *buf, size_t n) { return receiver(buf, n); }); + }; + } +#else + if (x.get_header_value("Content-Encoding") == "gzip") { + status = 415; + return false; + } +#endif -inline bool from_hex_to_i(const std::string& s, size_t i, size_t cnt, int& val) -{ - if (i >= s.size()) { - return false; - } + auto ret = true; + auto exceed_payload_max_length = false; - val = 0; - for (; cnt; i++, cnt--) { - if (!s[i]) { - return false; - } - int v = 0; - if (is_hex(s[i], v)) { - val = val * 16 + v; - } else { - return false; - } + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value_uint64(x.headers, "Content-Length", 0); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, progress, out); } - return true; + } + + if (!ret) { status = exceed_payload_max_length ? 413 : 400; } + + return ret; } -inline size_t to_utf8(int code, char* buff) -{ - if (code < 0x0080) { - buff[0] = (code & 0x7F); - return 1; - } else if (code < 0x0800) { - buff[0] = (0xC0 | ((code >> 6) & 0x1F)); - buff[1] = (0x80 | (code & 0x3F)); - return 2; - } else if (code < 0xD800) { - buff[0] = (0xE0 | ((code >> 12) & 0xF)); - buff[1] = (0x80 | ((code >> 6) & 0x3F)); - buff[2] = (0x80 | (code & 0x3F)); - return 3; - } else if (code < 0xE000) { // D800 - DFFF is invalid... - return 0; - } else if (code < 0x10000) { - buff[0] = (0xE0 | ((code >> 12) & 0xF)); - buff[1] = (0x80 | ((code >> 6) & 0x3F)); - buff[2] = (0x80 | (code & 0x3F)); - return 3; - } else if (code < 0x110000) { - buff[0] = (0xF0 | ((code >> 18) & 0x7)); - buff[1] = (0x80 | ((code >> 12) & 0x3F)); - buff[2] = (0x80 | ((code >> 6) & 0x3F)); - buff[3] = (0x80 | (code & 0x3F)); - return 4; - } +template <typename T> +inline int write_headers(Stream &strm, const T &info, const Headers &headers) { + auto write_len = 0; + for (const auto &x : info.headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + for (const auto &x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; +} + +inline ssize_t write_content(Stream &strm, ContentProvider content_provider, + size_t offset, size_t length) { + size_t begin_offset = offset; + size_t end_offset = offset + length; + while (offset < end_offset) { + ssize_t written_length = 0; + content_provider( + offset, end_offset - offset, + [&](const char *d, size_t l) { + offset += l; + written_length = strm.write(d, l); + }, + [&](void) { written_length = -1; }); + if (written_length < 0) { return written_length; } + } + return static_cast<ssize_t>(offset - begin_offset); +} + +inline ssize_t write_content_chunked(Stream &strm, + ContentProvider content_provider) { + size_t offset = 0; + auto data_available = true; + ssize_t total_written_length = 0; + while (data_available) { + ssize_t written_length = 0; + content_provider( + offset, 0, + [&](const char *d, size_t l) { + data_available = l > 0; + offset += l; + + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) + "\r\n"; + written_length = strm.write(chunk); + }, + [&](void) { + data_available = false; + written_length = strm.write("0\r\n\r\n"); + }); - // NOTREACHED - return 0; + if (written_length < 0) { return written_length; } + total_written_length += written_length; + } + return total_written_length; } -inline std::string decode_url(const std::string& s) -{ - std::string result; - - for (size_t i = 0; i < s.size(); i++) { - if (s[i] == '%' && i + 1 < s.size()) { - if (s[i + 1] == 'u') { - int val = 0; - if (from_hex_to_i(s, i + 2, 4, val)) { - // 4 digits Unicode codes - char buff[4]; - size_t len = to_utf8(val, buff); - if (len > 0) { - result.append(buff, len); - } - i += 5; // 'u0000' - } else { - result += s[i]; - } - } else { - int val = 0; - if (from_hex_to_i(s, i + 1, 2, val)) { - // 2 digits hex codes - result += val; - i += 2; // '00' - } else { - result += s[i]; - } - } - } else if (s[i] == '+') { - result += ' '; +template <typename T> +inline bool redirect(T &cli, const Request &req, Response &res, + const std::string &path) { + Request new_req; + new_req.method = req.method; + new_req.path = path; + new_req.headers = req.headers; + new_req.body = req.body; + new_req.redirect_count = req.redirect_count - 1; + new_req.response_handler = req.response_handler; + new_req.content_receiver = req.content_receiver; + new_req.progress = req.progress; + + Response new_res; + auto ret = cli.send(new_req, new_res); + if (ret) { res = new_res; } + return ret; +} + +inline std::string encode_url(const std::string &s) { + std::string result; + + for (auto i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + case ':': result += "%3A"; break; + case ';': result += "%3B"; break; + default: + auto c = static_cast<uint8_t>(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + size_t len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, len); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + int val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' } else { - result += s[i]; + result += s[i]; } + } else { + int val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast<char>(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (s[i] == '+') { + result += ' '; + } else { + result += s[i]; } + } - return result; + return result; } -inline void parse_query_text(const std::string& s, Params& params) -{ - split(&s[0], &s[s.size()], '&', [&](const char* b, const char* e) { - std::string key; - std::string val; - split(b, e, '=', [&](const char* b, const char* e) { - if (key.empty()) { - key.assign(b, e); - } else { - val.assign(b, e); - } - }); - params.emplace(key, decode_url(val)); +inline void parse_query_text(const std::string &s, Params ¶ms) { + split(&s[0], &s[s.size()], '&', [&](const char *b, const char *e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char *b, const char *e) { + if (key.empty()) { + key.assign(b, e); + } else { + val.assign(b, e); + } }); + params.emplace(key, decode_url(val)); + }); } -inline bool parse_multipart_boundary(const std::string& content_type, std::string& boundary) -{ - auto pos = content_type.find("boundary="); - if (pos == std::string::npos) { - return false; - } +inline bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary) { + auto pos = content_type.find("boundary="); + if (pos == std::string::npos) { return false; } - boundary = content_type.substr(pos + 9); - return true; + boundary = content_type.substr(pos + 9); + return true; } -inline bool parse_multipart_formdata( - const std::string& boundary, const std::string& body, MultipartFiles& files) -{ - static std::string dash = "--"; - static std::string crlf = "\r\n"; +inline bool parse_multipart_formdata(const std::string &boundary, + const std::string &body, + MultipartFiles &files) { + static std::string dash = "--"; + static std::string crlf = "\r\n"; - static std::regex re_content_type( - "Content-Type: (.*?)", std::regex_constants::icase); + static std::regex re_content_type("Content-Type: (.*?)", + std::regex_constants::icase); - static std::regex re_content_disposition( - "Content-Disposition: form-data; name=\"(.*?)\"(?:; filename=\"(.*?)\")?", - std::regex_constants::icase); + static std::regex re_content_disposition( + "Content-Disposition: form-data; name=\"(.*?)\"(?:; filename=\"(.*?)\")?", + std::regex_constants::icase); - auto dash_boundary = dash + boundary; + auto dash_boundary = dash + boundary; - auto pos = body.find(dash_boundary); - if (pos != 0) { - return false; - } + auto pos = body.find(dash_boundary); + if (pos != 0) { return false; } - pos += dash_boundary.size(); + pos += dash_boundary.size(); - auto next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } + auto next_pos = body.find(crlf, pos); + if (next_pos == std::string::npos) { return false; } - pos = next_pos + crlf.size(); + pos = next_pos + crlf.size(); - while (pos < body.size()) { - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } + while (pos < body.size()) { + next_pos = body.find(crlf, pos); + if (next_pos == std::string::npos) { return false; } - std::string name; - MultipartFile file; + std::string name; + MultipartFile file; - auto header = body.substr(pos, (next_pos - pos)); + auto header = body.substr(pos, (next_pos - pos)); - while (pos != next_pos) { - std::smatch m; - if (std::regex_match(header, m, re_content_type)) { - file.content_type = m[1]; - } else if (std::regex_match(header, m, re_content_disposition)) { - name = m[1]; - file.filename = m[2]; - } + while (pos != next_pos) { + std::smatch m; + if (std::regex_match(header, m, re_content_type)) { + file.content_type = m[1]; + } else if (std::regex_match(header, m, re_content_disposition)) { + name = m[1]; + file.filename = m[2]; + } - pos = next_pos + crlf.size(); + pos = next_pos + crlf.size(); - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } + next_pos = body.find(crlf, pos); + if (next_pos == std::string::npos) { return false; } - header = body.substr(pos, (next_pos - pos)); - } + header = body.substr(pos, (next_pos - pos)); + } - pos = next_pos + crlf.size(); + pos = next_pos + crlf.size(); - next_pos = body.find(crlf + dash_boundary, pos); + next_pos = body.find(crlf + dash_boundary, pos); - if (next_pos == std::string::npos) { - return false; - } + if (next_pos == std::string::npos) { return false; } - file.offset = pos; - file.length = next_pos - pos; + file.offset = pos; + file.length = next_pos - pos; - pos = next_pos + crlf.size() + dash_boundary.size(); + pos = next_pos + crlf.size() + dash_boundary.size(); - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } + next_pos = body.find(crlf, pos); + if (next_pos == std::string::npos) { return false; } - files.emplace(name, file); + files.emplace(name, file); - pos = next_pos + crlf.size(); - } + pos = next_pos + crlf.size(); + } - return true; + return true; } -inline std::string to_lower(const char* beg, const char* end) -{ - std::string out; - auto it = beg; - while (it != end) { - out += ::tolower(*it); - it++; +inline bool parse_range_header(const std::string &s, Ranges &ranges) { + try { + static auto re = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re)) { + auto pos = m.position(1); + auto len = m.length(1); + detail::split(&s[pos], &s[pos + len], ',', + [&](const char *b, const char *e) { + static auto re = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch m; + if (std::regex_match(b, e, m, re)) { + ssize_t first = -1; + if (!m.str(1).empty()) { + first = static_cast<ssize_t>(std::stoll(m.str(1))); + } + + ssize_t last = -1; + if (!m.str(2).empty()) { + last = static_cast<ssize_t>(std::stoll(m.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + throw std::runtime_error("invalid range error"); + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return true; } - return out; + return false; + } catch (...) { return false; } +} + +inline std::string to_lower(const char *beg, const char *end) { + std::string out; + auto it = beg; + while (it != end) { + out += static_cast<char>(::tolower(*it)); + it++; + } + return out; } -inline void make_range_header_core(std::string&) {} +inline std::string make_multipart_data_boundary() { + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; -template<typename uint64_t> -inline void make_range_header_core(std::string& field, uint64_t value) -{ - if (!field.empty()) { - field += ", "; - } - field += std::to_string(value) + "-"; + std::random_device seed_gen; + std::mt19937 engine(seed_gen()); + + std::string result = "--cpp-httplib-multipart-data-"; + + for (auto i = 0; i < 16; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + + return result; } -template<typename uint64_t, typename... Args> -inline void make_range_header_core(std::string& field, uint64_t value1, uint64_t value2, Args... args) -{ - if (!field.empty()) { - field += ", "; - } - field += std::to_string(value1) + "-" + std::to_string(value2); - make_range_header_core(field, args...); +inline std::pair<size_t, size_t> +get_range_offset_and_length(const Request &req, size_t content_length, + size_t index) { + auto r = req.ranges[index]; + + if (r.first == -1 && r.second == -1) { + return std::make_pair(0, content_length); + } + + if (r.first == -1) { + r.first = content_length - r.second; + r.second = content_length - 1; + } + + if (r.second == -1) { r.second = content_length - 1; } + + return std::make_pair(r.first, r.second - r.first + 1); } -#ifdef CPPHTTPLIB_ZLIB_SUPPORT -inline bool can_compress(const std::string& content_type) { - return !content_type.find("text/") || - content_type == "image/svg+xml" || - content_type == "application/javascript" || - content_type == "application/json" || - content_type == "application/xml" || - content_type == "application/xhtml+xml"; -} - -inline void compress(std::string& content) -{ - z_stream strm; - strm.zalloc = Z_NULL; - strm.zfree = Z_NULL; - strm.opaque = Z_NULL; +inline std::string make_content_range_header_field(size_t offset, size_t length, + size_t content_length) { + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; +} - auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, Z_DEFAULT_STRATEGY); - if (ret != Z_OK) { - return; +template <typename SToken, typename CToken, typename Content> +bool process_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + SToken stoken, CToken ctoken, + Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); } - strm.avail_in = content.size(); - strm.next_in = (Bytef *)content.data(); + auto offsets = detail::get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; - std::string compressed; + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if (!content(offset, length)) { return false; } + ctoken("\r\n"); + } - const auto bufsiz = 16384; - char buff[bufsiz]; - do { - strm.avail_out = bufsiz; - strm.next_out = (Bytef *)buff; - deflate(&strm, Z_FINISH); - compressed.append(buff, bufsiz - strm.avail_out); - } while (strm.avail_out == 0); + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); + + return true; +} - content.swap(compressed); +inline std::string make_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + std::string data; - deflateEnd(&strm); + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data += token; }, + [&](const char *token) { data += token; }, + [&](size_t offset, size_t length) { + data += res.body.substr(offset, length); + return true; + }); + + return data; } -inline void decompress(std::string& content) -{ - z_stream strm; - strm.zalloc = Z_NULL; - strm.zfree = Z_NULL; - strm.opaque = Z_NULL; +inline size_t +get_multipart_ranges_data_length(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + size_t data_length = 0; - // 15 is the value of wbits, which should be at the maximum possible value to ensure - // that any gzip stream can be decoded. The offset of 16 specifies that the stream - // to decompress will be formatted with a gzip wrapper. - auto ret = inflateInit2(&strm, 16 + 15); - if (ret != Z_OK) { - return; - } + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data_length += token.size(); }, + [&](const char *token) { data_length += strlen(token); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); - strm.avail_in = content.size(); - strm.next_in = (Bytef *)content.data(); + return data_length; +} - std::string decompressed; +inline bool write_multipart_ranges_data(Stream &strm, const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type) { + return process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { strm.write(token); }, + [&](const char *token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return detail::write_content(strm, res.content_provider, offset, + length) >= 0; + }); +} - const auto bufsiz = 16384; - char buff[bufsiz]; - do { - strm.avail_out = bufsiz; - strm.next_out = (Bytef *)buff; - inflate(&strm, Z_NO_FLUSH); - decompressed.append(buff, bufsiz - strm.avail_out); - } while (strm.avail_out == 0); +inline std::pair<size_t, size_t> +get_range_offset_and_length(const Request &req, const Response &res, + size_t index) { + auto r = req.ranges[index]; - content.swap(decompressed); + if (r.second == -1) { r.second = res.content_provider_resource_length - 1; } - inflateEnd(&strm); + return std::make_pair(r.first, r.second - r.first + 1); } -#endif #ifdef _WIN32 class WSInit { public: - WSInit() { - WSADATA wsaData; - WSAStartup(0x0002, &wsaData); - } + WSInit() { + WSADATA wsaData; + WSAStartup(0x0002, &wsaData); + } - ~WSInit() { - WSACleanup(); - } + ~WSInit() { WSACleanup(); } }; static WSInit wsinit_; @@ -1275,876 +2010,1273 @@ static WSInit wsinit_; } // namespace detail // Header utilities -template<typename uint64_t, typename... Args> -inline std::pair<std::string, std::string> make_range_header(uint64_t value, Args... args) -{ - std::string field; - detail::make_range_header_core(field, value, args...); - field.insert(0, "bytes="); - return std::make_pair("Range", field); +inline std::pair<std::string, std::string> make_range_header(Ranges ranges) { + std::string field = "bytes="; + auto i = 0; + for (auto r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", field); +} + +inline std::pair<std::string, std::string> +make_basic_authentication_header(const std::string &username, + const std::string &password) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + return std::make_pair("Authorization", field); } // Request implementation -inline bool Request::has_header(const char* key) const -{ - return headers.find(key) != headers.end(); +inline bool Request::has_header(const char *key) const { + return detail::has_header(headers, key); } -inline std::string Request::get_header_value(const char* key) const -{ - return detail::get_header_value(headers, key, ""); +inline std::string Request::get_header_value(const char *key, size_t id) const { + return detail::get_header_value(headers, key, id, ""); } -inline void Request::set_header(const char* key, const char* val) -{ - headers.emplace(key, val); +inline size_t Request::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return std::distance(r.first, r.second); } -inline bool Request::has_param(const char* key) const -{ - return params.find(key) != params.end(); +inline void Request::set_header(const char *key, const char *val) { + headers.emplace(key, val); } -inline std::string Request::get_param_value(const char* key) const -{ - auto it = params.find(key); - if (it != params.end()) { - return it->second; - } - return std::string(); +inline void Request::set_header(const char *key, const std::string &val) { + headers.emplace(key, val); } -inline bool Request::has_file(const char* key) const -{ - return files.find(key) != files.end(); +inline bool Request::has_param(const char *key) const { + return params.find(key) != params.end(); } -inline MultipartFile Request::get_file_value(const char* key) const -{ - auto it = files.find(key); - if (it != files.end()) { - return it->second; - } - return MultipartFile(); +inline std::string Request::get_param_value(const char *key, size_t id) const { + auto it = params.find(key); + std::advance(it, id); + if (it != params.end()) { return it->second; } + return std::string(); +} + +inline size_t Request::get_param_value_count(const char *key) const { + auto r = params.equal_range(key); + return std::distance(r.first, r.second); +} + +inline bool Request::has_file(const char *key) const { + return files.find(key) != files.end(); +} + +inline MultipartFile Request::get_file_value(const char *key) const { + auto it = files.find(key); + if (it != files.end()) { return it->second; } + return MultipartFile(); } // Response implementation -inline bool Response::has_header(const char* key) const -{ - return headers.find(key) != headers.end(); +inline bool Response::has_header(const char *key) const { + return headers.find(key) != headers.end(); +} + +inline std::string Response::get_header_value(const char *key, + size_t id) const { + return detail::get_header_value(headers, key, id, ""); +} + +inline size_t Response::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return std::distance(r.first, r.second); } -inline std::string Response::get_header_value(const char* key) const -{ - return detail::get_header_value(headers, key, ""); +inline void Response::set_header(const char *key, const char *val) { + headers.emplace(key, val); } -inline void Response::set_header(const char* key, const char* val) -{ - headers.emplace(key, val); +inline void Response::set_header(const char *key, const std::string &val) { + headers.emplace(key, val); } -inline void Response::set_redirect(const char* url) -{ - set_header("Location", url); - status = 302; +inline void Response::set_redirect(const char *url) { + set_header("Location", url); + status = 302; } -inline void Response::set_content(const char* s, size_t n, const char* content_type) -{ - body.assign(s, n); - set_header("Content-Type", content_type); +inline void Response::set_content(const char *s, size_t n, + const char *content_type) { + body.assign(s, n); + set_header("Content-Type", content_type); } -inline void Response::set_content(const std::string& s, const char* content_type) -{ - body = s; - set_header("Content-Type", content_type); +inline void Response::set_content(const std::string &s, + const char *content_type) { + body = s; + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider( + size_t length, + std::function<void(size_t offset, size_t length, DataSink sink)> provider, + std::function<void()> resource_releaser) { + assert(length > 0); + content_provider_resource_length = length; + content_provider = [provider](size_t offset, size_t length, DataSink sink, + Done) { provider(offset, length, sink); }; + content_provider_resource_releaser = resource_releaser; +} + +inline void Response::set_chunked_content_provider( + std::function<void(size_t offset, DataSink sink, Done done)> provider, + std::function<void()> resource_releaser) { + content_provider_resource_length = 0; + content_provider = [provider](size_t offset, size_t, DataSink sink, + Done done) { provider(offset, sink, done); }; + content_provider_resource_releaser = resource_releaser; } // Rstream implementation -template <typename ...Args> -inline void Stream::write_format(const char* fmt, const Args& ...args) -{ - const auto bufsiz = 2048; - char buf[bufsiz]; +template <typename... Args> +inline int Stream::write_format(const char *fmt, const Args &... args) { + const auto bufsiz = 2048; + char buf[bufsiz]; #if defined(_MSC_VER) && _MSC_VER < 1900 - auto n = _snprintf_s(buf, bufsiz, bufsiz - 1, fmt, args...); + auto n = _snprintf_s(buf, bufsiz, bufsiz - 1, fmt, args...); #else - auto n = snprintf(buf, bufsiz - 1, fmt, args...); + auto n = snprintf(buf, bufsiz - 1, fmt, args...); #endif - if (n > 0) { - if (n >= bufsiz - 1) { - std::vector<char> glowable_buf(bufsiz); + if (n <= 0) { return n; } + + if (n >= bufsiz - 1) { + std::vector<char> glowable_buf(bufsiz); - while (n >= static_cast<int>(glowable_buf.size() - 1)) { - glowable_buf.resize(glowable_buf.size() * 2); + while (n >= static_cast<int>(glowable_buf.size() - 1)) { + glowable_buf.resize(glowable_buf.size() * 2); #if defined(_MSC_VER) && _MSC_VER < 1900 - n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), glowable_buf.size() - 1, fmt, args...); + n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), + glowable_buf.size() - 1, fmt, args...); #else - n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); + n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); #endif - } - write(&glowable_buf[0], n); - } else { - write(buf, n); - } } + return write(&glowable_buf[0], n); + } else { + return write(buf, n); + } } // Socket stream implementation -inline SocketStream::SocketStream(socket_t sock): sock_(sock) -{ +inline SocketStream::SocketStream(socket_t sock) : sock_(sock) {} + +inline SocketStream::~SocketStream() {} + +inline int SocketStream::read(char *ptr, size_t size) { + if (detail::select_read(sock_, CPPHTTPLIB_READ_TIMEOUT_SECOND, + CPPHTTPLIB_READ_TIMEOUT_USECOND) > 0) { + return recv(sock_, ptr, static_cast<int>(size), 0); + } + return -1; +} + +inline int SocketStream::write(const char *ptr, size_t size) { + return send(sock_, ptr, static_cast<int>(size), 0); } -inline SocketStream::~SocketStream() -{ +inline int SocketStream::write(const char *ptr) { + return write(ptr, strlen(ptr)); } -inline int SocketStream::read(char* ptr, size_t size) -{ - return recv(sock_, ptr, size, 0); +inline int SocketStream::write(const std::string &s) { + return write(s.data(), s.size()); } -inline int SocketStream::write(const char* ptr, size_t size) -{ - return send(sock_, ptr, size, 0); +inline std::string SocketStream::get_remote_addr() const { + return detail::get_remote_addr(sock_); } -inline int SocketStream::write(const char* ptr) -{ - return write(ptr, strlen(ptr)); +// Buffer stream implementation +inline int BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1900 + return static_cast<int>(buffer._Copy_s(ptr, size, size)); +#else + return static_cast<int>(buffer.copy(ptr, size)); +#endif } -inline std::string SocketStream::get_remote_addr() { - return detail::get_remote_addr(sock_); +inline int BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast<int>(size); } +inline int BufferStream::write(const char *ptr) { + return write(ptr, strlen(ptr)); +} + +inline int BufferStream::write(const std::string &s) { + return write(s.data(), s.size()); +} + +inline std::string BufferStream::get_remote_addr() const { return ""; } + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + // HTTP server implementation inline Server::Server() - : keep_alive_max_count_(5) - , is_running_(false) - , svr_sock_(INVALID_SOCKET) - , running_threads_(0) -{ + : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), + payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false), + svr_sock_(INVALID_SOCKET) { #ifndef _WIN32 - signal(SIGPIPE, SIG_IGN); + signal(SIGPIPE, SIG_IGN); #endif + new_task_queue = [] { +#if CPPHTTPLIB_THREAD_POOL_COUNT > 0 + return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); +#else + return new Threads(); +#endif + }; } -inline Server::~Server() -{ -} +inline Server::~Server() {} -inline Server& Server::Get(const char* pattern, Handler handler) -{ - get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Get(const char *pattern, Handler handler) { + get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Post(const char* pattern, Handler handler) -{ - post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Post(const char *pattern, Handler handler) { + post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Put(const char* pattern, Handler handler) -{ - put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Put(const char *pattern, Handler handler) { + put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Delete(const char* pattern, Handler handler) -{ - delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Patch(const char *pattern, Handler handler) { + patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Options(const char* pattern, Handler handler) -{ - options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Delete(const char *pattern, Handler handler) { + delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline bool Server::set_base_dir(const char* path) -{ - if (detail::is_dir(path)) { - base_dir_ = path; - return true; - } - return false; +inline Server &Server::Options(const char *pattern, Handler handler) { + options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline void Server::set_error_handler(Handler handler) -{ - error_handler_ = handler; +inline bool Server::set_base_dir(const char *path) { + if (detail::is_dir(path)) { + base_dir_ = path; + return true; + } + return false; } -inline void Server::set_logger(Logger logger) -{ - logger_ = logger; +inline void Server::set_file_request_handler(Handler handler) { + file_request_handler_ = handler; } -inline void Server::set_keep_alive_max_count(size_t count) -{ - keep_alive_max_count_ = count; +inline void Server::set_error_handler(Handler handler) { + error_handler_ = handler; } -inline int Server::bind_to_any_port(const char* host, int socket_flags) -{ - return bind_internal(host, 0, socket_flags); +inline void Server::set_logger(Logger logger) { logger_ = logger; } + +inline void Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; } -inline bool Server::listen_after_bind() { - return listen_internal(); +inline void Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; } -inline bool Server::listen(const char* host, int port, int socket_flags) -{ - if (bind_internal(host, port, socket_flags) < 0) - return false; - return listen_internal(); +inline int Server::bind_to_any_port(const char *host, int socket_flags) { + return bind_internal(host, 0, socket_flags); } -inline bool Server::is_running() const -{ - return is_running_; +inline bool Server::listen_after_bind() { return listen_internal(); } + +inline bool Server::listen(const char *host, int port, int socket_flags) { + if (bind_internal(host, port, socket_flags) < 0) return false; + return listen_internal(); } -inline void Server::stop() -{ - if (is_running_) { - assert(svr_sock_ != INVALID_SOCKET); - detail::shutdown_socket(svr_sock_); - detail::close_socket(svr_sock_); - svr_sock_ = INVALID_SOCKET; - } +inline bool Server::is_running() const { return is_running_; } + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic<socket_t> sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } } -inline bool Server::parse_request_line(const char* s, Request& req) -{ - static std::regex re("(GET|HEAD|POST|PUT|DELETE|OPTIONS) (([^?]+)(?:\\?(.+?))?) (HTTP/1\\.[01])\r\n"); +inline bool Server::parse_request_line(const char *s, Request &req) { + static std::regex re("(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " + "(([^?]+)(?:\\?(.+?))?) (HTTP/1\\.[01])\r\n"); - std::cmatch m; - if (std::regex_match(s, m, re)) { - req.version = std::string(m[4]); - req.method = std::string(m[1]); - req.target = std::string(m[2]); - req.path = detail::decode_url(m[3]); - - // Parse query text - auto len = std::distance(m[4].first, m[4].second); - if (len > 0) { - detail::parse_query_text(m[4], req.params); - } + std::cmatch m; + if (std::regex_match(s, m, re)) { + req.version = std::string(m[5]); + req.method = std::string(m[1]); + req.target = std::string(m[2]); + req.path = detail::decode_url(m[3]); - return true; - } + // Parse query text + auto len = std::distance(m[4].first, m[4].second); + if (len > 0) { detail::parse_query_text(m[4], req.params); } - return false; -} + return true; + } -inline void Server::write_response(Stream& strm, bool last_connection, const Request& req, Response& res) -{ - assert(res.status != -1); + return false; +} - if (400 <= res.status && error_handler_) { - error_handler_(req, res); - } +inline bool Server::write_response(Stream &strm, bool last_connection, + const Request &req, Response &res) { + assert(res.status != -1); - // Response line - strm.write_format("HTTP/1.1 %d %s\r\n", - res.status, - detail::status_message(res.status)); + if (400 <= res.status && error_handler_) { error_handler_(req, res); } - // Headers - if (last_connection || - req.version == "HTTP/1.0" || - req.get_header_value("Connection") == "close") { - res.set_header("Connection", "close"); + // Response line + if (!strm.write_format("HTTP/1.1 %d %s\r\n", res.status, + detail::status_message(res.status))) { + return false; + } + + // Headers + if (last_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } + + if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") { + res.set_header("Connection", "Keep-Alive"); + } + + if (!res.has_header("Content-Type")) { + res.set_header("Content-Type", "text/plain"); + } + + if (!res.has_header("Accept-Ranges")) { + res.set_header("Accept-Ranges", "bytes"); + } + + std::string content_type; + std::string boundary; + + if (req.ranges.size() > 1) { + boundary = detail::make_multipart_data_boundary(); + + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + res.headers.emplace("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } + + if (res.body.empty()) { + if (res.content_provider_resource_length > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_provider_resource_length; + } else if (req.ranges.size() == 1) { + auto offsets = detail::get_range_offset_and_length( + req, res.content_provider_resource_length, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.content_provider_resource_length); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, res, boundary, + content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider) { + res.set_header("Transfer-Encoding", "chunked"); + } else { + res.set_header("Content-Length", "0"); + } + } + } else { + if (req.ranges.empty()) { + ; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.body.size(), 0); + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.body.size()); + res.set_header("Content-Range", content_range); + res.body = res.body.substr(offset, length); + } else { + res.body = + detail::make_multipart_ranges_data(req, res, boundary, content_type); } - if (!res.body.empty()) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - // TODO: 'Accpet-Encoding' has gzip, not gzip;q=0 - const auto& encodings = req.get_header_value("Accept-Encoding"); - if (encodings.find("gzip") != std::string::npos && - detail::can_compress(res.get_header_value("Content-Type"))) { - detail::compress(res.body); - res.set_header("Content-Encoding", "gzip"); - } + // TODO: 'Accpet-Encoding' has gzip, not gzip;q=0 + const auto &encodings = req.get_header_value("Accept-Encoding"); + if (encodings.find("gzip") != std::string::npos && + detail::can_compress(res.get_header_value("Content-Type"))) { + if (detail::compress(res.body)) { + res.set_header("Content-Encoding", "gzip"); + } + } #endif - if (!res.has_header("Content-Type")) { - res.set_header("Content-Type", "text/plain"); - } + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } + + if (!detail::write_headers(strm, res, Headers())) { return false; } - auto length = std::to_string(res.body.size()); - res.set_header("Content-Length", length.c_str()); + // Body + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!strm.write(res.body)) { return false; } + } else if (res.content_provider) { + if (!write_content_with_provider(strm, req, res, boundary, + content_type)) { + return false; + } } + } - detail::write_headers(strm, res); + // Log + if (logger_) { logger_(req, res); } - // Body - if (!res.body.empty() && req.method != "HEAD") { - strm.write(res.body.c_str(), res.body.size()); - } + return true; +} - // Log - if (logger_) { - logger_(req, res); +inline bool +Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + if (res.content_provider_resource_length) { + if (req.ranges.empty()) { + if (detail::write_content(strm, res.content_provider, 0, + res.content_provider_resource_length) < 0) { + return false; + } + } else if (req.ranges.size() == 1) { + auto offsets = detail::get_range_offset_and_length( + req, res.content_provider_resource_length, 0); + auto offset = offsets.first; + auto length = offsets.second; + if (detail::write_content(strm, res.content_provider, offset, length) < + 0) { + return false; + } + } else { + if (!detail::write_multipart_ranges_data(strm, req, res, boundary, + content_type)) { + return false; + } + } + } else { + if (detail::write_content_chunked(strm, res.content_provider) < 0) { + return false; } + } + return true; } -inline bool Server::handle_file_request(Request& req, Response& res) -{ - if (!base_dir_.empty() && detail::is_valid_path(req.path)) { - std::string path = base_dir_ + req.path; +inline bool Server::handle_file_request(Request &req, Response &res) { + if (!base_dir_.empty() && detail::is_valid_path(req.path)) { + std::string path = base_dir_ + req.path; - if (!path.empty() && path.back() == '/') { - path += "index.html"; - } + if (!path.empty() && path.back() == '/') { path += "index.html"; } - if (detail::is_file(path)) { - detail::read_file(path, res.body); - auto type = detail::find_content_type(path); - if (type) { - res.set_header("Content-Type", type); - } - res.status = 200; - return true; - } + if (detail::is_file(path)) { + detail::read_file(path, res.body); + auto type = detail::find_content_type(path); + if (type) { res.set_header("Content-Type", type); } + res.status = 200; + if (file_request_handler_) { file_request_handler_(req, res); } + return true; } + } - return false; + return false; } -inline socket_t Server::create_server_socket(const char* host, int port, int socket_flags) const -{ - return detail::create_socket(host, port, - [](socket_t sock, struct addrinfo& ai) -> bool { - if (::bind(sock, ai.ai_addr, ai.ai_addrlen)) { - return false; - } - if (::listen(sock, 5)) { // Listen through 5 channels - return false; - } - return true; - }, socket_flags); -} - -inline int Server::bind_internal(const char* host, int port, int socket_flags) -{ - if (!is_valid()) { - return -1; - } - - svr_sock_ = create_server_socket(host, port, socket_flags); - if (svr_sock_ == INVALID_SOCKET) { - return -1; - } - - if (port == 0) { - struct sockaddr_storage address; - socklen_t len = sizeof(address); - if (getsockname(svr_sock_, reinterpret_cast<struct sockaddr *>(&address), &len) == -1) { - return -1; +inline socket_t Server::create_server_socket(const char *host, int port, + int socket_flags) const { + return detail::create_socket( + host, port, + [](socket_t sock, struct addrinfo &ai) -> bool { + if (::bind(sock, ai.ai_addr, static_cast<int>(ai.ai_addrlen))) { + return false; } - if (address.ss_family == AF_INET) { - return ntohs(reinterpret_cast<struct sockaddr_in*>(&address)->sin_port); - } else if (address.ss_family == AF_INET6) { - return ntohs(reinterpret_cast<struct sockaddr_in6*>(&address)->sin6_port); - } else { - return -1; + if (::listen(sock, 5)) { // Listen through 5 channels + return false; } + return true; + }, + socket_flags); +} + +inline int Server::bind_internal(const char *host, int port, int socket_flags) { + if (!is_valid()) { return -1; } + + svr_sock_ = create_server_socket(host, port, socket_flags); + if (svr_sock_ == INVALID_SOCKET) { return -1; } + + if (port == 0) { + struct sockaddr_storage address; + socklen_t len = sizeof(address); + if (getsockname(svr_sock_, reinterpret_cast<struct sockaddr *>(&address), + &len) == -1) { + return -1; + } + if (address.ss_family == AF_INET) { + return ntohs(reinterpret_cast<struct sockaddr_in *>(&address)->sin_port); + } else if (address.ss_family == AF_INET6) { + return ntohs( + reinterpret_cast<struct sockaddr_in6 *>(&address)->sin6_port); } else { - return port; + return -1; } + } else { + return port; + } } -inline bool Server::listen_internal() -{ - auto ret = true; +inline bool Server::listen_internal() { + auto ret = true; + is_running_ = true; - is_running_ = true; + { + std::unique_ptr<TaskQueue> task_queue(new_task_queue()); for (;;) { - auto val = detail::select_read(svr_sock_, 0, 100000); + if (svr_sock_ == INVALID_SOCKET) { + // The server socket was closed by 'stop' method. + break; + } - if (val == 0) { // Timeout - if (svr_sock_ == INVALID_SOCKET) { - // The server socket was closed by 'stop' method. - break; - } - continue; - } + auto val = detail::select_read(svr_sock_, 0, 100000); - socket_t sock = accept(svr_sock_, NULL, NULL); + if (val == 0) { // Timeout + continue; + } - if (sock == INVALID_SOCKET) { - if (svr_sock_ != INVALID_SOCKET) { - detail::close_socket(svr_sock_); - ret = false; - } else { - ; // The server socket was closed by user. - } - break; + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. } + break; + } - // TODO: Use thread pool... - std::thread([=]() { - { - std::lock_guard<std::mutex> guard(running_threads_mutex_); - running_threads_++; - } + task_queue->enqueue([=]() { process_and_close_socket(sock); }); + } - read_and_close_socket(sock); + task_queue->shutdown(); + } - { - std::lock_guard<std::mutex> guard(running_threads_mutex_); - running_threads_--; - } - }).detach(); + is_running_ = false; + return ret; +} + +inline bool Server::routing(Request &req, Response &res) { + if (req.method == "GET" && handle_file_request(req, res)) { return true; } + + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = 400; + return false; +} + +inline bool Server::dispatch_request(Request &req, Response &res, + Handlers &handlers) { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res); + return true; + } + } + return false; +} + +inline bool +Server::process_request(Stream &strm, bool last_connection, + bool &connection_close, + std::function<void(Request &)> setup_request) { + const auto bufsiz = 2048; + char buf[bufsiz]; + + detail::stream_line_reader reader(strm, buf, bufsiz); + + // Connection has been closed on client + if (!reader.getline()) { return false; } + + Request req; + Response res; + + res.version = "HTTP/1.1"; + + // Check if the request URI doesn't exceed the limit + if (reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 414; + return write_response(strm, last_connection, req, res); + } + + // Request line and headers + if (!parse_request_line(reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_close = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_close = true; + } + + req.set_header("REMOTE_ADDR", strm.get_remote_addr()); + + // Body + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || req.method == "PRI") { + if (!detail::read_content(strm, req, payload_max_length_, res.status, + Progress(), [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { + return false; + } + req.body.append(buf, n); + return true; + })) { + return write_response(strm, last_connection, req, res); } - // TODO: Use thread pool... - for (;;) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - std::lock_guard<std::mutex> guard(running_threads_mutex_); - if (!running_threads_) { - break; - } + const auto &content_type = req.get_header_value("Content-Type"); + + if (!content_type.find("application/x-www-form-urlencoded")) { + detail::parse_query_text(req.body, req.params); + } else if (!content_type.find("multipart/form-data")) { + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary) || + !detail::parse_multipart_formdata(boundary, req.body, req.files)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + } + } + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + // TODO: error } + } - is_running_ = false; + if (setup_request) { setup_request(req); } - return ret; + if (routing(req, res)) { + if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } + } else { + if (res.status == -1) { res.status = 404; } + } + + return write_response(strm, last_connection, req, res); } -inline bool Server::routing(Request& req, Response& res) -{ - if (req.method == "GET" && handle_file_request(req, res)) { - return true; - } +inline bool Server::is_valid() const { return true; } - if (req.method == "GET" || req.method == "HEAD") { - return dispatch_request(req, res, get_handlers_); - } else if (req.method == "POST") { - return dispatch_request(req, res, post_handlers_); - } else if (req.method == "PUT") { - return dispatch_request(req, res, put_handlers_); - } else if (req.method == "DELETE") { - return dispatch_request(req, res, delete_handlers_); - } else if (req.method == "OPTIONS") { - return dispatch_request(req, res, options_handlers_); - } - return false; +inline bool Server::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket( + false, sock, keep_alive_max_count_, + [this](Stream &strm, bool last_connection, bool &connection_close) { + return process_request(strm, last_connection, connection_close, + nullptr); + }); } -inline bool Server::dispatch_request(Request& req, Response& res, Handlers& handlers) -{ - for (const auto& x: handlers) { - const auto& pattern = x.first; - const auto& handler = x.second; - - if (std::regex_match(req.path, req.matches, pattern)) { - handler(req, res); - return true; +// HTTP client implementation +inline Client::Client(const char *host, int port, time_t timeout_sec) + : host_(host), port_(port), timeout_sec_(timeout_sec), + host_and_port_(host_ + ":" + std::to_string(port_)), + keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), + follow_location_(false) {} + +inline Client::~Client() {} + +inline bool Client::is_valid() const { return true; } + +inline socket_t Client::create_client_socket() const { + return detail::create_socket( + host_.c_str(), port_, [=](socket_t sock, struct addrinfo &ai) -> bool { + detail::set_nonblocking(sock, true); + + auto ret = connect(sock, ai.ai_addr, static_cast<int>(ai.ai_addrlen)); + if (ret < 0) { + if (detail::is_connection_error() || + !detail::wait_until_socket_is_ready(sock, timeout_sec_, 0)) { + detail::close_socket(sock); + return false; + } } - } - return false; -} -inline bool Server::process_request(Stream& strm, bool last_connection, bool& connection_close) -{ - const auto bufsiz = 2048; - char buf[bufsiz]; + detail::set_nonblocking(sock, false); + return true; + }); +} - detail::stream_line_reader reader(strm, buf, bufsiz); +inline bool Client::read_response_line(Stream &strm, Response &res) { + const auto bufsiz = 2048; + char buf[bufsiz]; - // Connection has been closed on client - if (!reader.getline()) { - return false; - } + detail::stream_line_reader reader(strm, buf, bufsiz); - Request req; - Response res; + if (!reader.getline()) { return false; } - res.version = "HTTP/1.1"; + const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n"); - // Request line and headers - if (!parse_request_line(reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { - res.status = 400; - write_response(strm, last_connection, req, res); - return true; - } + std::cmatch m; + if (std::regex_match(reader.ptr(), m, re)) { + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + } - auto ret = true; - if (req.get_header_value("Connection") == "close") { - // ret = false; - connection_close = true; - } + return true; +} - req.set_header("REMOTE_ADDR", strm.get_remote_addr().c_str()); +inline bool Client::send(const Request &req, Response &res) { + if (req.path.empty()) { return false; } - // Body - if (req.method == "POST" || req.method == "PUT") { - if (!detail::read_content(strm, req)) { - res.status = 400; - write_response(strm, last_connection, req, res); - return ret; - } + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } - const auto& content_type = req.get_header_value("Content-Type"); + auto ret = process_and_close_socket( + sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) { + return process_request(strm, req, res, last_connection, + connection_close); + }); - if (req.get_header_value("Content-Encoding") == "gzip") { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - detail::decompress(req.body); -#else - res.status = 415; - write_response(strm, last_connection, req, res); - return ret; -#endif - } + if (ret && follow_location_ && (300 < res.status && res.status < 400)) { + ret = redirect(req, res); + } - if (!content_type.find("application/x-www-form-urlencoded")) { - detail::parse_query_text(req.body, req.params); - } else if(!content_type.find("multipart/form-data")) { - std::string boundary; - if (!detail::parse_multipart_boundary(content_type, boundary) || - !detail::parse_multipart_formdata(boundary, req.body, req.files)) { - res.status = 400; - write_response(strm, last_connection, req, res); - return ret; - } - } - } + return ret; +} - if (routing(req, res)) { - if (res.status == -1) { - res.status = 200; - } - } else { - res.status = 404; - } +inline bool Client::send(const std::vector<Request> &requests, + std::vector<Response> &responses) { + size_t i = 0; + while (i < requests.size()) { + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } - write_response(strm, last_connection, req, res); - return ret; -} + if (!process_and_close_socket( + sock, requests.size() - i, + [&](Stream &strm, bool last_connection, bool &connection_close) -> bool { + auto &req = requests[i]; + auto res = Response(); + i++; -inline bool Server::is_valid() const -{ - return true; -} + if (req.path.empty()) { return false; } + auto ret = process_request(strm, req, res, last_connection, + connection_close); -inline bool Server::read_and_close_socket(socket_t sock) -{ - return detail::read_and_close_socket( - sock, - keep_alive_max_count_, - [this](Stream& strm, bool last_connection, bool& connection_close) { - return process_request(strm, last_connection, connection_close); - }); -} + if (ret && follow_location_ && + (300 < res.status && res.status < 400)) { + ret = redirect(req, res); + } -// HTTP client implementation -inline Client::Client( - const char* host, int port, size_t timeout_sec) - : host_(host) - , port_(port) - , timeout_sec_(timeout_sec) - , host_and_port_(host_ + ":" + std::to_string(port_)) -{ -} + if (ret) { responses.emplace_back(std::move(res)); } -inline Client::~Client() -{ -} + return ret; + })) { + return false; + } + } -inline bool Client::is_valid() const -{ - return true; + return true; } -inline socket_t Client::create_client_socket() const -{ - return detail::create_socket(host_.c_str(), port_, - [=](socket_t sock, struct addrinfo& ai) -> bool { - detail::set_nonblocking(sock, true); +inline bool Client::redirect(const Request &req, Response &res) { + if (req.redirect_count == 0) { return false; } - auto ret = connect(sock, ai.ai_addr, ai.ai_addrlen); - if (ret < 0) { - if (detail::is_connection_error() || - !detail::wait_until_socket_is_ready(sock, timeout_sec_, 0)) { - detail::close_socket(sock); - return false; - } - } + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } - detail::set_nonblocking(sock, false); - return true; - }); -} + std::regex re( + R"(^(?:([^:/?#]+):)?(?://([^/?#]*))?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); -inline bool Client::read_response_line(Stream& strm, Response& res) -{ - const auto bufsiz = 2048; - char buf[bufsiz]; + auto scheme = is_ssl() ? "https" : "http"; - detail::stream_line_reader reader(strm, buf, bufsiz); + std::smatch m; + if (regex_match(location, m, re)) { + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto next_path = m[3].str(); + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } - if (!reader.getline()) { + if (next_scheme == scheme && next_host == host_) { + return detail::redirect(*this, req, res, next_path); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host.c_str()); + cli.follow_location(true); + return detail::redirect(cli, req, res, next_path); +#else return false; +#endif + } else { + Client cli(next_host.c_str()); + cli.follow_location(true); + return detail::redirect(cli, req, res, next_path); + } } + } + return false; +} - const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .+\r\n"); +inline void Client::write_request(Stream &strm, const Request &req, + bool last_connection) { + BufferStream bstrm; - std::cmatch m; - if (std::regex_match(reader.ptr(), m, re)) { - res.version = std::string(m[1]); - res.status = std::stoi(std::string(m[2])); + // Request line + auto path = detail::encode_url(req.path); + + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + + // Additonal headers + Headers headers; + if (last_connection) { headers.emplace("Connection", "close"); } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } else { + if (port_ == 80) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } } + } - return true; -} + if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } -inline bool Client::send(Request& req, Response& res) -{ - if (req.path.empty()) { - return false; + if (!req.has_header("User-Agent")) { + headers.emplace("User-Agent", "cpp-httplib/0.2"); + } + + if (req.body.empty()) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { + headers.emplace("Content-Length", "0"); + } + } else { + if (!req.has_header("Content-Type")) { + headers.emplace("Content-Type", "text/plain"); } - auto sock = create_client_socket(); - if (sock == INVALID_SOCKET) { - return false; + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + headers.emplace("Content-Length", length); } + } - return read_and_close_socket(sock, req, res); + detail::write_headers(bstrm, req, headers); + + // Body + if (!req.body.empty()) { bstrm.write(req.body); } + + // Flush buffer + auto &data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); } -inline void Client::write_request(Stream& strm, Request& req) -{ - auto path = detail::encode_url(req.path); +inline bool Client::process_request(Stream &strm, const Request &req, + Response &res, bool last_connection, + bool &connection_close) { + // Send request + write_request(strm, req, last_connection); + + // Receive response and headers + if (!read_response_line(strm, res) || + !detail::read_headers(strm, res.headers)) { + return false; + } - // Request line - strm.write_format("%s %s HTTP/1.1\r\n", - req.method.c_str(), - path.c_str()); + if (res.get_header_value("Connection") == "close" || + res.version == "HTTP/1.0") { + connection_close = true; + } - // Headers - req.set_header("Host", host_and_port_.c_str()); + if (req.response_handler) { + if (!req.response_handler(res)) { return false; } + } - if (!req.has_header("Accept")) { - req.set_header("Accept", "*/*"); - } + // Body + if (req.method != "HEAD") { + detail::ContentReceiverCore out = [&](const char *buf, size_t n) { + if (res.body.size() + n > res.body.max_size()) { return false; } + res.body.append(buf, n); + return true; + }; - if (!req.has_header("User-Agent")) { - req.set_header("User-Agent", "cpp-httplib/0.2"); + if (req.content_receiver) { + auto offset = std::make_shared<size_t>(); + auto length = get_header_value_uint64(res.headers, "Content-Length", 0); + auto receiver = req.content_receiver; + out = [offset, length, receiver](const char *buf, size_t n) { + auto ret = receiver(buf, n, *offset, length); + (*offset) += n; + return ret; + }; } - // TODO: Support KeepAlive connection - // if (!req.has_header("Connection")) { - req.set_header("Connection", "close"); - // } + int dummy_status; + if (!detail::read_content(strm, res, std::numeric_limits<size_t>::max(), + dummy_status, req.progress, out)) { + return false; + } + } - if (!req.body.empty()) { - if (!req.has_header("Content-Type")) { - req.set_header("Content-Type", "text/plain"); - } + return true; +} - auto length = std::to_string(req.body.size()); - req.set_header("Content-Length", length.c_str()); - } +inline bool Client::process_and_close_socket( + socket_t sock, size_t request_count, + std::function<bool(Stream &strm, bool last_connection, + bool &connection_close)> + callback) { + request_count = std::min(request_count, keep_alive_max_count_); + return detail::process_and_close_socket(true, sock, request_count, callback); +} - detail::write_headers(strm, req); +inline bool Client::is_ssl() const { return false; } - // Body - if (!req.body.empty()) { - if (req.get_header_value("Content-Type") == "application/x-www-form-urlencoded") { - auto str = detail::encode_url(req.body); - strm.write(str.c_str(), str.size()); - } else { - strm.write(req.body.c_str(), req.body.size()); - } - } +inline std::shared_ptr<Response> Client::Get(const char *path) { + Progress dummy; + return Get(path, Headers(), dummy); } -inline bool Client::process_request(Stream& strm, Request& req, Response& res, bool& connection_close) -{ - // Send request - write_request(strm, req); +inline std::shared_ptr<Response> Client::Get(const char *path, + Progress progress) { + return Get(path, Headers(), progress); +} - // Receive response and headers - if (!read_response_line(strm, res) || !detail::read_headers(strm, res.headers)) { - return false; - } +inline std::shared_ptr<Response> Client::Get(const char *path, + const Headers &headers) { + Progress dummy; + return Get(path, headers, dummy); +} - if (res.get_header_value("Connection") == "close" || res.version == "HTTP/1.0") { - connection_close = true; - } +inline std::shared_ptr<Response> +Client::Get(const char *path, const Headers &headers, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = progress; - // Body - if (req.method != "HEAD") { - if (!detail::read_content(strm, res, req.progress)) { - return false; - } + auto res = std::make_shared<Response>(); + return send(req, *res) ? res : nullptr; +} - if (res.get_header_value("Content-Encoding") == "gzip") { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - detail::decompress(res.body); -#else - return false; -#endif - } - } +inline std::shared_ptr<Response> Client::Get(const char *path, + ContentReceiver content_receiver) { + Progress dummy; + return Get(path, Headers(), nullptr, content_receiver, dummy); +} - return true; +inline std::shared_ptr<Response> Client::Get(const char *path, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), nullptr, content_receiver, progress); } -inline bool Client::read_and_close_socket(socket_t sock, Request& req, Response& res) -{ - return detail::read_and_close_socket( - sock, - 0, - [&](Stream& strm, bool /*last_connection*/, bool& connection_close) { - return process_request(strm, req, res, connection_close); - }); +inline std::shared_ptr<Response> Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver) { + Progress dummy; + return Get(path, headers, nullptr, content_receiver, dummy); } -inline std::shared_ptr<Response> Client::Get(const char* path, Progress progress) -{ - return Get(path, Headers(), progress); +inline std::shared_ptr<Response> Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, content_receiver, progress); } -inline std::shared_ptr<Response> Client::Get(const char* path, const Headers& headers, Progress progress) -{ - Request req; - req.method = "GET"; - req.path = path; - req.headers = headers; - req.progress = progress; +inline std::shared_ptr<Response> Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + Progress dummy; + return Get(path, headers, response_handler, content_receiver, dummy); +} - auto res = std::make_shared<Response>(); +inline std::shared_ptr<Response> Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = response_handler; + req.content_receiver = content_receiver; + req.progress = progress; - return send(req, *res) ? res : nullptr; + auto res = std::make_shared<Response>(); + return send(req, *res) ? res : nullptr; } -inline std::shared_ptr<Response> Client::Head(const char* path) -{ - return Head(path, Headers()); +inline std::shared_ptr<Response> Client::Head(const char *path) { + return Head(path, Headers()); } -inline std::shared_ptr<Response> Client::Head(const char* path, const Headers& headers) -{ - Request req; - req.method = "HEAD"; - req.headers = headers; - req.path = path; +inline std::shared_ptr<Response> Client::Head(const char *path, + const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; - auto res = std::make_shared<Response>(); + auto res = std::make_shared<Response>(); - return send(req, *res) ? res : nullptr; + return send(req, *res) ? res : nullptr; } -inline std::shared_ptr<Response> Client::Post( - const char* path, const std::string& body, const char* content_type) -{ - return Post(path, Headers(), body, content_type); +inline std::shared_ptr<Response> Client::Post(const char *path, + const std::string &body, + const char *content_type) { + return Post(path, Headers(), body, content_type); } -inline std::shared_ptr<Response> Client::Post( - const char* path, const Headers& headers, const std::string& body, const char* content_type) -{ - Request req; - req.method = "POST"; - req.headers = headers; - req.path = path; +inline std::shared_ptr<Response> Client::Post(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + Request req; + req.method = "POST"; + req.headers = headers; + req.path = path; - req.headers.emplace("Content-Type", content_type); - req.body = body; + req.headers.emplace("Content-Type", content_type); + req.body = body; - auto res = std::make_shared<Response>(); + auto res = std::make_shared<Response>(); - return send(req, *res) ? res : nullptr; + return send(req, *res) ? res : nullptr; } -inline std::shared_ptr<Response> Client::Post(const char* path, const Params& params) -{ - return Post(path, Headers(), params); +inline std::shared_ptr<Response> Client::Post(const char *path, + const Params ¶ms) { + return Post(path, Headers(), params); } -inline std::shared_ptr<Response> Client::Post(const char* path, const Headers& headers, const Params& params) -{ - std::string query; - for (auto it = params.begin(); it != params.end(); ++it) { - if (it != params.begin()) { - query += "&"; - } - query += it->first; - query += "="; - query += it->second; +inline std::shared_ptr<Response> +Client::Post(const char *path, const Headers &headers, const Params ¶ms) { + std::string query; + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += detail::encode_url(it->second); + } + + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline std::shared_ptr<Response> +Client::Post(const char *path, const MultipartFormDataItems &items) { + return Post(path, Headers(), items); +} + +inline std::shared_ptr<Response> +Client::Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items) { + Request req; + req.method = "POST"; + req.headers = headers; + req.path = path; + + auto boundary = detail::make_multipart_data_boundary(); + + req.headers.emplace("Content-Type", + "multipart/form-data; boundary=" + boundary); + + for (const auto &item : items) { + req.body += "--" + boundary + "\r\n"; + req.body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + req.body += "; filename=\"" + item.filename + "\""; + } + req.body += "\r\n"; + if (!item.content_type.empty()) { + req.body += "Content-Type: " + item.content_type + "\r\n"; } + req.body += "\r\n"; + req.body += item.content + "\r\n"; + } + + req.body += "--" + boundary + "--\r\n"; + + auto res = std::make_shared<Response>(); + + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr<Response> Client::Put(const char *path, + const std::string &body, + const char *content_type) { + return Put(path, Headers(), body, content_type); +} + +inline std::shared_ptr<Response> Client::Put(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + Request req; + req.method = "PUT"; + req.headers = headers; + req.path = path; + + req.headers.emplace("Content-Type", content_type); + req.body = body; + + auto res = std::make_shared<Response>(); - return Post(path, headers, query, "application/x-www-form-urlencoded"); + return send(req, *res) ? res : nullptr; } -inline std::shared_ptr<Response> Client::Put( - const char* path, const std::string& body, const char* content_type) -{ - return Put(path, Headers(), body, content_type); +inline std::shared_ptr<Response> Client::Patch(const char *path, + const std::string &body, + const char *content_type) { + return Patch(path, Headers(), body, content_type); } -inline std::shared_ptr<Response> Client::Put( - const char* path, const Headers& headers, const std::string& body, const char* content_type) -{ - Request req; - req.method = "PUT"; - req.headers = headers; - req.path = path; +inline std::shared_ptr<Response> Client::Patch(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + Request req; + req.method = "PATCH"; + req.headers = headers; + req.path = path; - req.headers.emplace("Content-Type", content_type); - req.body = body; + req.headers.emplace("Content-Type", content_type); + req.body = body; - auto res = std::make_shared<Response>(); + auto res = std::make_shared<Response>(); - return send(req, *res) ? res : nullptr; + return send(req, *res) ? res : nullptr; } -inline std::shared_ptr<Response> Client::Delete(const char* path) -{ - return Delete(path, Headers()); +inline std::shared_ptr<Response> Client::Delete(const char *path) { + return Delete(path, Headers(), std::string(), nullptr); } -inline std::shared_ptr<Response> Client::Delete(const char* path, const Headers& headers) -{ - Request req; - req.method = "DELETE"; - req.path = path; - req.headers = headers; +inline std::shared_ptr<Response> Client::Delete(const char *path, + const std::string &body, + const char *content_type) { + return Delete(path, Headers(), body, content_type); +} + +inline std::shared_ptr<Response> Client::Delete(const char *path, + const Headers &headers) { + return Delete(path, headers, std::string(), nullptr); +} - auto res = std::make_shared<Response>(); +inline std::shared_ptr<Response> Client::Delete(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; - return send(req, *res) ? res : nullptr; + if (content_type) { req.headers.emplace("Content-Type", content_type); } + req.body = body; + + auto res = std::make_shared<Response>(); + + return send(req, *res) ? res : nullptr; } -inline std::shared_ptr<Response> Client::Options(const char* path) -{ - return Options(path, Headers()); +inline std::shared_ptr<Response> Client::Options(const char *path) { + return Options(path, Headers()); } -inline std::shared_ptr<Response> Client::Options(const char* path, const Headers& headers) -{ - Request req; - req.method = "OPTIONS"; - req.path = path; - req.headers = headers; +inline std::shared_ptr<Response> Client::Options(const char *path, + const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.path = path; + req.headers = headers; + + auto res = std::make_shared<Response>(); - auto res = std::make_shared<Response>(); + return send(req, *res) ? res : nullptr; +} - return send(req, *res) ? res : nullptr; +inline void Client::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; } +inline void Client::follow_location(bool on) { follow_location_ = on; } + /* * SSL Implementation */ @@ -2152,74 +3284,114 @@ inline std::shared_ptr<Response> Client::Options(const char* path, const Headers namespace detail { template <typename U, typename V, typename T> -inline bool read_and_close_socket_ssl( - socket_t sock, size_t keep_alive_max_count, - // TODO: OpenSSL 1.0.2 occasionally crashes... - // The upcoming 1.1.0 is going to be thread safe. - SSL_CTX* ctx, std::mutex& ctx_mutex, - U SSL_connect_or_accept, V setup, - T callback) -{ - SSL* ssl = nullptr; - { - std::lock_guard<std::mutex> guard(ctx_mutex); - - ssl = SSL_new(ctx); - if (!ssl) { - return false; - } - } +inline bool process_and_close_socket_ssl(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, + SSL_CTX *ctx, std::mutex &ctx_mutex, + U SSL_connect_or_accept, V setup, + T callback) { + assert(keep_alive_max_count > 0); + + SSL *ssl = nullptr; + { + std::lock_guard<std::mutex> guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (!ssl) { + close_socket(sock); + return false; + } - auto bio = BIO_new_socket(sock, BIO_NOCLOSE); - SSL_set_bio(ssl, bio, bio); + auto bio = BIO_new_socket(static_cast<int>(sock), BIO_NOCLOSE); + SSL_set_bio(ssl, bio, bio); - setup(ssl); + if (!setup(ssl)) { + SSL_shutdown(ssl); + { + std::lock_guard<std::mutex> guard(ctx_mutex); + SSL_free(ssl); + } - SSL_connect_or_accept(ssl); + close_socket(sock); + return false; + } - bool ret = false; + bool ret = false; - if (keep_alive_max_count > 0) { - auto count = keep_alive_max_count; - while (count > 0 && - detail::select_read(sock, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { - SSLSocketStream strm(sock, ssl); - auto last_connection = count == 1; - auto connection_close = false; + if (SSL_connect_or_accept(ssl) == 1) { + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SSLSocketStream strm(sock, ssl); + auto last_connection = count == 1; + auto connection_close = false; - ret = callback(strm, last_connection, connection_close); - if (!ret || connection_close) { - break; - } + ret = callback(ssl, strm, last_connection, connection_close); + if (!ret || connection_close) { break; } - count--; - } + count--; + } } else { - SSLSocketStream strm(sock, ssl); - auto dummy_connection_close = false; - ret = callback(strm, true, dummy_connection_close); + SSLSocketStream strm(sock, ssl); + auto dummy_connection_close = false; + ret = callback(ssl, strm, true, dummy_connection_close); } + } - SSL_shutdown(ssl); + SSL_shutdown(ssl); + { + std::lock_guard<std::mutex> guard(ctx_mutex); + SSL_free(ssl); + } - { - std::lock_guard<std::mutex> guard(ctx_mutex); - SSL_free(ssl); - } - - close_socket(sock); + close_socket(sock); - return ret; + return ret; } -class SSLInit { +#if OPENSSL_VERSION_NUMBER < 0x10100000L +static std::shared_ptr<std::vector<std::mutex>> openSSL_locks_; + +class SSLThreadLocks { public: - SSLInit() { - SSL_load_error_strings(); - SSL_library_init(); + SSLThreadLocks() { + openSSL_locks_ = + std::make_shared<std::vector<std::mutex>>(CRYPTO_num_locks()); + CRYPTO_set_locking_callback(locking_callback); + } + + ~SSLThreadLocks() { CRYPTO_set_locking_callback(nullptr); } + +private: + static void locking_callback(int mode, int type, const char * /*file*/, + int /*line*/) { + auto &locks = *openSSL_locks_; + if (mode & CRYPTO_LOCK) { + locks[type].lock(); + } else { + locks[type].unlock(); } + } +}; + +#endif + +class SSLInit { +public: + SSLInit() { + SSL_load_error_strings(); + SSL_library_init(); + } + + ~SSLInit() { ERR_free_strings(); } + +private: +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSLThreadLocks thread_init_; +#endif }; static SSLInit sslinit_; @@ -2227,118 +3399,319 @@ static SSLInit sslinit_; } // namespace detail // SSL socket stream implementation -inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL* ssl) - : sock_(sock), ssl_(ssl) -{ -} +inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl) + : sock_(sock), ssl_(ssl) {} -inline SSLSocketStream::~SSLSocketStream() -{ +inline SSLSocketStream::~SSLSocketStream() {} + +inline int SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0 || + detail::select_read(sock_, CPPHTTPLIB_READ_TIMEOUT_SECOND, + CPPHTTPLIB_READ_TIMEOUT_USECOND) > 0) { + return SSL_read(ssl_, ptr, static_cast<int>(size)); + } + return -1; } -inline int SSLSocketStream::read(char* ptr, size_t size) -{ - return SSL_read(ssl_, ptr, size); +inline int SSLSocketStream::write(const char *ptr, size_t size) { + return SSL_write(ssl_, ptr, static_cast<int>(size)); } -inline int SSLSocketStream::write(const char* ptr, size_t size) -{ - return SSL_write(ssl_, ptr, size); +inline int SSLSocketStream::write(const char *ptr) { + return write(ptr, strlen(ptr)); } -inline int SSLSocketStream::write(const char* ptr) -{ - return write(ptr, strlen(ptr)); +inline int SSLSocketStream::write(const std::string &s) { + return write(s.data(), s.size()); } -inline std::string SSLSocketStream::get_remote_addr() { - return detail::get_remote_addr(sock_); +inline std::string SSLSocketStream::get_remote_addr() const { + return detail::get_remote_addr(sock_); } // SSL HTTP server implementation -inline SSLServer::SSLServer(const char* cert_path, const char* private_key_path) -{ - ctx_ = SSL_CTX_new(SSLv23_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path) { + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); - // EC_KEY_free(ecdh); + // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + // EC_KEY_free(ecdh); - if (SSL_CTX_use_certificate_file(ctx_, cert_path, SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + // if (client_ca_cert_file_path) { + // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); + // SSL_CTX_set_client_CA_list(ctx_, list); + // } + + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); } + } +} + +inline SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } +} + +inline bool SSLServer::is_valid() const { return ctx_; } + +inline bool SSLServer::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket_ssl( + false, sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept, + [](SSL * /*ssl*/) { return true; }, + [this](SSL *ssl, Stream &strm, bool last_connection, + bool &connection_close) { + return process_request(strm, last_connection, connection_close, + [&](Request &req) { req.ssl = ssl; }); + }); } -inline SSLServer::~SSLServer() -{ - if (ctx_) { - SSL_CTX_free(ctx_); +// SSL HTTP client implementation +inline SSLClient::SSLClient(const char *host, int port, time_t timeout_sec, + const char *client_cert_path, + const char *client_key_path) + : Client(host, port, timeout_sec) { + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (client_cert_path && client_key_path) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path, + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; } + } } -inline bool SSLServer::is_valid() const -{ - return ctx_; +inline SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } } -inline bool SSLServer::read_and_close_socket(socket_t sock) -{ - return detail::read_and_close_socket_ssl( - sock, - keep_alive_max_count_, - ctx_, ctx_mutex_, - SSL_accept, - [](SSL* /*ssl*/) {}, - [this](Stream& strm, bool last_connection, bool& connection_close) { - return process_request(strm, last_connection, connection_close); - }); +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path) { + if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; } + if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; } } -// SSL HTTP client implementation -inline SSLClient::SSLClient(const char* host, int port, size_t timeout_sec) - : Client(host, port, timeout_sec) -{ - ctx_ = SSL_CTX_new(SSLv23_client_method()); +inline void SSLClient::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; } -inline SSLClient::~SSLClient() -{ - if (ctx_) { - SSL_CTX_free(ctx_); - } +inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; } -inline bool SSLClient::is_valid() const -{ - return ctx_; +inline SSL_CTX* SSLClient::ssl_context() const noexcept { + return ctx_; } -inline bool SSLClient::read_and_close_socket(socket_t sock, Request& req, Response& res) -{ - return is_valid() && detail::read_and_close_socket_ssl( - sock, 0, - ctx_, ctx_mutex_, - SSL_connect, - [&](SSL* ssl) { - SSL_set_tlsext_host_name(ssl, host_.c_str()); - }, - [&](Stream& strm, bool /*last_connection*/, bool& connection_close) { - return process_request(strm, req, res, connection_close); - }); +inline bool SSLClient::process_and_close_socket( + socket_t sock, size_t request_count, + std::function<bool(Stream &strm, bool last_connection, + bool &connection_close)> + callback) { + + request_count = std::min(request_count, keep_alive_max_count_); + + return is_valid() && + detail::process_and_close_socket_ssl( + true, sock, request_count, ctx_, ctx_mutex_, + [&](SSL *ssl) { + if (ca_cert_file_path_.empty()) { + SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); + } else { + if (!SSL_CTX_load_verify_locations( + ctx_, ca_cert_file_path_.c_str(), nullptr)) { + return false; + } + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); + } + + if (SSL_connect(ssl) != 1) { return false; } + + if (server_certificate_verification_) { + verify_result_ = SSL_get_verify_result(ssl); + + if (verify_result_ != X509_V_OK) { return false; } + + auto server_cert = SSL_get_peer_certificate(ssl); + + if (server_cert == nullptr) { return false; } + + if (!verify_host(server_cert)) { + X509_free(server_cert); + return false; + } + X509_free(server_cert); + } + + return true; + }, + [&](SSL *ssl) { + SSL_set_tlsext_host_name(ssl, host_.c_str()); + return true; + }, + [&](SSL * /*ssl*/, Stream &strm, bool last_connection, + bool &connection_close) { + return callback(strm, last_connection, connection_close); + }); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); } + +inline bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } #endif -} // namespace httplib + auto alt_names = static_cast<const struct stack_st_GENERAL_NAME *>( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_mached = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (auto i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + if (strlen(name) == name_len) { + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_mached = true; + } + break; + } + } + } + } + + if (dsn_matched || ip_mached) { ret = true; } + } + + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); + + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { return check_host_name(name, name_len); } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector<std::string> pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(std::string(b, e)); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; +} #endif -// vim: et ts=4 sw=4 cin cino={1s ff=unix +} // namespace httplib + +#endif // CPPHTTPLIB_HTTPLIB_H diff --git a/src/audio_core/audio_renderer.cpp b/src/audio_core/audio_renderer.cpp index e6f38d600..c187d8ac5 100644 --- a/src/audio_core/audio_renderer.cpp +++ b/src/audio_core/audio_renderer.cpp @@ -36,9 +36,9 @@ public: } void SetWaveIndex(std::size_t index); - std::vector<s16> DequeueSamples(std::size_t sample_count); + std::vector<s16> DequeueSamples(std::size_t sample_count, Memory::Memory& memory); void UpdateState(); - void RefreshBuffer(); + void RefreshBuffer(Memory::Memory& memory); private: bool is_in_use{}; @@ -66,17 +66,18 @@ public: return info; } - void UpdateState(); + void UpdateState(Memory::Memory& memory); private: EffectOutStatus out_status{}; EffectInStatus info{}; }; -AudioRenderer::AudioRenderer(Core::Timing::CoreTiming& core_timing, AudioRendererParameter params, - Kernel::SharedPtr<Kernel::WritableEvent> buffer_event, +AudioRenderer::AudioRenderer(Core::Timing::CoreTiming& core_timing, Memory::Memory& memory_, + AudioRendererParameter params, + std::shared_ptr<Kernel::WritableEvent> buffer_event, std::size_t instance_number) : worker_params{params}, buffer_event{buffer_event}, voices(params.voice_count), - effects(params.effect_count) { + effects(params.effect_count), memory{memory_} { audio_out = std::make_unique<AudioCore::AudioOut>(); stream = audio_out->OpenStream(core_timing, STREAM_SAMPLE_RATE, STREAM_NUM_CHANNELS, @@ -162,7 +163,7 @@ std::vector<u8> AudioRenderer::UpdateAudioRenderer(const std::vector<u8>& input_ } for (auto& effect : effects) { - effect.UpdateState(); + effect.UpdateState(memory); } // Release previous buffers and queue next ones for playback @@ -206,13 +207,14 @@ void AudioRenderer::VoiceState::SetWaveIndex(std::size_t index) { is_refresh_pending = true; } -std::vector<s16> AudioRenderer::VoiceState::DequeueSamples(std::size_t sample_count) { +std::vector<s16> AudioRenderer::VoiceState::DequeueSamples(std::size_t sample_count, + Memory::Memory& memory) { if (!IsPlaying()) { return {}; } if (is_refresh_pending) { - RefreshBuffer(); + RefreshBuffer(memory); } const std::size_t max_size{samples.size() - offset}; @@ -256,10 +258,11 @@ void AudioRenderer::VoiceState::UpdateState() { is_in_use = info.is_in_use; } -void AudioRenderer::VoiceState::RefreshBuffer() { - std::vector<s16> new_samples(info.wave_buffer[wave_index].buffer_sz / sizeof(s16)); - Memory::ReadBlock(info.wave_buffer[wave_index].buffer_addr, new_samples.data(), - info.wave_buffer[wave_index].buffer_sz); +void AudioRenderer::VoiceState::RefreshBuffer(Memory::Memory& memory) { + const auto wave_buffer_address = info.wave_buffer[wave_index].buffer_addr; + const auto wave_buffer_size = info.wave_buffer[wave_index].buffer_sz; + std::vector<s16> new_samples(wave_buffer_size / sizeof(s16)); + memory.ReadBlock(wave_buffer_address, new_samples.data(), wave_buffer_size); switch (static_cast<Codec::PcmFormat>(info.sample_format)) { case Codec::PcmFormat::Int16: { @@ -269,7 +272,7 @@ void AudioRenderer::VoiceState::RefreshBuffer() { case Codec::PcmFormat::Adpcm: { // Decode ADPCM to PCM16 Codec::ADPCM_Coeff coeffs; - Memory::ReadBlock(info.additional_params_addr, coeffs.data(), sizeof(Codec::ADPCM_Coeff)); + memory.ReadBlock(info.additional_params_addr, coeffs.data(), sizeof(Codec::ADPCM_Coeff)); new_samples = Codec::DecodeADPCM(reinterpret_cast<u8*>(new_samples.data()), new_samples.size() * sizeof(s16), coeffs, adpcm_state); break; @@ -307,18 +310,18 @@ void AudioRenderer::VoiceState::RefreshBuffer() { is_refresh_pending = false; } -void AudioRenderer::EffectState::UpdateState() { +void AudioRenderer::EffectState::UpdateState(Memory::Memory& memory) { if (info.is_new) { out_status.state = EffectStatus::New; } else { if (info.type == Effect::Aux) { - ASSERT_MSG(Memory::Read32(info.aux_info.return_buffer_info) == 0, + ASSERT_MSG(memory.Read32(info.aux_info.return_buffer_info) == 0, "Aux buffers tried to update"); - ASSERT_MSG(Memory::Read32(info.aux_info.send_buffer_info) == 0, + ASSERT_MSG(memory.Read32(info.aux_info.send_buffer_info) == 0, "Aux buffers tried to update"); - ASSERT_MSG(Memory::Read32(info.aux_info.return_buffer_base) == 0, + ASSERT_MSG(memory.Read32(info.aux_info.return_buffer_base) == 0, "Aux buffers tried to update"); - ASSERT_MSG(Memory::Read32(info.aux_info.send_buffer_base) == 0, + ASSERT_MSG(memory.Read32(info.aux_info.send_buffer_base) == 0, "Aux buffers tried to update"); } } @@ -340,7 +343,7 @@ void AudioRenderer::QueueMixedBuffer(Buffer::Tag tag) { std::size_t offset{}; s64 samples_remaining{BUFFER_SIZE}; while (samples_remaining > 0) { - const std::vector<s16> samples{voice.DequeueSamples(samples_remaining)}; + const std::vector<s16> samples{voice.DequeueSamples(samples_remaining, memory)}; if (samples.empty()) { break; diff --git a/src/audio_core/audio_renderer.h b/src/audio_core/audio_renderer.h index 4f14b91cd..be1b019f1 100644 --- a/src/audio_core/audio_renderer.h +++ b/src/audio_core/audio_renderer.h @@ -22,6 +22,10 @@ namespace Kernel { class WritableEvent; } +namespace Memory { +class Memory; +} + namespace AudioCore { class AudioOut; @@ -217,9 +221,9 @@ static_assert(sizeof(UpdateDataHeader) == 0x40, "UpdateDataHeader has wrong size class AudioRenderer { public: - AudioRenderer(Core::Timing::CoreTiming& core_timing, AudioRendererParameter params, - Kernel::SharedPtr<Kernel::WritableEvent> buffer_event, - std::size_t instance_number); + AudioRenderer(Core::Timing::CoreTiming& core_timing, Memory::Memory& memory_, + AudioRendererParameter params, + std::shared_ptr<Kernel::WritableEvent> buffer_event, std::size_t instance_number); ~AudioRenderer(); std::vector<u8> UpdateAudioRenderer(const std::vector<u8>& input_params); @@ -235,11 +239,12 @@ private: class VoiceState; AudioRendererParameter worker_params; - Kernel::SharedPtr<Kernel::WritableEvent> buffer_event; + std::shared_ptr<Kernel::WritableEvent> buffer_event; std::vector<VoiceState> voices; std::vector<EffectState> effects; std::unique_ptr<AudioOut> audio_out; - AudioCore::StreamPtr stream; + StreamPtr stream; + Memory::Memory& memory; }; } // namespace AudioCore diff --git a/src/audio_core/stream.cpp b/src/audio_core/stream.cpp index 6a5f53a57..4ca98f8ea 100644 --- a/src/audio_core/stream.cpp +++ b/src/audio_core/stream.cpp @@ -37,7 +37,7 @@ Stream::Stream(Core::Timing::CoreTiming& core_timing, u32 sample_rate, Format fo : sample_rate{sample_rate}, format{format}, release_callback{std::move(release_callback)}, sink_stream{sink_stream}, core_timing{core_timing}, name{std::move(name_)} { - release_event = core_timing.RegisterEvent( + release_event = Core::Timing::CreateEvent( name, [this](u64 userdata, s64 cycles_late) { ReleaseActiveBuffer(); }); } diff --git a/src/audio_core/stream.h b/src/audio_core/stream.h index 8106cea43..1708a4d98 100644 --- a/src/audio_core/stream.h +++ b/src/audio_core/stream.h @@ -98,18 +98,19 @@ private: /// Gets the number of core cycles when the specified buffer will be released s64 GetBufferReleaseCycles(const Buffer& buffer) const; - u32 sample_rate; ///< Sample rate of the stream - Format format; ///< Format of the stream - float game_volume = 1.0f; ///< The volume the game currently has set - ReleaseCallback release_callback; ///< Buffer release callback for the stream - State state{State::Stopped}; ///< Playback state of the stream - Core::Timing::EventType* release_event{}; ///< Core timing release event for the stream - BufferPtr active_buffer; ///< Actively playing buffer in the stream - std::queue<BufferPtr> queued_buffers; ///< Buffers queued to be played in the stream - std::queue<BufferPtr> released_buffers; ///< Buffers recently released from the stream - SinkStream& sink_stream; ///< Output sink for the stream - Core::Timing::CoreTiming& core_timing; ///< Core timing instance. - std::string name; ///< Name of the stream, must be unique + u32 sample_rate; ///< Sample rate of the stream + Format format; ///< Format of the stream + float game_volume = 1.0f; ///< The volume the game currently has set + ReleaseCallback release_callback; ///< Buffer release callback for the stream + State state{State::Stopped}; ///< Playback state of the stream + std::shared_ptr<Core::Timing::EventType> + release_event; ///< Core timing release event for the stream + BufferPtr active_buffer; ///< Actively playing buffer in the stream + std::queue<BufferPtr> queued_buffers; ///< Buffers queued to be played in the stream + std::queue<BufferPtr> released_buffers; ///< Buffers recently released from the stream + SinkStream& sink_stream; ///< Output sink for the stream + Core::Timing::CoreTiming& core_timing; ///< Core timing instance. + std::string name; ///< Name of the stream, must be unique }; using StreamPtr = std::shared_ptr<Stream>; diff --git a/src/common/assert.h b/src/common/assert.h index 4b0e3f64e..5b67c5c52 100644 --- a/src/common/assert.h +++ b/src/common/assert.h @@ -41,8 +41,9 @@ __declspec(noinline, noreturn) } \ while (0) -#define UNREACHABLE() ASSERT_MSG(false, "Unreachable code!") -#define UNREACHABLE_MSG(...) ASSERT_MSG(false, __VA_ARGS__) +#define UNREACHABLE() assert_noinline_call([] { LOG_CRITICAL(Debug, "Unreachable code!"); }) +#define UNREACHABLE_MSG(...) \ + assert_noinline_call([&] { LOG_CRITICAL(Debug, "Unreachable code!\n" __VA_ARGS__); }) #ifdef _DEBUG #define DEBUG_ASSERT(_a_) ASSERT(_a_) diff --git a/src/common/common_funcs.h b/src/common/common_funcs.h index 6dc3e108f..052254678 100644 --- a/src/common/common_funcs.h +++ b/src/common/common_funcs.h @@ -19,13 +19,15 @@ /// Helper macros to insert unused bytes or words to properly align structs. These values will be /// zero-initialized. -#define INSERT_PADDING_BYTES(num_bytes) std::array<u8, num_bytes> CONCAT2(pad, __LINE__){}; -#define INSERT_PADDING_WORDS(num_words) std::array<u32, num_words> CONCAT2(pad, __LINE__){}; +#define INSERT_PADDING_BYTES(num_bytes) \ + std::array<u8, num_bytes> CONCAT2(pad, __LINE__) {} +#define INSERT_PADDING_WORDS(num_words) \ + std::array<u32, num_words> CONCAT2(pad, __LINE__) {} /// These are similar to the INSERT_PADDING_* macros, but are needed for padding unions. This is /// because unions can only be initialized by one member. -#define INSERT_UNION_PADDING_BYTES(num_bytes) std::array<u8, num_bytes> CONCAT2(pad, __LINE__); -#define INSERT_UNION_PADDING_WORDS(num_words) std::array<u32, num_words> CONCAT2(pad, __LINE__); +#define INSERT_UNION_PADDING_BYTES(num_bytes) std::array<u8, num_bytes> CONCAT2(pad, __LINE__) +#define INSERT_UNION_PADDING_WORDS(num_words) std::array<u32, num_words> CONCAT2(pad, __LINE__) #ifndef _MSC_VER diff --git a/src/common/logging/backend.cpp b/src/common/logging/backend.cpp index 1111cfbad..8f2591d53 100644 --- a/src/common/logging/backend.cpp +++ b/src/common/logging/backend.cpp @@ -272,8 +272,10 @@ const char* GetLogClassName(Class log_class) { #undef CLS #undef SUB case Class::Count: - UNREACHABLE(); + break; } + UNREACHABLE(); + return "Invalid"; } const char* GetLevelName(Level log_level) { @@ -288,9 +290,11 @@ const char* GetLevelName(Level log_level) { LVL(Error); LVL(Critical); case Level::Count: - UNREACHABLE(); + break; } #undef LVL + UNREACHABLE(); + return "Invalid"; } void SetGlobalFilter(const Filter& filter) { diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 4f6a87b0a..7fd226050 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -170,6 +170,7 @@ add_library(core STATIC hle/kernel/server_port.h hle/kernel/server_session.cpp hle/kernel/server_session.h + hle/kernel/session.cpp hle/kernel/session.h hle/kernel/shared_memory.cpp hle/kernel/shared_memory.h @@ -509,7 +510,6 @@ add_library(core STATIC memory/dmnt_cheat_vm.h memory.cpp memory.h - memory_setup.h perf_stats.cpp perf_stats.h reporter.cpp @@ -522,6 +522,23 @@ add_library(core STATIC tools/freezer.h ) +if (MSVC) + target_compile_options(core PRIVATE + # 'expression' : signed/unsigned mismatch + /we4018 + # 'argument' : conversion from 'type1' to 'type2', possible loss of data (floating-point) + /we4244 + # 'conversion' : conversion from 'type1' to 'type2', signed/unsigned mismatch + /we4245 + # 'operator': conversion from 'type1:field_bits' to 'type2:field_bits', possible loss of data + /we4254 + # 'var' : conversion from 'size_t' to 'type', possible loss of data + /we4267 + # 'context' : truncation from 'type1' to 'type2' + /we4305 + ) +endif() + create_target_directory_groups(core) target_link_libraries(core PUBLIC common PRIVATE audio_core video_core) diff --git a/src/core/arm/arm_interface.cpp b/src/core/arm/arm_interface.cpp index 372612c9b..7e846ddd5 100644 --- a/src/core/arm/arm_interface.cpp +++ b/src/core/arm/arm_interface.cpp @@ -13,7 +13,6 @@ #include "core/memory.h" namespace Core { - namespace { constexpr u64 ELF_DYNAMIC_TAG_NULL = 0; @@ -61,15 +60,15 @@ static_assert(sizeof(ELFSymbol) == 0x18, "ELFSymbol has incorrect size."); using Symbols = std::vector<std::pair<ELFSymbol, std::string>>; -Symbols GetSymbols(VAddr text_offset) { - const auto mod_offset = text_offset + Memory::Read32(text_offset + 4); +Symbols GetSymbols(VAddr text_offset, Memory::Memory& memory) { + const auto mod_offset = text_offset + memory.Read32(text_offset + 4); if (mod_offset < text_offset || (mod_offset & 0b11) != 0 || - Memory::Read32(mod_offset) != Common::MakeMagic('M', 'O', 'D', '0')) { + memory.Read32(mod_offset) != Common::MakeMagic('M', 'O', 'D', '0')) { return {}; } - const auto dynamic_offset = Memory::Read32(mod_offset + 0x4) + mod_offset; + const auto dynamic_offset = memory.Read32(mod_offset + 0x4) + mod_offset; VAddr string_table_offset{}; VAddr symbol_table_offset{}; @@ -77,8 +76,8 @@ Symbols GetSymbols(VAddr text_offset) { VAddr dynamic_index = dynamic_offset; while (true) { - const auto tag = Memory::Read64(dynamic_index); - const auto value = Memory::Read64(dynamic_index + 0x8); + const u64 tag = memory.Read64(dynamic_index); + const u64 value = memory.Read64(dynamic_index + 0x8); dynamic_index += 0x10; if (tag == ELF_DYNAMIC_TAG_NULL) { @@ -106,11 +105,11 @@ Symbols GetSymbols(VAddr text_offset) { VAddr symbol_index = symbol_table_address; while (symbol_index < string_table_address) { ELFSymbol symbol{}; - Memory::ReadBlock(symbol_index, &symbol, sizeof(ELFSymbol)); + memory.ReadBlock(symbol_index, &symbol, sizeof(ELFSymbol)); VAddr string_offset = string_table_address + symbol.name_index; std::string name; - for (u8 c = Memory::Read8(string_offset); c != 0; c = Memory::Read8(++string_offset)) { + for (u8 c = memory.Read8(string_offset); c != 0; c = memory.Read8(++string_offset)) { name += static_cast<char>(c); } @@ -142,28 +141,28 @@ constexpr u64 SEGMENT_BASE = 0x7100000000ull; std::vector<ARM_Interface::BacktraceEntry> ARM_Interface::GetBacktrace() const { std::vector<BacktraceEntry> out; + auto& memory = system.Memory(); auto fp = GetReg(29); auto lr = GetReg(30); - while (true) { out.push_back({"", 0, lr, 0}); if (!fp) { break; } - lr = Memory::Read64(fp + 8) - 4; - fp = Memory::Read64(fp); + lr = memory.Read64(fp + 8) - 4; + fp = memory.Read64(fp); } std::map<VAddr, std::string> modules; - auto& loader{System::GetInstance().GetAppLoader()}; + auto& loader{system.GetAppLoader()}; if (loader.ReadNSOModules(modules) != Loader::ResultStatus::Success) { return {}; } std::map<std::string, Symbols> symbols; for (const auto& module : modules) { - symbols.insert_or_assign(module.second, GetSymbols(module.first)); + symbols.insert_or_assign(module.second, GetSymbols(module.first, memory)); } for (auto& entry : out) { diff --git a/src/core/arm/arm_interface.h b/src/core/arm/arm_interface.h index 45e94e625..47b964eb7 100644 --- a/src/core/arm/arm_interface.h +++ b/src/core/arm/arm_interface.h @@ -17,11 +17,13 @@ enum class VMAPermission : u8; } namespace Core { +class System; /// Generic ARMv8 CPU interface class ARM_Interface : NonCopyable { public: - virtual ~ARM_Interface() {} + explicit ARM_Interface(System& system_) : system{system_} {} + virtual ~ARM_Interface() = default; struct ThreadContext { std::array<u64, 31> cpu_registers; @@ -163,6 +165,10 @@ public: /// fp+0 : pointer to previous frame record /// fp+8 : value of lr for frame void LogBacktrace() const; + +protected: + /// System context that this ARM interface is running under. + System& system; }; } // namespace Core diff --git a/src/core/arm/dynarmic/arm_dynarmic.cpp b/src/core/arm/dynarmic/arm_dynarmic.cpp index 700c4afff..f8c7f0efd 100644 --- a/src/core/arm/dynarmic/arm_dynarmic.cpp +++ b/src/core/arm/dynarmic/arm_dynarmic.cpp @@ -28,36 +28,38 @@ public: explicit ARM_Dynarmic_Callbacks(ARM_Dynarmic& parent) : parent(parent) {} u8 MemoryRead8(u64 vaddr) override { - return Memory::Read8(vaddr); + return parent.system.Memory().Read8(vaddr); } u16 MemoryRead16(u64 vaddr) override { - return Memory::Read16(vaddr); + return parent.system.Memory().Read16(vaddr); } u32 MemoryRead32(u64 vaddr) override { - return Memory::Read32(vaddr); + return parent.system.Memory().Read32(vaddr); } u64 MemoryRead64(u64 vaddr) override { - return Memory::Read64(vaddr); + return parent.system.Memory().Read64(vaddr); } Vector MemoryRead128(u64 vaddr) override { - return {Memory::Read64(vaddr), Memory::Read64(vaddr + 8)}; + auto& memory = parent.system.Memory(); + return {memory.Read64(vaddr), memory.Read64(vaddr + 8)}; } void MemoryWrite8(u64 vaddr, u8 value) override { - Memory::Write8(vaddr, value); + parent.system.Memory().Write8(vaddr, value); } void MemoryWrite16(u64 vaddr, u16 value) override { - Memory::Write16(vaddr, value); + parent.system.Memory().Write16(vaddr, value); } void MemoryWrite32(u64 vaddr, u32 value) override { - Memory::Write32(vaddr, value); + parent.system.Memory().Write32(vaddr, value); } void MemoryWrite64(u64 vaddr, u64 value) override { - Memory::Write64(vaddr, value); + parent.system.Memory().Write64(vaddr, value); } void MemoryWrite128(u64 vaddr, Vector value) override { - Memory::Write64(vaddr, value[0]); - Memory::Write64(vaddr + 8, value[1]); + auto& memory = parent.system.Memory(); + memory.Write64(vaddr, value[0]); + memory.Write64(vaddr + 8, value[1]); } void InterpreterFallback(u64 pc, std::size_t num_instructions) override { @@ -67,7 +69,7 @@ public: ARM_Interface::ThreadContext ctx; parent.SaveContext(ctx); parent.inner_unicorn.LoadContext(ctx); - parent.inner_unicorn.ExecuteInstructions(static_cast<int>(num_instructions)); + parent.inner_unicorn.ExecuteInstructions(num_instructions); parent.inner_unicorn.SaveContext(ctx); parent.LoadContext(ctx); num_interpreted_instructions += num_instructions; @@ -171,9 +173,10 @@ void ARM_Dynarmic::Step() { ARM_Dynarmic::ARM_Dynarmic(System& system, ExclusiveMonitor& exclusive_monitor, std::size_t core_index) - : cb(std::make_unique<ARM_Dynarmic_Callbacks>(*this)), inner_unicorn{system}, - core_index{core_index}, system{system}, - exclusive_monitor{dynamic_cast<DynarmicExclusiveMonitor&>(exclusive_monitor)} {} + : ARM_Interface{system}, + cb(std::make_unique<ARM_Dynarmic_Callbacks>(*this)), inner_unicorn{system}, + core_index{core_index}, exclusive_monitor{ + dynamic_cast<DynarmicExclusiveMonitor&>(exclusive_monitor)} {} ARM_Dynarmic::~ARM_Dynarmic() = default; @@ -264,7 +267,9 @@ void ARM_Dynarmic::PageTableChanged(Common::PageTable& page_table, jit = MakeJit(page_table, new_address_space_size_in_bits); } -DynarmicExclusiveMonitor::DynarmicExclusiveMonitor(std::size_t core_count) : monitor(core_count) {} +DynarmicExclusiveMonitor::DynarmicExclusiveMonitor(Memory::Memory& memory_, std::size_t core_count) + : monitor(core_count), memory{memory_} {} + DynarmicExclusiveMonitor::~DynarmicExclusiveMonitor() = default; void DynarmicExclusiveMonitor::SetExclusive(std::size_t core_index, VAddr addr) { @@ -277,29 +282,28 @@ void DynarmicExclusiveMonitor::ClearExclusive() { } bool DynarmicExclusiveMonitor::ExclusiveWrite8(std::size_t core_index, VAddr vaddr, u8 value) { - return monitor.DoExclusiveOperation(core_index, vaddr, 1, - [&] { Memory::Write8(vaddr, value); }); + return monitor.DoExclusiveOperation(core_index, vaddr, 1, [&] { memory.Write8(vaddr, value); }); } bool DynarmicExclusiveMonitor::ExclusiveWrite16(std::size_t core_index, VAddr vaddr, u16 value) { return monitor.DoExclusiveOperation(core_index, vaddr, 2, - [&] { Memory::Write16(vaddr, value); }); + [&] { memory.Write16(vaddr, value); }); } bool DynarmicExclusiveMonitor::ExclusiveWrite32(std::size_t core_index, VAddr vaddr, u32 value) { return monitor.DoExclusiveOperation(core_index, vaddr, 4, - [&] { Memory::Write32(vaddr, value); }); + [&] { memory.Write32(vaddr, value); }); } bool DynarmicExclusiveMonitor::ExclusiveWrite64(std::size_t core_index, VAddr vaddr, u64 value) { return monitor.DoExclusiveOperation(core_index, vaddr, 8, - [&] { Memory::Write64(vaddr, value); }); + [&] { memory.Write64(vaddr, value); }); } bool DynarmicExclusiveMonitor::ExclusiveWrite128(std::size_t core_index, VAddr vaddr, u128 value) { return monitor.DoExclusiveOperation(core_index, vaddr, 16, [&] { - Memory::Write64(vaddr + 0, value[0]); - Memory::Write64(vaddr + 8, value[1]); + memory.Write64(vaddr + 0, value[0]); + memory.Write64(vaddr + 8, value[1]); }); } diff --git a/src/core/arm/dynarmic/arm_dynarmic.h b/src/core/arm/dynarmic/arm_dynarmic.h index 504d46c68..9cd475cfb 100644 --- a/src/core/arm/dynarmic/arm_dynarmic.h +++ b/src/core/arm/dynarmic/arm_dynarmic.h @@ -12,6 +12,10 @@ #include "core/arm/exclusive_monitor.h" #include "core/arm/unicorn/arm_unicorn.h" +namespace Memory { +class Memory; +} + namespace Core { class ARM_Dynarmic_Callbacks; @@ -58,13 +62,12 @@ private: ARM_Unicorn inner_unicorn; std::size_t core_index; - System& system; DynarmicExclusiveMonitor& exclusive_monitor; }; class DynarmicExclusiveMonitor final : public ExclusiveMonitor { public: - explicit DynarmicExclusiveMonitor(std::size_t core_count); + explicit DynarmicExclusiveMonitor(Memory::Memory& memory_, std::size_t core_count); ~DynarmicExclusiveMonitor() override; void SetExclusive(std::size_t core_index, VAddr addr) override; @@ -79,6 +82,7 @@ public: private: friend class ARM_Dynarmic; Dynarmic::A64::ExclusiveMonitor monitor; + Memory::Memory& memory; }; } // namespace Core diff --git a/src/core/arm/unicorn/arm_unicorn.cpp b/src/core/arm/unicorn/arm_unicorn.cpp index d4f41bfc1..48182c99a 100644 --- a/src/core/arm/unicorn/arm_unicorn.cpp +++ b/src/core/arm/unicorn/arm_unicorn.cpp @@ -60,17 +60,18 @@ static bool UnmappedMemoryHook(uc_engine* uc, uc_mem_type type, u64 addr, int si return false; } -ARM_Unicorn::ARM_Unicorn(System& system) : system{system} { +ARM_Unicorn::ARM_Unicorn(System& system) : ARM_Interface{system} { CHECKED(uc_open(UC_ARCH_ARM64, UC_MODE_ARM, &uc)); auto fpv = 3 << 20; CHECKED(uc_reg_write(uc, UC_ARM64_REG_CPACR_EL1, &fpv)); uc_hook hook{}; - CHECKED(uc_hook_add(uc, &hook, UC_HOOK_INTR, (void*)InterruptHook, this, 0, -1)); - CHECKED(uc_hook_add(uc, &hook, UC_HOOK_MEM_INVALID, (void*)UnmappedMemoryHook, &system, 0, -1)); + CHECKED(uc_hook_add(uc, &hook, UC_HOOK_INTR, (void*)InterruptHook, this, 0, UINT64_MAX)); + CHECKED(uc_hook_add(uc, &hook, UC_HOOK_MEM_INVALID, (void*)UnmappedMemoryHook, &system, 0, + UINT64_MAX)); if (GDBStub::IsServerEnabled()) { - CHECKED(uc_hook_add(uc, &hook, UC_HOOK_CODE, (void*)CodeHook, this, 0, -1)); + CHECKED(uc_hook_add(uc, &hook, UC_HOOK_CODE, (void*)CodeHook, this, 0, UINT64_MAX)); last_bkpt_hit = false; } } @@ -154,9 +155,10 @@ void ARM_Unicorn::SetTPIDR_EL0(u64 value) { void ARM_Unicorn::Run() { if (GDBStub::IsServerEnabled()) { - ExecuteInstructions(std::max(4000000, 0)); + ExecuteInstructions(std::max(4000000U, 0U)); } else { - ExecuteInstructions(std::max(system.CoreTiming().GetDowncount(), s64{0})); + ExecuteInstructions( + std::max(std::size_t(system.CoreTiming().GetDowncount()), std::size_t{0})); } } @@ -166,7 +168,7 @@ void ARM_Unicorn::Step() { MICROPROFILE_DEFINE(ARM_Jit_Unicorn, "ARM JIT", "Unicorn", MP_RGB(255, 64, 64)); -void ARM_Unicorn::ExecuteInstructions(int num_instructions) { +void ARM_Unicorn::ExecuteInstructions(std::size_t num_instructions) { MICROPROFILE_SCOPE(ARM_Jit_Unicorn); CHECKED(uc_emu_start(uc, GetPC(), 1ULL << 63, 0, num_instructions)); system.CoreTiming().AddTicks(num_instructions); diff --git a/src/core/arm/unicorn/arm_unicorn.h b/src/core/arm/unicorn/arm_unicorn.h index fe2ffd70c..3c5b155f9 100644 --- a/src/core/arm/unicorn/arm_unicorn.h +++ b/src/core/arm/unicorn/arm_unicorn.h @@ -34,7 +34,7 @@ public: void LoadContext(const ThreadContext& ctx) override; void PrepareReschedule() override; void ClearExclusiveState() override; - void ExecuteInstructions(int num_instructions); + void ExecuteInstructions(std::size_t num_instructions); void Run() override; void Step() override; void ClearInstructionCache() override; @@ -45,7 +45,6 @@ private: static void InterruptHook(uc_engine* uc, u32 int_no, void* user_data); uc_engine* uc{}; - System& system; GDBStub::BreakpointAddress last_bkpt{}; bool last_bkpt_hit = false; }; diff --git a/src/core/core.cpp b/src/core/core.cpp index eba17218a..c45fb960c 100644 --- a/src/core/core.cpp +++ b/src/core/core.cpp @@ -39,6 +39,7 @@ #include "core/hle/service/service.h" #include "core/hle/service/sm/sm.h" #include "core/loader/loader.h" +#include "core/memory.h" #include "core/memory/cheat_engine.h" #include "core/perf_stats.h" #include "core/reporter.h" @@ -112,8 +113,8 @@ FileSys::VirtualFile GetGameFileFromPath(const FileSys::VirtualFilesystem& vfs, } struct System::Impl { explicit Impl(System& system) - : kernel{system}, fs_controller{system}, cpu_core_manager{system}, reporter{system}, - applet_manager{system} {} + : kernel{system}, fs_controller{system}, memory{system}, + cpu_core_manager{system}, reporter{system}, applet_manager{system} {} Cpu& CurrentCpuCore() { return cpu_core_manager.GetCurrentCore(); @@ -341,7 +342,8 @@ struct System::Impl { std::unique_ptr<VideoCore::RendererBase> renderer; std::unique_ptr<Tegra::GPU> gpu_core; std::shared_ptr<Tegra::DebugContext> debug_context; - std::unique_ptr<Core::Hardware::InterruptManager> interrupt_manager; + std::unique_ptr<Hardware::InterruptManager> interrupt_manager; + Memory::Memory memory; CpuCoreManager cpu_core_manager; bool is_powered_on = false; bool exit_lock = false; @@ -498,6 +500,14 @@ const ExclusiveMonitor& System::Monitor() const { return impl->cpu_core_manager.GetExclusiveMonitor(); } +Memory::Memory& System::Memory() { + return impl->memory; +} + +const Memory::Memory& System::Memory() const { + return impl->memory; +} + Tegra::GPU& System::GPU() { return *impl->gpu_core; } diff --git a/src/core/core.h b/src/core/core.h index 984074ce3..91184e433 100644 --- a/src/core/core.h +++ b/src/core/core.h @@ -7,6 +7,7 @@ #include <cstddef> #include <memory> #include <string> +#include <vector> #include "common/common_types.h" #include "core/file_sys/vfs_types.h" @@ -85,6 +86,10 @@ namespace Core::Hardware { class InterruptManager; } +namespace Memory { +class Memory; +} + namespace Core { class ARM_Interface; @@ -224,6 +229,12 @@ public: /// Gets a constant reference to the exclusive monitor const ExclusiveMonitor& Monitor() const; + /// Gets a mutable reference to the system memory instance. + Memory::Memory& Memory(); + + /// Gets a constant reference to the system memory instance. + const Memory::Memory& Memory() const; + /// Gets a mutable reference to the GPU interface Tegra::GPU& GPU(); diff --git a/src/core/core_cpu.cpp b/src/core/core_cpu.cpp index 233ea572c..cf3fe0b0b 100644 --- a/src/core/core_cpu.cpp +++ b/src/core/core_cpu.cpp @@ -66,9 +66,10 @@ Cpu::Cpu(System& system, ExclusiveMonitor& exclusive_monitor, CpuBarrier& cpu_ba Cpu::~Cpu() = default; -std::unique_ptr<ExclusiveMonitor> Cpu::MakeExclusiveMonitor(std::size_t num_cores) { +std::unique_ptr<ExclusiveMonitor> Cpu::MakeExclusiveMonitor( + [[maybe_unused]] Memory::Memory& memory, [[maybe_unused]] std::size_t num_cores) { #ifdef ARCHITECTURE_x86_64 - return std::make_unique<DynarmicExclusiveMonitor>(num_cores); + return std::make_unique<DynarmicExclusiveMonitor>(memory, num_cores); #else // TODO(merry): Passthrough exclusive monitor return nullptr; diff --git a/src/core/core_cpu.h b/src/core/core_cpu.h index cafca8df7..78f5021a2 100644 --- a/src/core/core_cpu.h +++ b/src/core/core_cpu.h @@ -24,6 +24,10 @@ namespace Core::Timing { class CoreTiming; } +namespace Memory { +class Memory; +} + namespace Core { class ARM_Interface; @@ -86,7 +90,19 @@ public: void Shutdown(); - static std::unique_ptr<ExclusiveMonitor> MakeExclusiveMonitor(std::size_t num_cores); + /** + * Creates an exclusive monitor to handle exclusive reads/writes. + * + * @param memory The current memory subsystem that the monitor may wish + * to keep track of. + * + * @param num_cores The number of cores to assume about the CPU. + * + * @returns The constructed exclusive monitor instance, or nullptr if the current + * CPU backend is unable to use an exclusive monitor. + */ + static std::unique_ptr<ExclusiveMonitor> MakeExclusiveMonitor(Memory::Memory& memory, + std::size_t num_cores); private: void Reschedule(); diff --git a/src/core/core_timing.cpp b/src/core/core_timing.cpp index 0e9570685..aa09fa453 100644 --- a/src/core/core_timing.cpp +++ b/src/core/core_timing.cpp @@ -17,11 +17,15 @@ namespace Core::Timing { constexpr int MAX_SLICE_LENGTH = 10000; +std::shared_ptr<EventType> CreateEvent(std::string name, TimedCallback&& callback) { + return std::make_shared<EventType>(std::move(callback), std::move(name)); +} + struct CoreTiming::Event { s64 time; u64 fifo_order; u64 userdata; - const EventType* type; + std::weak_ptr<EventType> type; // Sort by time, unless the times are the same, in which case sort by // the order added to the queue @@ -54,36 +58,15 @@ void CoreTiming::Initialize() { event_fifo_id = 0; const auto empty_timed_callback = [](u64, s64) {}; - ev_lost = RegisterEvent("_lost_event", empty_timed_callback); + ev_lost = CreateEvent("_lost_event", empty_timed_callback); } void CoreTiming::Shutdown() { ClearPendingEvents(); - UnregisterAllEvents(); -} - -EventType* CoreTiming::RegisterEvent(const std::string& name, TimedCallback callback) { - std::lock_guard guard{inner_mutex}; - // check for existing type with same name. - // we want event type names to remain unique so that we can use them for serialization. - ASSERT_MSG(event_types.find(name) == event_types.end(), - "CoreTiming Event \"{}\" is already registered. Events should only be registered " - "during Init to avoid breaking save states.", - name.c_str()); - - auto info = event_types.emplace(name, EventType{callback, nullptr}); - EventType* event_type = &info.first->second; - event_type->name = &info.first->first; - return event_type; -} - -void CoreTiming::UnregisterAllEvents() { - ASSERT_MSG(event_queue.empty(), "Cannot unregister events with events pending"); - event_types.clear(); } -void CoreTiming::ScheduleEvent(s64 cycles_into_future, const EventType* event_type, u64 userdata) { - ASSERT(event_type != nullptr); +void CoreTiming::ScheduleEvent(s64 cycles_into_future, const std::shared_ptr<EventType>& event_type, + u64 userdata) { std::lock_guard guard{inner_mutex}; const s64 timeout = GetTicks() + cycles_into_future; @@ -93,13 +76,15 @@ void CoreTiming::ScheduleEvent(s64 cycles_into_future, const EventType* event_ty } event_queue.emplace_back(Event{timeout, event_fifo_id++, userdata, event_type}); + std::push_heap(event_queue.begin(), event_queue.end(), std::greater<>()); } -void CoreTiming::UnscheduleEvent(const EventType* event_type, u64 userdata) { +void CoreTiming::UnscheduleEvent(const std::shared_ptr<EventType>& event_type, u64 userdata) { std::lock_guard guard{inner_mutex}; + const auto itr = std::remove_if(event_queue.begin(), event_queue.end(), [&](const Event& e) { - return e.type == event_type && e.userdata == userdata; + return e.type.lock().get() == event_type.get() && e.userdata == userdata; }); // Removing random items breaks the invariant so we have to re-establish it. @@ -130,10 +115,12 @@ void CoreTiming::ClearPendingEvents() { event_queue.clear(); } -void CoreTiming::RemoveEvent(const EventType* event_type) { +void CoreTiming::RemoveEvent(const std::shared_ptr<EventType>& event_type) { std::lock_guard guard{inner_mutex}; - const auto itr = std::remove_if(event_queue.begin(), event_queue.end(), - [&](const Event& e) { return e.type == event_type; }); + + const auto itr = std::remove_if(event_queue.begin(), event_queue.end(), [&](const Event& e) { + return e.type.lock().get() == event_type.get(); + }); // Removing random items breaks the invariant so we have to re-establish it. if (itr != event_queue.end()) { @@ -181,7 +168,11 @@ void CoreTiming::Advance() { std::pop_heap(event_queue.begin(), event_queue.end(), std::greater<>()); event_queue.pop_back(); inner_mutex.unlock(); - evt.type->callback(evt.userdata, global_timer - evt.time); + + if (auto event_type{evt.type.lock()}) { + event_type->callback(evt.userdata, global_timer - evt.time); + } + inner_mutex.lock(); } diff --git a/src/core/core_timing.h b/src/core/core_timing.h index 3bb88c810..d50f4eb8a 100644 --- a/src/core/core_timing.h +++ b/src/core/core_timing.h @@ -6,11 +6,12 @@ #include <chrono> #include <functional> +#include <memory> #include <mutex> #include <optional> #include <string> -#include <unordered_map> #include <vector> + #include "common/common_types.h" #include "common/threadsafe_queue.h" @@ -21,10 +22,13 @@ using TimedCallback = std::function<void(u64 userdata, s64 cycles_late)>; /// Contains the characteristics of a particular event. struct EventType { + EventType(TimedCallback&& callback, std::string&& name) + : callback{std::move(callback)}, name{std::move(name)} {} + /// The event's callback function. TimedCallback callback; /// A pointer to the name of the event. - const std::string* name; + const std::string name; }; /** @@ -57,31 +61,17 @@ public: /// Tears down all timing related functionality. void Shutdown(); - /// Registers a core timing event with the given name and callback. - /// - /// @param name The name of the core timing event to register. - /// @param callback The callback to execute for the event. - /// - /// @returns An EventType instance representing the registered event. - /// - /// @pre The name of the event being registered must be unique among all - /// registered events. - /// - EventType* RegisterEvent(const std::string& name, TimedCallback callback); - - /// Unregisters all registered events thus far. Note: not thread unsafe - void UnregisterAllEvents(); - /// After the first Advance, the slice lengths and the downcount will be reduced whenever an /// event is scheduled earlier than the current values. /// /// Scheduling from a callback will not update the downcount until the Advance() completes. - void ScheduleEvent(s64 cycles_into_future, const EventType* event_type, u64 userdata = 0); + void ScheduleEvent(s64 cycles_into_future, const std::shared_ptr<EventType>& event_type, + u64 userdata = 0); - void UnscheduleEvent(const EventType* event_type, u64 userdata); + void UnscheduleEvent(const std::shared_ptr<EventType>& event_type, u64 userdata); /// We only permit one event of each type in the queue at a time. - void RemoveEvent(const EventType* event_type); + void RemoveEvent(const std::shared_ptr<EventType>& event_type); void ForceExceptionCheck(s64 cycles); @@ -148,13 +138,18 @@ private: std::vector<Event> event_queue; u64 event_fifo_id = 0; - // Stores each element separately as a linked list node so pointers to elements - // remain stable regardless of rehashes/resizing. - std::unordered_map<std::string, EventType> event_types; - - EventType* ev_lost = nullptr; + std::shared_ptr<EventType> ev_lost; std::mutex inner_mutex; }; +/// Creates a core timing event with the given name and callback. +/// +/// @param name The name of the core timing event to create. +/// @param callback The callback to execute for the event. +/// +/// @returns An EventType instance representing the created event. +/// +std::shared_ptr<EventType> CreateEvent(std::string name, TimedCallback&& callback); + } // namespace Core::Timing diff --git a/src/core/cpu_core_manager.cpp b/src/core/cpu_core_manager.cpp index 8efd410bb..f04a34133 100644 --- a/src/core/cpu_core_manager.cpp +++ b/src/core/cpu_core_manager.cpp @@ -25,7 +25,7 @@ CpuCoreManager::~CpuCoreManager() = default; void CpuCoreManager::Initialize() { barrier = std::make_unique<CpuBarrier>(); - exclusive_monitor = Cpu::MakeExclusiveMonitor(cores.size()); + exclusive_monitor = Cpu::MakeExclusiveMonitor(system.Memory(), cores.size()); for (std::size_t index = 0; index < cores.size(); ++index) { cores[index] = std::make_unique<Cpu>(system, *exclusive_monitor, *barrier, index); diff --git a/src/core/crypto/key_manager.cpp b/src/core/crypto/key_manager.cpp index 222fc95ba..87e6a1fd3 100644 --- a/src/core/crypto/key_manager.cpp +++ b/src/core/crypto/key_manager.cpp @@ -22,6 +22,7 @@ #include "common/file_util.h" #include "common/hex_util.h" #include "common/logging/log.h" +#include "common/string_util.h" #include "core/core.h" #include "core/crypto/aes_util.h" #include "core/crypto/key_manager.h" @@ -378,8 +379,9 @@ std::vector<Ticket> GetTicketblob(const FileUtil::IOFile& ticket_save) { template <size_t size> static std::array<u8, size> operator^(const std::array<u8, size>& lhs, const std::array<u8, size>& rhs) { - std::array<u8, size> out{}; - std::transform(lhs.begin(), lhs.end(), rhs.begin(), out.begin(), std::bit_xor<>()); + std::array<u8, size> out; + std::transform(lhs.begin(), lhs.end(), rhs.begin(), out.begin(), + [](u8 lhs, u8 rhs) { return u8(lhs ^ rhs); }); return out; } @@ -396,7 +398,7 @@ static std::array<u8, target_size> MGF1(const std::array<u8, in_size>& seed) { while (out.size() < target_size) { out.resize(out.size() + 0x20); seed_exp[in_size + 3] = static_cast<u8>(i); - mbedtls_sha256(seed_exp.data(), seed_exp.size(), out.data() + out.size() - 0x20, 0); + mbedtls_sha256_ret(seed_exp.data(), seed_exp.size(), out.data() + out.size() - 0x20, 0); ++i; } @@ -538,7 +540,7 @@ void KeyManager::LoadFromFile(const std::string& filename, bool is_title_keys) { Key128 key = Common::HexStringToArray<16>(out[1]); s128_keys[{S128KeyType::Titlekey, rights_id[1], rights_id[0]}] = key; } else { - std::transform(out[0].begin(), out[0].end(), out[0].begin(), ::tolower); + out[0] = Common::ToLower(out[0]); if (s128_file_id.find(out[0]) != s128_file_id.end()) { const auto index = s128_file_id.at(out[0]); Key128 key = Common::HexStringToArray<16>(out[1]); @@ -668,23 +670,27 @@ void KeyManager::WriteKeyToFile(KeyCategory category, std::string_view keyname, const std::array<u8, Size>& key) { const std::string yuzu_keys_dir = FileUtil::GetUserPath(FileUtil::UserPath::KeysDir); std::string filename = "title.keys_autogenerated"; - if (category == KeyCategory::Standard) + if (category == KeyCategory::Standard) { filename = dev_mode ? "dev.keys_autogenerated" : "prod.keys_autogenerated"; - else if (category == KeyCategory::Console) + } else if (category == KeyCategory::Console) { filename = "console.keys_autogenerated"; - const auto add_info_text = !FileUtil::Exists(yuzu_keys_dir + DIR_SEP + filename); - FileUtil::CreateFullPath(yuzu_keys_dir + DIR_SEP + filename); - std::ofstream file(yuzu_keys_dir + DIR_SEP + filename, std::ios::app); - if (!file.is_open()) + } + + const auto path = yuzu_keys_dir + DIR_SEP + filename; + const auto add_info_text = !FileUtil::Exists(path); + FileUtil::CreateFullPath(path); + FileUtil::IOFile file{path, "a"}; + if (!file.IsOpen()) { return; + } if (add_info_text) { - file - << "# This file is autogenerated by Yuzu\n" - << "# It serves to store keys that were automatically generated from the normal keys\n" - << "# If you are experiencing issues involving keys, it may help to delete this file\n"; + file.WriteString( + "# This file is autogenerated by Yuzu\n" + "# It serves to store keys that were automatically generated from the normal keys\n" + "# If you are experiencing issues involving keys, it may help to delete this file\n"); } - file << fmt::format("\n{} = {}", keyname, Common::HexToString(key)); + file.WriteString(fmt::format("\n{} = {}", keyname, Common::HexToString(key))); AttemptLoadKeyFile(yuzu_keys_dir, yuzu_keys_dir, filename, category == KeyCategory::Title); } @@ -944,12 +950,10 @@ void KeyManager::DeriveETicket(PartitionDataManager& data) { return; } - Key128 rsa_oaep_kek{}; - std::transform(seed3.begin(), seed3.end(), mask0.begin(), rsa_oaep_kek.begin(), - std::bit_xor<>()); - - if (rsa_oaep_kek == Key128{}) + const Key128 rsa_oaep_kek = seed3 ^ mask0; + if (rsa_oaep_kek == Key128{}) { return; + } SetKey(S128KeyType::Source, rsa_oaep_kek, static_cast<u64>(SourceKeyType::RSAOaepKekGeneration)); diff --git a/src/core/crypto/partition_data_manager.cpp b/src/core/crypto/partition_data_manager.cpp index 594cd82c5..d64302f2e 100644 --- a/src/core/crypto/partition_data_manager.cpp +++ b/src/core/crypto/partition_data_manager.cpp @@ -161,7 +161,7 @@ std::array<u8, key_size> FindKeyFromHex(const std::vector<u8>& binary, std::array<u8, 0x20> temp{}; for (size_t i = 0; i < binary.size() - key_size; ++i) { - mbedtls_sha256(binary.data() + i, key_size, temp.data(), 0); + mbedtls_sha256_ret(binary.data() + i, key_size, temp.data(), 0); if (temp != hash) continue; @@ -189,7 +189,7 @@ static std::array<Key128, 0x20> FindEncryptedMasterKeyFromHex(const std::vector< AESCipher<Key128> cipher(key, Mode::ECB); for (size_t i = 0; i < binary.size() - 0x10; ++i) { cipher.Transcode(binary.data() + i, dec_temp.size(), dec_temp.data(), Op::Decrypt); - mbedtls_sha256(dec_temp.data(), dec_temp.size(), temp.data(), 0); + mbedtls_sha256_ret(dec_temp.data(), dec_temp.size(), temp.data(), 0); for (size_t k = 0; k < out.size(); ++k) { if (temp == master_key_hashes[k]) { @@ -204,11 +204,12 @@ static std::array<Key128, 0x20> FindEncryptedMasterKeyFromHex(const std::vector< FileSys::VirtualFile FindFileInDirWithNames(const FileSys::VirtualDir& dir, const std::string& name) { - auto upper = name; - std::transform(upper.begin(), upper.end(), upper.begin(), [](u8 c) { return std::toupper(c); }); + const auto upper = Common::ToUpper(name); + for (const auto& fname : {name, name + ".bin", upper, upper + ".BIN"}) { - if (dir->GetFile(fname) != nullptr) + if (dir->GetFile(fname) != nullptr) { return dir->GetFile(fname); + } } return nullptr; diff --git a/src/core/file_sys/directory.h b/src/core/file_sys/directory.h index 7b5c509fb..0d73eecc9 100644 --- a/src/core/file_sys/directory.h +++ b/src/core/file_sys/directory.h @@ -15,7 +15,7 @@ namespace FileSys { -enum EntryType : u8 { +enum class EntryType : u8 { Directory = 0, File = 1, }; diff --git a/src/core/file_sys/kernel_executable.cpp b/src/core/file_sys/kernel_executable.cpp index 371300684..76313679d 100644 --- a/src/core/file_sys/kernel_executable.cpp +++ b/src/core/file_sys/kernel_executable.cpp @@ -147,7 +147,7 @@ std::vector<u32> KIP::GetKernelCapabilities() const { } s32 KIP::GetMainThreadPriority() const { - return header.main_thread_priority; + return static_cast<s32>(header.main_thread_priority); } u32 KIP::GetMainThreadStackSize() const { diff --git a/src/core/file_sys/patch_manager.cpp b/src/core/file_sys/patch_manager.cpp index df0ecb15c..e226e9711 100644 --- a/src/core/file_sys/patch_manager.cpp +++ b/src/core/file_sys/patch_manager.cpp @@ -76,7 +76,7 @@ VirtualDir PatchManager::PatchExeFS(VirtualDir exefs) const { const auto& disabled = Settings::values.disabled_addons[title_id]; const auto update_disabled = - std::find(disabled.begin(), disabled.end(), "Update") != disabled.end(); + std::find(disabled.cbegin(), disabled.cend(), "Update") != disabled.cend(); // Game Updates const auto update_tid = GetUpdateTitleID(title_id); @@ -127,7 +127,7 @@ std::vector<VirtualFile> PatchManager::CollectPatches(const std::vector<VirtualD std::vector<VirtualFile> out; out.reserve(patch_dirs.size()); for (const auto& subdir : patch_dirs) { - if (std::find(disabled.begin(), disabled.end(), subdir->GetName()) != disabled.end()) + if (std::find(disabled.cbegin(), disabled.cend(), subdir->GetName()) != disabled.cend()) continue; auto exefs_dir = subdir->GetSubdirectory("exefs"); @@ -284,12 +284,17 @@ std::vector<Memory::CheatEntry> PatchManager::CreateCheatList( return {}; } + const auto& disabled = Settings::values.disabled_addons[title_id]; auto patch_dirs = load_dir->GetSubdirectories(); std::sort(patch_dirs.begin(), patch_dirs.end(), [](const VirtualDir& l, const VirtualDir& r) { return l->GetName() < r->GetName(); }); std::vector<Memory::CheatEntry> out; for (const auto& subdir : patch_dirs) { + if (std::find(disabled.cbegin(), disabled.cend(), subdir->GetName()) != disabled.cend()) { + continue; + } + auto cheats_dir = subdir->GetSubdirectory("cheats"); if (cheats_dir != nullptr) { auto res = ReadCheatFileFromFolder(system, title_id, build_id_, cheats_dir, true); @@ -331,8 +336,9 @@ static void ApplyLayeredFS(VirtualFile& romfs, u64 title_id, ContentRecordType t layers.reserve(patch_dirs.size() + 1); layers_ext.reserve(patch_dirs.size() + 1); for (const auto& subdir : patch_dirs) { - if (std::find(disabled.begin(), disabled.end(), subdir->GetName()) != disabled.end()) + if (std::find(disabled.cbegin(), disabled.cend(), subdir->GetName()) != disabled.cend()) { continue; + } auto romfs_dir = subdir->GetSubdirectory("romfs"); if (romfs_dir != nullptr) @@ -381,7 +387,7 @@ VirtualFile PatchManager::PatchRomFS(VirtualFile romfs, u64 ivfc_offset, Content const auto& disabled = Settings::values.disabled_addons[title_id]; const auto update_disabled = - std::find(disabled.begin(), disabled.end(), "Update") != disabled.end(); + std::find(disabled.cbegin(), disabled.cend(), "Update") != disabled.cend(); if (!update_disabled && update != nullptr) { const auto new_nca = std::make_shared<NCA>(update, romfs, ivfc_offset); @@ -431,7 +437,7 @@ std::map<std::string, std::string, std::less<>> PatchManager::GetPatchVersionNam auto [nacp, discard_icon_file] = update.GetControlMetadata(); const auto update_disabled = - std::find(disabled.begin(), disabled.end(), "Update") != disabled.end(); + std::find(disabled.cbegin(), disabled.cend(), "Update") != disabled.cend(); const auto update_label = update_disabled ? "[D] Update" : "Update"; if (nacp != nullptr) { diff --git a/src/core/file_sys/program_metadata.cpp b/src/core/file_sys/program_metadata.cpp index 7310b3602..1d6c30962 100644 --- a/src/core/file_sys/program_metadata.cpp +++ b/src/core/file_sys/program_metadata.cpp @@ -52,14 +52,14 @@ Loader::ResultStatus ProgramMetadata::Load(VirtualFile file) { } void ProgramMetadata::LoadManual(bool is_64_bit, ProgramAddressSpaceType address_space, - u8 main_thread_prio, u8 main_thread_core, + s32 main_thread_prio, u32 main_thread_core, u32 main_thread_stack_size, u64 title_id, u64 filesystem_permissions, KernelCapabilityDescriptors capabilities) { npdm_header.has_64_bit_instructions.Assign(is_64_bit); npdm_header.address_space_type.Assign(address_space); - npdm_header.main_thread_priority = main_thread_prio; - npdm_header.main_thread_cpu = main_thread_core; + npdm_header.main_thread_priority = static_cast<u8>(main_thread_prio); + npdm_header.main_thread_cpu = static_cast<u8>(main_thread_core); npdm_header.main_stack_size = main_thread_stack_size; aci_header.title_id = title_id; aci_file_access.permissions = filesystem_permissions; diff --git a/src/core/file_sys/program_metadata.h b/src/core/file_sys/program_metadata.h index 88ec97d85..f8759a396 100644 --- a/src/core/file_sys/program_metadata.h +++ b/src/core/file_sys/program_metadata.h @@ -47,8 +47,8 @@ public: Loader::ResultStatus Load(VirtualFile file); // Load from parameters instead of NPDM file, used for KIP - void LoadManual(bool is_64_bit, ProgramAddressSpaceType address_space, u8 main_thread_prio, - u8 main_thread_core, u32 main_thread_stack_size, u64 title_id, + void LoadManual(bool is_64_bit, ProgramAddressSpaceType address_space, s32 main_thread_prio, + u32 main_thread_core, u32 main_thread_stack_size, u64 title_id, u64 filesystem_permissions, KernelCapabilityDescriptors capabilities); bool Is64BitProgram() const; diff --git a/src/core/file_sys/registered_cache.cpp b/src/core/file_sys/registered_cache.cpp index ac3fbd849..6e9cf67ef 100644 --- a/src/core/file_sys/registered_cache.cpp +++ b/src/core/file_sys/registered_cache.cpp @@ -62,7 +62,7 @@ static std::string GetRelativePathFromNcaID(const std::array<u8, 16>& nca_id, bo Common::HexToString(nca_id, second_hex_upper)); Core::Crypto::SHA256Hash hash{}; - mbedtls_sha256(nca_id.data(), nca_id.size(), hash.data(), 0); + mbedtls_sha256_ret(nca_id.data(), nca_id.size(), hash.data(), 0); return fmt::format(cnmt_suffix ? "/000000{:02X}/{}.cnmt.nca" : "/000000{:02X}/{}.nca", hash[0], Common::HexToString(nca_id, second_hex_upper)); } @@ -141,7 +141,7 @@ bool PlaceholderCache::Create(const NcaID& id, u64 size) const { } Core::Crypto::SHA256Hash hash{}; - mbedtls_sha256(id.data(), id.size(), hash.data(), 0); + mbedtls_sha256_ret(id.data(), id.size(), hash.data(), 0); const auto dirname = fmt::format("000000{:02X}", hash[0]); const auto dir2 = GetOrCreateDirectoryRelative(dir, dirname); @@ -165,7 +165,7 @@ bool PlaceholderCache::Delete(const NcaID& id) const { } Core::Crypto::SHA256Hash hash{}; - mbedtls_sha256(id.data(), id.size(), hash.data(), 0); + mbedtls_sha256_ret(id.data(), id.size(), hash.data(), 0); const auto dirname = fmt::format("000000{:02X}", hash[0]); const auto dir2 = GetOrCreateDirectoryRelative(dir, dirname); @@ -603,7 +603,7 @@ InstallResult RegisteredCache::InstallEntry(const NCA& nca, TitleType type, OptionalHeader opt_header{0, 0}; ContentRecord c_rec{{}, {}, {}, GetCRTypeFromNCAType(nca.GetType()), {}}; const auto& data = nca.GetBaseFile()->ReadBytes(0x100000); - mbedtls_sha256(data.data(), data.size(), c_rec.hash.data(), 0); + mbedtls_sha256_ret(data.data(), data.size(), c_rec.hash.data(), 0); memcpy(&c_rec.nca_id, &c_rec.hash, 16); const CNMT new_cnmt(header, opt_header, {c_rec}, {}); if (!RawInstallYuzuMeta(new_cnmt)) @@ -626,7 +626,7 @@ InstallResult RegisteredCache::RawInstallNCA(const NCA& nca, const VfsCopyFuncti id = *override_id; } else { const auto& data = in->ReadBytes(0x100000); - mbedtls_sha256(data.data(), data.size(), hash.data(), 0); + mbedtls_sha256_ret(data.data(), data.size(), hash.data(), 0); memcpy(id.data(), hash.data(), 16); } diff --git a/src/core/file_sys/romfs.cpp b/src/core/file_sys/romfs.cpp index ebbdf081e..c909d1ce4 100644 --- a/src/core/file_sys/romfs.cpp +++ b/src/core/file_sys/romfs.cpp @@ -2,6 +2,8 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. +#include <memory> + #include "common/common_types.h" #include "common/swap.h" #include "core/file_sys/fsmitm_romfsbuild.h" @@ -12,7 +14,7 @@ #include "core/file_sys/vfs_vector.h" namespace FileSys { - +namespace { constexpr u32 ROMFS_ENTRY_EMPTY = 0xFFFFFFFF; struct TableLocation { @@ -51,7 +53,7 @@ struct FileEntry { static_assert(sizeof(FileEntry) == 0x20, "FileEntry has incorrect size."); template <typename Entry> -static std::pair<Entry, std::string> GetEntry(const VirtualFile& file, std::size_t offset) { +std::pair<Entry, std::string> GetEntry(const VirtualFile& file, std::size_t offset) { Entry entry{}; if (file->ReadObject(&entry, offset) != sizeof(Entry)) return {}; @@ -99,6 +101,7 @@ void ProcessDirectory(VirtualFile file, std::size_t dir_offset, std::size_t file this_dir_offset = entry.first.sibling; } } +} // Anonymous namespace VirtualDir ExtractRomFS(VirtualFile file, RomFSExtractionType type) { RomFSHeader header{}; diff --git a/src/core/file_sys/romfs.h b/src/core/file_sys/romfs.h index 1c89be8a4..2fd07ed04 100644 --- a/src/core/file_sys/romfs.h +++ b/src/core/file_sys/romfs.h @@ -5,10 +5,6 @@ #pragma once #include <array> -#include <map> -#include "common/common_funcs.h" -#include "common/common_types.h" -#include "common/swap.h" #include "core/file_sys/vfs.h" namespace FileSys { diff --git a/src/core/file_sys/romfs_factory.cpp b/src/core/file_sys/romfs_factory.cpp index 4bd2e6183..418a39a7e 100644 --- a/src/core/file_sys/romfs_factory.cpp +++ b/src/core/file_sys/romfs_factory.cpp @@ -71,12 +71,12 @@ ResultVal<VirtualFile> RomFSFactory::Open(u64 title_id, StorageId storage, if (res == nullptr) { // TODO(DarkLordZach): Find the right error code to use here - return ResultCode(-1); + return RESULT_UNKNOWN; } const auto romfs = res->GetRomFS(); if (romfs == nullptr) { // TODO(DarkLordZach): Find the right error code to use here - return ResultCode(-1); + return RESULT_UNKNOWN; } return MakeResult<VirtualFile>(romfs); } diff --git a/src/core/file_sys/savedata_factory.cpp b/src/core/file_sys/savedata_factory.cpp index e2a7eaf7b..f3def93ab 100644 --- a/src/core/file_sys/savedata_factory.cpp +++ b/src/core/file_sys/savedata_factory.cpp @@ -90,7 +90,7 @@ ResultVal<VirtualDir> SaveDataFactory::Create(SaveDataSpaceId space, // Return an error if the save data doesn't actually exist. if (out == nullptr) { // TODO(DarkLordZach): Find out correct error code. - return ResultCode(-1); + return RESULT_UNKNOWN; } return MakeResult<VirtualDir>(std::move(out)); @@ -111,7 +111,7 @@ ResultVal<VirtualDir> SaveDataFactory::Open(SaveDataSpaceId space, // Return an error if the save data doesn't actually exist. if (out == nullptr) { // TODO(Subv): Find out correct error code. - return ResultCode(-1); + return RESULT_UNKNOWN; } return MakeResult<VirtualDir>(std::move(out)); diff --git a/src/core/file_sys/vfs_libzip.cpp b/src/core/file_sys/vfs_libzip.cpp index 8bdaa7e4a..11d1978ea 100644 --- a/src/core/file_sys/vfs_libzip.cpp +++ b/src/core/file_sys/vfs_libzip.cpp @@ -27,7 +27,7 @@ VirtualDir ExtractZIP(VirtualFile file) { std::shared_ptr<VectorVfsDirectory> out = std::make_shared<VectorVfsDirectory>(); - const auto num_entries = zip_get_num_entries(zip.get(), 0); + const auto num_entries = static_cast<std::size_t>(zip_get_num_entries(zip.get(), 0)); zip_stat_t stat{}; zip_stat_init(&stat); diff --git a/src/core/file_sys/xts_archive.cpp b/src/core/file_sys/xts_archive.cpp index f5f8b91c9..86e06ccb9 100644 --- a/src/core/file_sys/xts_archive.cpp +++ b/src/core/file_sys/xts_archive.cpp @@ -7,12 +7,13 @@ #include <cstring> #include <regex> #include <string> + #include <mbedtls/md.h> #include <mbedtls/sha256.h> -#include "common/assert.h" + #include "common/file_util.h" #include "common/hex_util.h" -#include "common/logging/log.h" +#include "common/string_util.h" #include "core/crypto/aes_util.h" #include "core/crypto/xts_encryption_layer.h" #include "core/file_sys/partition_filesystem.h" @@ -53,18 +54,15 @@ NAX::NAX(VirtualFile file_) : header(std::make_unique<NAXHeader>()), file(std::m return; } - std::string two_dir = match[1]; - std::string nca_id = match[2]; - std::transform(two_dir.begin(), two_dir.end(), two_dir.begin(), ::toupper); - std::transform(nca_id.begin(), nca_id.end(), nca_id.begin(), ::tolower); - + const std::string two_dir = Common::ToUpper(match[1]); + const std::string nca_id = Common::ToLower(match[2]); status = Parse(fmt::format("/registered/{}/{}.nca", two_dir, nca_id)); } NAX::NAX(VirtualFile file_, std::array<u8, 0x10> nca_id) : header(std::make_unique<NAXHeader>()), file(std::move(file_)) { Core::Crypto::SHA256Hash hash{}; - mbedtls_sha256(nca_id.data(), nca_id.size(), hash.data(), 0); + mbedtls_sha256_ret(nca_id.data(), nca_id.size(), hash.data(), 0); status = Parse(fmt::format("/registered/000000{:02X}/{}.nca", hash[0], Common::HexToString(nca_id, false))); } diff --git a/src/core/gdbstub/gdbstub.cpp b/src/core/gdbstub/gdbstub.cpp index 20bb50868..37cb28848 100644 --- a/src/core/gdbstub/gdbstub.cpp +++ b/src/core/gdbstub/gdbstub.cpp @@ -468,7 +468,8 @@ static u8 ReadByte() { /// Calculate the checksum of the current command buffer. static u8 CalculateChecksum(const u8* buffer, std::size_t length) { - return static_cast<u8>(std::accumulate(buffer, buffer + length, 0, std::plus<u8>())); + return static_cast<u8>(std::accumulate(buffer, buffer + length, u8{0}, + [](u8 lhs, u8 rhs) { return u8(lhs + rhs); })); } /** @@ -507,8 +508,9 @@ static void RemoveBreakpoint(BreakpointType type, VAddr addr) { bp->second.len, bp->second.addr, static_cast<int>(type)); if (type == BreakpointType::Execute) { - Memory::WriteBlock(bp->second.addr, bp->second.inst.data(), bp->second.inst.size()); - Core::System::GetInstance().InvalidateCpuInstructionCaches(); + auto& system = Core::System::GetInstance(); + system.Memory().WriteBlock(bp->second.addr, bp->second.inst.data(), bp->second.inst.size()); + system.InvalidateCpuInstructionCaches(); } p.erase(addr); } @@ -968,12 +970,13 @@ static void ReadMemory() { SendReply("E01"); } - if (!Memory::IsValidVirtualAddress(addr)) { + auto& memory = Core::System::GetInstance().Memory(); + if (!memory.IsValidVirtualAddress(addr)) { return SendReply("E00"); } std::vector<u8> data(len); - Memory::ReadBlock(addr, data.data(), len); + memory.ReadBlock(addr, data.data(), len); MemToGdbHex(reply, data.data(), len); reply[len * 2] = '\0'; @@ -983,22 +986,23 @@ static void ReadMemory() { /// Modify location in memory with data received from the gdb client. static void WriteMemory() { auto start_offset = command_buffer + 1; - auto addr_pos = std::find(start_offset, command_buffer + command_length, ','); - VAddr addr = HexToLong(start_offset, static_cast<u64>(addr_pos - start_offset)); + const auto addr_pos = std::find(start_offset, command_buffer + command_length, ','); + const VAddr addr = HexToLong(start_offset, static_cast<u64>(addr_pos - start_offset)); start_offset = addr_pos + 1; - auto len_pos = std::find(start_offset, command_buffer + command_length, ':'); - u64 len = HexToLong(start_offset, static_cast<u64>(len_pos - start_offset)); + const auto len_pos = std::find(start_offset, command_buffer + command_length, ':'); + const u64 len = HexToLong(start_offset, static_cast<u64>(len_pos - start_offset)); - if (!Memory::IsValidVirtualAddress(addr)) { + auto& system = Core::System::GetInstance(); + auto& memory = system.Memory(); + if (!memory.IsValidVirtualAddress(addr)) { return SendReply("E00"); } std::vector<u8> data(len); - GdbHexToMem(data.data(), len_pos + 1, len); - Memory::WriteBlock(addr, data.data(), len); - Core::System::GetInstance().InvalidateCpuInstructionCaches(); + memory.WriteBlock(addr, data.data(), len); + system.InvalidateCpuInstructionCaches(); SendReply("OK"); } @@ -1054,12 +1058,15 @@ static bool CommitBreakpoint(BreakpointType type, VAddr addr, u64 len) { breakpoint.active = true; breakpoint.addr = addr; breakpoint.len = len; - Memory::ReadBlock(addr, breakpoint.inst.data(), breakpoint.inst.size()); + + auto& system = Core::System::GetInstance(); + auto& memory = system.Memory(); + memory.ReadBlock(addr, breakpoint.inst.data(), breakpoint.inst.size()); static constexpr std::array<u8, 4> btrap{0x00, 0x7d, 0x20, 0xd4}; if (type == BreakpointType::Execute) { - Memory::WriteBlock(addr, btrap.data(), btrap.size()); - Core::System::GetInstance().InvalidateCpuInstructionCaches(); + memory.WriteBlock(addr, btrap.data(), btrap.size()); + system.InvalidateCpuInstructionCaches(); } p.insert({addr, breakpoint}); diff --git a/src/core/hardware_interrupt_manager.cpp b/src/core/hardware_interrupt_manager.cpp index c2115db2d..c629d9fa1 100644 --- a/src/core/hardware_interrupt_manager.cpp +++ b/src/core/hardware_interrupt_manager.cpp @@ -11,13 +11,12 @@ namespace Core::Hardware { InterruptManager::InterruptManager(Core::System& system_in) : system(system_in) { - gpu_interrupt_event = - system.CoreTiming().RegisterEvent("GPUInterrupt", [this](u64 message, s64) { - auto nvdrv = system.ServiceManager().GetService<Service::Nvidia::NVDRV>("nvdrv"); - const u32 syncpt = static_cast<u32>(message >> 32); - const u32 value = static_cast<u32>(message); - nvdrv->SignalGPUInterruptSyncpt(syncpt, value); - }); + gpu_interrupt_event = Core::Timing::CreateEvent("GPUInterrupt", [this](u64 message, s64) { + auto nvdrv = system.ServiceManager().GetService<Service::Nvidia::NVDRV>("nvdrv"); + const u32 syncpt = static_cast<u32>(message >> 32); + const u32 value = static_cast<u32>(message); + nvdrv->SignalGPUInterruptSyncpt(syncpt, value); + }); } InterruptManager::~InterruptManager() = default; diff --git a/src/core/hardware_interrupt_manager.h b/src/core/hardware_interrupt_manager.h index 494db883a..5fa306ae0 100644 --- a/src/core/hardware_interrupt_manager.h +++ b/src/core/hardware_interrupt_manager.h @@ -4,6 +4,8 @@ #pragma once +#include <memory> + #include "common/common_types.h" namespace Core { @@ -25,7 +27,7 @@ public: private: Core::System& system; - Core::Timing::EventType* gpu_interrupt_event{}; + std::shared_ptr<Core::Timing::EventType> gpu_interrupt_event; }; } // namespace Core::Hardware diff --git a/src/core/hle/ipc_helpers.h b/src/core/hle/ipc_helpers.h index 5bb139483..0dc6a4a43 100644 --- a/src/core/hle/ipc_helpers.h +++ b/src/core/hle/ipc_helpers.h @@ -19,6 +19,7 @@ #include "core/hle/kernel/hle_ipc.h" #include "core/hle/kernel/object.h" #include "core/hle/kernel/server_session.h" +#include "core/hle/kernel/session.h" #include "core/hle/result.h" namespace IPC { @@ -139,10 +140,9 @@ public: context->AddDomainObject(std::move(iface)); } else { auto& kernel = Core::System::GetInstance().Kernel(); - auto [server, client] = - Kernel::ServerSession::CreateSessionPair(kernel, iface->GetServiceName()); - iface->ClientConnected(server); + auto [client, server] = Kernel::Session::Create(kernel, iface->GetServiceName()); context->AddMoveObject(std::move(client)); + iface->ClientConnected(std::move(server)); } } @@ -203,10 +203,10 @@ public: void PushRaw(const T& value); template <typename... O> - void PushMoveObjects(Kernel::SharedPtr<O>... pointers); + void PushMoveObjects(std::shared_ptr<O>... pointers); template <typename... O> - void PushCopyObjects(Kernel::SharedPtr<O>... pointers); + void PushCopyObjects(std::shared_ptr<O>... pointers); private: u32 normal_params_size{}; @@ -298,7 +298,7 @@ void ResponseBuilder::Push(const First& first_value, const Other&... other_value } template <typename... O> -inline void ResponseBuilder::PushCopyObjects(Kernel::SharedPtr<O>... pointers) { +inline void ResponseBuilder::PushCopyObjects(std::shared_ptr<O>... pointers) { auto objects = {pointers...}; for (auto& object : objects) { context->AddCopyObject(std::move(object)); @@ -306,7 +306,7 @@ inline void ResponseBuilder::PushCopyObjects(Kernel::SharedPtr<O>... pointers) { } template <typename... O> -inline void ResponseBuilder::PushMoveObjects(Kernel::SharedPtr<O>... pointers) { +inline void ResponseBuilder::PushMoveObjects(std::shared_ptr<O>... pointers) { auto objects = {pointers...}; for (auto& object : objects) { context->AddMoveObject(std::move(object)); @@ -357,10 +357,10 @@ public: T PopRaw(); template <typename T> - Kernel::SharedPtr<T> GetMoveObject(std::size_t index); + std::shared_ptr<T> GetMoveObject(std::size_t index); template <typename T> - Kernel::SharedPtr<T> GetCopyObject(std::size_t index); + std::shared_ptr<T> GetCopyObject(std::size_t index); template <class T> std::shared_ptr<T> PopIpcInterface() { @@ -465,12 +465,12 @@ void RequestParser::Pop(First& first_value, Other&... other_values) { } template <typename T> -Kernel::SharedPtr<T> RequestParser::GetMoveObject(std::size_t index) { +std::shared_ptr<T> RequestParser::GetMoveObject(std::size_t index) { return context->GetMoveObject<T>(index); } template <typename T> -Kernel::SharedPtr<T> RequestParser::GetCopyObject(std::size_t index) { +std::shared_ptr<T> RequestParser::GetCopyObject(std::size_t index) { return context->GetCopyObject<T>(index); } diff --git a/src/core/hle/kernel/address_arbiter.cpp b/src/core/hle/kernel/address_arbiter.cpp index de0a9064e..98d07fa5b 100644 --- a/src/core/hle/kernel/address_arbiter.cpp +++ b/src/core/hle/kernel/address_arbiter.cpp @@ -21,7 +21,7 @@ namespace Kernel { namespace { // Wake up num_to_wake (or all) threads in a vector. -void WakeThreads(const std::vector<SharedPtr<Thread>>& waiting_threads, s32 num_to_wake) { +void WakeThreads(const std::vector<std::shared_ptr<Thread>>& waiting_threads, s32 num_to_wake) { auto& system = Core::System::GetInstance(); // Only process up to 'target' threads, unless 'target' is <= 0, in which case process // them all. @@ -59,35 +59,41 @@ ResultCode AddressArbiter::SignalToAddress(VAddr address, SignalType type, s32 v } ResultCode AddressArbiter::SignalToAddressOnly(VAddr address, s32 num_to_wake) { - const std::vector<SharedPtr<Thread>> waiting_threads = GetThreadsWaitingOnAddress(address); + const std::vector<std::shared_ptr<Thread>> waiting_threads = + GetThreadsWaitingOnAddress(address); WakeThreads(waiting_threads, num_to_wake); return RESULT_SUCCESS; } ResultCode AddressArbiter::IncrementAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake) { + auto& memory = system.Memory(); + // Ensure that we can write to the address. - if (!Memory::IsValidVirtualAddress(address)) { + if (!memory.IsValidVirtualAddress(address)) { return ERR_INVALID_ADDRESS_STATE; } - if (static_cast<s32>(Memory::Read32(address)) != value) { + if (static_cast<s32>(memory.Read32(address)) != value) { return ERR_INVALID_STATE; } - Memory::Write32(address, static_cast<u32>(value + 1)); + memory.Write32(address, static_cast<u32>(value + 1)); return SignalToAddressOnly(address, num_to_wake); } ResultCode AddressArbiter::ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake) { + auto& memory = system.Memory(); + // Ensure that we can write to the address. - if (!Memory::IsValidVirtualAddress(address)) { + if (!memory.IsValidVirtualAddress(address)) { return ERR_INVALID_ADDRESS_STATE; } // Get threads waiting on the address. - const std::vector<SharedPtr<Thread>> waiting_threads = GetThreadsWaitingOnAddress(address); + const std::vector<std::shared_ptr<Thread>> waiting_threads = + GetThreadsWaitingOnAddress(address); // Determine the modified value depending on the waiting count. s32 updated_value; @@ -107,11 +113,11 @@ ResultCode AddressArbiter::ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr a } } - if (static_cast<s32>(Memory::Read32(address)) != value) { + if (static_cast<s32>(memory.Read32(address)) != value) { return ERR_INVALID_STATE; } - Memory::Write32(address, static_cast<u32>(updated_value)); + memory.Write32(address, static_cast<u32>(updated_value)); WakeThreads(waiting_threads, num_to_wake); return RESULT_SUCCESS; } @@ -132,18 +138,20 @@ ResultCode AddressArbiter::WaitForAddress(VAddr address, ArbitrationType type, s ResultCode AddressArbiter::WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout, bool should_decrement) { + auto& memory = system.Memory(); + // Ensure that we can read the address. - if (!Memory::IsValidVirtualAddress(address)) { + if (!memory.IsValidVirtualAddress(address)) { return ERR_INVALID_ADDRESS_STATE; } - const s32 cur_value = static_cast<s32>(Memory::Read32(address)); + const s32 cur_value = static_cast<s32>(memory.Read32(address)); if (cur_value >= value) { return ERR_INVALID_STATE; } if (should_decrement) { - Memory::Write32(address, static_cast<u32>(cur_value - 1)); + memory.Write32(address, static_cast<u32>(cur_value - 1)); } // Short-circuit without rescheduling, if timeout is zero. @@ -155,15 +163,19 @@ ResultCode AddressArbiter::WaitForAddressIfLessThan(VAddr address, s32 value, s6 } ResultCode AddressArbiter::WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout) { + auto& memory = system.Memory(); + // Ensure that we can read the address. - if (!Memory::IsValidVirtualAddress(address)) { + if (!memory.IsValidVirtualAddress(address)) { return ERR_INVALID_ADDRESS_STATE; } + // Only wait for the address if equal. - if (static_cast<s32>(Memory::Read32(address)) != value) { + if (static_cast<s32>(memory.Read32(address)) != value) { return ERR_INVALID_STATE; } - // Short-circuit without rescheduling, if timeout is zero. + + // Short-circuit without rescheduling if timeout is zero. if (timeout == 0) { return RESULT_TIMEOUT; } @@ -172,21 +184,21 @@ ResultCode AddressArbiter::WaitForAddressIfEqual(VAddr address, s32 value, s64 t } ResultCode AddressArbiter::WaitForAddressImpl(VAddr address, s64 timeout) { - SharedPtr<Thread> current_thread = system.CurrentScheduler().GetCurrentThread(); + Thread* current_thread = system.CurrentScheduler().GetCurrentThread(); current_thread->SetArbiterWaitAddress(address); current_thread->SetStatus(ThreadStatus::WaitArb); current_thread->InvalidateWakeupCallback(); - current_thread->WakeAfterDelay(timeout); system.PrepareReschedule(current_thread->GetProcessorID()); return RESULT_TIMEOUT; } -std::vector<SharedPtr<Thread>> AddressArbiter::GetThreadsWaitingOnAddress(VAddr address) const { +std::vector<std::shared_ptr<Thread>> AddressArbiter::GetThreadsWaitingOnAddress( + VAddr address) const { // Retrieve all threads that are waiting for this address. - std::vector<SharedPtr<Thread>> threads; + std::vector<std::shared_ptr<Thread>> threads; const auto& scheduler = system.GlobalScheduler(); const auto& thread_list = scheduler.GetThreadList(); @@ -198,7 +210,7 @@ std::vector<SharedPtr<Thread>> AddressArbiter::GetThreadsWaitingOnAddress(VAddr // Sort them by priority, such that the highest priority ones come first. std::sort(threads.begin(), threads.end(), - [](const SharedPtr<Thread>& lhs, const SharedPtr<Thread>& rhs) { + [](const std::shared_ptr<Thread>& lhs, const std::shared_ptr<Thread>& rhs) { return lhs->GetPriority() < rhs->GetPriority(); }); diff --git a/src/core/hle/kernel/address_arbiter.h b/src/core/hle/kernel/address_arbiter.h index ed0d0e69f..608918de5 100644 --- a/src/core/hle/kernel/address_arbiter.h +++ b/src/core/hle/kernel/address_arbiter.h @@ -72,7 +72,7 @@ private: ResultCode WaitForAddressImpl(VAddr address, s64 timeout); // Gets the threads waiting on an address. - std::vector<SharedPtr<Thread>> GetThreadsWaitingOnAddress(VAddr address) const; + std::vector<std::shared_ptr<Thread>> GetThreadsWaitingOnAddress(VAddr address) const; Core::System& system; }; diff --git a/src/core/hle/kernel/client_port.cpp b/src/core/hle/kernel/client_port.cpp index 744b1697d..00bb939a0 100644 --- a/src/core/hle/kernel/client_port.cpp +++ b/src/core/hle/kernel/client_port.cpp @@ -9,38 +9,35 @@ #include "core/hle/kernel/object.h" #include "core/hle/kernel/server_port.h" #include "core/hle/kernel/server_session.h" +#include "core/hle/kernel/session.h" namespace Kernel { ClientPort::ClientPort(KernelCore& kernel) : Object{kernel} {} ClientPort::~ClientPort() = default; -SharedPtr<ServerPort> ClientPort::GetServerPort() const { +std::shared_ptr<ServerPort> ClientPort::GetServerPort() const { return server_port; } -ResultVal<SharedPtr<ClientSession>> ClientPort::Connect() { - // Note: Threads do not wait for the server endpoint to call - // AcceptSession before returning from this call. - +ResultVal<std::shared_ptr<ClientSession>> ClientPort::Connect() { if (active_sessions >= max_sessions) { return ERR_MAX_CONNECTIONS_REACHED; } active_sessions++; - // Create a new session pair, let the created sessions inherit the parent port's HLE handler. - auto [server, client] = ServerSession::CreateSessionPair(kernel, server_port->GetName(), this); + auto [client, server] = Kernel::Session::Create(kernel, name); if (server_port->HasHLEHandler()) { - server_port->GetHLEHandler()->ClientConnected(server); + server_port->GetHLEHandler()->ClientConnected(std::move(server)); } else { - server_port->AppendPendingSession(server); + server_port->AppendPendingSession(std::move(server)); } // Wake the threads waiting on the ServerPort server_port->WakeupAllWaitingThreads(); - return MakeResult(client); + return MakeResult(std::move(client)); } void ClientPort::ConnectionClosed() { diff --git a/src/core/hle/kernel/client_port.h b/src/core/hle/kernel/client_port.h index 4921ad4f0..715edd18c 100644 --- a/src/core/hle/kernel/client_port.h +++ b/src/core/hle/kernel/client_port.h @@ -17,6 +17,9 @@ class ServerPort; class ClientPort final : public Object { public: + explicit ClientPort(KernelCore& kernel); + ~ClientPort() override; + friend class ServerPort; std::string GetTypeName() const override { return "ClientPort"; @@ -30,7 +33,7 @@ public: return HANDLE_TYPE; } - SharedPtr<ServerPort> GetServerPort() const; + std::shared_ptr<ServerPort> GetServerPort() const; /** * Creates a new Session pair, adds the created ServerSession to the associated ServerPort's @@ -38,7 +41,7 @@ public: * waiting on it to awake. * @returns ClientSession The client endpoint of the created Session pair, or error code. */ - ResultVal<SharedPtr<ClientSession>> Connect(); + ResultVal<std::shared_ptr<ClientSession>> Connect(); /** * Signifies that a previously active connection has been closed, @@ -47,10 +50,7 @@ public: void ConnectionClosed(); private: - explicit ClientPort(KernelCore& kernel); - ~ClientPort() override; - - SharedPtr<ServerPort> server_port; ///< ServerPort associated with this client port. + std::shared_ptr<ServerPort> server_port; ///< ServerPort associated with this client port. u32 max_sessions = 0; ///< Maximum number of simultaneous sessions the port can have u32 active_sessions = 0; ///< Number of currently open sessions to this port std::string name; ///< Name of client port (optional) diff --git a/src/core/hle/kernel/client_session.cpp b/src/core/hle/kernel/client_session.cpp index c17baa50a..4669a14ad 100644 --- a/src/core/hle/kernel/client_session.cpp +++ b/src/core/hle/kernel/client_session.cpp @@ -1,4 +1,4 @@ -// Copyright 2016 Citra Emulator Project +// Copyright 2019 yuzu emulator team // Licensed under GPLv2 or any later version // Refer to the license.txt file included. @@ -12,29 +12,44 @@ namespace Kernel { -ClientSession::ClientSession(KernelCore& kernel) : Object{kernel} {} +ClientSession::ClientSession(KernelCore& kernel) : WaitObject{kernel} {} + ClientSession::~ClientSession() { // This destructor will be called automatically when the last ClientSession handle is closed by // the emulated application. - - // A local reference to the ServerSession is necessary to guarantee it - // will be kept alive until after ClientDisconnected() returns. - SharedPtr<ServerSession> server = parent->server; - if (server) { - server->ClientDisconnected(); + if (parent->Server()) { + parent->Server()->ClientDisconnected(); } +} + +bool ClientSession::ShouldWait(const Thread* thread) const { + UNIMPLEMENTED(); + return {}; +} - parent->client = nullptr; +void ClientSession::Acquire(Thread* thread) { + UNIMPLEMENTED(); } -ResultCode ClientSession::SendSyncRequest(SharedPtr<Thread> thread) { +ResultVal<std::shared_ptr<ClientSession>> ClientSession::Create(KernelCore& kernel, + std::shared_ptr<Session> parent, + std::string name) { + std::shared_ptr<ClientSession> client_session{std::make_shared<ClientSession>(kernel)}; + + client_session->name = std::move(name); + client_session->parent = std::move(parent); + + return MakeResult(std::move(client_session)); +} + +ResultCode ClientSession::SendSyncRequest(std::shared_ptr<Thread> thread, Memory::Memory& memory) { // Keep ServerSession alive until we're done working with it. - SharedPtr<ServerSession> server = parent->server; - if (server == nullptr) + if (!parent->Server()) { return ERR_SESSION_CLOSED_BY_REMOTE; + } // Signal the server session that new data is available - return server->HandleSyncRequest(std::move(thread)); + return parent->Server()->HandleSyncRequest(std::move(thread), memory); } } // namespace Kernel diff --git a/src/core/hle/kernel/client_session.h b/src/core/hle/kernel/client_session.h index 09cdff588..b4289a9a8 100644 --- a/src/core/hle/kernel/client_session.h +++ b/src/core/hle/kernel/client_session.h @@ -1,4 +1,4 @@ -// Copyright 2016 Citra Emulator Project +// Copyright 2019 yuzu emulator team // Licensed under GPLv2 or any later version // Refer to the license.txt file included. @@ -6,20 +6,28 @@ #include <memory> #include <string> -#include "core/hle/kernel/object.h" + +#include "core/hle/kernel/wait_object.h" +#include "core/hle/result.h" union ResultCode; +namespace Memory { +class Memory; +} + namespace Kernel { class KernelCore; class Session; -class ServerSession; class Thread; -class ClientSession final : public Object { +class ClientSession final : public WaitObject { public: - friend class ServerSession; + explicit ClientSession(KernelCore& kernel); + ~ClientSession() override; + + friend class Session; std::string GetTypeName() const override { return "ClientSession"; @@ -34,11 +42,16 @@ public: return HANDLE_TYPE; } - ResultCode SendSyncRequest(SharedPtr<Thread> thread); + ResultCode SendSyncRequest(std::shared_ptr<Thread> thread, Memory::Memory& memory); + + bool ShouldWait(const Thread* thread) const override; + + void Acquire(Thread* thread) override; private: - explicit ClientSession(KernelCore& kernel); - ~ClientSession() override; + static ResultVal<std::shared_ptr<ClientSession>> Create(KernelCore& kernel, + std::shared_ptr<Session> parent, + std::string name = "Unknown"); /// The parent session, which links to the server endpoint. std::shared_ptr<Session> parent; diff --git a/src/core/hle/kernel/handle_table.cpp b/src/core/hle/kernel/handle_table.cpp index 2cc5d536b..e441a27fc 100644 --- a/src/core/hle/kernel/handle_table.cpp +++ b/src/core/hle/kernel/handle_table.cpp @@ -44,7 +44,7 @@ ResultCode HandleTable::SetSize(s32 handle_table_size) { return RESULT_SUCCESS; } -ResultVal<Handle> HandleTable::Create(SharedPtr<Object> obj) { +ResultVal<Handle> HandleTable::Create(std::shared_ptr<Object> obj) { DEBUG_ASSERT(obj != nullptr); const u16 slot = next_free_slot; @@ -70,7 +70,7 @@ ResultVal<Handle> HandleTable::Create(SharedPtr<Object> obj) { } ResultVal<Handle> HandleTable::Duplicate(Handle handle) { - SharedPtr<Object> object = GetGeneric(handle); + std::shared_ptr<Object> object = GetGeneric(handle); if (object == nullptr) { LOG_ERROR(Kernel, "Tried to duplicate invalid handle: {:08X}", handle); return ERR_INVALID_HANDLE; @@ -99,11 +99,11 @@ bool HandleTable::IsValid(Handle handle) const { return slot < table_size && objects[slot] != nullptr && generations[slot] == generation; } -SharedPtr<Object> HandleTable::GetGeneric(Handle handle) const { +std::shared_ptr<Object> HandleTable::GetGeneric(Handle handle) const { if (handle == CurrentThread) { - return GetCurrentThread(); + return SharedFrom(GetCurrentThread()); } else if (handle == CurrentProcess) { - return Core::System::GetInstance().CurrentProcess(); + return SharedFrom(Core::System::GetInstance().CurrentProcess()); } if (!IsValid(handle)) { diff --git a/src/core/hle/kernel/handle_table.h b/src/core/hle/kernel/handle_table.h index 44901391b..9fcb4cc15 100644 --- a/src/core/hle/kernel/handle_table.h +++ b/src/core/hle/kernel/handle_table.h @@ -68,7 +68,7 @@ public: * @return The created Handle or one of the following errors: * - `ERR_HANDLE_TABLE_FULL`: the maximum number of handles has been exceeded. */ - ResultVal<Handle> Create(SharedPtr<Object> obj); + ResultVal<Handle> Create(std::shared_ptr<Object> obj); /** * Returns a new handle that points to the same object as the passed in handle. @@ -92,7 +92,7 @@ public: * Looks up a handle. * @return Pointer to the looked-up object, or `nullptr` if the handle is not valid. */ - SharedPtr<Object> GetGeneric(Handle handle) const; + std::shared_ptr<Object> GetGeneric(Handle handle) const; /** * Looks up a handle while verifying its type. @@ -100,7 +100,7 @@ public: * type differs from the requested one. */ template <class T> - SharedPtr<T> Get(Handle handle) const { + std::shared_ptr<T> Get(Handle handle) const { return DynamicObjectCast<T>(GetGeneric(handle)); } @@ -109,7 +109,7 @@ public: private: /// Stores the Object referenced by the handle or null if the slot is empty. - std::array<SharedPtr<Object>, MAX_COUNT> objects; + std::array<std::shared_ptr<Object>, MAX_COUNT> objects; /** * The value of `next_generation` when the handle was created, used to check for validity. For diff --git a/src/core/hle/kernel/hle_ipc.cpp b/src/core/hle/kernel/hle_ipc.cpp index a7b5849b0..2db28dcf0 100644 --- a/src/core/hle/kernel/hle_ipc.cpp +++ b/src/core/hle/kernel/hle_ipc.cpp @@ -32,23 +32,25 @@ SessionRequestHandler::SessionRequestHandler() = default; SessionRequestHandler::~SessionRequestHandler() = default; -void SessionRequestHandler::ClientConnected(SharedPtr<ServerSession> server_session) { +void SessionRequestHandler::ClientConnected(std::shared_ptr<ServerSession> server_session) { server_session->SetHleHandler(shared_from_this()); connected_sessions.push_back(std::move(server_session)); } -void SessionRequestHandler::ClientDisconnected(const SharedPtr<ServerSession>& server_session) { +void SessionRequestHandler::ClientDisconnected( + const std::shared_ptr<ServerSession>& server_session) { server_session->SetHleHandler(nullptr); boost::range::remove_erase(connected_sessions, server_session); } -SharedPtr<WritableEvent> HLERequestContext::SleepClientThread( +std::shared_ptr<WritableEvent> HLERequestContext::SleepClientThread( const std::string& reason, u64 timeout, WakeupCallback&& callback, - SharedPtr<WritableEvent> writable_event) { + std::shared_ptr<WritableEvent> writable_event) { // Put the client thread to sleep until the wait event is signaled or the timeout expires. - thread->SetWakeupCallback([context = *this, callback]( - ThreadWakeupReason reason, SharedPtr<Thread> thread, - SharedPtr<WaitObject> object, std::size_t index) mutable -> bool { + thread->SetWakeupCallback([context = *this, callback](ThreadWakeupReason reason, + std::shared_ptr<Thread> thread, + std::shared_ptr<WaitObject> object, + std::size_t index) mutable -> bool { ASSERT(thread->GetStatus() == ThreadStatus::WaitHLEEvent); callback(thread, context, reason); context.WriteToOutgoingCommandBuffer(*thread); @@ -72,11 +74,13 @@ SharedPtr<WritableEvent> HLERequestContext::SleepClientThread( thread->WakeAfterDelay(timeout); } + is_thread_waiting = true; + return writable_event; } -HLERequestContext::HLERequestContext(SharedPtr<Kernel::ServerSession> server_session, - SharedPtr<Thread> thread) +HLERequestContext::HLERequestContext(std::shared_ptr<Kernel::ServerSession> server_session, + std::shared_ptr<Thread> thread) : server_session(std::move(server_session)), thread(std::move(thread)) { cmd_buf[0] = 0; } @@ -212,10 +216,11 @@ ResultCode HLERequestContext::PopulateFromIncomingCommandBuffer(const HandleTabl ResultCode HLERequestContext::WriteToOutgoingCommandBuffer(Thread& thread) { auto& owner_process = *thread.GetOwnerProcess(); auto& handle_table = owner_process.GetHandleTable(); + auto& memory = Core::System::GetInstance().Memory(); std::array<u32, IPC::COMMAND_BUFFER_LENGTH> dst_cmdbuf; - Memory::ReadBlock(owner_process, thread.GetTLSAddress(), dst_cmdbuf.data(), - dst_cmdbuf.size() * sizeof(u32)); + memory.ReadBlock(owner_process, thread.GetTLSAddress(), dst_cmdbuf.data(), + dst_cmdbuf.size() * sizeof(u32)); // The header was already built in the internal command buffer. Attempt to parse it to verify // the integrity and then copy it over to the target command buffer. @@ -271,8 +276,8 @@ ResultCode HLERequestContext::WriteToOutgoingCommandBuffer(Thread& thread) { } // Copy the translated command buffer back into the thread's command buffer area. - Memory::WriteBlock(owner_process, thread.GetTLSAddress(), dst_cmdbuf.data(), - dst_cmdbuf.size() * sizeof(u32)); + memory.WriteBlock(owner_process, thread.GetTLSAddress(), dst_cmdbuf.data(), + dst_cmdbuf.size() * sizeof(u32)); return RESULT_SUCCESS; } @@ -280,15 +285,14 @@ ResultCode HLERequestContext::WriteToOutgoingCommandBuffer(Thread& thread) { std::vector<u8> HLERequestContext::ReadBuffer(int buffer_index) const { std::vector<u8> buffer; const bool is_buffer_a{BufferDescriptorA().size() && BufferDescriptorA()[buffer_index].Size()}; + auto& memory = Core::System::GetInstance().Memory(); if (is_buffer_a) { buffer.resize(BufferDescriptorA()[buffer_index].Size()); - Memory::ReadBlock(BufferDescriptorA()[buffer_index].Address(), buffer.data(), - buffer.size()); + memory.ReadBlock(BufferDescriptorA()[buffer_index].Address(), buffer.data(), buffer.size()); } else { buffer.resize(BufferDescriptorX()[buffer_index].Size()); - Memory::ReadBlock(BufferDescriptorX()[buffer_index].Address(), buffer.data(), - buffer.size()); + memory.ReadBlock(BufferDescriptorX()[buffer_index].Address(), buffer.data(), buffer.size()); } return buffer; @@ -309,10 +313,11 @@ std::size_t HLERequestContext::WriteBuffer(const void* buffer, std::size_t size, size = buffer_size; // TODO(bunnei): This needs to be HW tested } + auto& memory = Core::System::GetInstance().Memory(); if (is_buffer_b) { - Memory::WriteBlock(BufferDescriptorB()[buffer_index].Address(), buffer, size); + memory.WriteBlock(BufferDescriptorB()[buffer_index].Address(), buffer, size); } else { - Memory::WriteBlock(BufferDescriptorC()[buffer_index].Address(), buffer, size); + memory.WriteBlock(BufferDescriptorC()[buffer_index].Address(), buffer, size); } return size; diff --git a/src/core/hle/kernel/hle_ipc.h b/src/core/hle/kernel/hle_ipc.h index ccf5e56aa..050ad8fd7 100644 --- a/src/core/hle/kernel/hle_ipc.h +++ b/src/core/hle/kernel/hle_ipc.h @@ -5,6 +5,7 @@ #pragma once #include <array> +#include <functional> #include <memory> #include <optional> #include <string> @@ -60,20 +61,20 @@ public: * associated ServerSession alive for the duration of the connection. * @param server_session Owning pointer to the ServerSession associated with the connection. */ - void ClientConnected(SharedPtr<ServerSession> server_session); + void ClientConnected(std::shared_ptr<ServerSession> server_session); /** * Signals that a client has just disconnected from this HLE handler and releases the * associated ServerSession. * @param server_session ServerSession associated with the connection. */ - void ClientDisconnected(const SharedPtr<ServerSession>& server_session); + void ClientDisconnected(const std::shared_ptr<ServerSession>& server_session); protected: /// List of sessions that are connected to this handler. /// A ServerSession whose server endpoint is an HLE implementation is kept alive by this list /// for the duration of the connection. - std::vector<SharedPtr<ServerSession>> connected_sessions; + std::vector<std::shared_ptr<ServerSession>> connected_sessions; }; /** @@ -97,7 +98,8 @@ protected: */ class HLERequestContext { public: - explicit HLERequestContext(SharedPtr<ServerSession> session, SharedPtr<Thread> thread); + explicit HLERequestContext(std::shared_ptr<ServerSession> session, + std::shared_ptr<Thread> thread); ~HLERequestContext(); /// Returns a pointer to the IPC command buffer for this request. @@ -109,12 +111,12 @@ public: * Returns the session through which this request was made. This can be used as a map key to * access per-client data on services. */ - const SharedPtr<Kernel::ServerSession>& Session() const { + const std::shared_ptr<Kernel::ServerSession>& Session() const { return server_session; } - using WakeupCallback = std::function<void(SharedPtr<Thread> thread, HLERequestContext& context, - ThreadWakeupReason reason)>; + using WakeupCallback = std::function<void( + std::shared_ptr<Thread> thread, HLERequestContext& context, ThreadWakeupReason reason)>; /** * Puts the specified guest thread to sleep until the returned event is signaled or until the @@ -129,9 +131,9 @@ public: * created. * @returns Event that when signaled will resume the thread and call the callback function. */ - SharedPtr<WritableEvent> SleepClientThread(const std::string& reason, u64 timeout, - WakeupCallback&& callback, - SharedPtr<WritableEvent> writable_event = nullptr); + std::shared_ptr<WritableEvent> SleepClientThread( + const std::string& reason, u64 timeout, WakeupCallback&& callback, + std::shared_ptr<WritableEvent> writable_event = nullptr); /// Populates this context with data from the requesting process/thread. ResultCode PopulateFromIncomingCommandBuffer(const HandleTable& handle_table, @@ -209,20 +211,20 @@ public: std::size_t GetWriteBufferSize(int buffer_index = 0) const; template <typename T> - SharedPtr<T> GetCopyObject(std::size_t index) { + std::shared_ptr<T> GetCopyObject(std::size_t index) { return DynamicObjectCast<T>(copy_objects.at(index)); } template <typename T> - SharedPtr<T> GetMoveObject(std::size_t index) { + std::shared_ptr<T> GetMoveObject(std::size_t index) { return DynamicObjectCast<T>(move_objects.at(index)); } - void AddMoveObject(SharedPtr<Object> object) { + void AddMoveObject(std::shared_ptr<Object> object) { move_objects.emplace_back(std::move(object)); } - void AddCopyObject(SharedPtr<Object> object) { + void AddCopyObject(std::shared_ptr<Object> object) { copy_objects.emplace_back(std::move(object)); } @@ -262,15 +264,27 @@ public: std::string Description() const; + Thread& GetThread() { + return *thread; + } + + const Thread& GetThread() const { + return *thread; + } + + bool IsThreadWaiting() const { + return is_thread_waiting; + } + private: void ParseCommandBuffer(const HandleTable& handle_table, u32_le* src_cmdbuf, bool incoming); std::array<u32, IPC::COMMAND_BUFFER_LENGTH> cmd_buf; - SharedPtr<Kernel::ServerSession> server_session; - SharedPtr<Thread> thread; + std::shared_ptr<Kernel::ServerSession> server_session; + std::shared_ptr<Thread> thread; // TODO(yuriks): Check common usage of this and optimize size accordingly - boost::container::small_vector<SharedPtr<Object>, 8> move_objects; - boost::container::small_vector<SharedPtr<Object>, 8> copy_objects; + boost::container::small_vector<std::shared_ptr<Object>, 8> move_objects; + boost::container::small_vector<std::shared_ptr<Object>, 8> copy_objects; boost::container::small_vector<std::shared_ptr<SessionRequestHandler>, 8> domain_objects; std::optional<IPC::CommandHeader> command_header; @@ -288,6 +302,7 @@ private: u32_le command{}; std::vector<std::shared_ptr<SessionRequestHandler>> domain_request_handlers; + bool is_thread_waiting{}; }; } // namespace Kernel diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp index f94ac150d..1c90546a4 100644 --- a/src/core/hle/kernel/kernel.cpp +++ b/src/core/hle/kernel/kernel.cpp @@ -40,7 +40,7 @@ static void ThreadWakeupCallback(u64 thread_handle, [[maybe_unused]] s64 cycles_ // Lock the global kernel mutex when we enter the kernel HLE. std::lock_guard lock{HLE::g_hle_lock}; - SharedPtr<Thread> thread = + std::shared_ptr<Thread> thread = system.Kernel().RetrieveThreadFromWakeupCallbackHandleTable(proper_handle); if (thread == nullptr) { LOG_CRITICAL(Kernel, "Callback fired for invalid thread {:08X}", proper_handle); @@ -53,7 +53,7 @@ static void ThreadWakeupCallback(u64 thread_handle, [[maybe_unused]] s64 cycles_ thread->GetStatus() == ThreadStatus::WaitHLEEvent) { // Remove the thread from each of its waiting objects' waitlists for (const auto& object : thread->GetWaitObjects()) { - object->RemoveWaitingThread(thread.get()); + object->RemoveWaitingThread(thread); } thread->ClearWaitObjects(); @@ -64,8 +64,11 @@ static void ThreadWakeupCallback(u64 thread_handle, [[maybe_unused]] s64 cycles_ } else if (thread->GetStatus() == ThreadStatus::WaitMutex || thread->GetStatus() == ThreadStatus::WaitCondVar) { thread->SetMutexWaitAddress(0); - thread->SetCondVarWaitAddress(0); thread->SetWaitHandle(0); + if (thread->GetStatus() == ThreadStatus::WaitCondVar) { + thread->GetOwnerProcess()->RemoveConditionVariableThread(thread); + thread->SetCondVarWaitAddress(0); + } auto* const lock_owner = thread->GetLockOwner(); // Threads waking up by timeout from WaitProcessWideKey do not perform priority inheritance @@ -136,12 +139,12 @@ struct KernelCore::Impl { void InitializeThreads() { thread_wakeup_event_type = - system.CoreTiming().RegisterEvent("ThreadWakeupCallback", ThreadWakeupCallback); + Core::Timing::CreateEvent("ThreadWakeupCallback", ThreadWakeupCallback); } void InitializePreemption() { - preemption_event = system.CoreTiming().RegisterEvent( - "PreemptionCallback", [this](u64 userdata, s64 cycles_late) { + preemption_event = + Core::Timing::CreateEvent("PreemptionCallback", [this](u64 userdata, s64 cycles_late) { global_scheduler.PreemptThreads(); s64 time_interval = Core::Timing::msToCycles(std::chrono::milliseconds(10)); system.CoreTiming().ScheduleEvent(time_interval, preemption_event); @@ -151,20 +154,31 @@ struct KernelCore::Impl { system.CoreTiming().ScheduleEvent(time_interval, preemption_event); } + void MakeCurrentProcess(Process* process) { + current_process = process; + + if (process == nullptr) { + return; + } + + system.Memory().SetCurrentPageTable(*process); + } + std::atomic<u32> next_object_id{0}; std::atomic<u64> next_kernel_process_id{Process::InitialKIPIDMin}; std::atomic<u64> next_user_process_id{Process::ProcessIDMin}; std::atomic<u64> next_thread_id{1}; // Lists all processes that exist in the current session. - std::vector<SharedPtr<Process>> process_list; + std::vector<std::shared_ptr<Process>> process_list; Process* current_process = nullptr; Kernel::GlobalScheduler global_scheduler; - SharedPtr<ResourceLimit> system_resource_limit; + std::shared_ptr<ResourceLimit> system_resource_limit; + + std::shared_ptr<Core::Timing::EventType> thread_wakeup_event_type; + std::shared_ptr<Core::Timing::EventType> preemption_event; - Core::Timing::EventType* thread_wakeup_event_type = nullptr; - Core::Timing::EventType* preemption_event = nullptr; // TODO(yuriks): This can be removed if Thread objects are explicitly pooled in the future, // allowing us to simply use a pool index or similar. Kernel::HandleTable thread_wakeup_callback_handle_table; @@ -190,26 +204,21 @@ void KernelCore::Shutdown() { impl->Shutdown(); } -SharedPtr<ResourceLimit> KernelCore::GetSystemResourceLimit() const { +std::shared_ptr<ResourceLimit> KernelCore::GetSystemResourceLimit() const { return impl->system_resource_limit; } -SharedPtr<Thread> KernelCore::RetrieveThreadFromWakeupCallbackHandleTable(Handle handle) const { +std::shared_ptr<Thread> KernelCore::RetrieveThreadFromWakeupCallbackHandleTable( + Handle handle) const { return impl->thread_wakeup_callback_handle_table.Get<Thread>(handle); } -void KernelCore::AppendNewProcess(SharedPtr<Process> process) { +void KernelCore::AppendNewProcess(std::shared_ptr<Process> process) { impl->process_list.push_back(std::move(process)); } void KernelCore::MakeCurrentProcess(Process* process) { - impl->current_process = process; - - if (process == nullptr) { - return; - } - - Memory::SetCurrentPageTable(*process); + impl->MakeCurrentProcess(process); } Process* KernelCore::CurrentProcess() { @@ -220,7 +229,7 @@ const Process* KernelCore::CurrentProcess() const { return impl->current_process; } -const std::vector<SharedPtr<Process>>& KernelCore::GetProcessList() const { +const std::vector<std::shared_ptr<Process>>& KernelCore::GetProcessList() const { return impl->process_list; } @@ -232,7 +241,7 @@ const Kernel::GlobalScheduler& KernelCore::GlobalScheduler() const { return impl->global_scheduler; } -void KernelCore::AddNamedPort(std::string name, SharedPtr<ClientPort> port) { +void KernelCore::AddNamedPort(std::string name, std::shared_ptr<ClientPort> port) { impl->named_ports.emplace(std::move(name), std::move(port)); } @@ -265,7 +274,7 @@ u64 KernelCore::CreateNewUserProcessID() { return impl->next_user_process_id++; } -Core::Timing::EventType* KernelCore::ThreadWakeupCallbackEventType() const { +const std::shared_ptr<Core::Timing::EventType>& KernelCore::ThreadWakeupCallbackEventType() const { return impl->thread_wakeup_event_type; } diff --git a/src/core/hle/kernel/kernel.h b/src/core/hle/kernel/kernel.h index c4397fc77..babb531c6 100644 --- a/src/core/hle/kernel/kernel.h +++ b/src/core/hle/kernel/kernel.h @@ -6,6 +6,7 @@ #include <string> #include <unordered_map> +#include <vector> #include "core/hle/kernel/object.h" namespace Core { @@ -30,7 +31,7 @@ class Thread; /// Represents a single instance of the kernel. class KernelCore { private: - using NamedPortTable = std::unordered_map<std::string, SharedPtr<ClientPort>>; + using NamedPortTable = std::unordered_map<std::string, std::shared_ptr<ClientPort>>; public: /// Constructs an instance of the kernel using the given System @@ -56,13 +57,13 @@ public: void Shutdown(); /// Retrieves a shared pointer to the system resource limit instance. - SharedPtr<ResourceLimit> GetSystemResourceLimit() const; + std::shared_ptr<ResourceLimit> GetSystemResourceLimit() const; /// Retrieves a shared pointer to a Thread instance within the thread wakeup handle table. - SharedPtr<Thread> RetrieveThreadFromWakeupCallbackHandleTable(Handle handle) const; + std::shared_ptr<Thread> RetrieveThreadFromWakeupCallbackHandleTable(Handle handle) const; /// Adds the given shared pointer to an internal list of active processes. - void AppendNewProcess(SharedPtr<Process> process); + void AppendNewProcess(std::shared_ptr<Process> process); /// Makes the given process the new current process. void MakeCurrentProcess(Process* process); @@ -74,7 +75,7 @@ public: const Process* CurrentProcess() const; /// Retrieves the list of processes. - const std::vector<SharedPtr<Process>>& GetProcessList() const; + const std::vector<std::shared_ptr<Process>>& GetProcessList() const; /// Gets the sole instance of the global scheduler Kernel::GlobalScheduler& GlobalScheduler(); @@ -83,7 +84,7 @@ public: const Kernel::GlobalScheduler& GlobalScheduler() const; /// Adds a port to the named port table - void AddNamedPort(std::string name, SharedPtr<ClientPort> port); + void AddNamedPort(std::string name, std::shared_ptr<ClientPort> port); /// Finds a port within the named port table with the given name. NamedPortTable::iterator FindNamedPort(const std::string& name); @@ -112,7 +113,7 @@ private: u64 CreateNewThreadID(); /// Retrieves the event type used for thread wakeup callbacks. - Core::Timing::EventType* ThreadWakeupCallbackEventType() const; + const std::shared_ptr<Core::Timing::EventType>& ThreadWakeupCallbackEventType() const; /// Provides a reference to the thread wakeup callback handle table. Kernel::HandleTable& ThreadWakeupCallbackHandleTable(); diff --git a/src/core/hle/kernel/mutex.cpp b/src/core/hle/kernel/mutex.cpp index 663d0f4b6..061e9bcb0 100644 --- a/src/core/hle/kernel/mutex.cpp +++ b/src/core/hle/kernel/mutex.cpp @@ -22,10 +22,10 @@ namespace Kernel { /// Returns the number of threads that are waiting for a mutex, and the highest priority one among /// those. -static std::pair<SharedPtr<Thread>, u32> GetHighestPriorityMutexWaitingThread( - const SharedPtr<Thread>& current_thread, VAddr mutex_addr) { +static std::pair<std::shared_ptr<Thread>, u32> GetHighestPriorityMutexWaitingThread( + const std::shared_ptr<Thread>& current_thread, VAddr mutex_addr) { - SharedPtr<Thread> highest_priority_thread; + std::shared_ptr<Thread> highest_priority_thread; u32 num_waiters = 0; for (const auto& thread : current_thread->GetMutexWaitingThreads()) { @@ -45,14 +45,14 @@ static std::pair<SharedPtr<Thread>, u32> GetHighestPriorityMutexWaitingThread( } /// Update the mutex owner field of all threads waiting on the mutex to point to the new owner. -static void TransferMutexOwnership(VAddr mutex_addr, SharedPtr<Thread> current_thread, - SharedPtr<Thread> new_owner) { +static void TransferMutexOwnership(VAddr mutex_addr, std::shared_ptr<Thread> current_thread, + std::shared_ptr<Thread> new_owner) { const auto threads = current_thread->GetMutexWaitingThreads(); for (const auto& thread : threads) { if (thread->GetMutexWaitAddress() != mutex_addr) continue; - ASSERT(thread->GetLockOwner() == current_thread); + ASSERT(thread->GetLockOwner() == current_thread.get()); current_thread->RemoveMutexWaiter(thread); if (new_owner != thread) new_owner->AddMutexWaiter(thread); @@ -70,15 +70,16 @@ ResultCode Mutex::TryAcquire(VAddr address, Handle holding_thread_handle, } const auto& handle_table = system.Kernel().CurrentProcess()->GetHandleTable(); - Thread* const current_thread = system.CurrentScheduler().GetCurrentThread(); - SharedPtr<Thread> holding_thread = handle_table.Get<Thread>(holding_thread_handle); - SharedPtr<Thread> requesting_thread = handle_table.Get<Thread>(requesting_thread_handle); + std::shared_ptr<Thread> current_thread = + SharedFrom(system.CurrentScheduler().GetCurrentThread()); + std::shared_ptr<Thread> holding_thread = handle_table.Get<Thread>(holding_thread_handle); + std::shared_ptr<Thread> requesting_thread = handle_table.Get<Thread>(requesting_thread_handle); // TODO(Subv): It is currently unknown if it is possible to lock a mutex in behalf of another // thread. ASSERT(requesting_thread == current_thread); - const u32 addr_value = Memory::Read32(address); + const u32 addr_value = system.Memory().Read32(address); // If the mutex isn't being held, just return success. if (addr_value != (holding_thread_handle | Mutex::MutexHasWaitersFlag)) { @@ -110,12 +111,13 @@ ResultCode Mutex::Release(VAddr address) { return ERR_INVALID_ADDRESS; } - auto* const current_thread = system.CurrentScheduler().GetCurrentThread(); + std::shared_ptr<Thread> current_thread = + SharedFrom(system.CurrentScheduler().GetCurrentThread()); auto [thread, num_waiters] = GetHighestPriorityMutexWaitingThread(current_thread, address); // There are no more threads waiting for the mutex, release it completely. if (thread == nullptr) { - Memory::Write32(address, 0); + system.Memory().Write32(address, 0); return RESULT_SUCCESS; } @@ -130,7 +132,7 @@ ResultCode Mutex::Release(VAddr address) { } // Grant the mutex to the next waiting thread and resume it. - Memory::Write32(address, mutex_value); + system.Memory().Write32(address, mutex_value); ASSERT(thread->GetStatus() == ThreadStatus::WaitMutex); thread->ResumeFromWait(); diff --git a/src/core/hle/kernel/object.cpp b/src/core/hle/kernel/object.cpp index 10431e94c..2c571792b 100644 --- a/src/core/hle/kernel/object.cpp +++ b/src/core/hle/kernel/object.cpp @@ -27,6 +27,7 @@ bool Object::IsWaitable() const { case HandleType::ResourceLimit: case HandleType::ClientPort: case HandleType::ClientSession: + case HandleType::Session: return false; } diff --git a/src/core/hle/kernel/object.h b/src/core/hle/kernel/object.h index a6faeb83b..e3391e2af 100644 --- a/src/core/hle/kernel/object.h +++ b/src/core/hle/kernel/object.h @@ -5,10 +5,9 @@ #pragma once #include <atomic> +#include <memory> #include <string> -#include <boost/smart_ptr/intrusive_ptr.hpp> - #include "common/common_types.h" namespace Kernel { @@ -30,9 +29,10 @@ enum class HandleType : u32 { ServerPort, ClientSession, ServerSession, + Session, }; -class Object : NonCopyable { +class Object : NonCopyable, public std::enable_shared_from_this<Object> { public: explicit Object(KernelCore& kernel); virtual ~Object(); @@ -61,35 +61,24 @@ protected: KernelCore& kernel; private: - friend void intrusive_ptr_add_ref(Object*); - friend void intrusive_ptr_release(Object*); - - std::atomic<u32> ref_count{0}; std::atomic<u32> object_id{0}; }; -// Special functions used by boost::instrusive_ptr to do automatic ref-counting -inline void intrusive_ptr_add_ref(Object* object) { - object->ref_count.fetch_add(1, std::memory_order_relaxed); -} - -inline void intrusive_ptr_release(Object* object) { - if (object->ref_count.fetch_sub(1, std::memory_order_acq_rel) == 1) { - delete object; - } -} - template <typename T> -using SharedPtr = boost::intrusive_ptr<T>; +std::shared_ptr<T> SharedFrom(T* raw) { + if (raw == nullptr) + return nullptr; + return std::static_pointer_cast<T>(raw->shared_from_this()); +} /** * Attempts to downcast the given Object pointer to a pointer to T. * @return Derived pointer to the object, or `nullptr` if `object` isn't of type T. */ template <typename T> -inline SharedPtr<T> DynamicObjectCast(SharedPtr<Object> object) { +inline std::shared_ptr<T> DynamicObjectCast(std::shared_ptr<Object> object) { if (object != nullptr && object->GetHandleType() == T::HANDLE_TYPE) { - return boost::static_pointer_cast<T>(object); + return std::static_pointer_cast<T>(object); } return nullptr; } diff --git a/src/core/hle/kernel/process.cpp b/src/core/hle/kernel/process.cpp index 12a900bcc..12ea4ebe3 100644 --- a/src/core/hle/kernel/process.cpp +++ b/src/core/hle/kernel/process.cpp @@ -38,7 +38,7 @@ void SetupMainThread(Process& owner_process, KernelCore& kernel, u32 priority) { auto thread_res = Thread::Create(kernel, "main", entry_point, priority, 0, owner_process.GetIdealCore(), stack_top, owner_process); - SharedPtr<Thread> thread = std::move(thread_res).Unwrap(); + std::shared_ptr<Thread> thread = std::move(thread_res).Unwrap(); // Register 1 must be a handle to the main thread const Handle thread_handle = owner_process.GetHandleTable().Create(thread).Unwrap(); @@ -100,10 +100,10 @@ private: std::bitset<num_slot_entries> is_slot_used; }; -SharedPtr<Process> Process::Create(Core::System& system, std::string name, ProcessType type) { +std::shared_ptr<Process> Process::Create(Core::System& system, std::string name, ProcessType type) { auto& kernel = system.Kernel(); - SharedPtr<Process> process(new Process(system)); + std::shared_ptr<Process> process = std::make_shared<Process>(system); process->name = std::move(name); process->resource_limit = kernel.GetSystemResourceLimit(); process->status = ProcessStatus::Created; @@ -121,7 +121,7 @@ SharedPtr<Process> Process::Create(Core::System& system, std::string name, Proce return process; } -SharedPtr<ResourceLimit> Process::GetResourceLimit() const { +std::shared_ptr<ResourceLimit> Process::GetResourceLimit() const { return resource_limit; } @@ -142,6 +142,49 @@ u64 Process::GetTotalPhysicalMemoryUsedWithoutSystemResource() const { return GetTotalPhysicalMemoryUsed() - GetSystemResourceUsage(); } +void Process::InsertConditionVariableThread(std::shared_ptr<Thread> thread) { + VAddr cond_var_addr = thread->GetCondVarWaitAddress(); + std::list<std::shared_ptr<Thread>>& thread_list = cond_var_threads[cond_var_addr]; + auto it = thread_list.begin(); + while (it != thread_list.end()) { + const std::shared_ptr<Thread> current_thread = *it; + if (current_thread->GetPriority() > thread->GetPriority()) { + thread_list.insert(it, thread); + return; + } + ++it; + } + thread_list.push_back(thread); +} + +void Process::RemoveConditionVariableThread(std::shared_ptr<Thread> thread) { + VAddr cond_var_addr = thread->GetCondVarWaitAddress(); + std::list<std::shared_ptr<Thread>>& thread_list = cond_var_threads[cond_var_addr]; + auto it = thread_list.begin(); + while (it != thread_list.end()) { + const std::shared_ptr<Thread> current_thread = *it; + if (current_thread.get() == thread.get()) { + thread_list.erase(it); + return; + } + ++it; + } + UNREACHABLE(); +} + +std::vector<std::shared_ptr<Thread>> Process::GetConditionVariableThreads( + const VAddr cond_var_addr) { + std::vector<std::shared_ptr<Thread>> result{}; + std::list<std::shared_ptr<Thread>>& thread_list = cond_var_threads[cond_var_addr]; + auto it = thread_list.begin(); + while (it != thread_list.end()) { + std::shared_ptr<Thread> current_thread = *it; + result.push_back(current_thread); + ++it; + } + return result; +} + void Process::RegisterThread(const Thread* thread) { thread_list.push_back(thread); } @@ -197,12 +240,12 @@ void Process::Run(s32 main_thread_priority, u64 stack_size) { void Process::PrepareForTermination() { ChangeStatus(ProcessStatus::Exiting); - const auto stop_threads = [this](const std::vector<SharedPtr<Thread>>& thread_list) { + const auto stop_threads = [this](const std::vector<std::shared_ptr<Thread>>& thread_list) { for (auto& thread : thread_list) { if (thread->GetOwnerProcess() != this) continue; - if (thread == system.CurrentScheduler().GetCurrentThread()) + if (thread.get() == system.CurrentScheduler().GetCurrentThread()) continue; // TODO(Subv): When are the other running/ready threads terminated? diff --git a/src/core/hle/kernel/process.h b/src/core/hle/kernel/process.h index c2df451f3..3483fa19d 100644 --- a/src/core/hle/kernel/process.h +++ b/src/core/hle/kernel/process.h @@ -8,6 +8,7 @@ #include <cstddef> #include <list> #include <string> +#include <unordered_map> #include <vector> #include "common/common_types.h" #include "core/hle/kernel/address_arbiter.h" @@ -61,6 +62,9 @@ enum class ProcessStatus { class Process final : public WaitObject { public: + explicit Process(Core::System& system); + ~Process() override; + enum : u64 { /// Lowest allowed process ID for a kernel initial process. InitialKIPIDMin = 1, @@ -81,7 +85,8 @@ public: static constexpr std::size_t RANDOM_ENTROPY_SIZE = 4; - static SharedPtr<Process> Create(Core::System& system, std::string name, ProcessType type); + static std::shared_ptr<Process> Create(Core::System& system, std::string name, + ProcessType type); std::string GetTypeName() const override { return "Process"; @@ -156,7 +161,7 @@ public: } /// Gets the resource limit descriptor for this process - SharedPtr<ResourceLimit> GetResourceLimit() const; + std::shared_ptr<ResourceLimit> GetResourceLimit() const; /// Gets the ideal CPU core ID for this process u8 GetIdealCore() const { @@ -232,6 +237,15 @@ public: return thread_list; } + /// Insert a thread into the condition variable wait container + void InsertConditionVariableThread(std::shared_ptr<Thread> thread); + + /// Remove a thread from the condition variable wait container + void RemoveConditionVariableThread(std::shared_ptr<Thread> thread); + + /// Obtain all condition variable threads waiting for some address + std::vector<std::shared_ptr<Thread>> GetConditionVariableThreads(VAddr cond_var_addr); + /// Registers a thread as being created under this process, /// adding it to this process' thread list. void RegisterThread(const Thread* thread); @@ -287,9 +301,6 @@ public: void FreeTLSRegion(VAddr tls_address); private: - explicit Process(Core::System& system); - ~Process() override; - /// Checks if the specified thread should wait until this process is available. bool ShouldWait(const Thread* thread) const override; @@ -328,7 +339,7 @@ private: u32 system_resource_size = 0; /// Resource limit descriptor for this process - SharedPtr<ResourceLimit> resource_limit; + std::shared_ptr<ResourceLimit> resource_limit; /// The ideal CPU core for this process, threads are scheduled on this core by default. u8 ideal_core = 0; @@ -375,6 +386,9 @@ private: /// List of threads that are running with this process as their owner. std::list<const Thread*> thread_list; + /// List of threads waiting for a condition variable + std::unordered_map<VAddr, std::list<std::shared_ptr<Thread>>> cond_var_threads; + /// System context Core::System& system; diff --git a/src/core/hle/kernel/resource_limit.cpp b/src/core/hle/kernel/resource_limit.cpp index 173f69915..b53423462 100644 --- a/src/core/hle/kernel/resource_limit.cpp +++ b/src/core/hle/kernel/resource_limit.cpp @@ -16,8 +16,8 @@ constexpr std::size_t ResourceTypeToIndex(ResourceType type) { ResourceLimit::ResourceLimit(KernelCore& kernel) : Object{kernel} {} ResourceLimit::~ResourceLimit() = default; -SharedPtr<ResourceLimit> ResourceLimit::Create(KernelCore& kernel) { - return new ResourceLimit(kernel); +std::shared_ptr<ResourceLimit> ResourceLimit::Create(KernelCore& kernel) { + return std::make_shared<ResourceLimit>(kernel); } s64 ResourceLimit::GetCurrentResourceValue(ResourceType resource) const { diff --git a/src/core/hle/kernel/resource_limit.h b/src/core/hle/kernel/resource_limit.h index 2613a6bb5..b5534620d 100644 --- a/src/core/hle/kernel/resource_limit.h +++ b/src/core/hle/kernel/resource_limit.h @@ -31,8 +31,11 @@ constexpr bool IsValidResourceType(ResourceType type) { class ResourceLimit final : public Object { public: + explicit ResourceLimit(KernelCore& kernel); + ~ResourceLimit() override; + /// Creates a resource limit object. - static SharedPtr<ResourceLimit> Create(KernelCore& kernel); + static std::shared_ptr<ResourceLimit> Create(KernelCore& kernel); std::string GetTypeName() const override { return "ResourceLimit"; @@ -76,9 +79,6 @@ public: ResultCode SetLimitValue(ResourceType resource, s64 value); private: - explicit ResourceLimit(KernelCore& kernel); - ~ResourceLimit() override; - // TODO(Subv): Increment resource limit current values in their respective Kernel::T::Create // functions // diff --git a/src/core/hle/kernel/scheduler.cpp b/src/core/hle/kernel/scheduler.cpp index 0e2dbf13e..3f5192087 100644 --- a/src/core/hle/kernel/scheduler.cpp +++ b/src/core/hle/kernel/scheduler.cpp @@ -26,27 +26,27 @@ GlobalScheduler::GlobalScheduler(Core::System& system) : system{system} {} GlobalScheduler::~GlobalScheduler() = default; -void GlobalScheduler::AddThread(SharedPtr<Thread> thread) { +void GlobalScheduler::AddThread(std::shared_ptr<Thread> thread) { thread_list.push_back(std::move(thread)); } -void GlobalScheduler::RemoveThread(const Thread* thread) { +void GlobalScheduler::RemoveThread(std::shared_ptr<Thread> thread) { thread_list.erase(std::remove(thread_list.begin(), thread_list.end(), thread), thread_list.end()); } -void GlobalScheduler::UnloadThread(s32 core) { +void GlobalScheduler::UnloadThread(std::size_t core) { Scheduler& sched = system.Scheduler(core); sched.UnloadThread(); } -void GlobalScheduler::SelectThread(u32 core) { +void GlobalScheduler::SelectThread(std::size_t core) { const auto update_thread = [](Thread* thread, Scheduler& sched) { - if (thread != sched.selected_thread) { + if (thread != sched.selected_thread.get()) { if (thread == nullptr) { ++sched.idle_selection_count; } - sched.selected_thread = thread; + sched.selected_thread = SharedFrom(thread); } sched.is_context_switch_pending = sched.selected_thread != sched.current_thread; std::atomic_thread_fence(std::memory_order_seq_cst); @@ -77,9 +77,9 @@ void GlobalScheduler::SelectThread(u32 core) { // if we got a suggested thread, select it, else do a second pass. if (winner && winner->GetPriority() > 2) { if (winner->IsRunning()) { - UnloadThread(winner->GetProcessorID()); + UnloadThread(static_cast<u32>(winner->GetProcessorID())); } - TransferToCore(winner->GetPriority(), core, winner); + TransferToCore(winner->GetPriority(), static_cast<s32>(core), winner); update_thread(winner, sched); return; } @@ -91,9 +91,9 @@ void GlobalScheduler::SelectThread(u32 core) { Thread* thread_on_core = scheduled_queue[src_core].front(); Thread* to_change = *it; if (thread_on_core->IsRunning() || to_change->IsRunning()) { - UnloadThread(src_core); + UnloadThread(static_cast<u32>(src_core)); } - TransferToCore(thread_on_core->GetPriority(), core, thread_on_core); + TransferToCore(thread_on_core->GetPriority(), static_cast<s32>(core), thread_on_core); current_thread = thread_on_core; break; } @@ -154,9 +154,9 @@ bool GlobalScheduler::YieldThreadAndBalanceLoad(Thread* yielding_thread) { if (winner != nullptr) { if (winner != yielding_thread) { if (winner->IsRunning()) { - UnloadThread(winner->GetProcessorID()); + UnloadThread(static_cast<u32>(winner->GetProcessorID())); } - TransferToCore(winner->GetPriority(), core_id, winner); + TransferToCore(winner->GetPriority(), s32(core_id), winner); } } else { winner = next_thread; @@ -196,9 +196,9 @@ bool GlobalScheduler::YieldThreadAndWaitForLoadBalancing(Thread* yielding_thread if (winner != nullptr) { if (winner != yielding_thread) { if (winner->IsRunning()) { - UnloadThread(winner->GetProcessorID()); + UnloadThread(static_cast<u32>(winner->GetProcessorID())); } - TransferToCore(winner->GetPriority(), core_id, winner); + TransferToCore(winner->GetPriority(), static_cast<s32>(core_id), winner); } } else { winner = yielding_thread; @@ -248,7 +248,7 @@ void GlobalScheduler::PreemptThreads() { if (winner != nullptr) { if (winner->IsRunning()) { - UnloadThread(winner->GetProcessorID()); + UnloadThread(static_cast<u32>(winner->GetProcessorID())); } TransferToCore(winner->GetPriority(), s32(core_id), winner); current_thread = @@ -281,7 +281,7 @@ void GlobalScheduler::PreemptThreads() { if (winner != nullptr) { if (winner->IsRunning()) { - UnloadThread(winner->GetProcessorID()); + UnloadThread(static_cast<u32>(winner->GetProcessorID())); } TransferToCore(winner->GetPriority(), s32(core_id), winner); current_thread = winner; @@ -292,30 +292,30 @@ void GlobalScheduler::PreemptThreads() { } } -void GlobalScheduler::Suggest(u32 priority, u32 core, Thread* thread) { +void GlobalScheduler::Suggest(u32 priority, std::size_t core, Thread* thread) { suggested_queue[core].add(thread, priority); } -void GlobalScheduler::Unsuggest(u32 priority, u32 core, Thread* thread) { +void GlobalScheduler::Unsuggest(u32 priority, std::size_t core, Thread* thread) { suggested_queue[core].remove(thread, priority); } -void GlobalScheduler::Schedule(u32 priority, u32 core, Thread* thread) { +void GlobalScheduler::Schedule(u32 priority, std::size_t core, Thread* thread) { ASSERT_MSG(thread->GetProcessorID() == s32(core), "Thread must be assigned to this core."); scheduled_queue[core].add(thread, priority); } -void GlobalScheduler::SchedulePrepend(u32 priority, u32 core, Thread* thread) { +void GlobalScheduler::SchedulePrepend(u32 priority, std::size_t core, Thread* thread) { ASSERT_MSG(thread->GetProcessorID() == s32(core), "Thread must be assigned to this core."); scheduled_queue[core].add(thread, priority, false); } -void GlobalScheduler::Reschedule(u32 priority, u32 core, Thread* thread) { +void GlobalScheduler::Reschedule(u32 priority, std::size_t core, Thread* thread) { scheduled_queue[core].remove(thread, priority); scheduled_queue[core].add(thread, priority); } -void GlobalScheduler::Unschedule(u32 priority, u32 core, Thread* thread) { +void GlobalScheduler::Unschedule(u32 priority, std::size_t core, Thread* thread) { scheduled_queue[core].remove(thread, priority); } @@ -327,14 +327,14 @@ void GlobalScheduler::TransferToCore(u32 priority, s32 destination_core, Thread* } thread->SetProcessorID(destination_core); if (source_core >= 0) { - Unschedule(priority, source_core, thread); + Unschedule(priority, static_cast<u32>(source_core), thread); } if (destination_core >= 0) { - Unsuggest(priority, destination_core, thread); - Schedule(priority, destination_core, thread); + Unsuggest(priority, static_cast<u32>(destination_core), thread); + Schedule(priority, static_cast<u32>(destination_core), thread); } if (source_core >= 0) { - Suggest(priority, source_core, thread); + Suggest(priority, static_cast<u32>(source_core), thread); } } @@ -357,7 +357,7 @@ void GlobalScheduler::Shutdown() { thread_list.clear(); } -Scheduler::Scheduler(Core::System& system, Core::ARM_Interface& cpu_core, u32 core_id) +Scheduler::Scheduler(Core::System& system, Core::ARM_Interface& cpu_core, std::size_t core_id) : system(system), cpu_core(cpu_core), core_id(core_id) {} Scheduler::~Scheduler() = default; @@ -446,7 +446,7 @@ void Scheduler::SwitchContext() { // Cancel any outstanding wakeup events for this thread new_thread->CancelWakeupTimer(); - current_thread = new_thread; + current_thread = SharedFrom(new_thread); new_thread->SetStatus(ThreadStatus::Running); new_thread->SetIsRunning(true); diff --git a/src/core/hle/kernel/scheduler.h b/src/core/hle/kernel/scheduler.h index f2d6311b8..3c5c21346 100644 --- a/src/core/hle/kernel/scheduler.h +++ b/src/core/hle/kernel/scheduler.h @@ -28,13 +28,13 @@ public: ~GlobalScheduler(); /// Adds a new thread to the scheduler - void AddThread(SharedPtr<Thread> thread); + void AddThread(std::shared_ptr<Thread> thread); /// Removes a thread from the scheduler - void RemoveThread(const Thread* thread); + void RemoveThread(std::shared_ptr<Thread> thread); /// Returns a list of all threads managed by the scheduler - const std::vector<SharedPtr<Thread>>& GetThreadList() const { + const std::vector<std::shared_ptr<Thread>>& GetThreadList() const { return thread_list; } @@ -42,41 +42,34 @@ public: * Add a thread to the suggested queue of a cpu core. Suggested threads may be * picked if no thread is scheduled to run on the core. */ - void Suggest(u32 priority, u32 core, Thread* thread); + void Suggest(u32 priority, std::size_t core, Thread* thread); /** * Remove a thread to the suggested queue of a cpu core. Suggested threads may be * picked if no thread is scheduled to run on the core. */ - void Unsuggest(u32 priority, u32 core, Thread* thread); + void Unsuggest(u32 priority, std::size_t core, Thread* thread); /** * Add a thread to the scheduling queue of a cpu core. The thread is added at the * back the queue in its priority level. */ - void Schedule(u32 priority, u32 core, Thread* thread); + void Schedule(u32 priority, std::size_t core, Thread* thread); /** * Add a thread to the scheduling queue of a cpu core. The thread is added at the * front the queue in its priority level. */ - void SchedulePrepend(u32 priority, u32 core, Thread* thread); + void SchedulePrepend(u32 priority, std::size_t core, Thread* thread); /// Reschedule an already scheduled thread based on a new priority - void Reschedule(u32 priority, u32 core, Thread* thread); + void Reschedule(u32 priority, std::size_t core, Thread* thread); /// Unschedules a thread. - void Unschedule(u32 priority, u32 core, Thread* thread); - - /** - * Transfers a thread into an specific core. If the destination_core is -1 - * it will be unscheduled from its source code and added into its suggested - * queue. - */ - void TransferToCore(u32 priority, s32 destination_core, Thread* thread); + void Unschedule(u32 priority, std::size_t core, Thread* thread); /// Selects a core and forces it to unload its current thread's context - void UnloadThread(s32 core); + void UnloadThread(std::size_t core); /** * Takes care of selecting the new scheduled thread in three steps: @@ -90,9 +83,9 @@ public: * 3. Third is no suggested thread is found, we do a second pass and pick a running * thread in another core and swap it with its current thread. */ - void SelectThread(u32 core); + void SelectThread(std::size_t core); - bool HaveReadyThreads(u32 core_id) const { + bool HaveReadyThreads(std::size_t core_id) const { return !scheduled_queue[core_id].empty(); } @@ -145,6 +138,13 @@ public: void Shutdown(); private: + /** + * Transfers a thread into an specific core. If the destination_core is -1 + * it will be unscheduled from its source code and added into its suggested + * queue. + */ + void TransferToCore(u32 priority, s32 destination_core, Thread* thread); + bool AskForReselectionOrMarkRedundant(Thread* current_thread, const Thread* winner); static constexpr u32 min_regular_priority = 2; @@ -157,13 +157,13 @@ private: std::array<u32, NUM_CPU_CORES> preemption_priorities = {59, 59, 59, 62}; /// Lists all thread ids that aren't deleted/etc. - std::vector<SharedPtr<Thread>> thread_list; + std::vector<std::shared_ptr<Thread>> thread_list; Core::System& system; }; class Scheduler final { public: - explicit Scheduler(Core::System& system, Core::ARM_Interface& cpu_core, u32 core_id); + explicit Scheduler(Core::System& system, Core::ARM_Interface& cpu_core, std::size_t core_id); ~Scheduler(); /// Returns whether there are any threads that are ready to run. @@ -213,14 +213,14 @@ private: */ void UpdateLastContextSwitchTime(Thread* thread, Process* process); - SharedPtr<Thread> current_thread = nullptr; - SharedPtr<Thread> selected_thread = nullptr; + std::shared_ptr<Thread> current_thread = nullptr; + std::shared_ptr<Thread> selected_thread = nullptr; Core::System& system; Core::ARM_Interface& cpu_core; u64 last_context_switch_time = 0; u64 idle_selection_count = 0; - const u32 core_id; + const std::size_t core_id; bool is_context_switch_pending = false; }; diff --git a/src/core/hle/kernel/server_port.cpp b/src/core/hle/kernel/server_port.cpp index 02e7c60e6..a4ccfa35e 100644 --- a/src/core/hle/kernel/server_port.cpp +++ b/src/core/hle/kernel/server_port.cpp @@ -16,7 +16,7 @@ namespace Kernel { ServerPort::ServerPort(KernelCore& kernel) : WaitObject{kernel} {} ServerPort::~ServerPort() = default; -ResultVal<SharedPtr<ServerSession>> ServerPort::Accept() { +ResultVal<std::shared_ptr<ServerSession>> ServerPort::Accept() { if (pending_sessions.empty()) { return ERR_NOT_FOUND; } @@ -26,7 +26,7 @@ ResultVal<SharedPtr<ServerSession>> ServerPort::Accept() { return MakeResult(std::move(session)); } -void ServerPort::AppendPendingSession(SharedPtr<ServerSession> pending_session) { +void ServerPort::AppendPendingSession(std::shared_ptr<ServerSession> pending_session) { pending_sessions.push_back(std::move(pending_session)); } @@ -41,8 +41,8 @@ void ServerPort::Acquire(Thread* thread) { ServerPort::PortPair ServerPort::CreatePortPair(KernelCore& kernel, u32 max_sessions, std::string name) { - SharedPtr<ServerPort> server_port(new ServerPort(kernel)); - SharedPtr<ClientPort> client_port(new ClientPort(kernel)); + std::shared_ptr<ServerPort> server_port = std::make_shared<ServerPort>(kernel); + std::shared_ptr<ClientPort> client_port = std::make_shared<ClientPort>(kernel); server_port->name = name + "_Server"; client_port->name = name + "_Client"; diff --git a/src/core/hle/kernel/server_port.h b/src/core/hle/kernel/server_port.h index dc88a1ebd..8be8a75ea 100644 --- a/src/core/hle/kernel/server_port.h +++ b/src/core/hle/kernel/server_port.h @@ -22,8 +22,11 @@ class SessionRequestHandler; class ServerPort final : public WaitObject { public: + explicit ServerPort(KernelCore& kernel); + ~ServerPort() override; + using HLEHandler = std::shared_ptr<SessionRequestHandler>; - using PortPair = std::pair<SharedPtr<ServerPort>, SharedPtr<ClientPort>>; + using PortPair = std::pair<std::shared_ptr<ServerPort>, std::shared_ptr<ClientPort>>; /** * Creates a pair of ServerPort and an associated ClientPort. @@ -52,7 +55,7 @@ public: * Accepts a pending incoming connection on this port. If there are no pending sessions, will * return ERR_NO_PENDING_SESSIONS. */ - ResultVal<SharedPtr<ServerSession>> Accept(); + ResultVal<std::shared_ptr<ServerSession>> Accept(); /// Whether or not this server port has an HLE handler available. bool HasHLEHandler() const { @@ -74,17 +77,14 @@ public: /// Appends a ServerSession to the collection of ServerSessions /// waiting to be accepted by this port. - void AppendPendingSession(SharedPtr<ServerSession> pending_session); + void AppendPendingSession(std::shared_ptr<ServerSession> pending_session); bool ShouldWait(const Thread* thread) const override; void Acquire(Thread* thread) override; private: - explicit ServerPort(KernelCore& kernel); - ~ServerPort() override; - /// ServerSessions waiting to be accepted by the port - std::vector<SharedPtr<ServerSession>> pending_sessions; + std::vector<std::shared_ptr<ServerSession>> pending_sessions; /// This session's HLE request handler template (optional) /// ServerSessions created from this port inherit a reference to this handler. diff --git a/src/core/hle/kernel/server_session.cpp b/src/core/hle/kernel/server_session.cpp index 30b2bfb5a..7825e1ec4 100644 --- a/src/core/hle/kernel/server_session.cpp +++ b/src/core/hle/kernel/server_session.cpp @@ -1,4 +1,4 @@ -// Copyright 2016 Citra Emulator Project +// Copyright 2019 yuzu emulator team // Licensed under GPLv2 or any later version // Refer to the license.txt file included. @@ -9,6 +9,7 @@ #include "common/common_types.h" #include "common/logging/log.h" #include "core/core.h" +#include "core/core_timing.h" #include "core/hle/ipc_helpers.h" #include "core/hle/kernel/client_port.h" #include "core/hle/kernel/client_session.h" @@ -19,35 +20,32 @@ #include "core/hle/kernel/server_session.h" #include "core/hle/kernel/session.h" #include "core/hle/kernel/thread.h" +#include "core/memory.h" namespace Kernel { ServerSession::ServerSession(KernelCore& kernel) : WaitObject{kernel} {} -ServerSession::~ServerSession() { - // This destructor will be called automatically when the last ServerSession handle is closed by - // the emulated application. +ServerSession::~ServerSession() = default; - // Decrease the port's connection count. - if (parent->port) { - parent->port->ConnectionClosed(); - } - - parent->server = nullptr; -} +ResultVal<std::shared_ptr<ServerSession>> ServerSession::Create(KernelCore& kernel, + std::shared_ptr<Session> parent, + std::string name) { + std::shared_ptr<ServerSession> session{std::make_shared<ServerSession>(kernel)}; -ResultVal<SharedPtr<ServerSession>> ServerSession::Create(KernelCore& kernel, std::string name) { - SharedPtr<ServerSession> server_session(new ServerSession(kernel)); + session->request_event = Core::Timing::CreateEvent( + name, [session](u64 userdata, s64 cycles_late) { session->CompleteSyncRequest(); }); + session->name = std::move(name); + session->parent = std::move(parent); - server_session->name = std::move(name); - server_session->parent = nullptr; - - return MakeResult(std::move(server_session)); + return MakeResult(std::move(session)); } bool ServerSession::ShouldWait(const Thread* thread) const { // Closed sessions should never wait, an error will be returned from svcReplyAndReceive. - if (parent->client == nullptr) + if (!parent->Client()) { return false; + } + // Wait if we have no pending requests, or if we're currently handling a request. return pending_requesting_threads.empty() || currently_handling != nullptr; } @@ -69,7 +67,7 @@ void ServerSession::ClientDisconnected() { if (handler) { // Note that after this returns, this server session's hle_handler is // invalidated (set to null). - handler->ClientDisconnected(this); + handler->ClientDisconnected(SharedFrom(this)); } // Clean up the list of client threads with pending requests, they are unneeded now that the @@ -126,13 +124,21 @@ ResultCode ServerSession::HandleDomainSyncRequest(Kernel::HLERequestContext& con return RESULT_SUCCESS; } -ResultCode ServerSession::HandleSyncRequest(SharedPtr<Thread> thread) { - // The ServerSession received a sync request, this means that there's new data available - // from its ClientSession, so wake up any threads that may be waiting on a svcReplyAndReceive or - // similar. - Kernel::HLERequestContext context(this, thread); - u32* cmd_buf = (u32*)Memory::GetPointer(thread->GetTLSAddress()); - context.PopulateFromIncomingCommandBuffer(kernel.CurrentProcess()->GetHandleTable(), cmd_buf); +ResultCode ServerSession::QueueSyncRequest(std::shared_ptr<Thread> thread, Memory::Memory& memory) { + u32* cmd_buf{reinterpret_cast<u32*>(memory.GetPointer(thread->GetTLSAddress()))}; + std::shared_ptr<Kernel::HLERequestContext> context{ + std::make_shared<Kernel::HLERequestContext>(SharedFrom(this), std::move(thread))}; + + context->PopulateFromIncomingCommandBuffer(kernel.CurrentProcess()->GetHandleTable(), cmd_buf); + request_queue.Push(std::move(context)); + + return RESULT_SUCCESS; +} + +ResultCode ServerSession::CompleteSyncRequest() { + ASSERT(!request_queue.Empty()); + + auto& context = *request_queue.Front(); ResultCode result = RESULT_SUCCESS; // If the session has been converted to a domain, handle the domain request @@ -144,61 +150,27 @@ ResultCode ServerSession::HandleSyncRequest(SharedPtr<Thread> thread) { result = hle_handler->HandleSyncRequest(context); } - if (thread->GetStatus() == ThreadStatus::Running) { - // Put the thread to sleep until the server replies, it will be awoken in - // svcReplyAndReceive for LLE servers. - thread->SetStatus(ThreadStatus::WaitIPC); - - if (hle_handler != nullptr) { - // For HLE services, we put the request threads to sleep for a short duration to - // simulate IPC overhead, but only if the HLE handler didn't put the thread to sleep for - // other reasons like an async callback. The IPC overhead is needed to prevent - // starvation when a thread only does sync requests to HLE services while a - // lower-priority thread is waiting to run. - - // This delay was approximated in a homebrew application by measuring the average time - // it takes for svcSendSyncRequest to return when performing the SetLcdForceBlack IPC - // request to the GSP:GPU service in a n3DS with firmware 11.6. The measured values have - // a high variance and vary between models. - static constexpr u64 IPCDelayNanoseconds = 39000; - thread->WakeAfterDelay(IPCDelayNanoseconds); - } else { - // Add the thread to the list of threads that have issued a sync request with this - // server. - pending_requesting_threads.push_back(std::move(thread)); - } - } - - // If this ServerSession does not have an HLE implementation, just wake up the threads waiting - // on it. - WakeupAllWaitingThreads(); - - // Handle scenario when ConvertToDomain command was issued, as we must do the conversion at the - // end of the command such that only commands following this one are handled as domains if (convert_to_domain) { ASSERT_MSG(IsSession(), "ServerSession is already a domain instance."); domain_request_handlers = {hle_handler}; convert_to_domain = false; } - return result; -} - -ServerSession::SessionPair ServerSession::CreateSessionPair(KernelCore& kernel, - const std::string& name, - SharedPtr<ClientPort> port) { - auto server_session = ServerSession::Create(kernel, name + "_Server").Unwrap(); - SharedPtr<ClientSession> client_session(new ClientSession(kernel)); - client_session->name = name + "_Client"; + // Some service requests require the thread to block + if (!context.IsThreadWaiting()) { + context.GetThread().ResumeFromWait(); + context.GetThread().SetWaitSynchronizationResult(result); + } - std::shared_ptr<Session> parent(new Session); - parent->client = client_session.get(); - parent->server = server_session.get(); - parent->port = std::move(port); + request_queue.Pop(); - client_session->parent = parent; - server_session->parent = parent; + return result; +} - return std::make_pair(std::move(server_session), std::move(client_session)); +ResultCode ServerSession::HandleSyncRequest(std::shared_ptr<Thread> thread, + Memory::Memory& memory) { + Core::System::GetInstance().CoreTiming().ScheduleEvent(20000, request_event, {}); + return QueueSyncRequest(std::move(thread), memory); } + } // namespace Kernel diff --git a/src/core/hle/kernel/server_session.h b/src/core/hle/kernel/server_session.h index 738df30f8..d6e48109e 100644 --- a/src/core/hle/kernel/server_session.h +++ b/src/core/hle/kernel/server_session.h @@ -1,4 +1,4 @@ -// Copyright 2014 Citra Emulator Project +// Copyright 2019 yuzu emulator team // Licensed under GPLv2 or any later version // Refer to the license.txt file included. @@ -9,17 +9,22 @@ #include <utility> #include <vector> -#include "core/hle/kernel/object.h" +#include "common/threadsafe_queue.h" #include "core/hle/kernel/wait_object.h" #include "core/hle/result.h" +namespace Memory { +class Memory; +} + +namespace Core::Timing { +struct EventType; +} + namespace Kernel { -class ClientPort; -class ClientSession; class HLERequestContext; class KernelCore; -class ServerSession; class Session; class SessionRequestHandler; class Thread; @@ -38,6 +43,15 @@ class Thread; */ class ServerSession final : public WaitObject { public: + explicit ServerSession(KernelCore& kernel); + ~ServerSession() override; + + friend class Session; + + static ResultVal<std::shared_ptr<ServerSession>> Create(KernelCore& kernel, + std::shared_ptr<Session> parent, + std::string name = "Unknown"); + std::string GetTypeName() const override { return "ServerSession"; } @@ -59,18 +73,6 @@ public: return parent.get(); } - using SessionPair = std::pair<SharedPtr<ServerSession>, SharedPtr<ClientSession>>; - - /** - * Creates a pair of ServerSession and an associated ClientSession. - * @param kernel The kernal instance to create the session pair under. - * @param name Optional name of the ports. - * @param client_port Optional The ClientPort that spawned this session. - * @return The created session tuple - */ - static SessionPair CreateSessionPair(KernelCore& kernel, const std::string& name = "Unknown", - SharedPtr<ClientPort> client_port = nullptr); - /** * Sets the HLE handler for the session. This handler will be called to service IPC requests * instead of the regular IPC machinery. (The regular IPC machinery is currently not @@ -82,10 +84,13 @@ public: /** * Handle a sync request from the emulated application. + * * @param thread Thread that initiated the request. + * @param memory Memory context to handle the sync request under. + * * @returns ResultCode from the operation. */ - ResultCode HandleSyncRequest(SharedPtr<Thread> thread); + ResultCode HandleSyncRequest(std::shared_ptr<Thread> thread, Memory::Memory& memory); bool ShouldWait(const Thread* thread) const override; @@ -118,18 +123,11 @@ public: } private: - explicit ServerSession(KernelCore& kernel); - ~ServerSession() override; + /// Queues a sync request from the emulated application. + ResultCode QueueSyncRequest(std::shared_ptr<Thread> thread, Memory::Memory& memory); - /** - * Creates a server session. The server session can have an optional HLE handler, - * which will be invoked to handle the IPC requests that this session receives. - * @param kernel The kernel instance to create this server session under. - * @param name Optional name of the server session. - * @return The created server session - */ - static ResultVal<SharedPtr<ServerSession>> Create(KernelCore& kernel, - std::string name = "Unknown"); + /// Completes a sync request from the emulated application. + ResultCode CompleteSyncRequest(); /// Handles a SyncRequest to a domain, forwarding the request to the proper object or closing an /// object handle. @@ -147,18 +145,24 @@ private: /// List of threads that are pending a response after a sync request. This list is processed in /// a LIFO manner, thus, the last request will be dispatched first. /// TODO(Subv): Verify if this is indeed processed in LIFO using a hardware test. - std::vector<SharedPtr<Thread>> pending_requesting_threads; + std::vector<std::shared_ptr<Thread>> pending_requesting_threads; /// Thread whose request is currently being handled. A request is considered "handled" when a /// response is sent via svcReplyAndReceive. /// TODO(Subv): Find a better name for this. - SharedPtr<Thread> currently_handling; + std::shared_ptr<Thread> currently_handling; /// When set to True, converts the session to a domain at the end of the command bool convert_to_domain{}; /// The name of this session (optional) std::string name; + + /// Core timing event used to schedule the service request at some point in the future + std::shared_ptr<Core::Timing::EventType> request_event; + + /// Queue of scheduled service requests + Common::MPSCQueue<std::shared_ptr<Kernel::HLERequestContext>> request_queue; }; } // namespace Kernel diff --git a/src/core/hle/kernel/session.cpp b/src/core/hle/kernel/session.cpp index 642914744..dee6e2b72 100644 --- a/src/core/hle/kernel/session.cpp +++ b/src/core/hle/kernel/session.cpp @@ -1,12 +1,36 @@ -// Copyright 2015 Citra Emulator Project +// Copyright 2019 yuzu emulator team // Licensed under GPLv2 or any later version // Refer to the license.txt file included. +#include "common/assert.h" +#include "core/hle/kernel/client_session.h" +#include "core/hle/kernel/server_session.h" #include "core/hle/kernel/session.h" -#include "core/hle/kernel/thread.h" namespace Kernel { -Session::Session() {} -Session::~Session() {} +Session::Session(KernelCore& kernel) : WaitObject{kernel} {} +Session::~Session() = default; + +Session::SessionPair Session::Create(KernelCore& kernel, std::string name) { + auto session{std::make_shared<Session>(kernel)}; + auto client_session{Kernel::ClientSession::Create(kernel, session, name + "_Client").Unwrap()}; + auto server_session{Kernel::ServerSession::Create(kernel, session, name + "_Server").Unwrap()}; + + session->name = std::move(name); + session->client = client_session; + session->server = server_session; + + return std::make_pair(std::move(client_session), std::move(server_session)); +} + +bool Session::ShouldWait(const Thread* thread) const { + UNIMPLEMENTED(); + return {}; +} + +void Session::Acquire(Thread* thread) { + UNIMPLEMENTED(); +} + } // namespace Kernel diff --git a/src/core/hle/kernel/session.h b/src/core/hle/kernel/session.h index 7a551f5e4..5a9d4e9ad 100644 --- a/src/core/hle/kernel/session.h +++ b/src/core/hle/kernel/session.h @@ -1,27 +1,64 @@ -// Copyright 2018 yuzu emulator team +// Copyright 2019 yuzu emulator team // Licensed under GPLv2 or any later version // Refer to the license.txt file included. #pragma once -#include "core/hle/kernel/object.h" +#include <memory> +#include <string> + +#include "core/hle/kernel/wait_object.h" +#include "core/hle/result.h" namespace Kernel { class ClientSession; -class ClientPort; class ServerSession; /** * Parent structure to link the client and server endpoints of a session with their associated - * client port. The client port need not exist, as is the case for portless sessions like the - * FS File and Directory sessions. When one of the endpoints of a session is destroyed, its - * corresponding field in this structure will be set to nullptr. + * client port. */ -class Session final { +class Session final : public WaitObject { public: - ClientSession* client = nullptr; ///< The client endpoint of the session. - ServerSession* server = nullptr; ///< The server endpoint of the session. - SharedPtr<ClientPort> port; ///< The port that this session is associated with (optional). + explicit Session(KernelCore& kernel); + ~Session() override; + + using SessionPair = std::pair<std::shared_ptr<ClientSession>, std::shared_ptr<ServerSession>>; + + static SessionPair Create(KernelCore& kernel, std::string name = "Unknown"); + + std::string GetName() const override { + return name; + } + + static constexpr HandleType HANDLE_TYPE = HandleType::Session; + HandleType GetHandleType() const override { + return HANDLE_TYPE; + } + + bool ShouldWait(const Thread* thread) const override; + + void Acquire(Thread* thread) override; + + std::shared_ptr<ClientSession> Client() { + if (auto result{client.lock()}) { + return result; + } + return {}; + } + + std::shared_ptr<ServerSession> Server() { + if (auto result{server.lock()}) { + return result; + } + return {}; + } + +private: + std::string name; + std::weak_ptr<ClientSession> client; + std::weak_ptr<ServerSession> server; }; + } // namespace Kernel diff --git a/src/core/hle/kernel/shared_memory.cpp b/src/core/hle/kernel/shared_memory.cpp index a815c4eea..afb2e3fc2 100644 --- a/src/core/hle/kernel/shared_memory.cpp +++ b/src/core/hle/kernel/shared_memory.cpp @@ -15,11 +15,12 @@ namespace Kernel { SharedMemory::SharedMemory(KernelCore& kernel) : Object{kernel} {} SharedMemory::~SharedMemory() = default; -SharedPtr<SharedMemory> SharedMemory::Create(KernelCore& kernel, Process* owner_process, u64 size, - MemoryPermission permissions, - MemoryPermission other_permissions, VAddr address, - MemoryRegion region, std::string name) { - SharedPtr<SharedMemory> shared_memory(new SharedMemory(kernel)); +std::shared_ptr<SharedMemory> SharedMemory::Create(KernelCore& kernel, Process* owner_process, + u64 size, MemoryPermission permissions, + MemoryPermission other_permissions, + VAddr address, MemoryRegion region, + std::string name) { + std::shared_ptr<SharedMemory> shared_memory = std::make_shared<SharedMemory>(kernel); shared_memory->owner_process = owner_process; shared_memory->name = std::move(name); @@ -58,10 +59,10 @@ SharedPtr<SharedMemory> SharedMemory::Create(KernelCore& kernel, Process* owner_ return shared_memory; } -SharedPtr<SharedMemory> SharedMemory::CreateForApplet( +std::shared_ptr<SharedMemory> SharedMemory::CreateForApplet( KernelCore& kernel, std::shared_ptr<Kernel::PhysicalMemory> heap_block, std::size_t offset, u64 size, MemoryPermission permissions, MemoryPermission other_permissions, std::string name) { - SharedPtr<SharedMemory> shared_memory(new SharedMemory(kernel)); + std::shared_ptr<SharedMemory> shared_memory = std::make_shared<SharedMemory>(kernel); shared_memory->owner_process = nullptr; shared_memory->name = std::move(name); diff --git a/src/core/hle/kernel/shared_memory.h b/src/core/hle/kernel/shared_memory.h index 01ca6dcd2..18400a5ad 100644 --- a/src/core/hle/kernel/shared_memory.h +++ b/src/core/hle/kernel/shared_memory.h @@ -33,6 +33,9 @@ enum class MemoryPermission : u32 { class SharedMemory final : public Object { public: + explicit SharedMemory(KernelCore& kernel); + ~SharedMemory() override; + /** * Creates a shared memory object. * @param kernel The kernel instance to create a shared memory instance under. @@ -46,11 +49,12 @@ public: * linear heap. * @param name Optional object name, used for debugging purposes. */ - static SharedPtr<SharedMemory> Create(KernelCore& kernel, Process* owner_process, u64 size, - MemoryPermission permissions, - MemoryPermission other_permissions, VAddr address = 0, - MemoryRegion region = MemoryRegion::BASE, - std::string name = "Unknown"); + static std::shared_ptr<SharedMemory> Create(KernelCore& kernel, Process* owner_process, + u64 size, MemoryPermission permissions, + MemoryPermission other_permissions, + VAddr address = 0, + MemoryRegion region = MemoryRegion::BASE, + std::string name = "Unknown"); /** * Creates a shared memory object from a block of memory managed by an HLE applet. @@ -63,7 +67,7 @@ public: * block. * @param name Optional object name, used for debugging purposes. */ - static SharedPtr<SharedMemory> CreateForApplet( + static std::shared_ptr<SharedMemory> CreateForApplet( KernelCore& kernel, std::shared_ptr<Kernel::PhysicalMemory> heap_block, std::size_t offset, u64 size, MemoryPermission permissions, MemoryPermission other_permissions, std::string name = "Unknown Applet"); @@ -130,9 +134,6 @@ public: const u8* GetPointer(std::size_t offset = 0) const; private: - explicit SharedMemory(KernelCore& kernel); - ~SharedMemory() override; - /// Backing memory for this shared memory block. std::shared_ptr<PhysicalMemory> backing_block; /// Offset into the backing block for this shared memory. diff --git a/src/core/hle/kernel/svc.cpp b/src/core/hle/kernel/svc.cpp index c63a9ba8b..bd25de478 100644 --- a/src/core/hle/kernel/svc.cpp +++ b/src/core/hle/kernel/svc.cpp @@ -17,6 +17,7 @@ #include "core/core.h" #include "core/core_cpu.h" #include "core/core_timing.h" +#include "core/core_timing_util.h" #include "core/hle/kernel/address_arbiter.h" #include "core/hle/kernel/client_port.h" #include "core/hle/kernel/client_session.h" @@ -331,7 +332,9 @@ static ResultCode UnmapMemory(Core::System& system, VAddr dst_addr, VAddr src_ad /// Connect to an OS service given the port name, returns the handle to the port to out static ResultCode ConnectToNamedPort(Core::System& system, Handle* out_handle, VAddr port_name_address) { - if (!Memory::IsValidVirtualAddress(port_name_address)) { + auto& memory = system.Memory(); + + if (!memory.IsValidVirtualAddress(port_name_address)) { LOG_ERROR(Kernel_SVC, "Port Name Address is not a valid virtual address, port_name_address=0x{:016X}", port_name_address); @@ -340,7 +343,7 @@ static ResultCode ConnectToNamedPort(Core::System& system, Handle* out_handle, static constexpr std::size_t PortNameMaxLength = 11; // Read 1 char beyond the max allowed port name to detect names that are too long. - std::string port_name = Memory::ReadCString(port_name_address, PortNameMaxLength + 1); + const std::string port_name = memory.ReadCString(port_name_address, PortNameMaxLength + 1); if (port_name.size() > PortNameMaxLength) { LOG_ERROR(Kernel_SVC, "Port name is too long, expected {} but got {}", PortNameMaxLength, port_name.size()); @@ -358,7 +361,7 @@ static ResultCode ConnectToNamedPort(Core::System& system, Handle* out_handle, auto client_port = it->second; - SharedPtr<ClientSession> client_session; + std::shared_ptr<ClientSession> client_session; CASCADE_RESULT(client_session, client_port->Connect()); // Return the client session @@ -370,7 +373,7 @@ static ResultCode ConnectToNamedPort(Core::System& system, Handle* out_handle, /// Makes a blocking IPC call to an OS service. static ResultCode SendSyncRequest(Core::System& system, Handle handle) { const auto& handle_table = system.Kernel().CurrentProcess()->GetHandleTable(); - SharedPtr<ClientSession> session = handle_table.Get<ClientSession>(handle); + std::shared_ptr<ClientSession> session = handle_table.Get<ClientSession>(handle); if (!session) { LOG_ERROR(Kernel_SVC, "called with invalid handle=0x{:08X}", handle); return ERR_INVALID_HANDLE; @@ -378,11 +381,12 @@ static ResultCode SendSyncRequest(Core::System& system, Handle handle) { LOG_TRACE(Kernel_SVC, "called handle=0x{:08X}({})", handle, session->GetName()); - system.PrepareReschedule(); + auto thread = system.CurrentScheduler().GetCurrentThread(); + thread->InvalidateWakeupCallback(); + thread->SetStatus(ThreadStatus::WaitIPC); + system.PrepareReschedule(thread->GetProcessorID()); - // TODO(Subv): svcSendSyncRequest should put the caller thread to sleep while the server - // responds and cause a reschedule. - return session->SendSyncRequest(system.CurrentScheduler().GetCurrentThread()); + return session->SendSyncRequest(SharedFrom(thread), system.Memory()); } /// Get the ID for the specified thread. @@ -390,7 +394,7 @@ static ResultCode GetThreadId(Core::System& system, u64* thread_id, Handle threa LOG_TRACE(Kernel_SVC, "called thread=0x{:08X}", thread_handle); const auto& handle_table = system.Kernel().CurrentProcess()->GetHandleTable(); - const SharedPtr<Thread> thread = handle_table.Get<Thread>(thread_handle); + const std::shared_ptr<Thread> thread = handle_table.Get<Thread>(thread_handle); if (!thread) { LOG_ERROR(Kernel_SVC, "Thread handle does not exist, handle=0x{:08X}", thread_handle); return ERR_INVALID_HANDLE; @@ -405,13 +409,13 @@ static ResultCode GetProcessId(Core::System& system, u64* process_id, Handle han LOG_DEBUG(Kernel_SVC, "called handle=0x{:08X}", handle); const auto& handle_table = system.Kernel().CurrentProcess()->GetHandleTable(); - const SharedPtr<Process> process = handle_table.Get<Process>(handle); + const std::shared_ptr<Process> process = handle_table.Get<Process>(handle); if (process) { *process_id = process->GetProcessID(); return RESULT_SUCCESS; } - const SharedPtr<Thread> thread = handle_table.Get<Thread>(handle); + const std::shared_ptr<Thread> thread = handle_table.Get<Thread>(handle); if (thread) { const Process* const owner_process = thread->GetOwnerProcess(); if (!owner_process) { @@ -430,8 +434,8 @@ static ResultCode GetProcessId(Core::System& system, u64* process_id, Handle han } /// Default thread wakeup callback for WaitSynchronization -static bool DefaultThreadWakeupCallback(ThreadWakeupReason reason, SharedPtr<Thread> thread, - SharedPtr<WaitObject> object, std::size_t index) { +static bool DefaultThreadWakeupCallback(ThreadWakeupReason reason, std::shared_ptr<Thread> thread, + std::shared_ptr<WaitObject> object, std::size_t index) { ASSERT(thread->GetStatus() == ThreadStatus::WaitSynch); if (reason == ThreadWakeupReason::Timeout) { @@ -451,7 +455,8 @@ static ResultCode WaitSynchronization(Core::System& system, Handle* index, VAddr LOG_TRACE(Kernel_SVC, "called handles_address=0x{:X}, handle_count={}, nano_seconds={}", handles_address, handle_count, nano_seconds); - if (!Memory::IsValidVirtualAddress(handles_address)) { + auto& memory = system.Memory(); + if (!memory.IsValidVirtualAddress(handles_address)) { LOG_ERROR(Kernel_SVC, "Handle address is not a valid virtual address, handle_address=0x{:016X}", handles_address); @@ -473,7 +478,7 @@ static ResultCode WaitSynchronization(Core::System& system, Handle* index, VAddr const auto& handle_table = system.Kernel().CurrentProcess()->GetHandleTable(); for (u64 i = 0; i < handle_count; ++i) { - const Handle handle = Memory::Read32(handles_address + i * sizeof(Handle)); + const Handle handle = memory.Read32(handles_address + i * sizeof(Handle)); const auto object = handle_table.Get<WaitObject>(handle); if (object == nullptr) { @@ -505,8 +510,13 @@ static ResultCode WaitSynchronization(Core::System& system, Handle* index, VAddr return RESULT_TIMEOUT; } + if (thread->IsSyncCancelled()) { + thread->SetSyncCancelled(false); + return ERR_SYNCHRONIZATION_CANCELED; + } + for (auto& object : objects) { - object->AddWaitingThread(thread); + object->AddWaitingThread(SharedFrom(thread)); } thread->SetWaitObjects(std::move(objects)); @@ -526,7 +536,7 @@ static ResultCode CancelSynchronization(Core::System& system, Handle thread_hand LOG_TRACE(Kernel_SVC, "called thread=0x{:X}", thread_handle); const auto& handle_table = system.Kernel().CurrentProcess()->GetHandleTable(); - SharedPtr<Thread> thread = handle_table.Get<Thread>(thread_handle); + std::shared_ptr<Thread> thread = handle_table.Get<Thread>(thread_handle); if (!thread) { LOG_ERROR(Kernel_SVC, "Thread handle does not exist, thread_handle=0x{:08X}", thread_handle); @@ -610,13 +620,15 @@ static void Break(Core::System& system, u32 reason, u64 info1, u64 info2) { return; } + auto& memory = system.Memory(); + // This typically is an error code so we're going to assume this is the case if (sz == sizeof(u32)) { - LOG_CRITICAL(Debug_Emulated, "debug_buffer_err_code={:X}", Memory::Read32(addr)); + LOG_CRITICAL(Debug_Emulated, "debug_buffer_err_code={:X}", memory.Read32(addr)); } else { // We don't know what's in here so we'll hexdump it debug_buffer.resize(sz); - Memory::ReadBlock(addr, debug_buffer.data(), sz); + memory.ReadBlock(addr, debug_buffer.data(), sz); std::string hexdump; for (std::size_t i = 0; i < debug_buffer.size(); i++) { hexdump += fmt::format("{:02X} ", debug_buffer[i]); @@ -706,7 +718,7 @@ static void OutputDebugString([[maybe_unused]] Core::System& system, VAddr addre } std::string str(len, '\0'); - Memory::ReadBlock(address, str.data(), str.size()); + system.Memory().ReadBlock(address, str.data(), str.size()); LOG_DEBUG(Debug_Emulated, "{}", str); } @@ -935,7 +947,7 @@ static ResultCode GetInfo(Core::System& system, u64* result, u64 info_id, u64 ha const auto& core_timing = system.CoreTiming(); const auto& scheduler = system.CurrentScheduler(); const auto* const current_thread = scheduler.GetCurrentThread(); - const bool same_thread = current_thread == thread; + const bool same_thread = current_thread == thread.get(); const u64 prev_ctx_ticks = scheduler.GetLastContextSwitchTicks(); u64 out_ticks = 0; @@ -1045,7 +1057,7 @@ static ResultCode SetThreadActivity(Core::System& system, Handle handle, u32 act } const auto* current_process = system.Kernel().CurrentProcess(); - const SharedPtr<Thread> thread = current_process->GetHandleTable().Get<Thread>(handle); + const std::shared_ptr<Thread> thread = current_process->GetHandleTable().Get<Thread>(handle); if (!thread) { LOG_ERROR(Kernel_SVC, "Thread handle does not exist, handle=0x{:08X}", handle); return ERR_INVALID_HANDLE; @@ -1061,7 +1073,7 @@ static ResultCode SetThreadActivity(Core::System& system, Handle handle, u32 act return ERR_INVALID_HANDLE; } - if (thread == system.CurrentScheduler().GetCurrentThread()) { + if (thread.get() == system.CurrentScheduler().GetCurrentThread()) { LOG_ERROR(Kernel_SVC, "The thread handle specified is the current running thread"); return ERR_BUSY; } @@ -1077,7 +1089,7 @@ static ResultCode GetThreadContext(Core::System& system, VAddr thread_context, H LOG_DEBUG(Kernel_SVC, "called, context=0x{:08X}, thread=0x{:X}", thread_context, handle); const auto* current_process = system.Kernel().CurrentProcess(); - const SharedPtr<Thread> thread = current_process->GetHandleTable().Get<Thread>(handle); + const std::shared_ptr<Thread> thread = current_process->GetHandleTable().Get<Thread>(handle); if (!thread) { LOG_ERROR(Kernel_SVC, "Thread handle does not exist, handle=0x{:08X}", handle); return ERR_INVALID_HANDLE; @@ -1093,7 +1105,7 @@ static ResultCode GetThreadContext(Core::System& system, VAddr thread_context, H return ERR_INVALID_HANDLE; } - if (thread == system.CurrentScheduler().GetCurrentThread()) { + if (thread.get() == system.CurrentScheduler().GetCurrentThread()) { LOG_ERROR(Kernel_SVC, "The thread handle specified is the current running thread"); return ERR_BUSY; } @@ -1109,7 +1121,7 @@ static ResultCode GetThreadContext(Core::System& system, VAddr thread_context, H std::fill(ctx.vector_registers.begin() + 16, ctx.vector_registers.end(), u128{}); } - Memory::WriteBlock(thread_context, &ctx, sizeof(ctx)); + system.Memory().WriteBlock(thread_context, &ctx, sizeof(ctx)); return RESULT_SUCCESS; } @@ -1118,7 +1130,7 @@ static ResultCode GetThreadPriority(Core::System& system, u32* priority, Handle LOG_TRACE(Kernel_SVC, "called"); const auto& handle_table = system.Kernel().CurrentProcess()->GetHandleTable(); - const SharedPtr<Thread> thread = handle_table.Get<Thread>(handle); + const std::shared_ptr<Thread> thread = handle_table.Get<Thread>(handle); if (!thread) { LOG_ERROR(Kernel_SVC, "Thread handle does not exist, handle=0x{:08X}", handle); return ERR_INVALID_HANDLE; @@ -1142,7 +1154,7 @@ static ResultCode SetThreadPriority(Core::System& system, Handle handle, u32 pri const auto* const current_process = system.Kernel().CurrentProcess(); - SharedPtr<Thread> thread = current_process->GetHandleTable().Get<Thread>(handle); + std::shared_ptr<Thread> thread = current_process->GetHandleTable().Get<Thread>(handle); if (!thread) { LOG_ERROR(Kernel_SVC, "Thread handle does not exist, handle=0x{:08X}", handle); return ERR_INVALID_HANDLE; @@ -1262,27 +1274,28 @@ static ResultCode QueryProcessMemory(Core::System& system, VAddr memory_info_add VAddr address) { LOG_TRACE(Kernel_SVC, "called process=0x{:08X} address={:X}", process_handle, address); const auto& handle_table = system.Kernel().CurrentProcess()->GetHandleTable(); - SharedPtr<Process> process = handle_table.Get<Process>(process_handle); + std::shared_ptr<Process> process = handle_table.Get<Process>(process_handle); if (!process) { LOG_ERROR(Kernel_SVC, "Process handle does not exist, process_handle=0x{:08X}", process_handle); return ERR_INVALID_HANDLE; } + auto& memory = system.Memory(); const auto& vm_manager = process->VMManager(); const MemoryInfo memory_info = vm_manager.QueryMemory(address); - Memory::Write64(memory_info_address, memory_info.base_address); - Memory::Write64(memory_info_address + 8, memory_info.size); - Memory::Write32(memory_info_address + 16, memory_info.state); - Memory::Write32(memory_info_address + 20, memory_info.attributes); - Memory::Write32(memory_info_address + 24, memory_info.permission); - Memory::Write32(memory_info_address + 32, memory_info.ipc_ref_count); - Memory::Write32(memory_info_address + 28, memory_info.device_ref_count); - Memory::Write32(memory_info_address + 36, 0); + memory.Write64(memory_info_address, memory_info.base_address); + memory.Write64(memory_info_address + 8, memory_info.size); + memory.Write32(memory_info_address + 16, memory_info.state); + memory.Write32(memory_info_address + 20, memory_info.attributes); + memory.Write32(memory_info_address + 24, memory_info.permission); + memory.Write32(memory_info_address + 32, memory_info.ipc_ref_count); + memory.Write32(memory_info_address + 28, memory_info.device_ref_count); + memory.Write32(memory_info_address + 36, 0); // Page info appears to be currently unused by the kernel and is always set to zero. - Memory::Write32(page_info_address, 0); + memory.Write32(page_info_address, 0); return RESULT_SUCCESS; } @@ -1490,7 +1503,7 @@ static ResultCode CreateThread(Core::System& system, Handle* out_handle, VAddr e } auto& kernel = system.Kernel(); - CASCADE_RESULT(SharedPtr<Thread> thread, + CASCADE_RESULT(std::shared_ptr<Thread> thread, Thread::Create(kernel, "", entry_point, priority, arg, processor_id, stack_top, *current_process)); @@ -1516,7 +1529,7 @@ static ResultCode StartThread(Core::System& system, Handle thread_handle) { LOG_DEBUG(Kernel_SVC, "called thread=0x{:08X}", thread_handle); const auto& handle_table = system.Kernel().CurrentProcess()->GetHandleTable(); - const SharedPtr<Thread> thread = handle_table.Get<Thread>(thread_handle); + const std::shared_ptr<Thread> thread = handle_table.Get<Thread>(thread_handle); if (!thread) { LOG_ERROR(Kernel_SVC, "Thread handle does not exist, thread_handle=0x{:08X}", thread_handle); @@ -1540,7 +1553,7 @@ static void ExitThread(Core::System& system) { auto* const current_thread = system.CurrentScheduler().GetCurrentThread(); current_thread->Stop(); - system.GlobalScheduler().RemoveThread(current_thread); + system.GlobalScheduler().RemoveThread(SharedFrom(current_thread)); system.PrepareReschedule(); } @@ -1612,7 +1625,7 @@ static ResultCode WaitProcessWideKeyAtomic(Core::System& system, VAddr mutex_add auto* const current_process = system.Kernel().CurrentProcess(); const auto& handle_table = current_process->GetHandleTable(); - SharedPtr<Thread> thread = handle_table.Get<Thread>(thread_handle); + std::shared_ptr<Thread> thread = handle_table.Get<Thread>(thread_handle); ASSERT(thread); const auto release_result = current_process->GetMutex().Release(mutex_addr); @@ -1620,12 +1633,13 @@ static ResultCode WaitProcessWideKeyAtomic(Core::System& system, VAddr mutex_add return release_result; } - SharedPtr<Thread> current_thread = system.CurrentScheduler().GetCurrentThread(); + Thread* current_thread = system.CurrentScheduler().GetCurrentThread(); current_thread->SetCondVarWaitAddress(condition_variable_addr); current_thread->SetMutexWaitAddress(mutex_addr); current_thread->SetWaitHandle(thread_handle); current_thread->SetStatus(ThreadStatus::WaitCondVar); current_thread->InvalidateWakeupCallback(); + current_process->InsertConditionVariableThread(SharedFrom(current_thread)); current_thread->WakeAfterDelay(nano_seconds); @@ -1644,42 +1658,28 @@ static ResultCode SignalProcessWideKey(Core::System& system, VAddr condition_var ASSERT(condition_variable_addr == Common::AlignDown(condition_variable_addr, 4)); // Retrieve a list of all threads that are waiting for this condition variable. - std::vector<SharedPtr<Thread>> waiting_threads; - const auto& scheduler = system.GlobalScheduler(); - const auto& thread_list = scheduler.GetThreadList(); - - for (const auto& thread : thread_list) { - if (thread->GetCondVarWaitAddress() == condition_variable_addr) { - waiting_threads.push_back(thread); - } - } - - // Sort them by priority, such that the highest priority ones come first. - std::sort(waiting_threads.begin(), waiting_threads.end(), - [](const SharedPtr<Thread>& lhs, const SharedPtr<Thread>& rhs) { - return lhs->GetPriority() < rhs->GetPriority(); - }); + auto* const current_process = system.Kernel().CurrentProcess(); + std::vector<std::shared_ptr<Thread>> waiting_threads = + current_process->GetConditionVariableThreads(condition_variable_addr); - // Only process up to 'target' threads, unless 'target' is -1, in which case process + // Only process up to 'target' threads, unless 'target' is less equal 0, in which case process // them all. std::size_t last = waiting_threads.size(); - if (target != -1) + if (target > 0) last = std::min(waiting_threads.size(), static_cast<std::size_t>(target)); - // If there are no threads waiting on this condition variable, just exit - if (last == 0) - return RESULT_SUCCESS; - for (std::size_t index = 0; index < last; ++index) { auto& thread = waiting_threads[index]; ASSERT(thread->GetCondVarWaitAddress() == condition_variable_addr); // liberate Cond Var Thread. + current_process->RemoveConditionVariableThread(thread); thread->SetCondVarWaitAddress(0); const std::size_t current_core = system.CurrentCoreIndex(); auto& monitor = system.Monitor(); + auto& memory = system.Memory(); // Atomically read the value of the mutex. u32 mutex_val = 0; @@ -1689,7 +1689,7 @@ static ResultCode SignalProcessWideKey(Core::System& system, VAddr condition_var monitor.SetExclusive(current_core, mutex_address); // If the mutex is not yet acquired, acquire it. - mutex_val = Memory::Read32(mutex_address); + mutex_val = memory.Read32(mutex_address); if (mutex_val != 0) { update_val = mutex_val | Mutex::MutexHasWaitersFlag; @@ -1786,7 +1786,9 @@ static u64 GetSystemTick(Core::System& system) { LOG_TRACE(Kernel_SVC, "called"); auto& core_timing = system.CoreTiming(); - const u64 result{core_timing.GetTicks()}; + + // Returns the value of cntpct_el0 (https://switchbrew.org/wiki/SVC#svcGetSystemTick) + const u64 result{Core::Timing::CpuCyclesToClockCycles(system.CoreTiming().GetTicks())}; // Advance time to defeat dumb games that busy-wait for the frame to end. core_timing.AddTicks(400); @@ -1975,7 +1977,7 @@ static ResultCode GetThreadCoreMask(Core::System& system, Handle thread_handle, LOG_TRACE(Kernel_SVC, "called, handle=0x{:08X}", thread_handle); const auto& handle_table = system.Kernel().CurrentProcess()->GetHandleTable(); - const SharedPtr<Thread> thread = handle_table.Get<Thread>(thread_handle); + const std::shared_ptr<Thread> thread = handle_table.Get<Thread>(thread_handle); if (!thread) { LOG_ERROR(Kernel_SVC, "Thread handle does not exist, thread_handle=0x{:08X}", thread_handle); @@ -2034,7 +2036,7 @@ static ResultCode SetThreadCoreMask(Core::System& system, Handle thread_handle, } const auto& handle_table = current_process->GetHandleTable(); - const SharedPtr<Thread> thread = handle_table.Get<Thread>(thread_handle); + const std::shared_ptr<Thread> thread = handle_table.Get<Thread>(thread_handle); if (!thread) { LOG_ERROR(Kernel_SVC, "Thread handle does not exist, thread_handle=0x{:08X}", thread_handle); @@ -2290,12 +2292,13 @@ static ResultCode GetProcessList(Core::System& system, u32* out_num_processes, return ERR_INVALID_ADDRESS_STATE; } + auto& memory = system.Memory(); const auto& process_list = kernel.GetProcessList(); const auto num_processes = process_list.size(); const auto copy_amount = std::min(std::size_t{out_process_ids_size}, num_processes); for (std::size_t i = 0; i < copy_amount; ++i) { - Memory::Write64(out_process_ids, process_list[i]->GetProcessID()); + memory.Write64(out_process_ids, process_list[i]->GetProcessID()); out_process_ids += sizeof(u64); } @@ -2329,13 +2332,14 @@ static ResultCode GetThreadList(Core::System& system, u32* out_num_threads, VAdd return ERR_INVALID_ADDRESS_STATE; } + auto& memory = system.Memory(); const auto& thread_list = current_process->GetThreadList(); const auto num_threads = thread_list.size(); const auto copy_amount = std::min(std::size_t{out_thread_ids_size}, num_threads); auto list_iter = thread_list.cbegin(); for (std::size_t i = 0; i < copy_amount; ++i, ++list_iter) { - Memory::Write64(out_thread_ids, (*list_iter)->GetThreadID()); + memory.Write64(out_thread_ids, (*list_iter)->GetThreadID()); out_thread_ids += sizeof(u64); } diff --git a/src/core/hle/kernel/thread.cpp b/src/core/hle/kernel/thread.cpp index 962530d2d..e84e5ce0d 100644 --- a/src/core/hle/kernel/thread.cpp +++ b/src/core/hle/kernel/thread.cpp @@ -50,7 +50,7 @@ void Thread::Stop() { // Clean up any dangling references in objects that this thread was waiting for for (auto& wait_object : wait_objects) { - wait_object->RemoveWaitingThread(this); + wait_object->RemoveWaitingThread(SharedFrom(this)); } wait_objects.clear(); @@ -77,18 +77,6 @@ void Thread::CancelWakeupTimer() { callback_handle); } -static std::optional<s32> GetNextProcessorId(u64 mask) { - for (s32 index = 0; index < Core::NUM_CPU_CORES; ++index) { - if (mask & (1ULL << index)) { - if (!Core::System::GetInstance().Scheduler(index).GetCurrentThread()) { - // Core is enabled and not running any threads, use this one - return index; - } - } - } - return {}; -} - void Thread::ResumeFromWait() { ASSERT_MSG(wait_objects.empty(), "Thread is waking up while waiting for objects"); @@ -132,8 +120,11 @@ void Thread::ResumeFromWait() { } void Thread::CancelWait() { - ASSERT(GetStatus() == ThreadStatus::WaitSynch); - ClearWaitObjects(); + if (GetSchedulingStatus() != ThreadSchedStatus::Paused) { + is_sync_cancelled = true; + return; + } + is_sync_cancelled = false; SetWaitSynchronizationResult(ERR_SYNCHRONIZATION_CANCELED); ResumeFromWait(); } @@ -156,9 +147,10 @@ static void ResetThreadContext(Core::ARM_Interface::ThreadContext& context, VAdd context.fpcr = 0x03C00000; } -ResultVal<SharedPtr<Thread>> Thread::Create(KernelCore& kernel, std::string name, VAddr entry_point, - u32 priority, u64 arg, s32 processor_id, - VAddr stack_top, Process& owner_process) { +ResultVal<std::shared_ptr<Thread>> Thread::Create(KernelCore& kernel, std::string name, + VAddr entry_point, u32 priority, u64 arg, + s32 processor_id, VAddr stack_top, + Process& owner_process) { // Check if priority is in ranged. Lowest priority -> highest priority id. if (priority > THREADPRIO_LOWEST) { LOG_ERROR(Kernel_SVC, "Invalid thread priority: {}", priority); @@ -170,14 +162,14 @@ ResultVal<SharedPtr<Thread>> Thread::Create(KernelCore& kernel, std::string name return ERR_INVALID_PROCESSOR_ID; } - if (!Memory::IsValidVirtualAddress(owner_process, entry_point)) { + auto& system = Core::System::GetInstance(); + if (!system.Memory().IsValidVirtualAddress(owner_process, entry_point)) { LOG_ERROR(Kernel_SVC, "(name={}): invalid entry {:016X}", name, entry_point); // TODO (bunnei): Find the correct error code to use here - return ResultCode(-1); + return RESULT_UNKNOWN; } - auto& system = Core::System::GetInstance(); - SharedPtr<Thread> thread(new Thread(kernel)); + std::shared_ptr<Thread> thread = std::make_shared<Thread>(kernel); thread->thread_id = kernel.CreateNewThreadID(); thread->status = ThreadStatus::Dormant; @@ -206,7 +198,7 @@ ResultVal<SharedPtr<Thread>> Thread::Create(KernelCore& kernel, std::string name // to initialize the context ResetThreadContext(thread->context, stack_top, entry_point, arg); - return MakeResult<SharedPtr<Thread>>(std::move(thread)); + return MakeResult<std::shared_ptr<Thread>>(std::move(thread)); } void Thread::SetPriority(u32 priority) { @@ -224,7 +216,7 @@ void Thread::SetWaitSynchronizationOutput(s32 output) { context.cpu_registers[1] = output; } -s32 Thread::GetWaitObjectIndex(const WaitObject* object) const { +s32 Thread::GetWaitObjectIndex(std::shared_ptr<WaitObject> object) const { ASSERT_MSG(!wait_objects.empty(), "Thread is not waiting for anything"); const auto match = std::find(wait_objects.rbegin(), wait_objects.rend(), object); return static_cast<s32>(std::distance(match, wait_objects.rend()) - 1); @@ -264,8 +256,8 @@ void Thread::SetStatus(ThreadStatus new_status) { status = new_status; } -void Thread::AddMutexWaiter(SharedPtr<Thread> thread) { - if (thread->lock_owner == this) { +void Thread::AddMutexWaiter(std::shared_ptr<Thread> thread) { + if (thread->lock_owner.get() == this) { // If the thread is already waiting for this thread to release the mutex, ensure that the // waiters list is consistent and return without doing anything. const auto iter = std::find(wait_mutex_threads.begin(), wait_mutex_threads.end(), thread); @@ -285,13 +277,13 @@ void Thread::AddMutexWaiter(SharedPtr<Thread> thread) { wait_mutex_threads.begin(), wait_mutex_threads.end(), [&thread](const auto& entry) { return entry->GetPriority() > thread->GetPriority(); }); wait_mutex_threads.insert(insertion_point, thread); - thread->lock_owner = this; + thread->lock_owner = SharedFrom(this); UpdatePriority(); } -void Thread::RemoveMutexWaiter(SharedPtr<Thread> thread) { - ASSERT(thread->lock_owner == this); +void Thread::RemoveMutexWaiter(std::shared_ptr<Thread> thread) { + ASSERT(thread->lock_owner.get() == this); // Ensure that the thread is in the list of mutex waiters const auto iter = std::find(wait_mutex_threads.begin(), wait_mutex_threads.end(), thread); @@ -318,16 +310,24 @@ void Thread::UpdatePriority() { return; } + if (GetStatus() == ThreadStatus::WaitCondVar) { + owner_process->RemoveConditionVariableThread(SharedFrom(this)); + } + SetCurrentPriority(new_priority); + if (GetStatus() == ThreadStatus::WaitCondVar) { + owner_process->InsertConditionVariableThread(SharedFrom(this)); + } + if (!lock_owner) { return; } // Ensure that the thread is within the correct location in the waiting list. auto old_owner = lock_owner; - lock_owner->RemoveMutexWaiter(this); - old_owner->AddMutexWaiter(this); + lock_owner->RemoveMutexWaiter(SharedFrom(this)); + old_owner->AddMutexWaiter(SharedFrom(this)); // Recursively update the priority of the thread that depends on the priority of this one. lock_owner->UpdatePriority(); @@ -340,11 +340,11 @@ void Thread::ChangeCore(u32 core, u64 mask) { bool Thread::AllWaitObjectsReady() const { return std::none_of( wait_objects.begin(), wait_objects.end(), - [this](const SharedPtr<WaitObject>& object) { return object->ShouldWait(this); }); + [this](const std::shared_ptr<WaitObject>& object) { return object->ShouldWait(this); }); } -bool Thread::InvokeWakeupCallback(ThreadWakeupReason reason, SharedPtr<Thread> thread, - SharedPtr<WaitObject> object, std::size_t index) { +bool Thread::InvokeWakeupCallback(ThreadWakeupReason reason, std::shared_ptr<Thread> thread, + std::shared_ptr<WaitObject> object, std::size_t index) { ASSERT(wakeup_callback); return wakeup_callback(reason, std::move(thread), std::move(object), index); } @@ -401,7 +401,7 @@ void Thread::SetCurrentPriority(u32 new_priority) { ResultCode Thread::SetCoreAndAffinityMask(s32 new_core, u64 new_affinity_mask) { const auto HighestSetCore = [](u64 mask, u32 max_cores) { - for (s32 core = max_cores - 1; core >= 0; core--) { + for (s32 core = static_cast<s32>(max_cores - 1); core >= 0; core--) { if (((mask >> core) & 1) != 0) { return core; } @@ -425,7 +425,7 @@ ResultCode Thread::SetCoreAndAffinityMask(s32 new_core, u64 new_affinity_mask) { if (old_affinity_mask != new_affinity_mask) { const s32 old_core = processor_id; if (processor_id >= 0 && ((affinity_mask >> processor_id) & 1) == 0) { - if (ideal_core < 0) { + if (static_cast<s32>(ideal_core) < 0) { processor_id = HighestSetCore(affinity_mask, GlobalScheduler::NUM_CPU_CORES); } else { processor_id = ideal_core; @@ -447,23 +447,23 @@ void Thread::AdjustSchedulingOnStatus(u32 old_flags) { ThreadSchedStatus::Runnable) { // In this case the thread was running, now it's pausing/exitting if (processor_id >= 0) { - scheduler.Unschedule(current_priority, processor_id, this); + scheduler.Unschedule(current_priority, static_cast<u32>(processor_id), this); } - for (s32 core = 0; core < GlobalScheduler::NUM_CPU_CORES; core++) { - if (core != processor_id && ((affinity_mask >> core) & 1) != 0) { - scheduler.Unsuggest(current_priority, static_cast<u32>(core), this); + for (u32 core = 0; core < GlobalScheduler::NUM_CPU_CORES; core++) { + if (core != static_cast<u32>(processor_id) && ((affinity_mask >> core) & 1) != 0) { + scheduler.Unsuggest(current_priority, core, this); } } } else if (GetSchedulingStatus() == ThreadSchedStatus::Runnable) { // The thread is now set to running from being stopped if (processor_id >= 0) { - scheduler.Schedule(current_priority, processor_id, this); + scheduler.Schedule(current_priority, static_cast<u32>(processor_id), this); } - for (s32 core = 0; core < GlobalScheduler::NUM_CPU_CORES; core++) { - if (core != processor_id && ((affinity_mask >> core) & 1) != 0) { - scheduler.Suggest(current_priority, static_cast<u32>(core), this); + for (u32 core = 0; core < GlobalScheduler::NUM_CPU_CORES; core++) { + if (core != static_cast<u32>(processor_id) && ((affinity_mask >> core) & 1) != 0) { + scheduler.Suggest(current_priority, core, this); } } } @@ -477,11 +477,11 @@ void Thread::AdjustSchedulingOnPriority(u32 old_priority) { } auto& scheduler = Core::System::GetInstance().GlobalScheduler(); if (processor_id >= 0) { - scheduler.Unschedule(old_priority, processor_id, this); + scheduler.Unschedule(old_priority, static_cast<u32>(processor_id), this); } for (u32 core = 0; core < GlobalScheduler::NUM_CPU_CORES; core++) { - if (core != processor_id && ((affinity_mask >> core) & 1) != 0) { + if (core != static_cast<u32>(processor_id) && ((affinity_mask >> core) & 1) != 0) { scheduler.Unsuggest(old_priority, core, this); } } @@ -491,14 +491,14 @@ void Thread::AdjustSchedulingOnPriority(u32 old_priority) { if (processor_id >= 0) { if (current_thread == this) { - scheduler.SchedulePrepend(current_priority, processor_id, this); + scheduler.SchedulePrepend(current_priority, static_cast<u32>(processor_id), this); } else { - scheduler.Schedule(current_priority, processor_id, this); + scheduler.Schedule(current_priority, static_cast<u32>(processor_id), this); } } for (u32 core = 0; core < GlobalScheduler::NUM_CPU_CORES; core++) { - if (core != processor_id && ((affinity_mask >> core) & 1) != 0) { + if (core != static_cast<u32>(processor_id) && ((affinity_mask >> core) & 1) != 0) { scheduler.Suggest(current_priority, core, this); } } @@ -515,7 +515,7 @@ void Thread::AdjustSchedulingOnAffinity(u64 old_affinity_mask, s32 old_core) { for (u32 core = 0; core < GlobalScheduler::NUM_CPU_CORES; core++) { if (((old_affinity_mask >> core) & 1) != 0) { - if (core == old_core) { + if (core == static_cast<u32>(old_core)) { scheduler.Unschedule(current_priority, core, this); } else { scheduler.Unsuggest(current_priority, core, this); @@ -525,7 +525,7 @@ void Thread::AdjustSchedulingOnAffinity(u64 old_affinity_mask, s32 old_core) { for (u32 core = 0; core < GlobalScheduler::NUM_CPU_CORES; core++) { if (((affinity_mask >> core) & 1) != 0) { - if (core == processor_id) { + if (core == static_cast<u32>(processor_id)) { scheduler.Schedule(current_priority, core, this); } else { scheduler.Suggest(current_priority, core, this); diff --git a/src/core/hle/kernel/thread.h b/src/core/hle/kernel/thread.h index c9870873d..3bcf9e137 100644 --- a/src/core/hle/kernel/thread.h +++ b/src/core/hle/kernel/thread.h @@ -97,14 +97,18 @@ enum class ThreadSchedMasks : u32 { class Thread final : public WaitObject { public: - using MutexWaitingThreads = std::vector<SharedPtr<Thread>>; + explicit Thread(KernelCore& kernel); + ~Thread() override; + + using MutexWaitingThreads = std::vector<std::shared_ptr<Thread>>; using ThreadContext = Core::ARM_Interface::ThreadContext; - using ThreadWaitObjects = std::vector<SharedPtr<WaitObject>>; + using ThreadWaitObjects = std::vector<std::shared_ptr<WaitObject>>; - using WakeupCallback = std::function<bool(ThreadWakeupReason reason, SharedPtr<Thread> thread, - SharedPtr<WaitObject> object, std::size_t index)>; + using WakeupCallback = + std::function<bool(ThreadWakeupReason reason, std::shared_ptr<Thread> thread, + std::shared_ptr<WaitObject> object, std::size_t index)>; /** * Creates and returns a new thread. The new thread is immediately scheduled @@ -118,10 +122,10 @@ public: * @param owner_process The parent process for the thread * @return A shared pointer to the newly created thread */ - static ResultVal<SharedPtr<Thread>> Create(KernelCore& kernel, std::string name, - VAddr entry_point, u32 priority, u64 arg, - s32 processor_id, VAddr stack_top, - Process& owner_process); + static ResultVal<std::shared_ptr<Thread>> Create(KernelCore& kernel, std::string name, + VAddr entry_point, u32 priority, u64 arg, + s32 processor_id, VAddr stack_top, + Process& owner_process); std::string GetName() const override { return name; @@ -166,10 +170,10 @@ public: void SetPriority(u32 priority); /// Adds a thread to the list of threads that are waiting for a lock held by this thread. - void AddMutexWaiter(SharedPtr<Thread> thread); + void AddMutexWaiter(std::shared_ptr<Thread> thread); /// Removes a thread from the list of threads that are waiting for a lock held by this thread. - void RemoveMutexWaiter(SharedPtr<Thread> thread); + void RemoveMutexWaiter(std::shared_ptr<Thread> thread); /// Recalculates the current priority taking into account priority inheritance. void UpdatePriority(); @@ -229,7 +233,7 @@ public: * * @param object Object to query the index of. */ - s32 GetWaitObjectIndex(const WaitObject* object) const; + s32 GetWaitObjectIndex(std::shared_ptr<WaitObject> object) const; /** * Stops a thread, invalidating it from further use @@ -320,7 +324,7 @@ public: void ClearWaitObjects() { for (const auto& waiting_object : wait_objects) { - waiting_object->RemoveWaitingThread(this); + waiting_object->RemoveWaitingThread(SharedFrom(this)); } wait_objects.clear(); } @@ -336,7 +340,7 @@ public: return lock_owner.get(); } - void SetLockOwner(SharedPtr<Thread> owner) { + void SetLockOwner(std::shared_ptr<Thread> owner) { lock_owner = std::move(owner); } @@ -390,8 +394,8 @@ public: * @pre A valid wakeup callback has been set. Violating this precondition * will cause an assertion to trigger. */ - bool InvokeWakeupCallback(ThreadWakeupReason reason, SharedPtr<Thread> thread, - SharedPtr<WaitObject> object, std::size_t index); + bool InvokeWakeupCallback(ThreadWakeupReason reason, std::shared_ptr<Thread> thread, + std::shared_ptr<WaitObject> object, std::size_t index); u32 GetIdealCore() const { return ideal_core; @@ -440,10 +444,15 @@ public: is_running = value; } -private: - explicit Thread(KernelCore& kernel); - ~Thread() override; + bool IsSyncCancelled() const { + return is_sync_cancelled; + } + void SetSyncCancelled(bool value) { + is_sync_cancelled = value; + } + +private: void SetSchedulingStatus(ThreadSchedStatus new_status); void SetCurrentPriority(u32 new_priority); ResultCode SetCoreAndAffinityMask(s32 new_core, u64 new_affinity_mask); @@ -491,7 +500,7 @@ private: MutexWaitingThreads wait_mutex_threads; /// Thread that owns the lock that this thread is waiting for. - SharedPtr<Thread> lock_owner; + std::shared_ptr<Thread> lock_owner; /// If waiting on a ConditionVariable, this is the ConditionVariable address VAddr condvar_wait_address = 0; @@ -524,6 +533,7 @@ private: u32 scheduling_state = 0; bool is_running = false; + bool is_sync_cancelled = false; std::string name; }; diff --git a/src/core/hle/kernel/transfer_memory.cpp b/src/core/hle/kernel/transfer_memory.cpp index 1113c815e..f0e73f57b 100644 --- a/src/core/hle/kernel/transfer_memory.cpp +++ b/src/core/hle/kernel/transfer_memory.cpp @@ -14,9 +14,9 @@ namespace Kernel { TransferMemory::TransferMemory(KernelCore& kernel) : Object{kernel} {} TransferMemory::~TransferMemory() = default; -SharedPtr<TransferMemory> TransferMemory::Create(KernelCore& kernel, VAddr base_address, u64 size, - MemoryPermission permissions) { - SharedPtr<TransferMemory> transfer_memory{new TransferMemory(kernel)}; +std::shared_ptr<TransferMemory> TransferMemory::Create(KernelCore& kernel, VAddr base_address, + u64 size, MemoryPermission permissions) { + std::shared_ptr<TransferMemory> transfer_memory{std::make_shared<TransferMemory>(kernel)}; transfer_memory->base_address = base_address; transfer_memory->memory_size = size; diff --git a/src/core/hle/kernel/transfer_memory.h b/src/core/hle/kernel/transfer_memory.h index 6be9dc094..556e6c62b 100644 --- a/src/core/hle/kernel/transfer_memory.h +++ b/src/core/hle/kernel/transfer_memory.h @@ -27,10 +27,13 @@ enum class MemoryPermission : u32; /// class TransferMemory final : public Object { public: + explicit TransferMemory(KernelCore& kernel); + ~TransferMemory() override; + static constexpr HandleType HANDLE_TYPE = HandleType::TransferMemory; - static SharedPtr<TransferMemory> Create(KernelCore& kernel, VAddr base_address, u64 size, - MemoryPermission permissions); + static std::shared_ptr<TransferMemory> Create(KernelCore& kernel, VAddr base_address, u64 size, + MemoryPermission permissions); TransferMemory(const TransferMemory&) = delete; TransferMemory& operator=(const TransferMemory&) = delete; @@ -79,9 +82,6 @@ public: ResultCode UnmapMemory(VAddr address, u64 size); private: - explicit TransferMemory(KernelCore& kernel); - ~TransferMemory() override; - /// Memory block backing this instance. std::shared_ptr<PhysicalMemory> backing_block; diff --git a/src/core/hle/kernel/vm_manager.cpp b/src/core/hle/kernel/vm_manager.cpp index c7af87073..a9a20ef76 100644 --- a/src/core/hle/kernel/vm_manager.cpp +++ b/src/core/hle/kernel/vm_manager.cpp @@ -16,7 +16,6 @@ #include "core/hle/kernel/resource_limit.h" #include "core/hle/kernel/vm_manager.h" #include "core/memory.h" -#include "core/memory_setup.h" namespace Kernel { namespace { @@ -167,7 +166,7 @@ ResultVal<VAddr> VMManager::FindFreeRegion(VAddr begin, VAddr end, u64 size) con if (vma_handle == vma_map.cend()) { // TODO(Subv): Find the correct error code here. - return ResultCode(-1); + return RESULT_UNKNOWN; } const VAddr target = std::max(begin, vma_handle->second.base); @@ -786,19 +785,21 @@ void VMManager::MergeAdjacentVMA(VirtualMemoryArea& left, const VirtualMemoryAre } void VMManager::UpdatePageTableForVMA(const VirtualMemoryArea& vma) { + auto& memory = system.Memory(); + switch (vma.type) { case VMAType::Free: - Memory::UnmapRegion(page_table, vma.base, vma.size); + memory.UnmapRegion(page_table, vma.base, vma.size); break; case VMAType::AllocatedMemoryBlock: - Memory::MapMemoryRegion(page_table, vma.base, vma.size, - vma.backing_block->data() + vma.offset); + memory.MapMemoryRegion(page_table, vma.base, vma.size, + vma.backing_block->data() + vma.offset); break; case VMAType::BackingMemory: - Memory::MapMemoryRegion(page_table, vma.base, vma.size, vma.backing_memory); + memory.MapMemoryRegion(page_table, vma.base, vma.size, vma.backing_memory); break; case VMAType::MMIO: - Memory::MapIoRegion(page_table, vma.base, vma.size, vma.mmio_handler); + memory.MapIoRegion(page_table, vma.base, vma.size, vma.mmio_handler); break; } } diff --git a/src/core/hle/kernel/wait_object.cpp b/src/core/hle/kernel/wait_object.cpp index c00cef062..745f2c4e8 100644 --- a/src/core/hle/kernel/wait_object.cpp +++ b/src/core/hle/kernel/wait_object.cpp @@ -18,13 +18,13 @@ namespace Kernel { WaitObject::WaitObject(KernelCore& kernel) : Object{kernel} {} WaitObject::~WaitObject() = default; -void WaitObject::AddWaitingThread(SharedPtr<Thread> thread) { +void WaitObject::AddWaitingThread(std::shared_ptr<Thread> thread) { auto itr = std::find(waiting_threads.begin(), waiting_threads.end(), thread); if (itr == waiting_threads.end()) waiting_threads.push_back(std::move(thread)); } -void WaitObject::RemoveWaitingThread(Thread* thread) { +void WaitObject::RemoveWaitingThread(std::shared_ptr<Thread> thread) { auto itr = std::find(waiting_threads.begin(), waiting_threads.end(), thread); // If a thread passed multiple handles to the same object, // the kernel might attempt to remove the thread from the object's @@ -33,7 +33,7 @@ void WaitObject::RemoveWaitingThread(Thread* thread) { waiting_threads.erase(itr); } -SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() const { +std::shared_ptr<Thread> WaitObject::GetHighestPriorityReadyThread() const { Thread* candidate = nullptr; u32 candidate_priority = THREADPRIO_LOWEST + 1; @@ -64,10 +64,10 @@ SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() const { } } - return candidate; + return SharedFrom(candidate); } -void WaitObject::WakeupWaitingThread(SharedPtr<Thread> thread) { +void WaitObject::WakeupWaitingThread(std::shared_ptr<Thread> thread) { ASSERT(!ShouldWait(thread.get())); if (!thread) { @@ -83,7 +83,7 @@ void WaitObject::WakeupWaitingThread(SharedPtr<Thread> thread) { Acquire(thread.get()); } - const std::size_t index = thread->GetWaitObjectIndex(this); + const std::size_t index = thread->GetWaitObjectIndex(SharedFrom(this)); thread->ClearWaitObjects(); @@ -91,7 +91,8 @@ void WaitObject::WakeupWaitingThread(SharedPtr<Thread> thread) { bool resume = true; if (thread->HasWakeupCallback()) { - resume = thread->InvokeWakeupCallback(ThreadWakeupReason::Signal, thread, this, index); + resume = thread->InvokeWakeupCallback(ThreadWakeupReason::Signal, thread, SharedFrom(this), + index); } if (resume) { thread->ResumeFromWait(); @@ -105,7 +106,7 @@ void WaitObject::WakeupAllWaitingThreads() { } } -const std::vector<SharedPtr<Thread>>& WaitObject::GetWaitingThreads() const { +const std::vector<std::shared_ptr<Thread>>& WaitObject::GetWaitingThreads() const { return waiting_threads; } diff --git a/src/core/hle/kernel/wait_object.h b/src/core/hle/kernel/wait_object.h index 3271a30a7..f9d596db9 100644 --- a/src/core/hle/kernel/wait_object.h +++ b/src/core/hle/kernel/wait_object.h @@ -33,13 +33,13 @@ public: * Add a thread to wait on this object * @param thread Pointer to thread to add */ - void AddWaitingThread(SharedPtr<Thread> thread); + void AddWaitingThread(std::shared_ptr<Thread> thread); /** * Removes a thread from waiting on this object (e.g. if it was resumed already) * @param thread Pointer to thread to remove */ - void RemoveWaitingThread(Thread* thread); + void RemoveWaitingThread(std::shared_ptr<Thread> thread); /** * Wake up all threads waiting on this object that can be awoken, in priority order, @@ -51,24 +51,24 @@ public: * Wakes up a single thread waiting on this object. * @param thread Thread that is waiting on this object to wakeup. */ - void WakeupWaitingThread(SharedPtr<Thread> thread); + void WakeupWaitingThread(std::shared_ptr<Thread> thread); /// Obtains the highest priority thread that is ready to run from this object's waiting list. - SharedPtr<Thread> GetHighestPriorityReadyThread() const; + std::shared_ptr<Thread> GetHighestPriorityReadyThread() const; /// Get a const reference to the waiting threads list for debug use - const std::vector<SharedPtr<Thread>>& GetWaitingThreads() const; + const std::vector<std::shared_ptr<Thread>>& GetWaitingThreads() const; private: /// Threads waiting for this object to become available - std::vector<SharedPtr<Thread>> waiting_threads; + std::vector<std::shared_ptr<Thread>> waiting_threads; }; // Specialization of DynamicObjectCast for WaitObjects template <> -inline SharedPtr<WaitObject> DynamicObjectCast<WaitObject>(SharedPtr<Object> object) { +inline std::shared_ptr<WaitObject> DynamicObjectCast<WaitObject>(std::shared_ptr<Object> object) { if (object != nullptr && object->IsWaitable()) { - return boost::static_pointer_cast<WaitObject>(object); + return std::static_pointer_cast<WaitObject>(object); } return nullptr; } diff --git a/src/core/hle/kernel/writable_event.cpp b/src/core/hle/kernel/writable_event.cpp index c783a34ee..c9332e3e1 100644 --- a/src/core/hle/kernel/writable_event.cpp +++ b/src/core/hle/kernel/writable_event.cpp @@ -16,8 +16,8 @@ WritableEvent::WritableEvent(KernelCore& kernel) : Object{kernel} {} WritableEvent::~WritableEvent() = default; EventPair WritableEvent::CreateEventPair(KernelCore& kernel, std::string name) { - SharedPtr<WritableEvent> writable_event(new WritableEvent(kernel)); - SharedPtr<ReadableEvent> readable_event(new ReadableEvent(kernel)); + std::shared_ptr<WritableEvent> writable_event(new WritableEvent(kernel)); + std::shared_ptr<ReadableEvent> readable_event(new ReadableEvent(kernel)); writable_event->name = name + ":Writable"; writable_event->readable = readable_event; @@ -27,7 +27,7 @@ EventPair WritableEvent::CreateEventPair(KernelCore& kernel, std::string name) { return {std::move(readable_event), std::move(writable_event)}; } -SharedPtr<ReadableEvent> WritableEvent::GetReadableEvent() const { +std::shared_ptr<ReadableEvent> WritableEvent::GetReadableEvent() const { return readable; } diff --git a/src/core/hle/kernel/writable_event.h b/src/core/hle/kernel/writable_event.h index f46cf1dd8..afe97f3a9 100644 --- a/src/core/hle/kernel/writable_event.h +++ b/src/core/hle/kernel/writable_event.h @@ -13,8 +13,8 @@ class ReadableEvent; class WritableEvent; struct EventPair { - SharedPtr<ReadableEvent> readable; - SharedPtr<WritableEvent> writable; + std::shared_ptr<ReadableEvent> readable; + std::shared_ptr<WritableEvent> writable; }; class WritableEvent final : public Object { @@ -40,7 +40,7 @@ public: return HANDLE_TYPE; } - SharedPtr<ReadableEvent> GetReadableEvent() const; + std::shared_ptr<ReadableEvent> GetReadableEvent() const; void Signal(); void Clear(); @@ -49,7 +49,7 @@ public: private: explicit WritableEvent(KernelCore& kernel); - SharedPtr<ReadableEvent> readable; + std::shared_ptr<ReadableEvent> readable; std::string name; ///< Name of event (optional) }; diff --git a/src/core/hle/result.h b/src/core/hle/result.h index 8a3701151..450f61fea 100644 --- a/src/core/hle/result.h +++ b/src/core/hle/result.h @@ -147,6 +147,14 @@ constexpr bool operator!=(const ResultCode& a, const ResultCode& b) { constexpr ResultCode RESULT_SUCCESS(0); /** + * Placeholder result code used for unknown error codes. + * + * @note This should only be used when a particular error code + * is not known yet. + */ +constexpr ResultCode RESULT_UNKNOWN(UINT32_MAX); + +/** * This is an optional value type. It holds a `ResultCode` and, if that code is a success code, * also holds a result of type `T`. If the code is an error code then trying to access the inner * value fails, thus ensuring that the ResultCode of functions is always checked properly before @@ -183,7 +191,7 @@ class ResultVal { public: /// Constructs an empty `ResultVal` with the given error code. The code must not be a success /// code. - ResultVal(ResultCode error_code = ResultCode(-1)) : result_code(error_code) { + ResultVal(ResultCode error_code = RESULT_UNKNOWN) : result_code(error_code) { ASSERT(error_code.IsError()); } diff --git a/src/core/hle/service/acc/acc.cpp b/src/core/hle/service/acc/acc.cpp index 0c0f7ed6e..7e3e311fb 100644 --- a/src/core/hle/service/acc/acc.cpp +++ b/src/core/hle/service/acc/acc.cpp @@ -84,7 +84,7 @@ protected: LOG_ERROR(Service_ACC, "Failed to get profile base and data for user={}", user_id.Format()); IPC::ResponseBuilder rb{ctx, 2}; - rb.Push(ResultCode(-1)); // TODO(ogniK): Get actual error code + rb.Push(RESULT_UNKNOWN); // TODO(ogniK): Get actual error code } } @@ -98,7 +98,7 @@ protected: } else { LOG_ERROR(Service_ACC, "Failed to get profile base for user={}", user_id.Format()); IPC::ResponseBuilder rb{ctx, 2}; - rb.Push(ResultCode(-1)); // TODO(ogniK): Get actual error code + rb.Push(RESULT_UNKNOWN); // TODO(ogniK): Get actual error code } } @@ -442,7 +442,7 @@ void Module::Interface::TrySelectUserWithoutInteraction(Kernel::HLERequestContex const auto user_list = profile_manager->GetAllUsers(); if (std::all_of(user_list.begin(), user_list.end(), [](const auto& user) { return user.uuid == Common::INVALID_UUID; })) { - rb.Push(ResultCode(-1)); // TODO(ogniK): Find the correct error code + rb.Push(RESULT_UNKNOWN); // TODO(ogniK): Find the correct error code rb.PushRaw<u128>(Common::INVALID_UUID); return; } diff --git a/src/core/hle/service/acc/acc_su.cpp b/src/core/hle/service/acc/acc_su.cpp index 0d1663657..b941c260b 100644 --- a/src/core/hle/service/acc/acc_su.cpp +++ b/src/core/hle/service/acc/acc_su.cpp @@ -28,6 +28,7 @@ ACC_SU::ACC_SU(std::shared_ptr<Module> module, std::shared_ptr<ProfileManager> p {103, nullptr, "GetBaasUserAvailabilityChangeNotifier"}, {104, nullptr, "GetProfileUpdateNotifier"}, {105, nullptr, "CheckNetworkServiceAvailabilityAsync"}, + {106, nullptr, "GetProfileSyncNotifier"}, {110, nullptr, "StoreSaveDataThumbnail"}, {111, nullptr, "ClearSaveDataThumbnail"}, {112, nullptr, "LoadSaveDataThumbnail"}, @@ -44,6 +45,8 @@ ACC_SU::ACC_SU(std::shared_ptr<Module> module, std::shared_ptr<ProfileManager> p {205, &ACC_SU::GetProfileEditor, "GetProfileEditor"}, {206, nullptr, "CompleteUserRegistrationForcibly"}, {210, nullptr, "CreateFloatingRegistrationRequest"}, + {211, nullptr, "CreateProcedureToRegisterUserWithNintendoAccount"}, + {212, nullptr, "ResumeProcedureToRegisterUserWithNintendoAccount"}, {230, nullptr, "AuthenticateServiceAsync"}, {250, nullptr, "GetBaasAccountAdministrator"}, {290, nullptr, "ProxyProcedureForGuestLoginWithNintendoAccount"}, diff --git a/src/core/hle/service/acc/acc_u1.cpp b/src/core/hle/service/acc/acc_u1.cpp index 6520b3968..858e91dde 100644 --- a/src/core/hle/service/acc/acc_u1.cpp +++ b/src/core/hle/service/acc/acc_u1.cpp @@ -28,6 +28,7 @@ ACC_U1::ACC_U1(std::shared_ptr<Module> module, std::shared_ptr<ProfileManager> p {103, nullptr, "GetProfileUpdateNotifier"}, {104, nullptr, "CheckNetworkServiceAvailabilityAsync"}, {105, nullptr, "GetBaasUserAvailabilityChangeNotifier"}, + {106, nullptr, "GetProfileSyncNotifier"}, {110, nullptr, "StoreSaveDataThumbnail"}, {111, nullptr, "ClearSaveDataThumbnail"}, {112, nullptr, "LoadSaveDataThumbnail"}, diff --git a/src/core/hle/service/acc/profile_manager.cpp b/src/core/hle/service/acc/profile_manager.cpp index 8f9986326..3e756e59e 100644 --- a/src/core/hle/service/acc/profile_manager.cpp +++ b/src/core/hle/service/acc/profile_manager.cpp @@ -31,8 +31,8 @@ struct ProfileDataRaw { static_assert(sizeof(ProfileDataRaw) == 0x650, "ProfileDataRaw has incorrect size."); // TODO(ogniK): Get actual error codes -constexpr ResultCode ERROR_TOO_MANY_USERS(ErrorModule::Account, -1); -constexpr ResultCode ERROR_USER_ALREADY_EXISTS(ErrorModule::Account, -2); +constexpr ResultCode ERROR_TOO_MANY_USERS(ErrorModule::Account, u32(-1)); +constexpr ResultCode ERROR_USER_ALREADY_EXISTS(ErrorModule::Account, u32(-2)); constexpr ResultCode ERROR_ARGUMENT_IS_NULL(ErrorModule::Account, 20); constexpr char ACC_SAVE_AVATORS_BASE_PATH[] = "/system/save/8000000000000010/su/avators/"; diff --git a/src/core/hle/service/am/am.cpp b/src/core/hle/service/am/am.cpp index ba54b3040..95aa5d23d 100644 --- a/src/core/hle/service/am/am.cpp +++ b/src/core/hle/service/am/am.cpp @@ -229,7 +229,15 @@ IDebugFunctions::IDebugFunctions() : ServiceFramework{"IDebugFunctions"} { {20, nullptr, "InvalidateTransitionLayer"}, {30, nullptr, "RequestLaunchApplicationWithUserAndArgumentForDebug"}, {40, nullptr, "GetAppletResourceUsageInfo"}, - {41, nullptr, "SetCpuBoostModeForApplet"}, + {100, nullptr, "SetCpuBoostModeForApplet"}, + {110, nullptr, "PushToAppletBoundChannelForDebug"}, + {111, nullptr, "TryPopFromAppletBoundChannelForDebug"}, + {120, nullptr, "AlarmSettingNotificationEnableAppEventReserve"}, + {121, nullptr, "AlarmSettingNotificationDisableAppEventReserve"}, + {122, nullptr, "AlarmSettingNotificationPushAppEventNotify"}, + {130, nullptr, "FriendInvitationSetApplicationParameter"}, + {131, nullptr, "FriendInvitationClearApplicationParameter"}, + {132, nullptr, "FriendInvitationPushApplicationParameter"}, }; // clang-format on @@ -278,10 +286,12 @@ ISelfController::ISelfController(Core::System& system, {69, &ISelfController::IsAutoSleepDisabled, "IsAutoSleepDisabled"}, {70, nullptr, "ReportMultimediaError"}, {71, nullptr, "GetCurrentIlluminanceEx"}, + {72, nullptr, "SetInputDetectionPolicy"}, {80, nullptr, "SetWirelessPriorityMode"}, {90, &ISelfController::GetAccumulatedSuspendedTickValue, "GetAccumulatedSuspendedTickValue"}, {91, &ISelfController::GetAccumulatedSuspendedTickChangedEvent, "GetAccumulatedSuspendedTickChangedEvent"}, {100, nullptr, "SetAlbumImageTakenNotificationEnabled"}, + {110, nullptr, "SetApplicationAlbumUserData"}, {1000, nullptr, "GetDebugStorageChannel"}, }; // clang-format on @@ -531,12 +541,11 @@ AppletMessageQueue::AppletMessageQueue(Kernel::KernelCore& kernel) { AppletMessageQueue::~AppletMessageQueue() = default; -const Kernel::SharedPtr<Kernel::ReadableEvent>& AppletMessageQueue::GetMesssageRecieveEvent() - const { +const std::shared_ptr<Kernel::ReadableEvent>& AppletMessageQueue::GetMesssageRecieveEvent() const { return on_new_message.readable; } -const Kernel::SharedPtr<Kernel::ReadableEvent>& AppletMessageQueue::GetOperationModeChangedEvent() +const std::shared_ptr<Kernel::ReadableEvent>& AppletMessageQueue::GetOperationModeChangedEvent() const { return on_operation_mode_changed.readable; } @@ -613,6 +622,7 @@ ICommonStateGetter::ICommonStateGetter(Core::System& system, {90, nullptr, "SetPerformanceConfigurationChangedNotification"}, {91, nullptr, "GetCurrentPerformanceConfiguration"}, {200, nullptr, "GetOperationModeSystemInfo"}, + {300, nullptr, "GetSettingsPlatformRegion"}, }; // clang-format on @@ -991,7 +1001,7 @@ void ILibraryAppletCreator::CreateLibraryApplet(Kernel::HLERequestContext& ctx) LOG_ERROR(Service_AM, "Applet doesn't exist! applet_id={}", static_cast<u32>(applet_id)); IPC::ResponseBuilder rb{ctx, 2}; - rb.Push(ResultCode(-1)); + rb.Push(RESULT_UNKNOWN); return; } @@ -1027,7 +1037,7 @@ void ILibraryAppletCreator::CreateTransferMemoryStorage(Kernel::HLERequestContex if (transfer_mem == nullptr) { LOG_ERROR(Service_AM, "shared_mem is a nullpr for handle={:08X}", handle); IPC::ResponseBuilder rb{ctx, 2}; - rb.Push(ResultCode(-1)); + rb.Push(RESULT_UNKNOWN); return; } @@ -1076,11 +1086,18 @@ IApplicationFunctions::IApplicationFunctions(Core::System& system_) {100, &IApplicationFunctions::InitializeApplicationCopyrightFrameBuffer, "InitializeApplicationCopyrightFrameBuffer"}, {101, &IApplicationFunctions::SetApplicationCopyrightImage, "SetApplicationCopyrightImage"}, {102, &IApplicationFunctions::SetApplicationCopyrightVisibility, "SetApplicationCopyrightVisibility"}, - {110, nullptr, "QueryApplicationPlayStatistics"}, + {110, &IApplicationFunctions::QueryApplicationPlayStatistics, "QueryApplicationPlayStatistics"}, + {111, &IApplicationFunctions::QueryApplicationPlayStatisticsByUid, "QueryApplicationPlayStatisticsByUid"}, {120, nullptr, "ExecuteProgram"}, {121, nullptr, "ClearUserChannel"}, {122, nullptr, "UnpopToUserChannel"}, {130, &IApplicationFunctions::GetGpuErrorDetectedSystemEvent, "GetGpuErrorDetectedSystemEvent"}, + {140, nullptr, "GetFriendInvitationStorageChannelEvent"}, + {141, nullptr, "TryPopFromFriendInvitationStorageChannel"}, + {150, nullptr, "GetNotificationStorageChannelEvent"}, + {151, nullptr, "TryPopFromNotificationStorageChannel"}, + {160, nullptr, "GetHealthWarningDisappearedSystemEvent"}, + {170, nullptr, "SetHdcpAuthenticationActivated"}, {500, nullptr, "StartContinuousRecordingFlushForDebug"}, {1000, nullptr, "CreateMovieMaker"}, {1001, nullptr, "PrepareForJit"}, @@ -1335,12 +1352,16 @@ void IApplicationFunctions::GetPseudoDeviceId(Kernel::HLERequestContext& ctx) { } void IApplicationFunctions::ExtendSaveData(Kernel::HLERequestContext& ctx) { + struct Parameters { + FileSys::SaveDataType type; + u128 user_id; + u64 new_normal_size; + u64 new_journal_size; + }; + static_assert(sizeof(Parameters) == 40); + IPC::RequestParser rp{ctx}; - const auto type{rp.PopRaw<FileSys::SaveDataType>()}; - rp.Skip(1, false); - const auto user_id{rp.PopRaw<u128>()}; - const auto new_normal_size{rp.PopRaw<u64>()}; - const auto new_journal_size{rp.PopRaw<u64>()}; + const auto [type, user_id, new_normal_size, new_journal_size] = rp.PopRaw<Parameters>(); LOG_DEBUG(Service_AM, "called with type={:02X}, user_id={:016X}{:016X}, new_normal={:016X}, " @@ -1359,10 +1380,14 @@ void IApplicationFunctions::ExtendSaveData(Kernel::HLERequestContext& ctx) { } void IApplicationFunctions::GetSaveDataSize(Kernel::HLERequestContext& ctx) { + struct Parameters { + FileSys::SaveDataType type; + u128 user_id; + }; + static_assert(sizeof(Parameters) == 24); + IPC::RequestParser rp{ctx}; - const auto type{rp.PopRaw<FileSys::SaveDataType>()}; - rp.Skip(1, false); - const auto user_id{rp.PopRaw<u128>()}; + const auto [type, user_id] = rp.PopRaw<Parameters>(); LOG_DEBUG(Service_AM, "called with type={:02X}, user_id={:016X}{:016X}", static_cast<u8>(type), user_id[1], user_id[0]); @@ -1376,6 +1401,22 @@ void IApplicationFunctions::GetSaveDataSize(Kernel::HLERequestContext& ctx) { rb.Push(size.journal); } +void IApplicationFunctions::QueryApplicationPlayStatistics(Kernel::HLERequestContext& ctx) { + LOG_WARNING(Service_AM, "(STUBBED) called"); + + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(RESULT_SUCCESS); + rb.Push<u32>(0); +} + +void IApplicationFunctions::QueryApplicationPlayStatisticsByUid(Kernel::HLERequestContext& ctx) { + LOG_WARNING(Service_AM, "(STUBBED) called"); + + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(RESULT_SUCCESS); + rb.Push<u32>(0); +} + void IApplicationFunctions::GetGpuErrorDetectedSystemEvent(Kernel::HLERequestContext& ctx) { LOG_WARNING(Service_AM, "(STUBBED) called"); @@ -1409,6 +1450,8 @@ IHomeMenuFunctions::IHomeMenuFunctions() : ServiceFramework("IHomeMenuFunctions" {30, nullptr, "GetHomeButtonWriterLockAccessor"}, {31, nullptr, "GetWriterLockAccessorEx"}, {100, nullptr, "PopRequestLaunchApplicationForDebug"}, + {110, nullptr, "IsForceTerminateApplicationDisabledForDebug"}, + {200, nullptr, "LaunchDevMenu"}, }; // clang-format on diff --git a/src/core/hle/service/am/am.h b/src/core/hle/service/am/am.h index 2ae9402a8..448817be9 100644 --- a/src/core/hle/service/am/am.h +++ b/src/core/hle/service/am/am.h @@ -54,8 +54,8 @@ public: explicit AppletMessageQueue(Kernel::KernelCore& kernel); ~AppletMessageQueue(); - const Kernel::SharedPtr<Kernel::ReadableEvent>& GetMesssageRecieveEvent() const; - const Kernel::SharedPtr<Kernel::ReadableEvent>& GetOperationModeChangedEvent() const; + const std::shared_ptr<Kernel::ReadableEvent>& GetMesssageRecieveEvent() const; + const std::shared_ptr<Kernel::ReadableEvent>& GetOperationModeChangedEvent() const; void PushMessage(AppletMessage msg); AppletMessage PopMessage(); std::size_t GetMessageCount() const; @@ -255,6 +255,8 @@ private: void InitializeApplicationCopyrightFrameBuffer(Kernel::HLERequestContext& ctx); void SetApplicationCopyrightImage(Kernel::HLERequestContext& ctx); void SetApplicationCopyrightVisibility(Kernel::HLERequestContext& ctx); + void QueryApplicationPlayStatistics(Kernel::HLERequestContext& ctx); + void QueryApplicationPlayStatisticsByUid(Kernel::HLERequestContext& ctx); void GetGpuErrorDetectedSystemEvent(Kernel::HLERequestContext& ctx); bool launch_popped_application_specific = false; diff --git a/src/core/hle/service/am/applets/applets.cpp b/src/core/hle/service/am/applets/applets.cpp index 673ad1f7f..92f995f8f 100644 --- a/src/core/hle/service/am/applets/applets.cpp +++ b/src/core/hle/service/am/applets/applets.cpp @@ -108,15 +108,15 @@ void AppletDataBroker::SignalStateChanged() const { state_changed_event.writable->Signal(); } -Kernel::SharedPtr<Kernel::ReadableEvent> AppletDataBroker::GetNormalDataEvent() const { +std::shared_ptr<Kernel::ReadableEvent> AppletDataBroker::GetNormalDataEvent() const { return pop_out_data_event.readable; } -Kernel::SharedPtr<Kernel::ReadableEvent> AppletDataBroker::GetInteractiveDataEvent() const { +std::shared_ptr<Kernel::ReadableEvent> AppletDataBroker::GetInteractiveDataEvent() const { return pop_interactive_out_data_event.readable; } -Kernel::SharedPtr<Kernel::ReadableEvent> AppletDataBroker::GetStateChangedEvent() const { +std::shared_ptr<Kernel::ReadableEvent> AppletDataBroker::GetStateChangedEvent() const { return state_changed_event.readable; } diff --git a/src/core/hle/service/am/applets/applets.h b/src/core/hle/service/am/applets/applets.h index 226be88b1..16e61fc6f 100644 --- a/src/core/hle/service/am/applets/applets.h +++ b/src/core/hle/service/am/applets/applets.h @@ -86,9 +86,9 @@ public: void SignalStateChanged() const; - Kernel::SharedPtr<Kernel::ReadableEvent> GetNormalDataEvent() const; - Kernel::SharedPtr<Kernel::ReadableEvent> GetInteractiveDataEvent() const; - Kernel::SharedPtr<Kernel::ReadableEvent> GetStateChangedEvent() const; + std::shared_ptr<Kernel::ReadableEvent> GetNormalDataEvent() const; + std::shared_ptr<Kernel::ReadableEvent> GetInteractiveDataEvent() const; + std::shared_ptr<Kernel::ReadableEvent> GetStateChangedEvent() const; private: // Queues are named from applet's perspective diff --git a/src/core/hle/service/am/applets/web_browser.cpp b/src/core/hle/service/am/applets/web_browser.cpp index 32283e819..5546ef6e8 100644 --- a/src/core/hle/service/am/applets/web_browser.cpp +++ b/src/core/hle/service/am/applets/web_browser.cpp @@ -337,7 +337,7 @@ void WebBrowser::ExecuteInternal() { void WebBrowser::InitializeShop() { if (frontend_e_commerce == nullptr) { LOG_ERROR(Service_AM, "Missing ECommerce Applet frontend!"); - status = ResultCode(-1); + status = RESULT_UNKNOWN; return; } @@ -353,7 +353,7 @@ void WebBrowser::InitializeShop() { if (url == args.end()) { LOG_ERROR(Service_AM, "Missing EShop Arguments URL for initialization!"); - status = ResultCode(-1); + status = RESULT_UNKNOWN; return; } @@ -366,7 +366,7 @@ void WebBrowser::InitializeShop() { // Less is missing info, More is malformed if (split_query.size() != 2) { LOG_ERROR(Service_AM, "EShop Arguments has more than one question mark, malformed"); - status = ResultCode(-1); + status = RESULT_UNKNOWN; return; } @@ -390,7 +390,7 @@ void WebBrowser::InitializeShop() { if (scene == shop_query.end()) { LOG_ERROR(Service_AM, "No scene parameter was passed via shop query!"); - status = ResultCode(-1); + status = RESULT_UNKNOWN; return; } @@ -406,7 +406,7 @@ void WebBrowser::InitializeShop() { const auto target = target_map.find(scene->second); if (target == target_map.end()) { LOG_ERROR(Service_AM, "Scene for shop query is invalid! (scene={})", scene->second); - status = ResultCode(-1); + status = RESULT_UNKNOWN; return; } @@ -427,7 +427,7 @@ void WebBrowser::InitializeOffline() { if (args.find(WebArgTLVType::DocumentPath) == args.end() || args.find(WebArgTLVType::DocumentKind) == args.end() || args.find(WebArgTLVType::ApplicationID) == args.end()) { - status = ResultCode(-1); + status = RESULT_UNKNOWN; LOG_ERROR(Service_AM, "Missing necessary parameters for initialization!"); } @@ -476,7 +476,7 @@ void WebBrowser::InitializeOffline() { offline_romfs = GetApplicationRomFS(system, title_id, type); if (offline_romfs == nullptr) { - status = ResultCode(-1); + status = RESULT_UNKNOWN; LOG_ERROR(Service_AM, "Failed to find offline data for request!"); } @@ -496,7 +496,7 @@ void WebBrowser::ExecuteShop() { const auto check_optional_parameter = [this](const auto& p) { if (!p.has_value()) { LOG_ERROR(Service_AM, "Missing one or more necessary parameters for execution!"); - status = ResultCode(-1); + status = RESULT_UNKNOWN; return false; } diff --git a/src/core/hle/service/am/idle.cpp b/src/core/hle/service/am/idle.cpp index f814fe2c0..d256d57c8 100644 --- a/src/core/hle/service/am/idle.cpp +++ b/src/core/hle/service/am/idle.cpp @@ -10,7 +10,7 @@ IdleSys::IdleSys() : ServiceFramework{"idle:sys"} { // clang-format off static const FunctionInfo functions[] = { {0, nullptr, "GetAutoPowerDownEvent"}, - {1, nullptr, "Unknown1"}, + {1, nullptr, "IsAutoPowerDownRequested"}, {2, nullptr, "Unknown2"}, {3, nullptr, "SetHandlingContext"}, {4, nullptr, "LoadAndApplySettings"}, diff --git a/src/core/hle/service/am/omm.cpp b/src/core/hle/service/am/omm.cpp index 6ab3fb906..37389ccda 100644 --- a/src/core/hle/service/am/omm.cpp +++ b/src/core/hle/service/am/omm.cpp @@ -35,6 +35,8 @@ OMM::OMM() : ServiceFramework{"omm"} { {23, nullptr, "GetHdcpState"}, {24, nullptr, "ShowCardUpdateProcessing"}, {25, nullptr, "SetApplicationCecSettingsAndNotifyChanged"}, + {26, nullptr, "GetOperationModeSystemInfo"}, + {27, nullptr, "GetAppletFullAwakingSystemEvent"}, }; // clang-format on diff --git a/src/core/hle/service/aoc/aoc_u.cpp b/src/core/hle/service/aoc/aoc_u.cpp index f36ccbc49..4227a4adf 100644 --- a/src/core/hle/service/aoc/aoc_u.cpp +++ b/src/core/hle/service/aoc/aoc_u.cpp @@ -61,6 +61,7 @@ AOC_U::AOC_U(Core::System& system) {7, &AOC_U::PrepareAddOnContent, "PrepareAddOnContent"}, {8, &AOC_U::GetAddOnContentListChangedEvent, "GetAddOnContentListChangedEvent"}, {100, nullptr, "CreateEcPurchasedEventManager"}, + {101, nullptr, "CreatePermanentEcPurchasedEventManager"}, }; // clang-format on @@ -131,7 +132,7 @@ void AOC_U::ListAddOnContent(Kernel::HLERequestContext& ctx) { if (out.size() < offset) { IPC::ResponseBuilder rb{ctx, 2}; // TODO(DarkLordZach): Find the correct error code. - rb.Push(ResultCode(-1)); + rb.Push(RESULT_UNKNOWN); return; } diff --git a/src/core/hle/service/audio/audctl.cpp b/src/core/hle/service/audio/audctl.cpp index 6a01d4d29..9e08e5346 100644 --- a/src/core/hle/service/audio/audctl.cpp +++ b/src/core/hle/service/audio/audctl.cpp @@ -38,6 +38,7 @@ AudCtl::AudCtl() : ServiceFramework{"audctl"} { {24, nullptr, "GetSystemOutputMasterVolume"}, {25, nullptr, "GetAudioVolumeDataForPlayReport"}, {26, nullptr, "UpdateHeadphoneSettings"}, + {27, nullptr, "SetVolumeMappingTableForDev"}, }; // clang-format on diff --git a/src/core/hle/service/audio/audout_u.cpp b/src/core/hle/service/audio/audout_u.cpp index 6a29377e3..4fb2cbc4b 100644 --- a/src/core/hle/service/audio/audout_u.cpp +++ b/src/core/hle/service/audio/audout_u.cpp @@ -43,7 +43,8 @@ public: IAudioOut(Core::System& system, AudoutParams audio_params, AudioCore::AudioOut& audio_core, std::string&& device_name, std::string&& unique_name) : ServiceFramework("IAudioOut"), audio_core(audio_core), - device_name(std::move(device_name)), audio_params(audio_params) { + device_name(std::move(device_name)), + audio_params(audio_params), main_memory{system.Memory()} { // clang-format off static const FunctionInfo functions[] = { {0, &IAudioOut::GetAudioOutState, "GetAudioOutState"}, @@ -137,7 +138,7 @@ private: const u64 tag{rp.Pop<u64>()}; std::vector<s16> samples(audio_buffer.buffer_size / sizeof(s16)); - Memory::ReadBlock(audio_buffer.buffer, samples.data(), audio_buffer.buffer_size); + main_memory.ReadBlock(audio_buffer.buffer, samples.data(), audio_buffer.buffer_size); if (!audio_core.QueueBuffer(stream, tag, std::move(samples))) { IPC::ResponseBuilder rb{ctx, 2}; @@ -209,6 +210,7 @@ private: /// This is the event handle used to check if the audio buffer was released Kernel::EventPair buffer_event; + Memory::Memory& main_memory; }; AudOutU::AudOutU(Core::System& system_) : ServiceFramework("audout:u"), system{system_} { diff --git a/src/core/hle/service/audio/audren_u.cpp b/src/core/hle/service/audio/audren_u.cpp index 4ea7ade6e..82a5dbf14 100644 --- a/src/core/hle/service/audio/audren_u.cpp +++ b/src/core/hle/service/audio/audren_u.cpp @@ -49,8 +49,9 @@ public: system_event = Kernel::WritableEvent::CreateEventPair(system.Kernel(), "IAudioRenderer:SystemEvent"); - renderer = std::make_unique<AudioCore::AudioRenderer>( - system.CoreTiming(), audren_params, system_event.writable, instance_number); + renderer = std::make_unique<AudioCore::AudioRenderer>(system.CoreTiming(), system.Memory(), + audren_params, system_event.writable, + instance_number); } private: diff --git a/src/core/hle/service/audio/hwopus.cpp b/src/core/hle/service/audio/hwopus.cpp index cb4a1160d..cb839e4a2 100644 --- a/src/core/hle/service/audio/hwopus.cpp +++ b/src/core/hle/service/audio/hwopus.cpp @@ -80,7 +80,7 @@ private: LOG_ERROR(Audio, "Failed to decode opus data"); IPC::ResponseBuilder rb{ctx, 2}; // TODO(ogniK): Use correct error code - rb.Push(ResultCode(-1)); + rb.Push(RESULT_UNKNOWN); return; } @@ -278,7 +278,7 @@ void HwOpus::OpenOpusDecoder(Kernel::HLERequestContext& ctx) { LOG_ERROR(Audio, "Failed to create Opus decoder (error={}).", error); IPC::ResponseBuilder rb{ctx, 2}; // TODO(ogniK): Use correct error code - rb.Push(ResultCode(-1)); + rb.Push(RESULT_UNKNOWN); return; } diff --git a/src/core/hle/service/bcat/backend/backend.cpp b/src/core/hle/service/bcat/backend/backend.cpp index dec0849b8..6f5ea095a 100644 --- a/src/core/hle/service/bcat/backend/backend.cpp +++ b/src/core/hle/service/bcat/backend/backend.cpp @@ -16,7 +16,7 @@ ProgressServiceBackend::ProgressServiceBackend(Kernel::KernelCore& kernel, kernel, std::string("ProgressServiceBackend:UpdateEvent:").append(event_name)); } -Kernel::SharedPtr<Kernel::ReadableEvent> ProgressServiceBackend::GetEvent() const { +std::shared_ptr<Kernel::ReadableEvent> ProgressServiceBackend::GetEvent() const { return event.readable; } diff --git a/src/core/hle/service/bcat/backend/backend.h b/src/core/hle/service/bcat/backend/backend.h index ea4b16ad0..48bbbe66f 100644 --- a/src/core/hle/service/bcat/backend/backend.h +++ b/src/core/hle/service/bcat/backend/backend.h @@ -98,7 +98,7 @@ public: private: explicit ProgressServiceBackend(Kernel::KernelCore& kernel, std::string_view event_name); - Kernel::SharedPtr<Kernel::ReadableEvent> GetEvent() const; + std::shared_ptr<Kernel::ReadableEvent> GetEvent() const; DeliveryCacheProgressImpl& GetImpl(); void SignalUpdate() const; diff --git a/src/core/hle/service/bcat/backend/boxcat.cpp b/src/core/hle/service/bcat/backend/boxcat.cpp index 918159e11..67e39a5c4 100644 --- a/src/core/hle/service/bcat/backend/boxcat.cpp +++ b/src/core/hle/service/bcat/backend/boxcat.cpp @@ -114,7 +114,7 @@ void HandleDownloadDisplayResult(const AM::Applets::AppletManager& applet_manage const auto& frontend{applet_manager.GetAppletFrontendSet()}; frontend.error->ShowCustomErrorText( - ResultCode(-1), "There was an error while attempting to use Boxcat.", + RESULT_UNKNOWN, "There was an error while attempting to use Boxcat.", DOWNLOAD_RESULT_LOG_MESSAGES[static_cast<std::size_t>(res)], [] {}); } @@ -255,7 +255,7 @@ private: using Digest = std::array<u8, 0x20>; static Digest DigestFile(std::vector<u8> bytes) { Digest out{}; - mbedtls_sha256(bytes.data(), bytes.size(), out.data(), 0); + mbedtls_sha256_ret(bytes.data(), bytes.size(), out.data(), 0); return out; } diff --git a/src/core/hle/service/bcat/module.cpp b/src/core/hle/service/bcat/module.cpp index 6d9d1527d..7ada67130 100644 --- a/src/core/hle/service/bcat/module.cpp +++ b/src/core/hle/service/bcat/module.cpp @@ -46,7 +46,7 @@ u64 GetCurrentBuildID(const Core::System::CurrentBuildProcessID& id) { BCATDigest DigestFile(const FileSys::VirtualFile& file) { BCATDigest out{}; const auto bytes = file->ReadAllBytes(); - mbedtls_md5(bytes.data(), bytes.size(), out.data()); + mbedtls_md5_ret(bytes.data(), bytes.size(), out.data()); return out; } @@ -87,7 +87,7 @@ struct DeliveryCacheDirectoryEntry { class IDeliveryCacheProgressService final : public ServiceFramework<IDeliveryCacheProgressService> { public: - IDeliveryCacheProgressService(Kernel::SharedPtr<Kernel::ReadableEvent> event, + IDeliveryCacheProgressService(std::shared_ptr<Kernel::ReadableEvent> event, const DeliveryCacheProgressImpl& impl) : ServiceFramework{"IDeliveryCacheProgressService"}, event(std::move(event)), impl(impl) { // clang-format off @@ -118,7 +118,7 @@ private: rb.Push(RESULT_SUCCESS); } - Kernel::SharedPtr<Kernel::ReadableEvent> event; + std::shared_ptr<Kernel::ReadableEvent> event; const DeliveryCacheProgressImpl& impl; }; @@ -137,14 +137,20 @@ public: {10200, nullptr, "CancelSyncDeliveryCacheRequest"}, {20100, nullptr, "RequestSyncDeliveryCacheWithApplicationId"}, {20101, nullptr, "RequestSyncDeliveryCacheWithApplicationIdAndDirectoryName"}, + {20300, nullptr, "GetDeliveryCacheStorageUpdateNotifier"}, + {20301, nullptr, "RequestSuspendDeliveryTask"}, + {20400, nullptr, "RegisterSystemApplicationDeliveryTask"}, + {20401, nullptr, "UnregisterSystemApplicationDeliveryTask"}, {30100, &IBcatService::SetPassphrase, "SetPassphrase"}, {30200, nullptr, "RegisterBackgroundDeliveryTask"}, {30201, nullptr, "UnregisterBackgroundDeliveryTask"}, {30202, nullptr, "BlockDeliveryTask"}, {30203, nullptr, "UnblockDeliveryTask"}, + {30300, nullptr, "RegisterSystemApplicationDeliveryTasks"}, {90100, nullptr, "EnumerateBackgroundDeliveryTask"}, {90200, nullptr, "GetDeliveryList"}, {90201, &IBcatService::ClearDeliveryCacheStorage, "ClearDeliveryCacheStorage"}, + {90202, nullptr, "ClearDeliveryTaskSubscriptionStatus"}, {90300, nullptr, "GetPushNotificationLog"}, }; // clang-format on diff --git a/src/core/hle/service/btdrv/btdrv.cpp b/src/core/hle/service/btdrv/btdrv.cpp index 4574d9572..40a06c9fd 100644 --- a/src/core/hle/service/btdrv/btdrv.cpp +++ b/src/core/hle/service/btdrv/btdrv.cpp @@ -155,6 +155,7 @@ public: {98, nullptr, "SetLeScanParameter"}, {256, nullptr, "GetIsManufacturingMode"}, {257, nullptr, "EmulateBluetoothCrash"}, + {258, nullptr, "GetBleChannelMap"}, }; // clang-format on diff --git a/src/core/hle/service/erpt/erpt.cpp b/src/core/hle/service/erpt/erpt.cpp index d9b32954e..4ec8c3093 100644 --- a/src/core/hle/service/erpt/erpt.cpp +++ b/src/core/hle/service/erpt/erpt.cpp @@ -24,6 +24,8 @@ public: {6, nullptr, "SubmitMultipleCategoryContext"}, {7, nullptr, "UpdateApplicationLaunchTime"}, {8, nullptr, "ClearApplicationLaunchTime"}, + {9, nullptr, "SubmitAttachment"}, + {10, nullptr, "CreateReportWithAttachments"}, }; // clang-format on @@ -38,6 +40,7 @@ public: static const FunctionInfo functions[] = { {0, nullptr, "OpenReport"}, {1, nullptr, "OpenManager"}, + {2, nullptr, "OpenAttachment"}, }; // clang-format on diff --git a/src/core/hle/service/es/es.cpp b/src/core/hle/service/es/es.cpp index f77ddd739..df00ae625 100644 --- a/src/core/hle/service/es/es.cpp +++ b/src/core/hle/service/es/es.cpp @@ -52,6 +52,8 @@ public: {34, nullptr, "GetEncryptedTicketSize"}, {35, nullptr, "GetEncryptedTicketData"}, {36, nullptr, "DeleteAllInactiveELicenseRequiredPersonalizedTicket"}, + {37, nullptr, "OwnTicket2"}, + {38, nullptr, "OwnTicket3"}, {503, nullptr, "GetTitleKey"}, }; // clang-format on diff --git a/src/core/hle/service/filesystem/filesystem.cpp b/src/core/hle/service/filesystem/filesystem.cpp index 11e5c56b7..102017d73 100644 --- a/src/core/hle/service/filesystem/filesystem.cpp +++ b/src/core/hle/service/filesystem/filesystem.cpp @@ -58,11 +58,11 @@ ResultCode VfsDirectoryServiceWrapper::CreateFile(const std::string& path_, u64 auto file = dir->CreateFile(FileUtil::GetFilename(path)); if (file == nullptr) { // TODO(DarkLordZach): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } if (!file->Resize(size)) { // TODO(DarkLordZach): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } return RESULT_SUCCESS; } @@ -80,7 +80,7 @@ ResultCode VfsDirectoryServiceWrapper::DeleteFile(const std::string& path_) cons } if (!dir->DeleteFile(FileUtil::GetFilename(path))) { // TODO(DarkLordZach): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } return RESULT_SUCCESS; @@ -94,7 +94,7 @@ ResultCode VfsDirectoryServiceWrapper::CreateDirectory(const std::string& path_) auto new_dir = dir->CreateSubdirectory(FileUtil::GetFilename(path)); if (new_dir == nullptr) { // TODO(DarkLordZach): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } return RESULT_SUCCESS; } @@ -104,7 +104,7 @@ ResultCode VfsDirectoryServiceWrapper::DeleteDirectory(const std::string& path_) auto dir = GetDirectoryRelativeWrapped(backing, FileUtil::GetParentPath(path)); if (!dir->DeleteSubdirectory(FileUtil::GetFilename(path))) { // TODO(DarkLordZach): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } return RESULT_SUCCESS; } @@ -114,7 +114,7 @@ ResultCode VfsDirectoryServiceWrapper::DeleteDirectoryRecursively(const std::str auto dir = GetDirectoryRelativeWrapped(backing, FileUtil::GetParentPath(path)); if (!dir->DeleteSubdirectoryRecursive(FileUtil::GetFilename(path))) { // TODO(DarkLordZach): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } return RESULT_SUCCESS; } @@ -125,7 +125,7 @@ ResultCode VfsDirectoryServiceWrapper::CleanDirectoryRecursively(const std::stri if (!dir->CleanSubdirectoryRecursive(FileUtil::GetFilename(sanitized_path))) { // TODO(DarkLordZach): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } return RESULT_SUCCESS; @@ -142,7 +142,7 @@ ResultCode VfsDirectoryServiceWrapper::RenameFile(const std::string& src_path_, return FileSys::ERROR_PATH_NOT_FOUND; if (!src->Rename(FileUtil::GetFilename(dest_path))) { // TODO(DarkLordZach): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } return RESULT_SUCCESS; } @@ -160,7 +160,7 @@ ResultCode VfsDirectoryServiceWrapper::RenameFile(const std::string& src_path_, if (!src->GetContainingDirectory()->DeleteFile(FileUtil::GetFilename(src_path))) { // TODO(DarkLordZach): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } return RESULT_SUCCESS; @@ -177,7 +177,7 @@ ResultCode VfsDirectoryServiceWrapper::RenameDirectory(const std::string& src_pa return FileSys::ERROR_PATH_NOT_FOUND; if (!src->Rename(FileUtil::GetFilename(dest_path))) { // TODO(DarkLordZach): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } return RESULT_SUCCESS; } @@ -189,7 +189,7 @@ ResultCode VfsDirectoryServiceWrapper::RenameDirectory(const std::string& src_pa src_path, dest_path); // TODO(DarkLordZach): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } ResultVal<FileSys::VirtualFile> VfsDirectoryServiceWrapper::OpenFile(const std::string& path_, @@ -287,7 +287,7 @@ ResultVal<FileSys::VirtualFile> FileSystemController::OpenRomFSCurrentProcess() if (romfs_factory == nullptr) { // TODO(bunnei): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } return romfs_factory->OpenCurrentProcess(system.CurrentProcess()->GetTitleID()); @@ -300,7 +300,7 @@ ResultVal<FileSys::VirtualFile> FileSystemController::OpenRomFS( if (romfs_factory == nullptr) { // TODO(bunnei): Find a better error code for this - return ResultCode(-1); + return RESULT_UNKNOWN; } return romfs_factory->Open(title_id, storage_id, type); diff --git a/src/core/hle/service/filesystem/fsp_srv.cpp b/src/core/hle/service/filesystem/fsp_srv.cpp index cbd5466c1..55d62fc5e 100644 --- a/src/core/hle/service/filesystem/fsp_srv.cpp +++ b/src/core/hle/service/filesystem/fsp_srv.cpp @@ -256,8 +256,8 @@ public: // TODO(DarkLordZach): Verify that this is the correct behavior. // Build entry index now to save time later. - BuildEntryIndex(entries, backend->GetFiles(), FileSys::File); - BuildEntryIndex(entries, backend->GetSubdirectories(), FileSys::Directory); + BuildEntryIndex(entries, backend->GetFiles(), FileSys::EntryType::File); + BuildEntryIndex(entries, backend->GetSubdirectories(), FileSys::EntryType::Directory); } private: @@ -391,13 +391,10 @@ public: } void RenameFile(Kernel::HLERequestContext& ctx) { - std::vector<u8> buffer; - buffer.resize(ctx.BufferDescriptorX()[0].Size()); - Memory::ReadBlock(ctx.BufferDescriptorX()[0].Address(), buffer.data(), buffer.size()); + std::vector<u8> buffer = ctx.ReadBuffer(0); const std::string src_name = Common::StringFromBuffer(buffer); - buffer.resize(ctx.BufferDescriptorX()[1].Size()); - Memory::ReadBlock(ctx.BufferDescriptorX()[1].Address(), buffer.data(), buffer.size()); + buffer = ctx.ReadBuffer(1); const std::string dst_name = Common::StringFromBuffer(buffer); LOG_DEBUG(Service_FS, "called. file '{}' to file '{}'", src_name, dst_name); @@ -680,6 +677,7 @@ FSP_SRV::FSP_SRV(FileSystemController& fsc, const Core::Reporter& reporter) {33, nullptr, "DeleteCacheStorage"}, {34, nullptr, "GetCacheStorageSize"}, {35, nullptr, "CreateSaveDataFileSystemByHashSalt"}, + {36, nullptr, "OpenHostFileSystemWithOption"}, {51, &FSP_SRV::OpenSaveDataFileSystem, "OpenSaveDataFileSystem"}, {52, nullptr, "OpenSaveDataFileSystemBySystemSaveDataId"}, {53, &FSP_SRV::OpenReadOnlySaveDataFileSystem, "OpenReadOnlySaveDataFileSystem"}, @@ -694,11 +692,14 @@ FSP_SRV::FSP_SRV(FileSystemController& fsc, const Core::Reporter& reporter) {66, nullptr, "WriteSaveDataFileSystemExtraData2"}, {67, nullptr, "FindSaveDataWithFilter"}, {68, nullptr, "OpenSaveDataInfoReaderBySaveDataFilter"}, + {69, nullptr, "ReadSaveDataFileSystemExtraDataBySaveDataAttribute"}, + {70, nullptr, "WriteSaveDataFileSystemExtraDataBySaveDataAttribute"}, {80, nullptr, "OpenSaveDataMetaFile"}, {81, nullptr, "OpenSaveDataTransferManager"}, {82, nullptr, "OpenSaveDataTransferManagerVersion2"}, {83, nullptr, "OpenSaveDataTransferProhibiterForCloudBackUp"}, {84, nullptr, "ListApplicationAccessibleSaveDataOwnerId"}, + {85, nullptr, "OpenSaveDataTransferManagerForSaveDataRepair"}, {100, nullptr, "OpenImageDirectoryFileSystem"}, {110, nullptr, "OpenContentStorageFileSystem"}, {120, nullptr, "OpenCloudBackupWorkStorageFileSystem"}, @@ -756,6 +757,8 @@ FSP_SRV::FSP_SRV(FileSystemController& fsc, const Core::Reporter& reporter) {1009, nullptr, "GetAndClearMemoryReportInfo"}, {1010, nullptr, "SetDataStorageRedirectTarget"}, {1011, &FSP_SRV::GetAccessLogVersionInfo, "GetAccessLogVersionInfo"}, + {1012, nullptr, "GetFsStackUsage"}, + {1013, nullptr, "UnsetSaveDataRootPath"}, {1100, nullptr, "OverrideSaveDataTransferTokenSignVerificationKey"}, {1110, nullptr, "CorruptSaveDataFileSystemBySaveDataSpaceId2"}, {1200, nullptr, "OpenMultiCommitManager"}, @@ -785,7 +788,7 @@ void FSP_SRV::OpenFileSystemWithPatch(Kernel::HLERequestContext& ctx) { static_cast<u8>(type), title_id); IPC::ResponseBuilder rb{ctx, 2, 0, 0}; - rb.Push(ResultCode(-1)); + rb.Push(RESULT_UNKNOWN); } void FSP_SRV::OpenSdCardFileSystem(Kernel::HLERequestContext& ctx) { @@ -891,7 +894,7 @@ void FSP_SRV::OpenDataStorageByCurrentProcess(Kernel::HLERequestContext& ctx) { // TODO (bunnei): Find the right error code to use here LOG_CRITICAL(Service_FS, "no file system interface available!"); IPC::ResponseBuilder rb{ctx, 2}; - rb.Push(ResultCode(-1)); + rb.Push(RESULT_UNKNOWN); return; } @@ -928,7 +931,7 @@ void FSP_SRV::OpenDataStorageByDataId(Kernel::HLERequestContext& ctx) { "could not open data storage with title_id={:016X}, storage_id={:02X}", title_id, static_cast<u8>(storage_id)); IPC::ResponseBuilder rb{ctx, 2}; - rb.Push(ResultCode(-1)); + rb.Push(RESULT_UNKNOWN); return; } diff --git a/src/core/hle/service/friend/friend.cpp b/src/core/hle/service/friend/friend.cpp index 1a0214f08..219176c31 100644 --- a/src/core/hle/service/friend/friend.cpp +++ b/src/core/hle/service/friend/friend.cpp @@ -60,6 +60,9 @@ public: {20801, nullptr, "SyncUserSetting"}, {20900, nullptr, "RequestListSummaryOverlayNotification"}, {21000, nullptr, "GetExternalApplicationCatalog"}, + {22000, nullptr, "GetReceivedFriendInvitationList"}, + {22001, nullptr, "GetReceivedFriendInvitationDetailedInfo"}, + {22010, nullptr, "GetReceivedFriendInvitationCountCache"}, {30100, nullptr, "DropFriendNewlyFlags"}, {30101, nullptr, "DeleteFriend"}, {30110, nullptr, "DropFriendNewlyFlag"}, @@ -91,6 +94,8 @@ public: {30812, nullptr, "ChangePlayLogPermission"}, {30820, nullptr, "IssueFriendCode"}, {30830, nullptr, "ClearPlayLog"}, + {30900, nullptr, "SendFriendInvitation"}, + {30910, nullptr, "ReadFriendInvitation"}, {49900, nullptr, "DeleteNetworkServiceAccountCache"}, }; // clang-format on diff --git a/src/core/hle/service/hid/controllers/npad.cpp b/src/core/hle/service/hid/controllers/npad.cpp index 79fff517e..4d952adc0 100644 --- a/src/core/hle/service/hid/controllers/npad.cpp +++ b/src/core/hle/service/hid/controllers/npad.cpp @@ -501,8 +501,7 @@ void Controller_NPad::VibrateController(const std::vector<u32>& controller_ids, last_processed_vibration = vibrations.back(); } -Kernel::SharedPtr<Kernel::ReadableEvent> Controller_NPad::GetStyleSetChangedEvent( - u32 npad_id) const { +std::shared_ptr<Kernel::ReadableEvent> Controller_NPad::GetStyleSetChangedEvent(u32 npad_id) const { // TODO(ogniK): Figure out the best time to signal this event. This event seems that it should // be signalled at least once, and signaled after a new controller is connected? const auto& styleset_event = styleset_changed_events[NPadIdToIndex(npad_id)]; diff --git a/src/core/hle/service/hid/controllers/npad.h b/src/core/hle/service/hid/controllers/npad.h index 16c4caa1f..931f03430 100644 --- a/src/core/hle/service/hid/controllers/npad.h +++ b/src/core/hle/service/hid/controllers/npad.h @@ -109,7 +109,7 @@ public: void VibrateController(const std::vector<u32>& controller_ids, const std::vector<Vibration>& vibrations); - Kernel::SharedPtr<Kernel::ReadableEvent> GetStyleSetChangedEvent(u32 npad_id) const; + std::shared_ptr<Kernel::ReadableEvent> GetStyleSetChangedEvent(u32 npad_id) const; Vibration GetLastVibration() const; void AddNewController(NPadControllerType controller); diff --git a/src/core/hle/service/hid/hid.cpp b/src/core/hle/service/hid/hid.cpp index ecc130f6c..89bf8b815 100644 --- a/src/core/hle/service/hid/hid.cpp +++ b/src/core/hle/service/hid/hid.cpp @@ -77,15 +77,14 @@ IAppletResource::IAppletResource(Core::System& system) GetController<Controller_Stubbed>(HidController::Unknown3).SetCommonHeaderOffset(0x5000); // Register update callbacks - auto& core_timing = system.CoreTiming(); pad_update_event = - core_timing.RegisterEvent("HID::UpdatePadCallback", [this](u64 userdata, s64 cycles_late) { + Core::Timing::CreateEvent("HID::UpdatePadCallback", [this](u64 userdata, s64 cycles_late) { UpdateControllers(userdata, cycles_late); }); // TODO(shinyquagsire23): Other update callbacks? (accel, gyro?) - core_timing.ScheduleEvent(pad_update_ticks, pad_update_event); + system.CoreTiming().ScheduleEvent(pad_update_ticks, pad_update_event); ReloadInputDevices(); } @@ -215,6 +214,8 @@ Hid::Hid(Core::System& system) : ServiceFramework("hid"), system(system) { {132, nullptr, "EnableUnintendedHomeButtonInputProtection"}, {133, nullptr, "SetNpadJoyAssignmentModeSingleWithDestination"}, {134, nullptr, "SetNpadAnalogStickUseCenterClamp"}, + {135, nullptr, "SetNpadCaptureButtonAssignment"}, + {136, nullptr, "ClearNpadCaptureButtonAssignment"}, {200, &Hid::GetVibrationDeviceInfo, "GetVibrationDeviceInfo"}, {201, &Hid::SendVibrationValue, "SendVibrationValue"}, {202, &Hid::GetActualVibrationValue, "GetActualVibrationValue"}, @@ -245,6 +246,8 @@ Hid::Hid(Core::System& system) : ServiceFramework("hid"), system(system) { {404, nullptr, "HasLeftRightBattery"}, {405, nullptr, "GetNpadInterfaceType"}, {406, nullptr, "GetNpadLeftRightInterfaceType"}, + {407, nullptr, "GetNpadOfHighestBatteryLevelForJoyLeft"}, + {408, nullptr, "GetNpadOfHighestBatteryLevelForJoyRight"}, {500, nullptr, "GetPalmaConnectionHandle"}, {501, nullptr, "InitializePalma"}, {502, nullptr, "AcquirePalmaOperationCompleteEvent"}, @@ -272,8 +275,13 @@ Hid::Hid(Core::System& system) : ServiceFramework("hid"), system(system) { {524, nullptr, "PairPalma"}, {525, &Hid::SetPalmaBoostMode, "SetPalmaBoostMode"}, {526, nullptr, "CancelWritePalmaWaveEntry"}, + {527, nullptr, "EnablePalmaBoostMode"}, + {528, nullptr, "GetPalmaBluetoothAddress"}, + {529, nullptr, "SetDisallowedPalmaConnection"}, {1000, nullptr, "SetNpadCommunicationMode"}, {1001, nullptr, "GetNpadCommunicationMode"}, + {1002, nullptr, "SetTouchScreenConfiguration"}, + {1003, nullptr, "IsFirmwareUpdateNeededForNotification"}, }; // clang-format on @@ -969,6 +977,9 @@ public: {310, nullptr, "GetMaskedSupportedNpadStyleSet"}, {311, nullptr, "SetNpadPlayerLedBlinkingDevice"}, {312, nullptr, "SetSupportedNpadStyleSetAll"}, + {313, nullptr, "GetNpadCaptureButtonAssignment"}, + {314, nullptr, "GetAppletFooterUiType"}, + {315, nullptr, "GetAppletDetailedUiType"}, {321, nullptr, "GetUniquePadsFromNpad"}, {322, nullptr, "GetIrSensorState"}, {323, nullptr, "GetXcdHandleForNpadWithIrSensor"}, @@ -984,6 +995,8 @@ public: {513, nullptr, "EndPermitVibrationSession"}, {520, nullptr, "EnableHandheldHids"}, {521, nullptr, "DisableHandheldHids"}, + {522, nullptr, "SetJoyConRailEnabled"}, + {523, nullptr, "IsJoyConRailEnabled"}, {540, nullptr, "AcquirePlayReportControllerUsageUpdateEvent"}, {541, nullptr, "GetPlayReportControllerUsages"}, {542, nullptr, "AcquirePlayReportRegisteredDeviceUpdateEvent"}, @@ -1010,6 +1023,7 @@ public: {809, nullptr, "GetUniquePadSerialNumber"}, {810, nullptr, "GetUniquePadControllerNumber"}, {811, nullptr, "GetSixAxisSensorUserCalibrationStage"}, + {812, nullptr, "GetConsoleUniqueSixAxisSensorHandle"}, {821, nullptr, "StartAnalogStickManualCalibration"}, {822, nullptr, "RetryCurrentAnalogStickManualCalibrationStage"}, {823, nullptr, "CancelAnalogStickManualCalibration"}, @@ -1020,6 +1034,8 @@ public: {828, nullptr, "IsAnalogStickInReleasePosition"}, {829, nullptr, "IsAnalogStickInCircumference"}, {830, nullptr, "SetNotificationLedPattern"}, + {831, nullptr, "SetNotificationLedPatternWithTimeout"}, + {832, nullptr, "PrepareHidsForNotificationWake"}, {850, nullptr, "IsUsbFullKeyControllerEnabled"}, {851, nullptr, "EnableUsbFullKeyController"}, {852, nullptr, "IsUsbConnected"}, @@ -1049,6 +1065,13 @@ public: {1132, nullptr, "CheckUsbFirmwareUpdateRequired"}, {1133, nullptr, "StartUsbFirmwareUpdate"}, {1134, nullptr, "GetUsbFirmwareUpdateState"}, + {1150, nullptr, "SetTouchScreenMagnification"}, + {1151, nullptr, "GetTouchScreenFirmwareVersion"}, + {1152, nullptr, "SetTouchScreenDefaultConfiguration"}, + {1153, nullptr, "GetTouchScreenDefaultConfiguration"}, + {1154, nullptr, "IsFirmwareAvailableForNotification"}, + {1155, nullptr, "SetForceHandheldStyleVibration"}, + {1156, nullptr, "SendConnectionTriggerWithoutTimeoutEvent"}, }; // clang-format on diff --git a/src/core/hle/service/hid/hid.h b/src/core/hle/service/hid/hid.h index f08e036a3..ad20f147c 100644 --- a/src/core/hle/service/hid/hid.h +++ b/src/core/hle/service/hid/hid.h @@ -67,9 +67,9 @@ private: void GetSharedMemoryHandle(Kernel::HLERequestContext& ctx); void UpdateControllers(u64 userdata, s64 cycles_late); - Kernel::SharedPtr<Kernel::SharedMemory> shared_mem; + std::shared_ptr<Kernel::SharedMemory> shared_mem; - Core::Timing::EventType* pad_update_event; + std::shared_ptr<Core::Timing::EventType> pad_update_event; Core::System& system; std::array<std::unique_ptr<ControllerBase>, static_cast<size_t>(HidController::MaxControllers)> diff --git a/src/core/hle/service/hid/irs.h b/src/core/hle/service/hid/irs.h index eb4e898dd..8918ad6ca 100644 --- a/src/core/hle/service/hid/irs.h +++ b/src/core/hle/service/hid/irs.h @@ -37,7 +37,7 @@ private: void RunIrLedProcessor(Kernel::HLERequestContext& ctx); void StopImageProcessorAsync(Kernel::HLERequestContext& ctx); void ActivateIrsensorWithFunctionLevel(Kernel::HLERequestContext& ctx); - Kernel::SharedPtr<Kernel::SharedMemory> shared_mem; + std::shared_ptr<Kernel::SharedMemory> shared_mem; const u32 device_handle{0xABCD}; Core::System& system; }; diff --git a/src/core/hle/service/ldr/ldr.cpp b/src/core/hle/service/ldr/ldr.cpp index 499376bfc..157aeec88 100644 --- a/src/core/hle/service/ldr/ldr.cpp +++ b/src/core/hle/service/ldr/ldr.cpp @@ -140,9 +140,10 @@ public: rb.Push(ERROR_INVALID_SIZE); return; } + // Read NRR data from memory std::vector<u8> nrr_data(nrr_size); - Memory::ReadBlock(nrr_address, nrr_data.data(), nrr_size); + system.Memory().ReadBlock(nrr_address, nrr_data.data(), nrr_size); NRRHeader header; std::memcpy(&header, nrr_data.data(), sizeof(NRRHeader)); @@ -291,10 +292,10 @@ public: // Read NRO data from memory std::vector<u8> nro_data(nro_size); - Memory::ReadBlock(nro_address, nro_data.data(), nro_size); + system.Memory().ReadBlock(nro_address, nro_data.data(), nro_size); SHA256Hash hash{}; - mbedtls_sha256(nro_data.data(), nro_data.size(), hash.data(), 0); + mbedtls_sha256_ret(nro_data.data(), nro_data.size(), hash.data(), 0); // NRO Hash is already loaded if (std::any_of(nro.begin(), nro.end(), [&hash](const std::pair<VAddr, NROInfo>& info) { diff --git a/src/core/hle/service/lm/lm.cpp b/src/core/hle/service/lm/lm.cpp index 435f2d286..346c8f899 100644 --- a/src/core/hle/service/lm/lm.cpp +++ b/src/core/hle/service/lm/lm.cpp @@ -17,7 +17,8 @@ namespace Service::LM { class ILogger final : public ServiceFramework<ILogger> { public: - ILogger(Manager& manager) : ServiceFramework("ILogger"), manager(manager) { + explicit ILogger(Manager& manager_, Memory::Memory& memory_) + : ServiceFramework("ILogger"), manager{manager_}, memory{memory_} { static const FunctionInfo functions[] = { {0, &ILogger::Log, "Log"}, {1, &ILogger::SetDestination, "SetDestination"}, @@ -35,15 +36,15 @@ private: MessageHeader header{}; VAddr addr{ctx.BufferDescriptorX()[0].Address()}; const VAddr end_addr{addr + ctx.BufferDescriptorX()[0].size}; - Memory::ReadBlock(addr, &header, sizeof(MessageHeader)); + memory.ReadBlock(addr, &header, sizeof(MessageHeader)); addr += sizeof(MessageHeader); FieldMap fields; while (addr < end_addr) { - const auto field = static_cast<Field>(Memory::Read8(addr++)); - const auto length = Memory::Read8(addr++); + const auto field = static_cast<Field>(memory.Read8(addr++)); + const auto length = memory.Read8(addr++); - if (static_cast<Field>(Memory::Read8(addr)) == Field::Skip) { + if (static_cast<Field>(memory.Read8(addr)) == Field::Skip) { ++addr; } @@ -54,7 +55,7 @@ private: } std::vector<u8> data(length); - Memory::ReadBlock(addr, data.data(), length); + memory.ReadBlock(addr, data.data(), length); fields.emplace(field, std::move(data)); } @@ -74,11 +75,13 @@ private: } Manager& manager; + Memory::Memory& memory; }; class LM final : public ServiceFramework<LM> { public: - explicit LM(Manager& manager) : ServiceFramework{"lm"}, manager(manager) { + explicit LM(Manager& manager_, Memory::Memory& memory_) + : ServiceFramework{"lm"}, manager{manager_}, memory{memory_} { // clang-format off static const FunctionInfo functions[] = { {0, &LM::OpenLogger, "OpenLogger"}, @@ -94,14 +97,16 @@ private: IPC::ResponseBuilder rb{ctx, 2, 0, 1}; rb.Push(RESULT_SUCCESS); - rb.PushIpcInterface<ILogger>(manager); + rb.PushIpcInterface<ILogger>(manager, memory); } Manager& manager; + Memory::Memory& memory; }; void InstallInterfaces(Core::System& system) { - std::make_shared<LM>(system.GetLogManager())->InstallAsService(system.ServiceManager()); + std::make_shared<LM>(system.GetLogManager(), system.Memory()) + ->InstallAsService(system.ServiceManager()); } } // namespace Service::LM diff --git a/src/core/hle/service/mii/mii.cpp b/src/core/hle/service/mii/mii.cpp index 0b3923ad9..a128edb43 100644 --- a/src/core/hle/service/mii/mii.cpp +++ b/src/core/hle/service/mii/mii.cpp @@ -50,6 +50,8 @@ public: {21, &IDatabaseService::GetIndex, "GetIndex"}, {22, &IDatabaseService::SetInterfaceVersion, "SetInterfaceVersion"}, {23, nullptr, "Convert"}, + {24, nullptr, "ConvertCoreDataToCharInfo"}, + {25, nullptr, "ConvertCharInfoToCoreData"}, }; // clang-format on @@ -242,7 +244,7 @@ private: const auto index = db.IndexOf(uuid); if (index > MAX_MIIS) { // TODO(DarkLordZach): Find a better error code - rb.Push(ResultCode(-1)); + rb.Push(RESULT_UNKNOWN); rb.Push(index); } else { rb.Push(RESULT_SUCCESS); @@ -268,7 +270,7 @@ private: IPC::ResponseBuilder rb{ctx, 2}; // TODO(DarkLordZach): Find a better error code - rb.Push(success ? RESULT_SUCCESS : ResultCode(-1)); + rb.Push(success ? RESULT_SUCCESS : RESULT_UNKNOWN); } void AddOrReplace(Kernel::HLERequestContext& ctx) { @@ -282,7 +284,7 @@ private: IPC::ResponseBuilder rb{ctx, 2}; // TODO(DarkLordZach): Find a better error code - rb.Push(success ? RESULT_SUCCESS : ResultCode(-1)); + rb.Push(success ? RESULT_SUCCESS : RESULT_UNKNOWN); } void Delete(Kernel::HLERequestContext& ctx) { diff --git a/src/core/hle/service/ncm/ncm.cpp b/src/core/hle/service/ncm/ncm.cpp index b405a4b66..89e283ca5 100644 --- a/src/core/hle/service/ncm/ncm.cpp +++ b/src/core/hle/service/ncm/ncm.cpp @@ -61,7 +61,8 @@ public: {5, nullptr, "RegisterHtmlDocumentPath"}, {6, nullptr, "UnregisterHtmlDocumentPath"}, {7, nullptr, "RedirectHtmlDocumentPath"}, - {8, nullptr, ""}, + {8, nullptr, "Refresh"}, + {9, nullptr, "RefreshExcluding"}, }; // clang-format on @@ -77,6 +78,8 @@ public: {0, nullptr, "ResolveAddOnContentPath"}, {1, nullptr, "RegisterAddOnContentStorage"}, {2, nullptr, "UnregisterAllAddOnContentPath"}, + {3, nullptr, "RefreshApplicationAddOnContent"}, + {4, nullptr, "UnregisterApplicationAddOnContent"}, }; // clang-format on @@ -118,6 +121,7 @@ public: {10, nullptr, "InactivateContentStorage"}, {11, nullptr, "ActivateContentMetaDatabase"}, {12, nullptr, "InactivateContentMetaDatabase"}, + {13, nullptr, "InvalidateRightsIdCache"}, }; // clang-format on diff --git a/src/core/hle/service/nfc/nfc.cpp b/src/core/hle/service/nfc/nfc.cpp index ca88bf97f..b7b34ce7e 100644 --- a/src/core/hle/service/nfc/nfc.cpp +++ b/src/core/hle/service/nfc/nfc.cpp @@ -215,6 +215,7 @@ public: {411, nullptr, "AttachActivateEvent"}, {412, nullptr, "AttachDeactivateEvent"}, {500, nullptr, "SetNfcEnabled"}, + {510, nullptr, "OutputTestWave"}, {1000, nullptr, "ReadMifare"}, {1001, nullptr, "WriteMifare"}, {1300, nullptr, "SendCommandByPassThrough"}, diff --git a/src/core/hle/service/nfp/nfp.cpp b/src/core/hle/service/nfp/nfp.cpp index 795d7b716..4b79eb81d 100644 --- a/src/core/hle/service/nfp/nfp.cpp +++ b/src/core/hle/service/nfp/nfp.cpp @@ -16,10 +16,7 @@ #include "core/hle/service/nfp/nfp_user.h" namespace Service::NFP { - namespace ErrCodes { -[[maybe_unused]] constexpr ResultCode ERR_TAG_FAILED(ErrorModule::NFP, - -1); // TODO(ogniK): Find the actual error code constexpr ResultCode ERR_NO_APPLICATION_AREA(ErrorModule::NFP, 152); } // namespace ErrCodes @@ -192,7 +189,7 @@ private: LOG_DEBUG(Service_NFP, "called"); auto nfc_event = nfp_interface.GetNFCEvent(); - if (!nfc_event->ShouldWait(Kernel::GetCurrentThread()) && !has_attached_handle) { + if (!nfc_event->ShouldWait(&ctx.GetThread()) && !has_attached_handle) { device_state = DeviceState::TagFound; nfc_event->Clear(); } @@ -345,7 +342,7 @@ bool Module::Interface::LoadAmiibo(const std::vector<u8>& buffer) { return true; } -const Kernel::SharedPtr<Kernel::ReadableEvent>& Module::Interface::GetNFCEvent() const { +const std::shared_ptr<Kernel::ReadableEvent>& Module::Interface::GetNFCEvent() const { return nfc_tag_load.readable; } diff --git a/src/core/hle/service/nfp/nfp.h b/src/core/hle/service/nfp/nfp.h index 9718ef745..200013795 100644 --- a/src/core/hle/service/nfp/nfp.h +++ b/src/core/hle/service/nfp/nfp.h @@ -34,7 +34,7 @@ public: void CreateUserInterface(Kernel::HLERequestContext& ctx); bool LoadAmiibo(const std::vector<u8>& buffer); - const Kernel::SharedPtr<Kernel::ReadableEvent>& GetNFCEvent() const; + const std::shared_ptr<Kernel::ReadableEvent>& GetNFCEvent() const; const AmiiboFile& GetAmiiboBuffer() const; private: diff --git a/src/core/hle/service/nifm/nifm.cpp b/src/core/hle/service/nifm/nifm.cpp index 01d557c7a..2e53b3221 100644 --- a/src/core/hle/service/nifm/nifm.cpp +++ b/src/core/hle/service/nifm/nifm.cpp @@ -208,6 +208,7 @@ private: IGeneralService::IGeneralService(Core::System& system) : ServiceFramework("IGeneralService"), system(system) { + // clang-format off static const FunctionInfo functions[] = { {1, &IGeneralService::GetClientId, "GetClientId"}, {2, &IGeneralService::CreateScanRequest, "CreateScanRequest"}, @@ -246,7 +247,14 @@ IGeneralService::IGeneralService(Core::System& system) {36, nullptr, "GetCurrentAccessPoint"}, {37, nullptr, "Shutdown"}, {38, nullptr, "GetAllowedChannels"}, + {39, nullptr, "NotifyApplicationSuspended"}, + {40, nullptr, "SetAcceptableNetworkTypeFlag"}, + {41, nullptr, "GetAcceptableNetworkTypeFlag"}, + {42, nullptr, "NotifyConnectionStateChanged"}, + {43, nullptr, "SetWowlDelayedWakeTime"}, }; + // clang-format on + RegisterHandlers(functions); } diff --git a/src/core/hle/service/nim/nim.cpp b/src/core/hle/service/nim/nim.cpp index 7d6cf2070..e85f123e2 100644 --- a/src/core/hle/service/nim/nim.cpp +++ b/src/core/hle/service/nim/nim.cpp @@ -116,6 +116,8 @@ public: {500, nullptr, "RequestSyncTicket"}, {501, nullptr, "RequestDownloadTicket"}, {502, nullptr, "RequestDownloadTicketForPrepurchasedContents"}, + {503, nullptr, "RequestSyncTicket"}, + {504, nullptr, "RequestDownloadTicketForPrepurchasedContents2"}, }; // clang-format on diff --git a/src/core/hle/service/npns/npns.cpp b/src/core/hle/service/npns/npns.cpp index 8751522ca..aa171473b 100644 --- a/src/core/hle/service/npns/npns.cpp +++ b/src/core/hle/service/npns/npns.cpp @@ -44,6 +44,10 @@ public: {113, nullptr, "DestroyJid"}, {114, nullptr, "AttachJid"}, {115, nullptr, "DetachJid"}, + {120, nullptr, "CreateNotificationReceiver"}, + {151, nullptr, "GetStateWithHandover"}, + {152, nullptr, "GetStateChangeEventWithHandover"}, + {153, nullptr, "GetDropEventWithHandover"}, {201, nullptr, "RequestChangeStateForceTimed"}, {202, nullptr, "RequestChangeStateForceAsync"}, }; @@ -74,6 +78,9 @@ public: {104, nullptr, "GetStatistics"}, {111, nullptr, "GetJid"}, {120, nullptr, "CreateNotificationReceiver"}, + {151, nullptr, "GetStateWithHandover"}, + {152, nullptr, "GetStateChangeEventWithHandover"}, + {153, nullptr, "GetDropEventWithHandover"}, }; // clang-format on diff --git a/src/core/hle/service/ns/ns.cpp b/src/core/hle/service/ns/ns.cpp index 15c156ce1..fdab3cf78 100644 --- a/src/core/hle/service/ns/ns.cpp +++ b/src/core/hle/service/ns/ns.cpp @@ -106,6 +106,7 @@ IApplicationManagerInterface::IApplicationManagerInterface() {96, nullptr, "AcquireApplicationLaunchInfo"}, {97, nullptr, "GetMainApplicationProgramIndex2"}, {98, nullptr, "EnableApplicationAllThreadDumpOnCrash"}, + {99, nullptr, "LaunchDevMenu"}, {100, nullptr, "ResetToFactorySettings"}, {101, nullptr, "ResetToFactorySettingsWithoutUserSaveData"}, {102, nullptr, "ResetToFactorySettingsForRefurbishment"}, @@ -130,6 +131,8 @@ IApplicationManagerInterface::IApplicationManagerInterface() {404, nullptr, "InvalidateApplicationControlCache"}, {405, nullptr, "ListApplicationControlCacheEntryInfo"}, {406, nullptr, "GetApplicationControlProperty"}, + {407, nullptr, "ListApplicationTitle"}, + {408, nullptr, "ListApplicationIcon"}, {502, nullptr, "RequestCheckGameCardRegistration"}, {503, nullptr, "RequestGameCardRegistrationGoldPoint"}, {504, nullptr, "RequestRegisterGameCard"}, @@ -138,6 +141,7 @@ IApplicationManagerInterface::IApplicationManagerInterface() {507, nullptr, "EnsureGameCardAccess"}, {508, nullptr, "GetLastGameCardMountFailureResult"}, {509, nullptr, "ListApplicationIdOnGameCard"}, + {510, nullptr, "GetGameCardPlatformRegion"}, {600, nullptr, "CountApplicationContentMeta"}, {601, nullptr, "ListApplicationContentMetaStatus"}, {602, nullptr, "ListAvailableAddOnContent"}, @@ -168,6 +172,9 @@ IApplicationManagerInterface::IApplicationManagerInterface() {910, nullptr, "HasApplicationRecord"}, {911, nullptr, "SetPreInstalledApplication"}, {912, nullptr, "ClearPreInstalledApplicationFlag"}, + {913, nullptr, "ListAllApplicationRecord"}, + {914, nullptr, "HideApplicationRecord"}, + {915, nullptr, "ShowApplicationRecord"}, {1000, nullptr, "RequestVerifyApplicationDeprecated"}, {1001, nullptr, "CorruptApplicationForDebug"}, {1002, nullptr, "RequestVerifyAddOnContentsRights"}, @@ -190,12 +197,14 @@ IApplicationManagerInterface::IApplicationManagerInterface() {1502, nullptr, "GetLastSdCardFormatUnexpectedResult"}, {1504, nullptr, "InsertSdCard"}, {1505, nullptr, "RemoveSdCard"}, + {1506, nullptr, "GetSdCardStartupStatus"}, {1600, nullptr, "GetSystemSeedForPseudoDeviceId"}, {1601, nullptr, "ResetSystemSeedForPseudoDeviceId"}, {1700, nullptr, "ListApplicationDownloadingContentMeta"}, {1701, nullptr, "GetApplicationView"}, {1702, nullptr, "GetApplicationDownloadTaskStatus"}, {1703, nullptr, "GetApplicationViewDownloadErrorContext"}, + {1704, nullptr, "GetApplicationViewWithPromotionInfo"}, {1800, nullptr, "IsNotificationSetupCompleted"}, {1801, nullptr, "GetLastNotificationInfoCount"}, {1802, nullptr, "ListLastNotificationInfo"}, @@ -223,6 +232,7 @@ IApplicationManagerInterface::IApplicationManagerInterface() {2017, nullptr, "CreateDownloadTask"}, {2018, nullptr, "GetApplicationDeliveryInfoHash"}, {2050, nullptr, "GetApplicationRightsOnClient"}, + {2051, nullptr, "InvalidateRightsIdCache"}, {2100, nullptr, "GetApplicationTerminateResult"}, {2101, nullptr, "GetRawApplicationTerminateResult"}, {2150, nullptr, "CreateRightsEnvironment"}, @@ -230,6 +240,8 @@ IApplicationManagerInterface::IApplicationManagerInterface() {2152, nullptr, "ActivateRightsEnvironment"}, {2153, nullptr, "DeactivateRightsEnvironment"}, {2154, nullptr, "ForceActivateRightsContextForExit"}, + {2155, nullptr, "UpdateRightsEnvironmentStatus"}, + {2156, nullptr, "CreateRightsEnvironmentForPreomia"}, {2160, nullptr, "AddTargetApplicationToRightsEnvironment"}, {2161, nullptr, "SetUsersToRightsEnvironment"}, {2170, nullptr, "GetRightsEnvironmentStatus"}, @@ -243,6 +255,20 @@ IApplicationManagerInterface::IApplicationManagerInterface() {2201, nullptr, "GetInstalledApplicationCopyIdentifier"}, {2250, nullptr, "RequestReportActiveELicence"}, {2300, nullptr, "ListEventLog"}, + {2350, nullptr, "PerformAutoUpdateByApplicationId"}, + {2351, nullptr, "RequestNoDownloadRightsErrorResolution"}, + {2352, nullptr, "RequestResolveNoDownloadRightsError"}, + {2400, nullptr, "GetPromotionInfo"}, + {2401, nullptr, "CountPromotionInfo"}, + {2402, nullptr, "ListPromotionInfo"}, + {2403, nullptr, "ImportPromotionJsonForDebug"}, + {2404, nullptr, "ClearPromotionInfoForDebug"}, + {2500, nullptr, "ConfirmAvailableTime"}, + {2510, nullptr, "CreateApplicationResource"}, + {2511, nullptr, "GetApplicationResource"}, + {2513, nullptr, "LaunchPreomia"}, + {2514, nullptr, "ClearTaskOfAsyncTaskManager"}, + {2800, nullptr, "GetApplicationIdOfPreomia"}, }; // clang-format on @@ -271,7 +297,7 @@ void IApplicationManagerInterface::GetApplicationControlData(Kernel::HLERequestC "output buffer is too small! (actual={:016X}, expected_min=0x4000)", size); IPC::ResponseBuilder rb{ctx, 2}; // TODO(DarkLordZach): Find a better error code for this. - rb.Push(ResultCode(-1)); + rb.Push(RESULT_UNKNOWN); return; } @@ -291,7 +317,7 @@ void IApplicationManagerInterface::GetApplicationControlData(Kernel::HLERequestC 0x4000 + control.second->GetSize()); IPC::ResponseBuilder rb{ctx, 2}; // TODO(DarkLordZach): Find a better error code for this. - rb.Push(ResultCode(-1)); + rb.Push(RESULT_UNKNOWN); return; } @@ -463,6 +489,7 @@ IECommerceInterface::IECommerceInterface() : ServiceFramework{"IECommerceInterfa {3, nullptr, "RequestSyncRights"}, {4, nullptr, "RequestUnlinkDevice"}, {5, nullptr, "RequestRevokeAllELicense"}, + {6, nullptr, "RequestSyncRightsBasedOnAssignedELicenses"}, }; // clang-format on diff --git a/src/core/hle/service/ns/pl_u.cpp b/src/core/hle/service/ns/pl_u.cpp index 23477315f..8da4e52c5 100644 --- a/src/core/hle/service/ns/pl_u.cpp +++ b/src/core/hle/service/ns/pl_u.cpp @@ -97,7 +97,7 @@ void EncryptSharedFont(const std::vector<u32>& input, std::vector<u8>& output, const auto key = Common::swap32(EXPECTED_RESULT ^ EXPECTED_MAGIC); std::vector<u32> transformed_font(input.size() + 2); transformed_font[0] = Common::swap32(EXPECTED_MAGIC); - transformed_font[1] = Common::swap32(input.size() * sizeof(u32)) ^ key; + transformed_font[1] = Common::swap32(static_cast<u32>(input.size() * sizeof(u32))) ^ key; std::transform(input.begin(), input.end(), transformed_font.begin() + 2, [key](u32 in) { return in ^ key; }); std::memcpy(output.data() + offset, transformed_font.data(), @@ -141,7 +141,7 @@ struct PL_U::Impl { } /// Handle to shared memory region designated for a shared font - Kernel::SharedPtr<Kernel::SharedMemory> shared_font_mem; + std::shared_ptr<Kernel::SharedMemory> shared_font_mem; /// Backing memory for the shared font data std::shared_ptr<Kernel::PhysicalMemory> shared_font; @@ -152,7 +152,7 @@ struct PL_U::Impl { PL_U::PL_U(Core::System& system) : ServiceFramework("pl:u"), impl{std::make_unique<Impl>()}, system(system) { - + // clang-format off static const FunctionInfo functions[] = { {0, &PL_U::RequestLoad, "RequestLoad"}, {1, &PL_U::GetLoadState, "GetLoadState"}, @@ -160,7 +160,13 @@ PL_U::PL_U(Core::System& system) {3, &PL_U::GetSharedMemoryAddressOffset, "GetSharedMemoryAddressOffset"}, {4, &PL_U::GetSharedMemoryNativeHandle, "GetSharedMemoryNativeHandle"}, {5, &PL_U::GetSharedFontInOrderOfPriority, "GetSharedFontInOrderOfPriority"}, + {6, nullptr, "GetSharedFontInOrderOfPriorityForSystem"}, + {100, nullptr, "RequestApplicationFunctionAuthorization"}, + {101, nullptr, "RequestApplicationFunctionAuthorizationForSystem"}, + {1000, nullptr, "LoadNgWordDataForPlatformRegionChina"}, + {1001, nullptr, "GetNgWordDataSizeForPlatformRegionChina"}, }; + // clang-format on RegisterHandlers(functions); auto& fsc = system.GetFileSystemController(); diff --git a/src/core/hle/service/nvdrv/devices/nvhost_gpu.cpp b/src/core/hle/service/nvdrv/devices/nvhost_gpu.cpp index 9de0ace22..6d8bca8bb 100644 --- a/src/core/hle/service/nvdrv/devices/nvhost_gpu.cpp +++ b/src/core/hle/service/nvdrv/devices/nvhost_gpu.cpp @@ -191,8 +191,8 @@ u32 nvhost_gpu::KickoffPB(const std::vector<u8>& input, std::vector<u8>& output, std::memcpy(entries.data(), input2.data(), params.num_entries * sizeof(Tegra::CommandListHeader)); } else { - Memory::ReadBlock(params.address, entries.data(), - params.num_entries * sizeof(Tegra::CommandListHeader)); + system.Memory().ReadBlock(params.address, entries.data(), + params.num_entries * sizeof(Tegra::CommandListHeader)); } UNIMPLEMENTED_IF(params.flags.add_wait.Value() != 0); UNIMPLEMENTED_IF(params.flags.add_increment.Value() != 0); diff --git a/src/core/hle/service/nvdrv/interface.cpp b/src/core/hle/service/nvdrv/interface.cpp index 68d139cfb..c8ea6c661 100644 --- a/src/core/hle/service/nvdrv/interface.cpp +++ b/src/core/hle/service/nvdrv/interface.cpp @@ -61,7 +61,7 @@ void NVDRV::IoctlBase(Kernel::HLERequestContext& ctx, IoctlVersion version) { if (ctrl.must_delay) { ctrl.fresh_call = false; ctx.SleepClientThread("NVServices::DelayedResponse", ctrl.timeout, - [=](Kernel::SharedPtr<Kernel::Thread> thread, + [=](std::shared_ptr<Kernel::Thread> thread, Kernel::HLERequestContext& ctx, Kernel::ThreadWakeupReason reason) { IoctlCtrl ctrl2{ctrl}; diff --git a/src/core/hle/service/nvdrv/nvdrv.cpp b/src/core/hle/service/nvdrv/nvdrv.cpp index cc9cd3fd1..197c77db0 100644 --- a/src/core/hle/service/nvdrv/nvdrv.cpp +++ b/src/core/hle/service/nvdrv/nvdrv.cpp @@ -100,11 +100,11 @@ void Module::SignalSyncpt(const u32 syncpoint_id, const u32 value) { } } -Kernel::SharedPtr<Kernel::ReadableEvent> Module::GetEvent(const u32 event_id) const { +std::shared_ptr<Kernel::ReadableEvent> Module::GetEvent(const u32 event_id) const { return events_interface.events[event_id].readable; } -Kernel::SharedPtr<Kernel::WritableEvent> Module::GetEventWriteable(const u32 event_id) const { +std::shared_ptr<Kernel::WritableEvent> Module::GetEventWriteable(const u32 event_id) const { return events_interface.events[event_id].writable; } diff --git a/src/core/hle/service/nvdrv/nvdrv.h b/src/core/hle/service/nvdrv/nvdrv.h index f8bb28969..d7a1bef91 100644 --- a/src/core/hle/service/nvdrv/nvdrv.h +++ b/src/core/hle/service/nvdrv/nvdrv.h @@ -114,9 +114,9 @@ public: void SignalSyncpt(const u32 syncpoint_id, const u32 value); - Kernel::SharedPtr<Kernel::ReadableEvent> GetEvent(u32 event_id) const; + std::shared_ptr<Kernel::ReadableEvent> GetEvent(u32 event_id) const; - Kernel::SharedPtr<Kernel::WritableEvent> GetEventWriteable(u32 event_id) const; + std::shared_ptr<Kernel::WritableEvent> GetEventWriteable(u32 event_id) const; private: /// Id to use for the next open file descriptor. diff --git a/src/core/hle/service/nvflinger/buffer_queue.cpp b/src/core/hle/service/nvflinger/buffer_queue.cpp index 1af11e80c..32b6f4b27 100644 --- a/src/core/hle/service/nvflinger/buffer_queue.cpp +++ b/src/core/hle/service/nvflinger/buffer_queue.cpp @@ -117,11 +117,11 @@ u32 BufferQueue::Query(QueryType type) { return 0; } -Kernel::SharedPtr<Kernel::WritableEvent> BufferQueue::GetWritableBufferWaitEvent() const { +std::shared_ptr<Kernel::WritableEvent> BufferQueue::GetWritableBufferWaitEvent() const { return buffer_wait_event.writable; } -Kernel::SharedPtr<Kernel::ReadableEvent> BufferQueue::GetBufferWaitEvent() const { +std::shared_ptr<Kernel::ReadableEvent> BufferQueue::GetBufferWaitEvent() const { return buffer_wait_event.readable; } diff --git a/src/core/hle/service/nvflinger/buffer_queue.h b/src/core/hle/service/nvflinger/buffer_queue.h index 8f9b18547..f4bbfd945 100644 --- a/src/core/hle/service/nvflinger/buffer_queue.h +++ b/src/core/hle/service/nvflinger/buffer_queue.h @@ -93,9 +93,9 @@ public: return id; } - Kernel::SharedPtr<Kernel::WritableEvent> GetWritableBufferWaitEvent() const; + std::shared_ptr<Kernel::WritableEvent> GetWritableBufferWaitEvent() const; - Kernel::SharedPtr<Kernel::ReadableEvent> GetBufferWaitEvent() const; + std::shared_ptr<Kernel::ReadableEvent> GetBufferWaitEvent() const; private: u32 id; diff --git a/src/core/hle/service/nvflinger/nvflinger.cpp b/src/core/hle/service/nvflinger/nvflinger.cpp index cc9522aad..52623cf89 100644 --- a/src/core/hle/service/nvflinger/nvflinger.cpp +++ b/src/core/hle/service/nvflinger/nvflinger.cpp @@ -37,8 +37,8 @@ NVFlinger::NVFlinger(Core::System& system) : system(system) { displays.emplace_back(4, "Null", system); // Schedule the screen composition events - composition_event = system.CoreTiming().RegisterEvent( - "ScreenComposition", [this](u64 userdata, s64 cycles_late) { + composition_event = + Core::Timing::CreateEvent("ScreenComposition", [this](u64 userdata, s64 cycles_late) { Compose(); const auto ticks = Settings::values.force_30fps_mode ? frame_ticks_30fps : GetNextTicks(); @@ -98,7 +98,7 @@ std::optional<u32> NVFlinger::FindBufferQueueId(u64 display_id, u64 layer_id) co return layer->GetBufferQueue().GetId(); } -Kernel::SharedPtr<Kernel::ReadableEvent> NVFlinger::FindVsyncEvent(u64 display_id) const { +std::shared_ptr<Kernel::ReadableEvent> NVFlinger::FindVsyncEvent(u64 display_id) const { auto* const display = FindDisplay(display_id); if (display == nullptr) { diff --git a/src/core/hle/service/nvflinger/nvflinger.h b/src/core/hle/service/nvflinger/nvflinger.h index 5d7e3bfb8..e3cc14bdc 100644 --- a/src/core/hle/service/nvflinger/nvflinger.h +++ b/src/core/hle/service/nvflinger/nvflinger.h @@ -62,7 +62,7 @@ public: /// Gets the vsync event for the specified display. /// /// If an invalid display ID is provided, then nullptr is returned. - Kernel::SharedPtr<Kernel::ReadableEvent> FindVsyncEvent(u64 display_id) const; + std::shared_ptr<Kernel::ReadableEvent> FindVsyncEvent(u64 display_id) const; /// Obtains a buffer queue identified by the ID. BufferQueue& FindBufferQueue(u32 id); @@ -103,7 +103,7 @@ private: u32 swap_interval = 1; /// Event that handles screen composition. - Core::Timing::EventType* composition_event; + std::shared_ptr<Core::Timing::EventType> composition_event; Core::System& system; }; diff --git a/src/core/hle/service/pm/pm.cpp b/src/core/hle/service/pm/pm.cpp index fe6b5f798..809eca0ab 100644 --- a/src/core/hle/service/pm/pm.cpp +++ b/src/core/hle/service/pm/pm.cpp @@ -16,9 +16,9 @@ constexpr ResultCode ERROR_PROCESS_NOT_FOUND{ErrorModule::PM, 1}; constexpr u64 NO_PROCESS_FOUND_PID{0}; -std::optional<Kernel::SharedPtr<Kernel::Process>> SearchProcessList( - const std::vector<Kernel::SharedPtr<Kernel::Process>>& process_list, - std::function<bool(const Kernel::SharedPtr<Kernel::Process>&)> predicate) { +std::optional<std::shared_ptr<Kernel::Process>> SearchProcessList( + const std::vector<std::shared_ptr<Kernel::Process>>& process_list, + std::function<bool(const std::shared_ptr<Kernel::Process>&)> predicate) { const auto iter = std::find_if(process_list.begin(), process_list.end(), predicate); if (iter == process_list.end()) { @@ -29,7 +29,7 @@ std::optional<Kernel::SharedPtr<Kernel::Process>> SearchProcessList( } void GetApplicationPidGeneric(Kernel::HLERequestContext& ctx, - const std::vector<Kernel::SharedPtr<Kernel::Process>>& process_list) { + const std::vector<std::shared_ptr<Kernel::Process>>& process_list) { const auto process = SearchProcessList(process_list, [](const auto& process) { return process->GetProcessID() == Kernel::Process::ProcessIDMin; }); @@ -124,7 +124,7 @@ private: class Info final : public ServiceFramework<Info> { public: - explicit Info(const std::vector<Kernel::SharedPtr<Kernel::Process>>& process_list) + explicit Info(const std::vector<std::shared_ptr<Kernel::Process>>& process_list) : ServiceFramework{"pm:info"}, process_list(process_list) { static const FunctionInfo functions[] = { {0, &Info::GetTitleId, "GetTitleId"}, @@ -154,7 +154,7 @@ private: rb.Push((*process)->GetTitleID()); } - const std::vector<Kernel::SharedPtr<Kernel::Process>>& process_list; + const std::vector<std::shared_ptr<Kernel::Process>>& process_list; }; class Shell final : public ServiceFramework<Shell> { @@ -172,7 +172,7 @@ public: {6, &Shell::GetApplicationPid, "GetApplicationPid"}, {7, nullptr, "BoostSystemMemoryResourceLimit"}, {8, nullptr, "EnableAdditionalSystemThreads"}, - {9, nullptr, "GetUnimplementedEventHandle"}, + {9, nullptr, "GetBootFinishedEventHandle"}, }; // clang-format on diff --git a/src/core/hle/service/prepo/prepo.cpp b/src/core/hle/service/prepo/prepo.cpp index 18d895263..5eb26caf8 100644 --- a/src/core/hle/service/prepo/prepo.cpp +++ b/src/core/hle/service/prepo/prepo.cpp @@ -25,6 +25,7 @@ public: {10103, &PlayReport::SaveReportWithUser<Core::Reporter::PlayReportType::New>, "SaveReportWithUser"}, {10200, nullptr, "RequestImmediateTransmission"}, {10300, nullptr, "GetTransmissionStatus"}, + {10400, nullptr, "GetSystemSessionId"}, {20100, &PlayReport::SaveSystemReport, "SaveSystemReport"}, {20101, &PlayReport::SaveSystemReportWithUser, "SaveSystemReportWithUser"}, {20200, nullptr, "SetOperationMode"}, diff --git a/src/core/hle/service/service.cpp b/src/core/hle/service/service.cpp index 7c5302017..fa5347af9 100644 --- a/src/core/hle/service/service.cpp +++ b/src/core/hle/service/service.cpp @@ -116,7 +116,7 @@ void ServiceFrameworkBase::InstallAsNamedPort() { port_installed = true; } -Kernel::SharedPtr<Kernel::ClientPort> ServiceFrameworkBase::CreatePort() { +std::shared_ptr<Kernel::ClientPort> ServiceFrameworkBase::CreatePort() { ASSERT(!port_installed); auto& kernel = Core::System::GetInstance().Kernel(); @@ -186,7 +186,7 @@ ResultCode ServiceFrameworkBase::HandleSyncRequest(Kernel::HLERequestContext& co UNIMPLEMENTED_MSG("command_type={}", static_cast<int>(context.GetCommandType())); } - context.WriteToOutgoingCommandBuffer(*Kernel::GetCurrentThread()); + context.WriteToOutgoingCommandBuffer(context.GetThread()); return RESULT_SUCCESS; } @@ -201,7 +201,7 @@ void Init(std::shared_ptr<SM::ServiceManager>& sm, Core::System& system) { auto nv_flinger = std::make_shared<NVFlinger::NVFlinger>(system); system.GetFileSystemController().CreateFactories(*system.GetFilesystem(), false); - SM::ServiceManager::InstallInterfaces(sm); + SM::ServiceManager::InstallInterfaces(sm, system.Kernel()); Account::InstallInterfaces(system); AM::InstallInterfaces(*sm, nv_flinger, system); diff --git a/src/core/hle/service/service.h b/src/core/hle/service/service.h index aef964861..022d885b6 100644 --- a/src/core/hle/service/service.h +++ b/src/core/hle/service/service.h @@ -65,7 +65,7 @@ public: /// Creates a port pair and registers it on the kernel's global port registry. void InstallAsNamedPort(); /// Creates and returns an unregistered port for the service. - Kernel::SharedPtr<Kernel::ClientPort> CreatePort(); + std::shared_ptr<Kernel::ClientPort> CreatePort(); void InvokeRequest(Kernel::HLERequestContext& ctx); diff --git a/src/core/hle/service/set/set.cpp b/src/core/hle/service/set/set.cpp index b54214421..5bcc0b588 100644 --- a/src/core/hle/service/set/set.cpp +++ b/src/core/hle/service/set/set.cpp @@ -124,6 +124,7 @@ SET::SET() : ServiceFramework("set") { {7, nullptr, "GetKeyCodeMap"}, {8, &SET::GetQuestFlag, "GetQuestFlag"}, {9, nullptr, "GetKeyCodeMap2"}, + {10, nullptr, "GetFirmwareVersionForDebug"}, }; // clang-format on diff --git a/src/core/hle/service/set/set_cal.cpp b/src/core/hle/service/set/set_cal.cpp index 5981c575c..1398a4a48 100644 --- a/src/core/hle/service/set/set_cal.cpp +++ b/src/core/hle/service/set/set_cal.cpp @@ -7,6 +7,7 @@ namespace Service::Set { SET_CAL::SET_CAL() : ServiceFramework("set:cal") { + // clang-format off static const FunctionInfo functions[] = { {0, nullptr, "GetBluetoothBdAddress"}, {1, nullptr, "GetConfigurationId1"}, @@ -40,8 +41,18 @@ SET_CAL::SET_CAL() : ServiceFramework("set:cal") { {30, nullptr, "GetAmiiboEcqvBlsCertificate"}, {31, nullptr, "GetAmiiboEcqvBlsRootCertificate"}, {32, nullptr, "GetUsbTypeCPowerSourceCircuitVersion"}, + {33, nullptr, "GetAnalogStickModuleTypeL"}, + {34, nullptr, "GetAnalogStickModelParameterL"}, + {35, nullptr, "GetAnalogStickFactoryCalibrationL"}, + {36, nullptr, "GetAnalogStickModuleTypeR"}, + {37, nullptr, "GetAnalogStickModelParameterR"}, + {38, nullptr, "GetAnalogStickFactoryCalibrationR"}, + {39, nullptr, "GetConsoleSixAxisSensorModuleType"}, + {40, nullptr, "GetConsoleSixAxisSensorHorizontalOffset"}, {41, nullptr, "GetBatteryVersion"}, }; + // clang-format on + RegisterHandlers(functions); } diff --git a/src/core/hle/service/set/set_fd.cpp b/src/core/hle/service/set/set_fd.cpp index cac6af86d..565882a31 100644 --- a/src/core/hle/service/set/set_fd.cpp +++ b/src/core/hle/service/set/set_fd.cpp @@ -7,6 +7,7 @@ namespace Service::Set { SET_FD::SET_FD() : ServiceFramework("set:fd") { + // clang-format off static const FunctionInfo functions[] = { {2, nullptr, "SetSettingsItemValue"}, {3, nullptr, "ResetSettingsItemValue"}, @@ -16,7 +17,10 @@ SET_FD::SET_FD() : ServiceFramework("set:fd") { {20, nullptr, "SetWebInspectorFlag"}, {21, nullptr, "SetAllowedSslHosts"}, {22, nullptr, "SetHostFsMountPoint"}, + {23, nullptr, "SetMemoryUsageRateFlag"}, }; + // clang-format on + RegisterHandlers(functions); } diff --git a/src/core/hle/service/set/set_sys.cpp b/src/core/hle/service/set/set_sys.cpp index 98d0cfdfd..b7c9ea74b 100644 --- a/src/core/hle/service/set/set_sys.cpp +++ b/src/core/hle/service/set/set_sys.cpp @@ -273,10 +273,21 @@ SET_SYS::SET_SYS() : ServiceFramework("set:sys") { {171, nullptr, "SetChineseTraditionalInputMethod"}, {172, nullptr, "GetPtmCycleCountReliability"}, {173, nullptr, "SetPtmCycleCountReliability"}, + {174, nullptr, "GetHomeMenuScheme"}, {175, nullptr, "GetThemeSettings"}, {176, nullptr, "SetThemeSettings"}, {177, nullptr, "GetThemeKey"}, {178, nullptr, "SetThemeKey"}, + {179, nullptr, "GetZoomFlag"}, + {180, nullptr, "SetZoomFlag"}, + {181, nullptr, "GetT"}, + {182, nullptr, "SetT"}, + {183, nullptr, "GetPlatformRegion"}, + {184, nullptr, "SetPlatformRegion"}, + {185, nullptr, "GetHomeMenuSchemeModel"}, + {186, nullptr, "GetMemoryUsageRateFlag"}, + {187, nullptr, "GetTouchScreenMode"}, + {188, nullptr, "SetTouchScreenMode"}, }; // clang-format on diff --git a/src/core/hle/service/sm/controller.cpp b/src/core/hle/service/sm/controller.cpp index e9ee73710..c45b285f8 100644 --- a/src/core/hle/service/sm/controller.cpp +++ b/src/core/hle/service/sm/controller.cpp @@ -30,10 +30,7 @@ void Controller::DuplicateSession(Kernel::HLERequestContext& ctx) { IPC::ResponseBuilder rb{ctx, 2, 0, 1, IPC::ResponseBuilder::Flags::AlwaysMoveHandles}; rb.Push(RESULT_SUCCESS); - Kernel::SharedPtr<Kernel::ClientSession> session{ctx.Session()->GetParent()->client}; - rb.PushMoveObjects(session); - - LOG_DEBUG(Service, "session={}", session->GetObjectId()); + rb.PushMoveObjects(ctx.Session()->GetParent()->Client()); } void Controller::DuplicateSessionEx(Kernel::HLERequestContext& ctx) { diff --git a/src/core/hle/service/sm/sm.cpp b/src/core/hle/service/sm/sm.cpp index 142929124..88909504d 100644 --- a/src/core/hle/service/sm/sm.cpp +++ b/src/core/hle/service/sm/sm.cpp @@ -36,16 +36,17 @@ static ResultCode ValidateServiceName(const std::string& name) { return RESULT_SUCCESS; } -void ServiceManager::InstallInterfaces(std::shared_ptr<ServiceManager> self) { +void ServiceManager::InstallInterfaces(std::shared_ptr<ServiceManager> self, + Kernel::KernelCore& kernel) { ASSERT(self->sm_interface.expired()); - auto sm = std::make_shared<SM>(self); + auto sm = std::make_shared<SM>(self, kernel); sm->InstallAsNamedPort(); self->sm_interface = sm; self->controller_interface = std::make_unique<Controller>(); } -ResultVal<Kernel::SharedPtr<Kernel::ServerPort>> ServiceManager::RegisterService( +ResultVal<std::shared_ptr<Kernel::ServerPort>> ServiceManager::RegisterService( std::string name, unsigned int max_sessions) { CASCADE_CODE(ValidateServiceName(name)); @@ -72,7 +73,7 @@ ResultCode ServiceManager::UnregisterService(const std::string& name) { return RESULT_SUCCESS; } -ResultVal<Kernel::SharedPtr<Kernel::ClientPort>> ServiceManager::GetServicePort( +ResultVal<std::shared_ptr<Kernel::ClientPort>> ServiceManager::GetServicePort( const std::string& name) { CASCADE_CODE(ValidateServiceName(name)); @@ -84,7 +85,7 @@ ResultVal<Kernel::SharedPtr<Kernel::ClientPort>> ServiceManager::GetServicePort( return MakeResult(it->second); } -ResultVal<Kernel::SharedPtr<Kernel::ClientSession>> ServiceManager::ConnectToService( +ResultVal<std::shared_ptr<Kernel::ClientSession>> ServiceManager::ConnectToService( const std::string& name) { CASCADE_RESULT(auto client_port, GetServicePort(name)); @@ -114,8 +115,6 @@ void SM::GetService(Kernel::HLERequestContext& ctx) { std::string name(name_buf.begin(), end); - // TODO(yuriks): Permission checks go here - auto client_port = service_manager->GetServicePort(name); if (client_port.Failed()) { IPC::ResponseBuilder rb{ctx, 2}; @@ -127,14 +126,22 @@ void SM::GetService(Kernel::HLERequestContext& ctx) { return; } - auto session = client_port.Unwrap()->Connect(); - ASSERT(session.Succeeded()); - if (session.Succeeded()) { - LOG_DEBUG(Service_SM, "called service={} -> session={}", name, (*session)->GetObjectId()); - IPC::ResponseBuilder rb{ctx, 2, 0, 1, IPC::ResponseBuilder::Flags::AlwaysMoveHandles}; - rb.Push(session.Code()); - rb.PushMoveObjects(std::move(session).Unwrap()); + auto [client, server] = Kernel::Session::Create(kernel, name); + + const auto& server_port = client_port.Unwrap()->GetServerPort(); + if (server_port->GetHLEHandler()) { + server_port->GetHLEHandler()->ClientConnected(server); + } else { + server_port->AppendPendingSession(server); } + + // Wake the threads waiting on the ServerPort + server_port->WakeupAllWaitingThreads(); + + LOG_DEBUG(Service_SM, "called service={} -> session={}", name, client->GetObjectId()); + IPC::ResponseBuilder rb{ctx, 2, 0, 1, IPC::ResponseBuilder::Flags::AlwaysMoveHandles}; + rb.Push(RESULT_SUCCESS); + rb.PushMoveObjects(std::move(client)); } void SM::RegisterService(Kernel::HLERequestContext& ctx) { @@ -178,8 +185,8 @@ void SM::UnregisterService(Kernel::HLERequestContext& ctx) { rb.Push(service_manager->UnregisterService(name)); } -SM::SM(std::shared_ptr<ServiceManager> service_manager) - : ServiceFramework("sm:", 4), service_manager(std::move(service_manager)) { +SM::SM(std::shared_ptr<ServiceManager> service_manager, Kernel::KernelCore& kernel) + : ServiceFramework{"sm:", 4}, service_manager{std::move(service_manager)}, kernel{kernel} { static const FunctionInfo functions[] = { {0x00000000, &SM::Initialize, "Initialize"}, {0x00000001, &SM::GetService, "GetService"}, diff --git a/src/core/hle/service/sm/sm.h b/src/core/hle/service/sm/sm.h index b9d6381b4..b06d2f103 100644 --- a/src/core/hle/service/sm/sm.h +++ b/src/core/hle/service/sm/sm.h @@ -18,6 +18,7 @@ namespace Kernel { class ClientPort; class ClientSession; +class KernelCore; class ServerPort; class SessionRequestHandler; } // namespace Kernel @@ -29,7 +30,7 @@ class Controller; /// Interface to "sm:" service class SM final : public ServiceFramework<SM> { public: - explicit SM(std::shared_ptr<ServiceManager> service_manager); + explicit SM(std::shared_ptr<ServiceManager> service_manager, Kernel::KernelCore& kernel); ~SM() override; private: @@ -39,20 +40,21 @@ private: void UnregisterService(Kernel::HLERequestContext& ctx); std::shared_ptr<ServiceManager> service_manager; + Kernel::KernelCore& kernel; }; class ServiceManager { public: - static void InstallInterfaces(std::shared_ptr<ServiceManager> self); + static void InstallInterfaces(std::shared_ptr<ServiceManager> self, Kernel::KernelCore& kernel); ServiceManager(); ~ServiceManager(); - ResultVal<Kernel::SharedPtr<Kernel::ServerPort>> RegisterService(std::string name, - unsigned int max_sessions); + ResultVal<std::shared_ptr<Kernel::ServerPort>> RegisterService(std::string name, + unsigned int max_sessions); ResultCode UnregisterService(const std::string& name); - ResultVal<Kernel::SharedPtr<Kernel::ClientPort>> GetServicePort(const std::string& name); - ResultVal<Kernel::SharedPtr<Kernel::ClientSession>> ConnectToService(const std::string& name); + ResultVal<std::shared_ptr<Kernel::ClientPort>> GetServicePort(const std::string& name); + ResultVal<std::shared_ptr<Kernel::ClientSession>> ConnectToService(const std::string& name); template <typename T> std::shared_ptr<T> GetService(const std::string& service_name) const { @@ -77,7 +79,7 @@ private: std::unique_ptr<Controller> controller_interface; /// Map of registered services, retrieved using GetServicePort or ConnectToService. - std::unordered_map<std::string, Kernel::SharedPtr<Kernel::ClientPort>> registered_services; + std::unordered_map<std::string, std::shared_ptr<Kernel::ClientPort>> registered_services; }; } // namespace Service::SM diff --git a/src/core/hle/service/sockets/nsd.cpp b/src/core/hle/service/sockets/nsd.cpp index e6d73065e..dc70fd6fe 100644 --- a/src/core/hle/service/sockets/nsd.cpp +++ b/src/core/hle/service/sockets/nsd.cpp @@ -7,6 +7,7 @@ namespace Service::Sockets { NSD::NSD(const char* name) : ServiceFramework(name) { + // clang-format off static const FunctionInfo functions[] = { {10, nullptr, "GetSettingName"}, {11, nullptr, "GetEnvironmentIdentifier"}, @@ -22,10 +23,14 @@ NSD::NSD(const char* name) : ServiceFramework(name) { {42, nullptr, "GetNasApiFqdn"}, {43, nullptr, "GetNasApiFqdnEx"}, {50, nullptr, "GetCurrentSetting"}, + {51, nullptr, "WriteTestParameter"}, + {52, nullptr, "ReadTestParameter"}, {60, nullptr, "ReadSaveDataFromFsForTest"}, {61, nullptr, "WriteSaveDataToFsForTest"}, {62, nullptr, "DeleteSaveDataOfFsForTest"}, }; + // clang-format on + RegisterHandlers(functions); } diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp index 65040c077..1ba8c19a0 100644 --- a/src/core/hle/service/ssl/ssl.cpp +++ b/src/core/hle/service/ssl/ssl.cpp @@ -13,6 +13,7 @@ namespace Service::SSL { class ISslConnection final : public ServiceFramework<ISslConnection> { public: ISslConnection() : ServiceFramework("ISslConnection") { + // clang-format off static const FunctionInfo functions[] = { {0, nullptr, "SetSocketDescriptor"}, {1, nullptr, "SetHostName"}, @@ -40,7 +41,11 @@ public: {23, nullptr, "GetOption"}, {24, nullptr, "GetVerifyCertErrors"}, {25, nullptr, "GetCipherInfo"}, + {26, nullptr, "SetNextAlpnProto"}, + {27, nullptr, "GetNextAlpnProto"}, }; + // clang-format on + RegisterHandlers(functions); } }; diff --git a/src/core/hle/service/time/interface.cpp b/src/core/hle/service/time/interface.cpp index 9565e7de5..bc74f1e1d 100644 --- a/src/core/hle/service/time/interface.cpp +++ b/src/core/hle/service/time/interface.cpp @@ -21,6 +21,7 @@ Time::Time(std::shared_ptr<Module> time, std::shared_ptr<SharedMemory> shared_me {30, nullptr, "GetStandardNetworkClockOperationEventReadableHandle"}, {31, nullptr, "GetEphemeralNetworkClockOperationEventReadableHandle"}, {50, nullptr, "SetStandardSteadyClockInternalOffset"}, + {51, nullptr, "GetStandardSteadyClockRtcValue"}, {100, &Time::IsStandardUserSystemClockAutomaticCorrectionEnabled, "IsStandardUserSystemClockAutomaticCorrectionEnabled"}, {101, &Time::SetStandardUserSystemClockAutomaticCorrectionEnabled, "SetStandardUserSystemClockAutomaticCorrectionEnabled"}, {102, nullptr, "GetStandardUserSystemClockInitialYear"}, diff --git a/src/core/hle/service/time/time.cpp b/src/core/hle/service/time/time.cpp index 1b9ab8401..6ee77c5f9 100644 --- a/src/core/hle/service/time/time.cpp +++ b/src/core/hle/service/time/time.cpp @@ -34,12 +34,12 @@ static void PosixToCalendar(u64 posix_time, CalendarTime& calendar_time, additional_info = {}; return; } - calendar_time.year = tm->tm_year + 1900; - calendar_time.month = tm->tm_mon + 1; - calendar_time.day = tm->tm_mday; - calendar_time.hour = tm->tm_hour; - calendar_time.minute = tm->tm_min; - calendar_time.second = tm->tm_sec; + calendar_time.year = static_cast<u16_le>(tm->tm_year + 1900); + calendar_time.month = static_cast<u8>(tm->tm_mon + 1); + calendar_time.day = static_cast<u8>(tm->tm_mday); + calendar_time.hour = static_cast<u8>(tm->tm_hour); + calendar_time.minute = static_cast<u8>(tm->tm_min); + calendar_time.second = static_cast<u8>(tm->tm_sec); additional_info.day_of_week = tm->tm_wday; additional_info.day_of_year = tm->tm_yday; @@ -74,15 +74,17 @@ public: ISystemClock(std::shared_ptr<Service::Time::SharedMemory> shared_memory, ClockContextType clock_type) : ServiceFramework("ISystemClock"), shared_memory(shared_memory), clock_type(clock_type) { + // clang-format off static const FunctionInfo functions[] = { {0, &ISystemClock::GetCurrentTime, "GetCurrentTime"}, {1, nullptr, "SetCurrentTime"}, {2, &ISystemClock::GetSystemClockContext, "GetSystemClockContext"}, {3, nullptr, "SetSystemClockContext"}, - + {4, nullptr, "GetOperationEventReadableHandle"}, }; - RegisterHandlers(functions); + // clang-format on + RegisterHandlers(functions); UpdateSharedMemoryContext(system_clock_context); } @@ -162,6 +164,7 @@ private: class ITimeZoneService final : public ServiceFramework<ITimeZoneService> { public: ITimeZoneService() : ServiceFramework("ITimeZoneService") { + // clang-format off static const FunctionInfo functions[] = { {0, &ITimeZoneService::GetDeviceLocationName, "GetDeviceLocationName"}, {1, nullptr, "SetDeviceLocationName"}, @@ -169,11 +172,17 @@ public: {3, nullptr, "LoadLocationNameList"}, {4, &ITimeZoneService::LoadTimeZoneRule, "LoadTimeZoneRule"}, {5, nullptr, "GetTimeZoneRuleVersion"}, + {6, nullptr, "GetDeviceLocationNameAndUpdatedTime"}, + {7, nullptr, "SetDeviceLocationNameWithTimeZoneRule"}, + {8, nullptr, "ParseTimeZoneBinary"}, + {20, nullptr, "GetDeviceLocationNameOperationEventReadableHandle"}, {100, &ITimeZoneService::ToCalendarTime, "ToCalendarTime"}, {101, &ITimeZoneService::ToCalendarTimeWithMyRule, "ToCalendarTimeWithMyRule"}, {201, &ITimeZoneService::ToPosixTime, "ToPosixTime"}, {202, &ITimeZoneService::ToPosixTimeWithMyRule, "ToPosixTimeWithMyRule"}, }; + // clang-format on + RegisterHandlers(functions); } @@ -322,7 +331,7 @@ void Module::Interface::GetClockSnapshot(Kernel::HLERequestContext& ctx) { if (tm == nullptr) { LOG_ERROR(Service_Time, "tm is a nullptr"); IPC::ResponseBuilder rb{ctx, 2}; - rb.Push(ResultCode(-1)); // TODO(ogniK): Find appropriate error code + rb.Push(RESULT_UNKNOWN); // TODO(ogniK): Find appropriate error code return; } @@ -331,12 +340,12 @@ void Module::Interface::GetClockSnapshot(Kernel::HLERequestContext& ctx) { const SteadyClockTimePoint steady_clock_time_point{static_cast<u64_le>(ms.count() / 1000), {}}; CalendarTime calendar_time{}; - calendar_time.year = tm->tm_year + 1900; - calendar_time.month = tm->tm_mon + 1; - calendar_time.day = tm->tm_mday; - calendar_time.hour = tm->tm_hour; - calendar_time.minute = tm->tm_min; - calendar_time.second = tm->tm_sec; + calendar_time.year = static_cast<u16_le>(tm->tm_year + 1900); + calendar_time.month = static_cast<u8>(tm->tm_mon + 1); + calendar_time.day = static_cast<u8>(tm->tm_mday); + calendar_time.hour = static_cast<u8>(tm->tm_hour); + calendar_time.minute = static_cast<u8>(tm->tm_min); + calendar_time.second = static_cast<u8>(tm->tm_sec); ClockSnapshot clock_snapshot{}; clock_snapshot.system_posix_time = time_since_epoch; diff --git a/src/core/hle/service/time/time_sharedmemory.cpp b/src/core/hle/service/time/time_sharedmemory.cpp index bfc81b83c..4035f5072 100644 --- a/src/core/hle/service/time/time_sharedmemory.cpp +++ b/src/core/hle/service/time/time_sharedmemory.cpp @@ -21,7 +21,7 @@ SharedMemory::SharedMemory(Core::System& system) : system(system) { SharedMemory::~SharedMemory() = default; -Kernel::SharedPtr<Kernel::SharedMemory> SharedMemory::GetSharedMemoryHolder() const { +std::shared_ptr<Kernel::SharedMemory> SharedMemory::GetSharedMemoryHolder() const { return shared_memory_holder; } diff --git a/src/core/hle/service/time/time_sharedmemory.h b/src/core/hle/service/time/time_sharedmemory.h index cb8253541..904a96430 100644 --- a/src/core/hle/service/time/time_sharedmemory.h +++ b/src/core/hle/service/time/time_sharedmemory.h @@ -15,7 +15,7 @@ public: ~SharedMemory(); // Return the shared memory handle - Kernel::SharedPtr<Kernel::SharedMemory> GetSharedMemoryHolder() const; + std::shared_ptr<Kernel::SharedMemory> GetSharedMemoryHolder() const; // Set memory barriers in shared memory and update them void SetStandardSteadyClockTimepoint(const SteadyClockTimePoint& timepoint); @@ -66,7 +66,7 @@ public: static_assert(sizeof(Format) == 0xd8, "Format is an invalid size"); private: - Kernel::SharedPtr<Kernel::SharedMemory> shared_memory_holder{}; + std::shared_ptr<Kernel::SharedMemory> shared_memory_holder{}; Core::System& system; Format shared_memory_format{}; }; diff --git a/src/core/hle/service/vi/display/vi_display.cpp b/src/core/hle/service/vi/display/vi_display.cpp index 07033fb98..cd18c1610 100644 --- a/src/core/hle/service/vi/display/vi_display.cpp +++ b/src/core/hle/service/vi/display/vi_display.cpp @@ -31,7 +31,7 @@ const Layer& Display::GetLayer(std::size_t index) const { return layers.at(index); } -Kernel::SharedPtr<Kernel::ReadableEvent> Display::GetVSyncEvent() const { +std::shared_ptr<Kernel::ReadableEvent> Display::GetVSyncEvent() const { return vsync_event.readable; } diff --git a/src/core/hle/service/vi/display/vi_display.h b/src/core/hle/service/vi/display/vi_display.h index f56b5badc..8bb966a85 100644 --- a/src/core/hle/service/vi/display/vi_display.h +++ b/src/core/hle/service/vi/display/vi_display.h @@ -57,7 +57,7 @@ public: const Layer& GetLayer(std::size_t index) const; /// Gets the readable vsync event. - Kernel::SharedPtr<Kernel::ReadableEvent> GetVSyncEvent() const; + std::shared_ptr<Kernel::ReadableEvent> GetVSyncEvent() const; /// Signals the internal vsync event. void SignalVSyncEvent(); diff --git a/src/core/hle/service/vi/vi.cpp b/src/core/hle/service/vi/vi.cpp index 611cecc20..651c89dc0 100644 --- a/src/core/hle/service/vi/vi.cpp +++ b/src/core/hle/service/vi/vi.cpp @@ -541,8 +541,8 @@ private: } else { // Wait the current thread until a buffer becomes available ctx.SleepClientThread( - "IHOSBinderDriver::DequeueBuffer", -1, - [=](Kernel::SharedPtr<Kernel::Thread> thread, Kernel::HLERequestContext& ctx, + "IHOSBinderDriver::DequeueBuffer", UINT64_MAX, + [=](std::shared_ptr<Kernel::Thread> thread, Kernel::HLERequestContext& ctx, Kernel::ThreadWakeupReason reason) { // Repeat TransactParcel DequeueBuffer when a buffer is available auto& buffer_queue = nv_flinger->FindBufferQueue(id); @@ -731,6 +731,7 @@ class IManagerDisplayService final : public ServiceFramework<IManagerDisplayServ public: explicit IManagerDisplayService(std::shared_ptr<NVFlinger::NVFlinger> nv_flinger) : ServiceFramework("IManagerDisplayService"), nv_flinger(std::move(nv_flinger)) { + // clang-format off static const FunctionInfo functions[] = { {200, nullptr, "AllocateProcessHeapBlock"}, {201, nullptr, "FreeProcessHeapBlock"}, @@ -766,8 +767,11 @@ public: {6008, nullptr, "StartLayerPresentationFenceWait"}, {6009, nullptr, "StopLayerPresentationFenceWait"}, {6010, nullptr, "GetLayerPresentationAllFencesExpiredEvent"}, + {6011, nullptr, "EnableLayerAutoClearTransitionBuffer"}, + {6012, nullptr, "DisableLayerAutoClearTransitionBuffer"}, {7000, nullptr, "SetContentVisibility"}, {8000, nullptr, "SetConductorLayer"}, + {8001, nullptr, "SetTimestampTracking"}, {8100, nullptr, "SetIndirectProducerFlipOffset"}, {8200, nullptr, "CreateSharedBufferStaticStorage"}, {8201, nullptr, "CreateSharedBufferTransferMemory"}, @@ -800,6 +804,8 @@ public: {8297, nullptr, "GetSharedFrameBufferContentParameter"}, {8298, nullptr, "ExpandStartupLogoOnSharedFrameBuffer"}, }; + // clang-format on + RegisterHandlers(functions); } diff --git a/src/core/memory.cpp b/src/core/memory.cpp index fa49f3dd0..91bf07a92 100644 --- a/src/core/memory.cpp +++ b/src/core/memory.cpp @@ -17,529 +17,699 @@ #include "core/hle/kernel/process.h" #include "core/hle/kernel/vm_manager.h" #include "core/memory.h" -#include "core/memory_setup.h" #include "video_core/gpu.h" namespace Memory { -static Common::PageTable* current_page_table = nullptr; +// Implementation class used to keep the specifics of the memory subsystem hidden +// from outside classes. This also allows modification to the internals of the memory +// subsystem without needing to rebuild all files that make use of the memory interface. +struct Memory::Impl { + explicit Impl(Core::System& system_) : system{system_} {} -void SetCurrentPageTable(Kernel::Process& process) { - current_page_table = &process.VMManager().page_table; + void SetCurrentPageTable(Kernel::Process& process) { + current_page_table = &process.VMManager().page_table; - const std::size_t address_space_width = process.VMManager().GetAddressSpaceWidth(); + const std::size_t address_space_width = process.VMManager().GetAddressSpaceWidth(); - auto& system = Core::System::GetInstance(); - system.ArmInterface(0).PageTableChanged(*current_page_table, address_space_width); - system.ArmInterface(1).PageTableChanged(*current_page_table, address_space_width); - system.ArmInterface(2).PageTableChanged(*current_page_table, address_space_width); - system.ArmInterface(3).PageTableChanged(*current_page_table, address_space_width); -} + system.ArmInterface(0).PageTableChanged(*current_page_table, address_space_width); + system.ArmInterface(1).PageTableChanged(*current_page_table, address_space_width); + system.ArmInterface(2).PageTableChanged(*current_page_table, address_space_width); + system.ArmInterface(3).PageTableChanged(*current_page_table, address_space_width); + } -static void MapPages(Common::PageTable& page_table, VAddr base, u64 size, u8* memory, - Common::PageType type) { - LOG_DEBUG(HW_Memory, "Mapping {} onto {:016X}-{:016X}", fmt::ptr(memory), base * PAGE_SIZE, - (base + size) * PAGE_SIZE); - - // During boot, current_page_table might not be set yet, in which case we need not flush - if (Core::System::GetInstance().IsPoweredOn()) { - auto& gpu = Core::System::GetInstance().GPU(); - for (u64 i = 0; i < size; i++) { - const auto page = base + i; - if (page_table.attributes[page] == Common::PageType::RasterizerCachedMemory) { - gpu.FlushAndInvalidateRegion(page << PAGE_BITS, PAGE_SIZE); - } - } + void MapMemoryRegion(Common::PageTable& page_table, VAddr base, u64 size, u8* target) { + ASSERT_MSG((size & PAGE_MASK) == 0, "non-page aligned size: {:016X}", size); + ASSERT_MSG((base & PAGE_MASK) == 0, "non-page aligned base: {:016X}", base); + MapPages(page_table, base / PAGE_SIZE, size / PAGE_SIZE, target, Common::PageType::Memory); } - VAddr end = base + size; - ASSERT_MSG(end <= page_table.pointers.size(), "out of range mapping at {:016X}", - base + page_table.pointers.size()); + void MapIoRegion(Common::PageTable& page_table, VAddr base, u64 size, + Common::MemoryHookPointer mmio_handler) { + ASSERT_MSG((size & PAGE_MASK) == 0, "non-page aligned size: {:016X}", size); + ASSERT_MSG((base & PAGE_MASK) == 0, "non-page aligned base: {:016X}", base); + MapPages(page_table, base / PAGE_SIZE, size / PAGE_SIZE, nullptr, + Common::PageType::Special); + + const auto interval = boost::icl::discrete_interval<VAddr>::closed(base, base + size - 1); + const Common::SpecialRegion region{Common::SpecialRegion::Type::IODevice, + std::move(mmio_handler)}; + page_table.special_regions.add( + std::make_pair(interval, std::set<Common::SpecialRegion>{region})); + } - std::fill(page_table.attributes.begin() + base, page_table.attributes.begin() + end, type); + void UnmapRegion(Common::PageTable& page_table, VAddr base, u64 size) { + ASSERT_MSG((size & PAGE_MASK) == 0, "non-page aligned size: {:016X}", size); + ASSERT_MSG((base & PAGE_MASK) == 0, "non-page aligned base: {:016X}", base); + MapPages(page_table, base / PAGE_SIZE, size / PAGE_SIZE, nullptr, + Common::PageType::Unmapped); - if (memory == nullptr) { - std::fill(page_table.pointers.begin() + base, page_table.pointers.begin() + end, memory); - } else { - while (base != end) { - page_table.pointers[base] = memory; + const auto interval = boost::icl::discrete_interval<VAddr>::closed(base, base + size - 1); + page_table.special_regions.erase(interval); + } - base += 1; - memory += PAGE_SIZE; - } + void AddDebugHook(Common::PageTable& page_table, VAddr base, u64 size, + Common::MemoryHookPointer hook) { + const auto interval = boost::icl::discrete_interval<VAddr>::closed(base, base + size - 1); + const Common::SpecialRegion region{Common::SpecialRegion::Type::DebugHook, std::move(hook)}; + page_table.special_regions.add( + std::make_pair(interval, std::set<Common::SpecialRegion>{region})); } -} -void MapMemoryRegion(Common::PageTable& page_table, VAddr base, u64 size, u8* target) { - ASSERT_MSG((size & PAGE_MASK) == 0, "non-page aligned size: {:016X}", size); - ASSERT_MSG((base & PAGE_MASK) == 0, "non-page aligned base: {:016X}", base); - MapPages(page_table, base / PAGE_SIZE, size / PAGE_SIZE, target, Common::PageType::Memory); -} + void RemoveDebugHook(Common::PageTable& page_table, VAddr base, u64 size, + Common::MemoryHookPointer hook) { + const auto interval = boost::icl::discrete_interval<VAddr>::closed(base, base + size - 1); + const Common::SpecialRegion region{Common::SpecialRegion::Type::DebugHook, std::move(hook)}; + page_table.special_regions.subtract( + std::make_pair(interval, std::set<Common::SpecialRegion>{region})); + } -void MapIoRegion(Common::PageTable& page_table, VAddr base, u64 size, - Common::MemoryHookPointer mmio_handler) { - ASSERT_MSG((size & PAGE_MASK) == 0, "non-page aligned size: {:016X}", size); - ASSERT_MSG((base & PAGE_MASK) == 0, "non-page aligned base: {:016X}", base); - MapPages(page_table, base / PAGE_SIZE, size / PAGE_SIZE, nullptr, Common::PageType::Special); + bool IsValidVirtualAddress(const Kernel::Process& process, const VAddr vaddr) const { + const auto& page_table = process.VMManager().page_table; - auto interval = boost::icl::discrete_interval<VAddr>::closed(base, base + size - 1); - Common::SpecialRegion region{Common::SpecialRegion::Type::IODevice, std::move(mmio_handler)}; - page_table.special_regions.add( - std::make_pair(interval, std::set<Common::SpecialRegion>{region})); -} + const u8* const page_pointer = page_table.pointers[vaddr >> PAGE_BITS]; + if (page_pointer != nullptr) { + return true; + } -void UnmapRegion(Common::PageTable& page_table, VAddr base, u64 size) { - ASSERT_MSG((size & PAGE_MASK) == 0, "non-page aligned size: {:016X}", size); - ASSERT_MSG((base & PAGE_MASK) == 0, "non-page aligned base: {:016X}", base); - MapPages(page_table, base / PAGE_SIZE, size / PAGE_SIZE, nullptr, Common::PageType::Unmapped); + if (page_table.attributes[vaddr >> PAGE_BITS] == Common::PageType::RasterizerCachedMemory) { + return true; + } - auto interval = boost::icl::discrete_interval<VAddr>::closed(base, base + size - 1); - page_table.special_regions.erase(interval); -} + if (page_table.attributes[vaddr >> PAGE_BITS] != Common::PageType::Special) { + return false; + } -void AddDebugHook(Common::PageTable& page_table, VAddr base, u64 size, - Common::MemoryHookPointer hook) { - auto interval = boost::icl::discrete_interval<VAddr>::closed(base, base + size - 1); - Common::SpecialRegion region{Common::SpecialRegion::Type::DebugHook, std::move(hook)}; - page_table.special_regions.add( - std::make_pair(interval, std::set<Common::SpecialRegion>{region})); -} + return false; + } -void RemoveDebugHook(Common::PageTable& page_table, VAddr base, u64 size, - Common::MemoryHookPointer hook) { - auto interval = boost::icl::discrete_interval<VAddr>::closed(base, base + size - 1); - Common::SpecialRegion region{Common::SpecialRegion::Type::DebugHook, std::move(hook)}; - page_table.special_regions.subtract( - std::make_pair(interval, std::set<Common::SpecialRegion>{region})); -} + bool IsValidVirtualAddress(VAddr vaddr) const { + return IsValidVirtualAddress(*system.CurrentProcess(), vaddr); + } -/** - * Gets a pointer to the exact memory at the virtual address (i.e. not page aligned) - * using a VMA from the current process - */ -static u8* GetPointerFromVMA(const Kernel::Process& process, VAddr vaddr) { - const auto& vm_manager = process.VMManager(); - - const auto it = vm_manager.FindVMA(vaddr); - DEBUG_ASSERT(vm_manager.IsValidHandle(it)); - - u8* direct_pointer = nullptr; - const auto& vma = it->second; - switch (vma.type) { - case Kernel::VMAType::AllocatedMemoryBlock: - direct_pointer = vma.backing_block->data() + vma.offset; - break; - case Kernel::VMAType::BackingMemory: - direct_pointer = vma.backing_memory; - break; - case Kernel::VMAType::Free: - return nullptr; - default: - UNREACHABLE(); + /** + * Gets a pointer to the exact memory at the virtual address (i.e. not page aligned) + * using a VMA from the current process + */ + u8* GetPointerFromVMA(const Kernel::Process& process, VAddr vaddr) { + const auto& vm_manager = process.VMManager(); + + const auto it = vm_manager.FindVMA(vaddr); + DEBUG_ASSERT(vm_manager.IsValidHandle(it)); + + u8* direct_pointer = nullptr; + const auto& vma = it->second; + switch (vma.type) { + case Kernel::VMAType::AllocatedMemoryBlock: + direct_pointer = vma.backing_block->data() + vma.offset; + break; + case Kernel::VMAType::BackingMemory: + direct_pointer = vma.backing_memory; + break; + case Kernel::VMAType::Free: + return nullptr; + default: + UNREACHABLE(); + } + + return direct_pointer + (vaddr - vma.base); } - return direct_pointer + (vaddr - vma.base); -} + /** + * Gets a pointer to the exact memory at the virtual address (i.e. not page aligned) + * using a VMA from the current process. + */ + u8* GetPointerFromVMA(VAddr vaddr) { + return GetPointerFromVMA(*system.CurrentProcess(), vaddr); + } -/** - * Gets a pointer to the exact memory at the virtual address (i.e. not page aligned) - * using a VMA from the current process. - */ -static u8* GetPointerFromVMA(VAddr vaddr) { - return GetPointerFromVMA(*Core::System::GetInstance().CurrentProcess(), vaddr); -} + u8* GetPointer(const VAddr vaddr) { + u8* const page_pointer = current_page_table->pointers[vaddr >> PAGE_BITS]; + if (page_pointer != nullptr) { + return page_pointer + (vaddr & PAGE_MASK); + } -template <typename T> -T Read(const VAddr vaddr) { - const u8* page_pointer = current_page_table->pointers[vaddr >> PAGE_BITS]; - if (page_pointer) { - // NOTE: Avoid adding any extra logic to this fast-path block - T value; - std::memcpy(&value, &page_pointer[vaddr & PAGE_MASK], sizeof(T)); - return value; - } - - Common::PageType type = current_page_table->attributes[vaddr >> PAGE_BITS]; - switch (type) { - case Common::PageType::Unmapped: - LOG_ERROR(HW_Memory, "Unmapped Read{} @ 0x{:08X}", sizeof(T) * 8, vaddr); - return 0; - case Common::PageType::Memory: - ASSERT_MSG(false, "Mapped memory page without a pointer @ {:016X}", vaddr); - break; - case Common::PageType::RasterizerCachedMemory: { - auto host_ptr{GetPointerFromVMA(vaddr)}; - Core::System::GetInstance().GPU().FlushRegion(ToCacheAddr(host_ptr), sizeof(T)); - T value; - std::memcpy(&value, host_ptr, sizeof(T)); - return value; - } - default: - UNREACHABLE(); - } - return {}; -} + if (current_page_table->attributes[vaddr >> PAGE_BITS] == + Common::PageType::RasterizerCachedMemory) { + return GetPointerFromVMA(vaddr); + } -template <typename T> -void Write(const VAddr vaddr, const T data) { - u8* page_pointer = current_page_table->pointers[vaddr >> PAGE_BITS]; - if (page_pointer) { - // NOTE: Avoid adding any extra logic to this fast-path block - std::memcpy(&page_pointer[vaddr & PAGE_MASK], &data, sizeof(T)); - return; - } - - Common::PageType type = current_page_table->attributes[vaddr >> PAGE_BITS]; - switch (type) { - case Common::PageType::Unmapped: - LOG_ERROR(HW_Memory, "Unmapped Write{} 0x{:08X} @ 0x{:016X}", sizeof(data) * 8, - static_cast<u32>(data), vaddr); - return; - case Common::PageType::Memory: - ASSERT_MSG(false, "Mapped memory page without a pointer @ {:016X}", vaddr); - break; - case Common::PageType::RasterizerCachedMemory: { - auto host_ptr{GetPointerFromVMA(vaddr)}; - Core::System::GetInstance().GPU().InvalidateRegion(ToCacheAddr(host_ptr), sizeof(T)); - std::memcpy(host_ptr, &data, sizeof(T)); - break; - } - default: - UNREACHABLE(); + LOG_ERROR(HW_Memory, "Unknown GetPointer @ 0x{:016X}", vaddr); + return nullptr; } -} -bool IsValidVirtualAddress(const Kernel::Process& process, const VAddr vaddr) { - const auto& page_table = process.VMManager().page_table; + u8 Read8(const VAddr addr) { + return Read<u8>(addr); + } - const u8* page_pointer = page_table.pointers[vaddr >> PAGE_BITS]; - if (page_pointer) - return true; + u16 Read16(const VAddr addr) { + return Read<u16_le>(addr); + } - if (page_table.attributes[vaddr >> PAGE_BITS] == Common::PageType::RasterizerCachedMemory) - return true; + u32 Read32(const VAddr addr) { + return Read<u32_le>(addr); + } - if (page_table.attributes[vaddr >> PAGE_BITS] != Common::PageType::Special) - return false; + u64 Read64(const VAddr addr) { + return Read<u64_le>(addr); + } - return false; -} + void Write8(const VAddr addr, const u8 data) { + Write<u8>(addr, data); + } -bool IsValidVirtualAddress(const VAddr vaddr) { - return IsValidVirtualAddress(*Core::System::GetInstance().CurrentProcess(), vaddr); -} + void Write16(const VAddr addr, const u16 data) { + Write<u16_le>(addr, data); + } -bool IsKernelVirtualAddress(const VAddr vaddr) { - return KERNEL_REGION_VADDR <= vaddr && vaddr < KERNEL_REGION_END; -} + void Write32(const VAddr addr, const u32 data) { + Write<u32_le>(addr, data); + } -u8* GetPointer(const VAddr vaddr) { - u8* page_pointer = current_page_table->pointers[vaddr >> PAGE_BITS]; - if (page_pointer) { - return page_pointer + (vaddr & PAGE_MASK); + void Write64(const VAddr addr, const u64 data) { + Write<u64_le>(addr, data); } - if (current_page_table->attributes[vaddr >> PAGE_BITS] == - Common::PageType::RasterizerCachedMemory) { - return GetPointerFromVMA(vaddr); + std::string ReadCString(VAddr vaddr, std::size_t max_length) { + std::string string; + string.reserve(max_length); + for (std::size_t i = 0; i < max_length; ++i) { + const char c = Read8(vaddr); + if (c == '\0') { + break; + } + string.push_back(c); + ++vaddr; + } + string.shrink_to_fit(); + return string; } - LOG_ERROR(HW_Memory, "Unknown GetPointer @ 0x{:016X}", vaddr); - return nullptr; -} + void ReadBlock(const Kernel::Process& process, const VAddr src_addr, void* dest_buffer, + const std::size_t size) { + const auto& page_table = process.VMManager().page_table; + + std::size_t remaining_size = size; + std::size_t page_index = src_addr >> PAGE_BITS; + std::size_t page_offset = src_addr & PAGE_MASK; + + while (remaining_size > 0) { + const std::size_t copy_amount = + std::min(static_cast<std::size_t>(PAGE_SIZE) - page_offset, remaining_size); + const auto current_vaddr = static_cast<VAddr>((page_index << PAGE_BITS) + page_offset); + + switch (page_table.attributes[page_index]) { + case Common::PageType::Unmapped: { + LOG_ERROR(HW_Memory, + "Unmapped ReadBlock @ 0x{:016X} (start address = 0x{:016X}, size = {})", + current_vaddr, src_addr, size); + std::memset(dest_buffer, 0, copy_amount); + break; + } + case Common::PageType::Memory: { + DEBUG_ASSERT(page_table.pointers[page_index]); -std::string ReadCString(VAddr vaddr, std::size_t max_length) { - std::string string; - string.reserve(max_length); - for (std::size_t i = 0; i < max_length; ++i) { - char c = Read8(vaddr); - if (c == '\0') - break; - string.push_back(c); - ++vaddr; + const u8* const src_ptr = page_table.pointers[page_index] + page_offset; + std::memcpy(dest_buffer, src_ptr, copy_amount); + break; + } + case Common::PageType::RasterizerCachedMemory: { + const u8* const host_ptr = GetPointerFromVMA(process, current_vaddr); + system.GPU().FlushRegion(ToCacheAddr(host_ptr), copy_amount); + std::memcpy(dest_buffer, host_ptr, copy_amount); + break; + } + default: + UNREACHABLE(); + } + + page_index++; + page_offset = 0; + dest_buffer = static_cast<u8*>(dest_buffer) + copy_amount; + remaining_size -= copy_amount; + } } - string.shrink_to_fit(); - return string; -} -void RasterizerMarkRegionCached(VAddr vaddr, u64 size, bool cached) { - if (vaddr == 0) { - return; + void ReadBlock(const VAddr src_addr, void* dest_buffer, const std::size_t size) { + ReadBlock(*system.CurrentProcess(), src_addr, dest_buffer, size); } - // Iterate over a contiguous CPU address space, which corresponds to the specified GPU address - // space, marking the region as un/cached. The region is marked un/cached at a granularity of - // CPU pages, hence why we iterate on a CPU page basis (note: GPU page size is different). This - // assumes the specified GPU address region is contiguous as well. + void WriteBlock(const Kernel::Process& process, const VAddr dest_addr, const void* src_buffer, + const std::size_t size) { + const auto& page_table = process.VMManager().page_table; + std::size_t remaining_size = size; + std::size_t page_index = dest_addr >> PAGE_BITS; + std::size_t page_offset = dest_addr & PAGE_MASK; + + while (remaining_size > 0) { + const std::size_t copy_amount = + std::min(static_cast<std::size_t>(PAGE_SIZE) - page_offset, remaining_size); + const auto current_vaddr = static_cast<VAddr>((page_index << PAGE_BITS) + page_offset); + + switch (page_table.attributes[page_index]) { + case Common::PageType::Unmapped: { + LOG_ERROR(HW_Memory, + "Unmapped WriteBlock @ 0x{:016X} (start address = 0x{:016X}, size = {})", + current_vaddr, dest_addr, size); + break; + } + case Common::PageType::Memory: { + DEBUG_ASSERT(page_table.pointers[page_index]); + + u8* const dest_ptr = page_table.pointers[page_index] + page_offset; + std::memcpy(dest_ptr, src_buffer, copy_amount); + break; + } + case Common::PageType::RasterizerCachedMemory: { + u8* const host_ptr = GetPointerFromVMA(process, current_vaddr); + system.GPU().InvalidateRegion(ToCacheAddr(host_ptr), copy_amount); + std::memcpy(host_ptr, src_buffer, copy_amount); + break; + } + default: + UNREACHABLE(); + } + + page_index++; + page_offset = 0; + src_buffer = static_cast<const u8*>(src_buffer) + copy_amount; + remaining_size -= copy_amount; + } + } - u64 num_pages = ((vaddr + size - 1) >> PAGE_BITS) - (vaddr >> PAGE_BITS) + 1; - for (unsigned i = 0; i < num_pages; ++i, vaddr += PAGE_SIZE) { - Common::PageType& page_type = current_page_table->attributes[vaddr >> PAGE_BITS]; + void WriteBlock(const VAddr dest_addr, const void* src_buffer, const std::size_t size) { + WriteBlock(*system.CurrentProcess(), dest_addr, src_buffer, size); + } - if (cached) { - // Switch page type to cached if now cached - switch (page_type) { - case Common::PageType::Unmapped: - // It is not necessary for a process to have this region mapped into its address - // space, for example, a system module need not have a VRAM mapping. + void ZeroBlock(const Kernel::Process& process, const VAddr dest_addr, const std::size_t size) { + const auto& page_table = process.VMManager().page_table; + std::size_t remaining_size = size; + std::size_t page_index = dest_addr >> PAGE_BITS; + std::size_t page_offset = dest_addr & PAGE_MASK; + + while (remaining_size > 0) { + const std::size_t copy_amount = + std::min(static_cast<std::size_t>(PAGE_SIZE) - page_offset, remaining_size); + const auto current_vaddr = static_cast<VAddr>((page_index << PAGE_BITS) + page_offset); + + switch (page_table.attributes[page_index]) { + case Common::PageType::Unmapped: { + LOG_ERROR(HW_Memory, + "Unmapped ZeroBlock @ 0x{:016X} (start address = 0x{:016X}, size = {})", + current_vaddr, dest_addr, size); break; - case Common::PageType::Memory: - page_type = Common::PageType::RasterizerCachedMemory; - current_page_table->pointers[vaddr >> PAGE_BITS] = nullptr; + } + case Common::PageType::Memory: { + DEBUG_ASSERT(page_table.pointers[page_index]); + + u8* dest_ptr = page_table.pointers[page_index] + page_offset; + std::memset(dest_ptr, 0, copy_amount); break; - case Common::PageType::RasterizerCachedMemory: - // There can be more than one GPU region mapped per CPU region, so it's common that - // this area is already marked as cached. + } + case Common::PageType::RasterizerCachedMemory: { + u8* const host_ptr = GetPointerFromVMA(process, current_vaddr); + system.GPU().InvalidateRegion(ToCacheAddr(host_ptr), copy_amount); + std::memset(host_ptr, 0, copy_amount); break; + } default: UNREACHABLE(); } - } else { - // Switch page type to uncached if now uncached - switch (page_type) { - case Common::PageType::Unmapped: - // It is not necessary for a process to have this region mapped into its address - // space, for example, a system module need not have a VRAM mapping. + + page_index++; + page_offset = 0; + remaining_size -= copy_amount; + } + } + + void ZeroBlock(const VAddr dest_addr, const std::size_t size) { + ZeroBlock(*system.CurrentProcess(), dest_addr, size); + } + + void CopyBlock(const Kernel::Process& process, VAddr dest_addr, VAddr src_addr, + const std::size_t size) { + const auto& page_table = process.VMManager().page_table; + std::size_t remaining_size = size; + std::size_t page_index = src_addr >> PAGE_BITS; + std::size_t page_offset = src_addr & PAGE_MASK; + + while (remaining_size > 0) { + const std::size_t copy_amount = + std::min(static_cast<std::size_t>(PAGE_SIZE) - page_offset, remaining_size); + const auto current_vaddr = static_cast<VAddr>((page_index << PAGE_BITS) + page_offset); + + switch (page_table.attributes[page_index]) { + case Common::PageType::Unmapped: { + LOG_ERROR(HW_Memory, + "Unmapped CopyBlock @ 0x{:016X} (start address = 0x{:016X}, size = {})", + current_vaddr, src_addr, size); + ZeroBlock(process, dest_addr, copy_amount); break; - case Common::PageType::Memory: - // There can be more than one GPU region mapped per CPU region, so it's common that - // this area is already unmarked as cached. + } + case Common::PageType::Memory: { + DEBUG_ASSERT(page_table.pointers[page_index]); + const u8* src_ptr = page_table.pointers[page_index] + page_offset; + WriteBlock(process, dest_addr, src_ptr, copy_amount); break; + } case Common::PageType::RasterizerCachedMemory: { - u8* pointer = GetPointerFromVMA(vaddr & ~PAGE_MASK); - if (pointer == nullptr) { - // It's possible that this function has been called while updating the pagetable - // after unmapping a VMA. In that case the underlying VMA will no longer exist, - // and we should just leave the pagetable entry blank. - page_type = Common::PageType::Unmapped; - } else { - page_type = Common::PageType::Memory; - current_page_table->pointers[vaddr >> PAGE_BITS] = pointer; - } + const u8* const host_ptr = GetPointerFromVMA(process, current_vaddr); + system.GPU().FlushRegion(ToCacheAddr(host_ptr), copy_amount); + WriteBlock(process, dest_addr, host_ptr, copy_amount); break; } default: UNREACHABLE(); } + + page_index++; + page_offset = 0; + dest_addr += static_cast<VAddr>(copy_amount); + src_addr += static_cast<VAddr>(copy_amount); + remaining_size -= copy_amount; } } -} -u8 Read8(const VAddr addr) { - return Read<u8>(addr); -} + void CopyBlock(VAddr dest_addr, VAddr src_addr, std::size_t size) { + return CopyBlock(*system.CurrentProcess(), dest_addr, src_addr, size); + } -u16 Read16(const VAddr addr) { - return Read<u16_le>(addr); -} + void RasterizerMarkRegionCached(VAddr vaddr, u64 size, bool cached) { + if (vaddr == 0) { + return; + } -u32 Read32(const VAddr addr) { - return Read<u32_le>(addr); -} + // Iterate over a contiguous CPU address space, which corresponds to the specified GPU + // address space, marking the region as un/cached. The region is marked un/cached at a + // granularity of CPU pages, hence why we iterate on a CPU page basis (note: GPU page size + // is different). This assumes the specified GPU address region is contiguous as well. + + u64 num_pages = ((vaddr + size - 1) >> PAGE_BITS) - (vaddr >> PAGE_BITS) + 1; + for (unsigned i = 0; i < num_pages; ++i, vaddr += PAGE_SIZE) { + Common::PageType& page_type = current_page_table->attributes[vaddr >> PAGE_BITS]; + + if (cached) { + // Switch page type to cached if now cached + switch (page_type) { + case Common::PageType::Unmapped: + // It is not necessary for a process to have this region mapped into its address + // space, for example, a system module need not have a VRAM mapping. + break; + case Common::PageType::Memory: + page_type = Common::PageType::RasterizerCachedMemory; + current_page_table->pointers[vaddr >> PAGE_BITS] = nullptr; + break; + case Common::PageType::RasterizerCachedMemory: + // There can be more than one GPU region mapped per CPU region, so it's common + // that this area is already marked as cached. + break; + default: + UNREACHABLE(); + } + } else { + // Switch page type to uncached if now uncached + switch (page_type) { + case Common::PageType::Unmapped: + // It is not necessary for a process to have this region mapped into its address + // space, for example, a system module need not have a VRAM mapping. + break; + case Common::PageType::Memory: + // There can be more than one GPU region mapped per CPU region, so it's common + // that this area is already unmarked as cached. + break; + case Common::PageType::RasterizerCachedMemory: { + u8* pointer = GetPointerFromVMA(vaddr & ~PAGE_MASK); + if (pointer == nullptr) { + // It's possible that this function has been called while updating the + // pagetable after unmapping a VMA. In that case the underlying VMA will no + // longer exist, and we should just leave the pagetable entry blank. + page_type = Common::PageType::Unmapped; + } else { + page_type = Common::PageType::Memory; + current_page_table->pointers[vaddr >> PAGE_BITS] = pointer; + } + break; + } + default: + UNREACHABLE(); + } + } + } + } -u64 Read64(const VAddr addr) { - return Read<u64_le>(addr); -} + /** + * Maps a region of pages as a specific type. + * + * @param page_table The page table to use to perform the mapping. + * @param base The base address to begin mapping at. + * @param size The total size of the range in bytes. + * @param memory The memory to map. + * @param type The page type to map the memory as. + */ + void MapPages(Common::PageTable& page_table, VAddr base, u64 size, u8* memory, + Common::PageType type) { + LOG_DEBUG(HW_Memory, "Mapping {} onto {:016X}-{:016X}", fmt::ptr(memory), base * PAGE_SIZE, + (base + size) * PAGE_SIZE); + + // During boot, current_page_table might not be set yet, in which case we need not flush + if (system.IsPoweredOn()) { + auto& gpu = system.GPU(); + for (u64 i = 0; i < size; i++) { + const auto page = base + i; + if (page_table.attributes[page] == Common::PageType::RasterizerCachedMemory) { + gpu.FlushAndInvalidateRegion(page << PAGE_BITS, PAGE_SIZE); + } + } + } -void ReadBlock(const Kernel::Process& process, const VAddr src_addr, void* dest_buffer, - const std::size_t size) { - const auto& page_table = process.VMManager().page_table; - - std::size_t remaining_size = size; - std::size_t page_index = src_addr >> PAGE_BITS; - std::size_t page_offset = src_addr & PAGE_MASK; - - while (remaining_size > 0) { - const std::size_t copy_amount = - std::min(static_cast<std::size_t>(PAGE_SIZE) - page_offset, remaining_size); - const VAddr current_vaddr = static_cast<VAddr>((page_index << PAGE_BITS) + page_offset); - - switch (page_table.attributes[page_index]) { - case Common::PageType::Unmapped: { - LOG_ERROR(HW_Memory, - "Unmapped ReadBlock @ 0x{:016X} (start address = 0x{:016X}, size = {})", - current_vaddr, src_addr, size); - std::memset(dest_buffer, 0, copy_amount); - break; + const VAddr end = base + size; + ASSERT_MSG(end <= page_table.pointers.size(), "out of range mapping at {:016X}", + base + page_table.pointers.size()); + + std::fill(page_table.attributes.begin() + base, page_table.attributes.begin() + end, type); + + if (memory == nullptr) { + std::fill(page_table.pointers.begin() + base, page_table.pointers.begin() + end, + memory); + } else { + while (base != end) { + page_table.pointers[base] = memory; + + base += 1; + memory += PAGE_SIZE; + } + } + } + + /** + * Reads a particular data type out of memory at the given virtual address. + * + * @param vaddr The virtual address to read the data type from. + * + * @tparam T The data type to read out of memory. This type *must* be + * trivially copyable, otherwise the behavior of this function + * is undefined. + * + * @returns The instance of T read from the specified virtual address. + */ + template <typename T> + T Read(const VAddr vaddr) { + const u8* const page_pointer = current_page_table->pointers[vaddr >> PAGE_BITS]; + if (page_pointer != nullptr) { + // NOTE: Avoid adding any extra logic to this fast-path block + T value; + std::memcpy(&value, &page_pointer[vaddr & PAGE_MASK], sizeof(T)); + return value; } - case Common::PageType::Memory: { - DEBUG_ASSERT(page_table.pointers[page_index]); - const u8* src_ptr = page_table.pointers[page_index] + page_offset; - std::memcpy(dest_buffer, src_ptr, copy_amount); + const Common::PageType type = current_page_table->attributes[vaddr >> PAGE_BITS]; + switch (type) { + case Common::PageType::Unmapped: + LOG_ERROR(HW_Memory, "Unmapped Read{} @ 0x{:08X}", sizeof(T) * 8, vaddr); + return 0; + case Common::PageType::Memory: + ASSERT_MSG(false, "Mapped memory page without a pointer @ {:016X}", vaddr); break; + case Common::PageType::RasterizerCachedMemory: { + const u8* const host_ptr = GetPointerFromVMA(vaddr); + system.GPU().FlushRegion(ToCacheAddr(host_ptr), sizeof(T)); + T value; + std::memcpy(&value, host_ptr, sizeof(T)); + return value; + } + default: + UNREACHABLE(); } + return {}; + } + + /** + * Writes a particular data type to memory at the given virtual address. + * + * @param vaddr The virtual address to write the data type to. + * + * @tparam T The data type to write to memory. This type *must* be + * trivially copyable, otherwise the behavior of this function + * is undefined. + * + * @returns The instance of T write to the specified virtual address. + */ + template <typename T> + void Write(const VAddr vaddr, const T data) { + u8* const page_pointer = current_page_table->pointers[vaddr >> PAGE_BITS]; + if (page_pointer != nullptr) { + // NOTE: Avoid adding any extra logic to this fast-path block + std::memcpy(&page_pointer[vaddr & PAGE_MASK], &data, sizeof(T)); + return; + } + + const Common::PageType type = current_page_table->attributes[vaddr >> PAGE_BITS]; + switch (type) { + case Common::PageType::Unmapped: + LOG_ERROR(HW_Memory, "Unmapped Write{} 0x{:08X} @ 0x{:016X}", sizeof(data) * 8, + static_cast<u32>(data), vaddr); + return; + case Common::PageType::Memory: + ASSERT_MSG(false, "Mapped memory page without a pointer @ {:016X}", vaddr); + break; case Common::PageType::RasterizerCachedMemory: { - const auto& host_ptr{GetPointerFromVMA(process, current_vaddr)}; - Core::System::GetInstance().GPU().FlushRegion(ToCacheAddr(host_ptr), copy_amount); - std::memcpy(dest_buffer, host_ptr, copy_amount); + u8* const host_ptr{GetPointerFromVMA(vaddr)}; + system.GPU().InvalidateRegion(ToCacheAddr(host_ptr), sizeof(T)); + std::memcpy(host_ptr, &data, sizeof(T)); break; } default: UNREACHABLE(); } - - page_index++; - page_offset = 0; - dest_buffer = static_cast<u8*>(dest_buffer) + copy_amount; - remaining_size -= copy_amount; } + + Common::PageTable* current_page_table = nullptr; + Core::System& system; +}; + +Memory::Memory(Core::System& system) : impl{std::make_unique<Impl>(system)} {} +Memory::~Memory() = default; + +void Memory::SetCurrentPageTable(Kernel::Process& process) { + impl->SetCurrentPageTable(process); } -void ReadBlock(const VAddr src_addr, void* dest_buffer, const std::size_t size) { - ReadBlock(*Core::System::GetInstance().CurrentProcess(), src_addr, dest_buffer, size); +void Memory::MapMemoryRegion(Common::PageTable& page_table, VAddr base, u64 size, u8* target) { + impl->MapMemoryRegion(page_table, base, size, target); } -void Write8(const VAddr addr, const u8 data) { - Write<u8>(addr, data); +void Memory::MapIoRegion(Common::PageTable& page_table, VAddr base, u64 size, + Common::MemoryHookPointer mmio_handler) { + impl->MapIoRegion(page_table, base, size, std::move(mmio_handler)); } -void Write16(const VAddr addr, const u16 data) { - Write<u16_le>(addr, data); +void Memory::UnmapRegion(Common::PageTable& page_table, VAddr base, u64 size) { + impl->UnmapRegion(page_table, base, size); } -void Write32(const VAddr addr, const u32 data) { - Write<u32_le>(addr, data); +void Memory::AddDebugHook(Common::PageTable& page_table, VAddr base, u64 size, + Common::MemoryHookPointer hook) { + impl->AddDebugHook(page_table, base, size, std::move(hook)); } -void Write64(const VAddr addr, const u64 data) { - Write<u64_le>(addr, data); +void Memory::RemoveDebugHook(Common::PageTable& page_table, VAddr base, u64 size, + Common::MemoryHookPointer hook) { + impl->RemoveDebugHook(page_table, base, size, std::move(hook)); } -void WriteBlock(const Kernel::Process& process, const VAddr dest_addr, const void* src_buffer, - const std::size_t size) { - const auto& page_table = process.VMManager().page_table; - std::size_t remaining_size = size; - std::size_t page_index = dest_addr >> PAGE_BITS; - std::size_t page_offset = dest_addr & PAGE_MASK; - - while (remaining_size > 0) { - const std::size_t copy_amount = - std::min(static_cast<std::size_t>(PAGE_SIZE) - page_offset, remaining_size); - const VAddr current_vaddr = static_cast<VAddr>((page_index << PAGE_BITS) + page_offset); - - switch (page_table.attributes[page_index]) { - case Common::PageType::Unmapped: { - LOG_ERROR(HW_Memory, - "Unmapped WriteBlock @ 0x{:016X} (start address = 0x{:016X}, size = {})", - current_vaddr, dest_addr, size); - break; - } - case Common::PageType::Memory: { - DEBUG_ASSERT(page_table.pointers[page_index]); +bool Memory::IsValidVirtualAddress(const Kernel::Process& process, const VAddr vaddr) const { + return impl->IsValidVirtualAddress(process, vaddr); +} - u8* dest_ptr = page_table.pointers[page_index] + page_offset; - std::memcpy(dest_ptr, src_buffer, copy_amount); - break; - } - case Common::PageType::RasterizerCachedMemory: { - const auto& host_ptr{GetPointerFromVMA(process, current_vaddr)}; - Core::System::GetInstance().GPU().InvalidateRegion(ToCacheAddr(host_ptr), copy_amount); - std::memcpy(host_ptr, src_buffer, copy_amount); - break; - } - default: - UNREACHABLE(); - } +bool Memory::IsValidVirtualAddress(const VAddr vaddr) const { + return impl->IsValidVirtualAddress(vaddr); +} - page_index++; - page_offset = 0; - src_buffer = static_cast<const u8*>(src_buffer) + copy_amount; - remaining_size -= copy_amount; - } +u8* Memory::GetPointer(VAddr vaddr) { + return impl->GetPointer(vaddr); } -void WriteBlock(const VAddr dest_addr, const void* src_buffer, const std::size_t size) { - WriteBlock(*Core::System::GetInstance().CurrentProcess(), dest_addr, src_buffer, size); +const u8* Memory::GetPointer(VAddr vaddr) const { + return impl->GetPointer(vaddr); } -void ZeroBlock(const Kernel::Process& process, const VAddr dest_addr, const std::size_t size) { - const auto& page_table = process.VMManager().page_table; - std::size_t remaining_size = size; - std::size_t page_index = dest_addr >> PAGE_BITS; - std::size_t page_offset = dest_addr & PAGE_MASK; - - while (remaining_size > 0) { - const std::size_t copy_amount = - std::min(static_cast<std::size_t>(PAGE_SIZE) - page_offset, remaining_size); - const VAddr current_vaddr = static_cast<VAddr>((page_index << PAGE_BITS) + page_offset); - - switch (page_table.attributes[page_index]) { - case Common::PageType::Unmapped: { - LOG_ERROR(HW_Memory, - "Unmapped ZeroBlock @ 0x{:016X} (start address = 0x{:016X}, size = {})", - current_vaddr, dest_addr, size); - break; - } - case Common::PageType::Memory: { - DEBUG_ASSERT(page_table.pointers[page_index]); +u8 Memory::Read8(const VAddr addr) { + return impl->Read8(addr); +} - u8* dest_ptr = page_table.pointers[page_index] + page_offset; - std::memset(dest_ptr, 0, copy_amount); - break; - } - case Common::PageType::RasterizerCachedMemory: { - const auto& host_ptr{GetPointerFromVMA(process, current_vaddr)}; - Core::System::GetInstance().GPU().InvalidateRegion(ToCacheAddr(host_ptr), copy_amount); - std::memset(host_ptr, 0, copy_amount); - break; - } - default: - UNREACHABLE(); - } +u16 Memory::Read16(const VAddr addr) { + return impl->Read16(addr); +} - page_index++; - page_offset = 0; - remaining_size -= copy_amount; - } +u32 Memory::Read32(const VAddr addr) { + return impl->Read32(addr); } -void CopyBlock(const Kernel::Process& process, VAddr dest_addr, VAddr src_addr, - const std::size_t size) { - const auto& page_table = process.VMManager().page_table; - std::size_t remaining_size = size; - std::size_t page_index = src_addr >> PAGE_BITS; - std::size_t page_offset = src_addr & PAGE_MASK; - - while (remaining_size > 0) { - const std::size_t copy_amount = - std::min(static_cast<std::size_t>(PAGE_SIZE) - page_offset, remaining_size); - const VAddr current_vaddr = static_cast<VAddr>((page_index << PAGE_BITS) + page_offset); - - switch (page_table.attributes[page_index]) { - case Common::PageType::Unmapped: { - LOG_ERROR(HW_Memory, - "Unmapped CopyBlock @ 0x{:016X} (start address = 0x{:016X}, size = {})", - current_vaddr, src_addr, size); - ZeroBlock(process, dest_addr, copy_amount); - break; - } - case Common::PageType::Memory: { - DEBUG_ASSERT(page_table.pointers[page_index]); - const u8* src_ptr = page_table.pointers[page_index] + page_offset; - WriteBlock(process, dest_addr, src_ptr, copy_amount); - break; - } - case Common::PageType::RasterizerCachedMemory: { - const auto& host_ptr{GetPointerFromVMA(process, current_vaddr)}; - Core::System::GetInstance().GPU().FlushRegion(ToCacheAddr(host_ptr), copy_amount); - WriteBlock(process, dest_addr, host_ptr, copy_amount); - break; - } - default: - UNREACHABLE(); - } +u64 Memory::Read64(const VAddr addr) { + return impl->Read64(addr); +} - page_index++; - page_offset = 0; - dest_addr += static_cast<VAddr>(copy_amount); - src_addr += static_cast<VAddr>(copy_amount); - remaining_size -= copy_amount; - } +void Memory::Write8(VAddr addr, u8 data) { + impl->Write8(addr, data); +} + +void Memory::Write16(VAddr addr, u16 data) { + impl->Write16(addr, data); +} + +void Memory::Write32(VAddr addr, u32 data) { + impl->Write32(addr, data); +} + +void Memory::Write64(VAddr addr, u64 data) { + impl->Write64(addr, data); +} + +std::string Memory::ReadCString(VAddr vaddr, std::size_t max_length) { + return impl->ReadCString(vaddr, max_length); +} + +void Memory::ReadBlock(const Kernel::Process& process, const VAddr src_addr, void* dest_buffer, + const std::size_t size) { + impl->ReadBlock(process, src_addr, dest_buffer, size); +} + +void Memory::ReadBlock(const VAddr src_addr, void* dest_buffer, const std::size_t size) { + impl->ReadBlock(src_addr, dest_buffer, size); +} + +void Memory::WriteBlock(const Kernel::Process& process, VAddr dest_addr, const void* src_buffer, + std::size_t size) { + impl->WriteBlock(process, dest_addr, src_buffer, size); +} + +void Memory::WriteBlock(const VAddr dest_addr, const void* src_buffer, const std::size_t size) { + impl->WriteBlock(dest_addr, src_buffer, size); } -void CopyBlock(VAddr dest_addr, VAddr src_addr, std::size_t size) { - CopyBlock(*Core::System::GetInstance().CurrentProcess(), dest_addr, src_addr, size); +void Memory::ZeroBlock(const Kernel::Process& process, VAddr dest_addr, std::size_t size) { + impl->ZeroBlock(process, dest_addr, size); +} + +void Memory::ZeroBlock(VAddr dest_addr, std::size_t size) { + impl->ZeroBlock(dest_addr, size); +} + +void Memory::CopyBlock(const Kernel::Process& process, VAddr dest_addr, VAddr src_addr, + const std::size_t size) { + impl->CopyBlock(process, dest_addr, src_addr, size); +} + +void Memory::CopyBlock(VAddr dest_addr, VAddr src_addr, std::size_t size) { + impl->CopyBlock(dest_addr, src_addr, size); +} + +void Memory::RasterizerMarkRegionCached(VAddr vaddr, u64 size, bool cached) { + impl->RasterizerMarkRegionCached(vaddr, size, cached); +} + +bool IsKernelVirtualAddress(const VAddr vaddr) { + return KERNEL_REGION_VADDR <= vaddr && vaddr < KERNEL_REGION_END; } } // namespace Memory diff --git a/src/core/memory.h b/src/core/memory.h index 09008e1dd..1428a6d60 100644 --- a/src/core/memory.h +++ b/src/core/memory.h @@ -5,8 +5,18 @@ #pragma once #include <cstddef> +#include <memory> #include <string> #include "common/common_types.h" +#include "common/memory_hook.h" + +namespace Common { +struct PageTable; +} + +namespace Core { +class System; +} namespace Kernel { class Process; @@ -36,41 +46,369 @@ enum : VAddr { KERNEL_REGION_END = KERNEL_REGION_VADDR + KERNEL_REGION_SIZE, }; -/// Changes the currently active page table to that of -/// the given process instance. -void SetCurrentPageTable(Kernel::Process& process); +/// Central class that handles all memory operations and state. +class Memory { +public: + explicit Memory(Core::System& system); + ~Memory(); -/// Determines if the given VAddr is valid for the specified process. -bool IsValidVirtualAddress(const Kernel::Process& process, VAddr vaddr); -bool IsValidVirtualAddress(VAddr vaddr); -/// Determines if the given VAddr is a kernel address -bool IsKernelVirtualAddress(VAddr vaddr); + Memory(const Memory&) = delete; + Memory& operator=(const Memory&) = delete; -u8 Read8(VAddr addr); -u16 Read16(VAddr addr); -u32 Read32(VAddr addr); -u64 Read64(VAddr addr); + Memory(Memory&&) = default; + Memory& operator=(Memory&&) = default; -void Write8(VAddr addr, u8 data); -void Write16(VAddr addr, u16 data); -void Write32(VAddr addr, u32 data); -void Write64(VAddr addr, u64 data); + /** + * Changes the currently active page table to that of the given process instance. + * + * @param process The process to use the page table of. + */ + void SetCurrentPageTable(Kernel::Process& process); -void ReadBlock(const Kernel::Process& process, VAddr src_addr, void* dest_buffer, std::size_t size); -void ReadBlock(VAddr src_addr, void* dest_buffer, std::size_t size); -void WriteBlock(const Kernel::Process& process, VAddr dest_addr, const void* src_buffer, - std::size_t size); -void WriteBlock(VAddr dest_addr, const void* src_buffer, std::size_t size); -void ZeroBlock(const Kernel::Process& process, VAddr dest_addr, std::size_t size); -void CopyBlock(VAddr dest_addr, VAddr src_addr, std::size_t size); + /** + * Maps an allocated buffer onto a region of the emulated process address space. + * + * @param page_table The page table of the emulated process. + * @param base The address to start mapping at. Must be page-aligned. + * @param size The amount of bytes to map. Must be page-aligned. + * @param target Buffer with the memory backing the mapping. Must be of length at least + * `size`. + */ + void MapMemoryRegion(Common::PageTable& page_table, VAddr base, u64 size, u8* target); -u8* GetPointer(VAddr vaddr); + /** + * Maps a region of the emulated process address space as a IO region. + * + * @param page_table The page table of the emulated process. + * @param base The address to start mapping at. Must be page-aligned. + * @param size The amount of bytes to map. Must be page-aligned. + * @param mmio_handler The handler that backs the mapping. + */ + void MapIoRegion(Common::PageTable& page_table, VAddr base, u64 size, + Common::MemoryHookPointer mmio_handler); -std::string ReadCString(VAddr vaddr, std::size_t max_length); + /** + * Unmaps a region of the emulated process address space. + * + * @param page_table The page table of the emulated process. + * @param base The address to begin unmapping at. + * @param size The amount of bytes to unmap. + */ + void UnmapRegion(Common::PageTable& page_table, VAddr base, u64 size); -/** - * Mark each page touching the region as cached. - */ -void RasterizerMarkRegionCached(VAddr vaddr, u64 size, bool cached); + /** + * Adds a memory hook to intercept reads and writes to given region of memory. + * + * @param page_table The page table of the emulated process + * @param base The starting address to apply the hook to. + * @param size The size of the memory region to apply the hook to, in bytes. + * @param hook The hook to apply to the region of memory. + */ + void AddDebugHook(Common::PageTable& page_table, VAddr base, u64 size, + Common::MemoryHookPointer hook); + + /** + * Removes a memory hook from a given range of memory. + * + * @param page_table The page table of the emulated process. + * @param base The starting address to remove the hook from. + * @param size The size of the memory region to remove the hook from, in bytes. + * @param hook The hook to remove from the specified region of memory. + */ + void RemoveDebugHook(Common::PageTable& page_table, VAddr base, u64 size, + Common::MemoryHookPointer hook); + + /** + * Checks whether or not the supplied address is a valid virtual + * address for the given process. + * + * @param process The emulated process to check the address against. + * @param vaddr The virtual address to check the validity of. + * + * @returns True if the given virtual address is valid, false otherwise. + */ + bool IsValidVirtualAddress(const Kernel::Process& process, VAddr vaddr) const; + + /** + * Checks whether or not the supplied address is a valid virtual + * address for the current process. + * + * @param vaddr The virtual address to check the validity of. + * + * @returns True if the given virtual address is valid, false otherwise. + */ + bool IsValidVirtualAddress(VAddr vaddr) const; + + /** + * Gets a pointer to the given address. + * + * @param vaddr Virtual address to retrieve a pointer to. + * + * @returns The pointer to the given address, if the address is valid. + * If the address is not valid, nullptr will be returned. + */ + u8* GetPointer(VAddr vaddr); + + /** + * Gets a pointer to the given address. + * + * @param vaddr Virtual address to retrieve a pointer to. + * + * @returns The pointer to the given address, if the address is valid. + * If the address is not valid, nullptr will be returned. + */ + const u8* GetPointer(VAddr vaddr) const; + + /** + * Reads an 8-bit unsigned value from the current process' address space + * at the given virtual address. + * + * @param addr The virtual address to read the 8-bit value from. + * + * @returns the read 8-bit unsigned value. + */ + u8 Read8(VAddr addr); + + /** + * Reads a 16-bit unsigned value from the current process' address space + * at the given virtual address. + * + * @param addr The virtual address to read the 16-bit value from. + * + * @returns the read 16-bit unsigned value. + */ + u16 Read16(VAddr addr); + + /** + * Reads a 32-bit unsigned value from the current process' address space + * at the given virtual address. + * + * @param addr The virtual address to read the 32-bit value from. + * + * @returns the read 32-bit unsigned value. + */ + u32 Read32(VAddr addr); + + /** + * Reads a 64-bit unsigned value from the current process' address space + * at the given virtual address. + * + * @param addr The virtual address to read the 64-bit value from. + * + * @returns the read 64-bit value. + */ + u64 Read64(VAddr addr); + + /** + * Writes an 8-bit unsigned integer to the given virtual address in + * the current process' address space. + * + * @param addr The virtual address to write the 8-bit unsigned integer to. + * @param data The 8-bit unsigned integer to write to the given virtual address. + * + * @post The memory at the given virtual address contains the specified data value. + */ + void Write8(VAddr addr, u8 data); + + /** + * Writes a 16-bit unsigned integer to the given virtual address in + * the current process' address space. + * + * @param addr The virtual address to write the 16-bit unsigned integer to. + * @param data The 16-bit unsigned integer to write to the given virtual address. + * + * @post The memory range [addr, sizeof(data)) contains the given data value. + */ + void Write16(VAddr addr, u16 data); + + /** + * Writes a 32-bit unsigned integer to the given virtual address in + * the current process' address space. + * + * @param addr The virtual address to write the 32-bit unsigned integer to. + * @param data The 32-bit unsigned integer to write to the given virtual address. + * + * @post The memory range [addr, sizeof(data)) contains the given data value. + */ + void Write32(VAddr addr, u32 data); + + /** + * Writes a 64-bit unsigned integer to the given virtual address in + * the current process' address space. + * + * @param addr The virtual address to write the 64-bit unsigned integer to. + * @param data The 64-bit unsigned integer to write to the given virtual address. + * + * @post The memory range [addr, sizeof(data)) contains the given data value. + */ + void Write64(VAddr addr, u64 data); + + /** + * Reads a null-terminated string from the given virtual address. + * This function will continually read characters until either: + * + * - A null character ('\0') is reached. + * - max_length characters have been read. + * + * @note The final null-terminating character (if found) is not included + * in the returned string. + * + * @param vaddr The address to begin reading the string from. + * @param max_length The maximum length of the string to read in characters. + * + * @returns The read string. + */ + std::string ReadCString(VAddr vaddr, std::size_t max_length); + + /** + * Reads a contiguous block of bytes from a specified process' address space. + * + * @param process The process to read the data from. + * @param src_addr The virtual address to begin reading from. + * @param dest_buffer The buffer to place the read bytes into. + * @param size The amount of data to read, in bytes. + * + * @note If a size of 0 is specified, then this function reads nothing and + * no attempts to access memory are made at all. + * + * @pre dest_buffer must be at least size bytes in length, otherwise a + * buffer overrun will occur. + * + * @post The range [dest_buffer, size) contains the read bytes from the + * process' address space. + */ + void ReadBlock(const Kernel::Process& process, VAddr src_addr, void* dest_buffer, + std::size_t size); + + /** + * Reads a contiguous block of bytes from the current process' address space. + * + * @param src_addr The virtual address to begin reading from. + * @param dest_buffer The buffer to place the read bytes into. + * @param size The amount of data to read, in bytes. + * + * @note If a size of 0 is specified, then this function reads nothing and + * no attempts to access memory are made at all. + * + * @pre dest_buffer must be at least size bytes in length, otherwise a + * buffer overrun will occur. + * + * @post The range [dest_buffer, size) contains the read bytes from the + * current process' address space. + */ + void ReadBlock(VAddr src_addr, void* dest_buffer, std::size_t size); + + /** + * Writes a range of bytes into a given process' address space at the specified + * virtual address. + * + * @param process The process to write data into the address space of. + * @param dest_addr The destination virtual address to begin writing the data at. + * @param src_buffer The data to write into the process' address space. + * @param size The size of the data to write, in bytes. + * + * @post The address range [dest_addr, size) in the process' address space + * contains the data that was within src_buffer. + * + * @post If an attempt is made to write into an unmapped region of memory, the writes + * will be ignored and an error will be logged. + * + * @post If a write is performed into a region of memory that is considered cached + * rasterizer memory, will cause the currently active rasterizer to be notified + * and will mark that region as invalidated to caches that the active + * graphics backend may be maintaining over the course of execution. + */ + void WriteBlock(const Kernel::Process& process, VAddr dest_addr, const void* src_buffer, + std::size_t size); + + /** + * Writes a range of bytes into the current process' address space at the specified + * virtual address. + * + * @param dest_addr The destination virtual address to begin writing the data at. + * @param src_buffer The data to write into the current process' address space. + * @param size The size of the data to write, in bytes. + * + * @post The address range [dest_addr, size) in the current process' address space + * contains the data that was within src_buffer. + * + * @post If an attempt is made to write into an unmapped region of memory, the writes + * will be ignored and an error will be logged. + * + * @post If a write is performed into a region of memory that is considered cached + * rasterizer memory, will cause the currently active rasterizer to be notified + * and will mark that region as invalidated to caches that the active + * graphics backend may be maintaining over the course of execution. + */ + void WriteBlock(VAddr dest_addr, const void* src_buffer, std::size_t size); + + /** + * Fills the specified address range within a process' address space with zeroes. + * + * @param process The process that will have a portion of its memory zeroed out. + * @param dest_addr The starting virtual address of the range to zero out. + * @param size The size of the address range to zero out, in bytes. + * + * @post The range [dest_addr, size) within the process' address space is + * filled with zeroes. + */ + void ZeroBlock(const Kernel::Process& process, VAddr dest_addr, std::size_t size); + + /** + * Fills the specified address range within the current process' address space with zeroes. + * + * @param dest_addr The starting virtual address of the range to zero out. + * @param size The size of the address range to zero out, in bytes. + * + * @post The range [dest_addr, size) within the current process' address space is + * filled with zeroes. + */ + void ZeroBlock(VAddr dest_addr, std::size_t size); + + /** + * Copies data within a process' address space to another location within the + * same address space. + * + * @param process The process that will have data copied within its address space. + * @param dest_addr The destination virtual address to begin copying the data into. + * @param src_addr The source virtual address to begin copying the data from. + * @param size The size of the data to copy, in bytes. + * + * @post The range [dest_addr, size) within the process' address space contains the + * same data within the range [src_addr, size). + */ + void CopyBlock(const Kernel::Process& process, VAddr dest_addr, VAddr src_addr, + std::size_t size); + + /** + * Copies data within the current process' address space to another location within the + * same address space. + * + * @param dest_addr The destination virtual address to begin copying the data into. + * @param src_addr The source virtual address to begin copying the data from. + * @param size The size of the data to copy, in bytes. + * + * @post The range [dest_addr, size) within the current process' address space + * contains the same data within the range [src_addr, size). + */ + void CopyBlock(VAddr dest_addr, VAddr src_addr, std::size_t size); + + /** + * Marks each page within the specified address range as cached or uncached. + * + * @param vaddr The virtual address indicating the start of the address range. + * @param size The size of the address range in bytes. + * @param cached Whether or not any pages within the address range should be + * marked as cached or uncached. + */ + void RasterizerMarkRegionCached(VAddr vaddr, u64 size, bool cached); + +private: + struct Impl; + std::unique_ptr<Impl> impl; +}; + +/// Determines if the given VAddr is a kernel address +bool IsKernelVirtualAddress(VAddr vaddr); } // namespace Memory diff --git a/src/core/memory/cheat_engine.cpp b/src/core/memory/cheat_engine.cpp index 10821d452..d1e6bed93 100644 --- a/src/core/memory/cheat_engine.cpp +++ b/src/core/memory/cheat_engine.cpp @@ -20,18 +20,17 @@ namespace Memory { constexpr s64 CHEAT_ENGINE_TICKS = static_cast<s64>(Core::Timing::BASE_CLOCK_RATE / 12); constexpr u32 KEYPAD_BITMASK = 0x3FFFFFF; -StandardVmCallbacks::StandardVmCallbacks(const Core::System& system, - const CheatProcessMetadata& metadata) +StandardVmCallbacks::StandardVmCallbacks(Core::System& system, const CheatProcessMetadata& metadata) : metadata(metadata), system(system) {} StandardVmCallbacks::~StandardVmCallbacks() = default; void StandardVmCallbacks::MemoryRead(VAddr address, void* data, u64 size) { - ReadBlock(SanitizeAddress(address), data, size); + system.Memory().ReadBlock(SanitizeAddress(address), data, size); } void StandardVmCallbacks::MemoryWrite(VAddr address, const void* data, u64 size) { - WriteBlock(SanitizeAddress(address), data, size); + system.Memory().WriteBlock(SanitizeAddress(address), data, size); } u64 StandardVmCallbacks::HidKeysDown() { @@ -186,7 +185,7 @@ CheatEngine::~CheatEngine() { } void CheatEngine::Initialize() { - event = core_timing.RegisterEvent( + event = Core::Timing::CreateEvent( "CheatEngine::FrameCallback::" + Common::HexToString(metadata.main_nso_build_id), [this](u64 userdata, s64 cycles_late) { FrameCallback(userdata, cycles_late); }); core_timing.ScheduleEvent(CHEAT_ENGINE_TICKS, event); diff --git a/src/core/memory/cheat_engine.h b/src/core/memory/cheat_engine.h index 0f012e9b5..3d6b2298a 100644 --- a/src/core/memory/cheat_engine.h +++ b/src/core/memory/cheat_engine.h @@ -5,6 +5,7 @@ #pragma once #include <atomic> +#include <memory> #include <vector> #include "common/common_types.h" #include "core/memory/dmnt_cheat_types.h" @@ -23,7 +24,7 @@ namespace Memory { class StandardVmCallbacks : public DmntCheatVm::Callbacks { public: - StandardVmCallbacks(const Core::System& system, const CheatProcessMetadata& metadata); + StandardVmCallbacks(Core::System& system, const CheatProcessMetadata& metadata); ~StandardVmCallbacks() override; void MemoryRead(VAddr address, void* data, u64 size) override; @@ -36,7 +37,7 @@ private: VAddr SanitizeAddress(VAddr address) const; const CheatProcessMetadata& metadata; - const Core::System& system; + Core::System& system; }; // Intermediary class that parses a text file or other disk format for storing cheats into a @@ -78,7 +79,7 @@ private: std::vector<CheatEntry> cheats; std::atomic_bool is_pending_reload{false}; - Core::Timing::EventType* event{}; + std::shared_ptr<Core::Timing::EventType> event; Core::Timing::CoreTiming& core_timing; Core::System& system; }; diff --git a/src/core/memory_setup.h b/src/core/memory_setup.h deleted file mode 100644 index 5225ee8e2..000000000 --- a/src/core/memory_setup.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2015 Citra Emulator Project -// Licensed under GPLv2 or any later version -// Refer to the license.txt file included. - -#pragma once - -#include "common/common_types.h" -#include "common/memory_hook.h" - -namespace Common { -struct PageTable; -} - -namespace Memory { - -/** - * Maps an allocated buffer onto a region of the emulated process address space. - * - * @param page_table The page table of the emulated process. - * @param base The address to start mapping at. Must be page-aligned. - * @param size The amount of bytes to map. Must be page-aligned. - * @param target Buffer with the memory backing the mapping. Must be of length at least `size`. - */ -void MapMemoryRegion(Common::PageTable& page_table, VAddr base, u64 size, u8* target); - -/** - * Maps a region of the emulated process address space as a IO region. - * @param page_table The page table of the emulated process. - * @param base The address to start mapping at. Must be page-aligned. - * @param size The amount of bytes to map. Must be page-aligned. - * @param mmio_handler The handler that backs the mapping. - */ -void MapIoRegion(Common::PageTable& page_table, VAddr base, u64 size, - Common::MemoryHookPointer mmio_handler); - -void UnmapRegion(Common::PageTable& page_table, VAddr base, u64 size); - -void AddDebugHook(Common::PageTable& page_table, VAddr base, u64 size, - Common::MemoryHookPointer hook); -void RemoveDebugHook(Common::PageTable& page_table, VAddr base, u64 size, - Common::MemoryHookPointer hook); - -} // namespace Memory diff --git a/src/core/perf_stats.cpp b/src/core/perf_stats.cpp index d2c69d1a0..f1ae9d4df 100644 --- a/src/core/perf_stats.cpp +++ b/src/core/perf_stats.cpp @@ -81,7 +81,7 @@ double PerfStats::GetMeanFrametime() { return 0; } const double sum = std::accumulate(perf_history.begin() + IgnoreFrames, - perf_history.begin() + current_index, 0); + perf_history.begin() + current_index, 0.0); return sum / (current_index - IgnoreFrames); } diff --git a/src/core/reporter.cpp b/src/core/reporter.cpp index 6f4af77fd..f95eee3b1 100644 --- a/src/core/reporter.cpp +++ b/src/core/reporter.cpp @@ -147,7 +147,7 @@ json GetFullDataAuto(const std::string& timestamp, u64 title_id, Core::System& s } template <bool read_value, typename DescriptorType> -json GetHLEBufferDescriptorData(const std::vector<DescriptorType>& buffer) { +json GetHLEBufferDescriptorData(const std::vector<DescriptorType>& buffer, Memory::Memory& memory) { auto buffer_out = json::array(); for (const auto& desc : buffer) { auto entry = json{ @@ -157,7 +157,7 @@ json GetHLEBufferDescriptorData(const std::vector<DescriptorType>& buffer) { if constexpr (read_value) { std::vector<u8> data(desc.Size()); - Memory::ReadBlock(desc.Address(), data.data(), desc.Size()); + memory.ReadBlock(desc.Address(), data.data(), desc.Size()); entry["data"] = Common::HexToString(data); } @@ -167,7 +167,7 @@ json GetHLEBufferDescriptorData(const std::vector<DescriptorType>& buffer) { return buffer_out; } -json GetHLERequestContextData(Kernel::HLERequestContext& ctx) { +json GetHLERequestContextData(Kernel::HLERequestContext& ctx, Memory::Memory& memory) { json out; auto cmd_buf = json::array(); @@ -177,10 +177,10 @@ json GetHLERequestContextData(Kernel::HLERequestContext& ctx) { out["command_buffer"] = std::move(cmd_buf); - out["buffer_descriptor_a"] = GetHLEBufferDescriptorData<true>(ctx.BufferDescriptorA()); - out["buffer_descriptor_b"] = GetHLEBufferDescriptorData<false>(ctx.BufferDescriptorB()); - out["buffer_descriptor_c"] = GetHLEBufferDescriptorData<false>(ctx.BufferDescriptorC()); - out["buffer_descriptor_x"] = GetHLEBufferDescriptorData<true>(ctx.BufferDescriptorX()); + out["buffer_descriptor_a"] = GetHLEBufferDescriptorData<true>(ctx.BufferDescriptorA(), memory); + out["buffer_descriptor_b"] = GetHLEBufferDescriptorData<false>(ctx.BufferDescriptorB(), memory); + out["buffer_descriptor_c"] = GetHLEBufferDescriptorData<false>(ctx.BufferDescriptorC(), memory); + out["buffer_descriptor_x"] = GetHLEBufferDescriptorData<true>(ctx.BufferDescriptorX(), memory); return out; } @@ -259,7 +259,7 @@ void Reporter::SaveUnimplementedFunctionReport(Kernel::HLERequestContext& ctx, u const auto title_id = system.CurrentProcess()->GetTitleID(); auto out = GetFullDataAuto(timestamp, title_id, system); - auto function_out = GetHLERequestContextData(ctx); + auto function_out = GetHLERequestContextData(ctx, system.Memory()); function_out["command_id"] = command_id; function_out["function_name"] = name; function_out["service_name"] = service_name; diff --git a/src/core/telemetry_session.cpp b/src/core/telemetry_session.cpp index 793d102d3..320e8ad73 100644 --- a/src/core/telemetry_session.cpp +++ b/src/core/telemetry_session.cpp @@ -165,24 +165,20 @@ void TelemetrySession::AddInitialInfo(Loader::AppLoader& app_loader) { Telemetry::AppendOSInfo(field_collection); // Log user configuration information - AddField(Telemetry::FieldType::UserConfig, "Audio_SinkId", Settings::values.sink_id); - AddField(Telemetry::FieldType::UserConfig, "Audio_EnableAudioStretching", - Settings::values.enable_audio_stretching); - AddField(Telemetry::FieldType::UserConfig, "Core_UseMultiCore", - Settings::values.use_multi_core); - AddField(Telemetry::FieldType::UserConfig, "Renderer_ResolutionFactor", - Settings::values.resolution_factor); - AddField(Telemetry::FieldType::UserConfig, "Renderer_UseFrameLimit", - Settings::values.use_frame_limit); - AddField(Telemetry::FieldType::UserConfig, "Renderer_FrameLimit", Settings::values.frame_limit); - AddField(Telemetry::FieldType::UserConfig, "Renderer_UseDiskShaderCache", - Settings::values.use_disk_shader_cache); - AddField(Telemetry::FieldType::UserConfig, "Renderer_UseAccurateGpuEmulation", + constexpr auto field_type = Telemetry::FieldType::UserConfig; + AddField(field_type, "Audio_SinkId", Settings::values.sink_id); + AddField(field_type, "Audio_EnableAudioStretching", Settings::values.enable_audio_stretching); + AddField(field_type, "Core_UseMultiCore", Settings::values.use_multi_core); + AddField(field_type, "Renderer_Backend", "OpenGL"); + AddField(field_type, "Renderer_ResolutionFactor", Settings::values.resolution_factor); + AddField(field_type, "Renderer_UseFrameLimit", Settings::values.use_frame_limit); + AddField(field_type, "Renderer_FrameLimit", Settings::values.frame_limit); + AddField(field_type, "Renderer_UseDiskShaderCache", Settings::values.use_disk_shader_cache); + AddField(field_type, "Renderer_UseAccurateGpuEmulation", Settings::values.use_accurate_gpu_emulation); - AddField(Telemetry::FieldType::UserConfig, "Renderer_UseAsynchronousGpuEmulation", + AddField(field_type, "Renderer_UseAsynchronousGpuEmulation", Settings::values.use_asynchronous_gpu_emulation); - AddField(Telemetry::FieldType::UserConfig, "System_UseDockedMode", - Settings::values.use_docked_mode); + AddField(field_type, "System_UseDockedMode", Settings::values.use_docked_mode); } bool TelemetrySession::SubmitTestcase() { diff --git a/src/core/tools/freezer.cpp b/src/core/tools/freezer.cpp index 17f050068..55e0dbc49 100644 --- a/src/core/tools/freezer.cpp +++ b/src/core/tools/freezer.cpp @@ -11,40 +11,39 @@ #include "core/tools/freezer.h" namespace Tools { - namespace { constexpr s64 MEMORY_FREEZER_TICKS = static_cast<s64>(Core::Timing::BASE_CLOCK_RATE / 60); -u64 MemoryReadWidth(u32 width, VAddr addr) { +u64 MemoryReadWidth(Memory::Memory& memory, u32 width, VAddr addr) { switch (width) { case 1: - return Memory::Read8(addr); + return memory.Read8(addr); case 2: - return Memory::Read16(addr); + return memory.Read16(addr); case 4: - return Memory::Read32(addr); + return memory.Read32(addr); case 8: - return Memory::Read64(addr); + return memory.Read64(addr); default: UNREACHABLE(); return 0; } } -void MemoryWriteWidth(u32 width, VAddr addr, u64 value) { +void MemoryWriteWidth(Memory::Memory& memory, u32 width, VAddr addr, u64 value) { switch (width) { case 1: - Memory::Write8(addr, static_cast<u8>(value)); + memory.Write8(addr, static_cast<u8>(value)); break; case 2: - Memory::Write16(addr, static_cast<u16>(value)); + memory.Write16(addr, static_cast<u16>(value)); break; case 4: - Memory::Write32(addr, static_cast<u32>(value)); + memory.Write32(addr, static_cast<u32>(value)); break; case 8: - Memory::Write64(addr, value); + memory.Write64(addr, value); break; default: UNREACHABLE(); @@ -53,8 +52,9 @@ void MemoryWriteWidth(u32 width, VAddr addr, u64 value) { } // Anonymous namespace -Freezer::Freezer(Core::Timing::CoreTiming& core_timing) : core_timing(core_timing) { - event = core_timing.RegisterEvent( +Freezer::Freezer(Core::Timing::CoreTiming& core_timing_, Memory::Memory& memory_) + : core_timing{core_timing_}, memory{memory_} { + event = Core::Timing::CreateEvent( "MemoryFreezer::FrameCallback", [this](u64 userdata, s64 cycles_late) { FrameCallback(userdata, cycles_late); }); core_timing.ScheduleEvent(MEMORY_FREEZER_TICKS, event); @@ -89,7 +89,7 @@ void Freezer::Clear() { u64 Freezer::Freeze(VAddr address, u32 width) { std::lock_guard lock{entries_mutex}; - const auto current_value = MemoryReadWidth(width, address); + const auto current_value = MemoryReadWidth(memory, width, address); entries.push_back({address, width, current_value}); LOG_DEBUG(Common_Memory, @@ -169,7 +169,7 @@ void Freezer::FrameCallback(u64 userdata, s64 cycles_late) { LOG_DEBUG(Common_Memory, "Enforcing memory freeze at address={:016X}, value={:016X}, width={:02X}", entry.address, entry.value, entry.width); - MemoryWriteWidth(entry.width, entry.address, entry.value); + MemoryWriteWidth(memory, entry.width, entry.address, entry.value); } core_timing.ScheduleEvent(MEMORY_FREEZER_TICKS - cycles_late, event); @@ -181,7 +181,7 @@ void Freezer::FillEntryReads() { LOG_DEBUG(Common_Memory, "Updating memory freeze entries to current values."); for (auto& entry : entries) { - entry.value = MemoryReadWidth(entry.width, entry.address); + entry.value = MemoryReadWidth(memory, entry.width, entry.address); } } diff --git a/src/core/tools/freezer.h b/src/core/tools/freezer.h index b58de5472..916339c6c 100644 --- a/src/core/tools/freezer.h +++ b/src/core/tools/freezer.h @@ -5,6 +5,7 @@ #pragma once #include <atomic> +#include <memory> #include <mutex> #include <optional> #include <vector> @@ -15,6 +16,10 @@ class CoreTiming; struct EventType; } // namespace Core::Timing +namespace Memory { +class Memory; +} + namespace Tools { /** @@ -33,7 +38,7 @@ public: u64 value; }; - explicit Freezer(Core::Timing::CoreTiming& core_timing); + explicit Freezer(Core::Timing::CoreTiming& core_timing_, Memory::Memory& memory_); ~Freezer(); // Enables or disables the entire memory freezer. @@ -75,8 +80,9 @@ private: mutable std::mutex entries_mutex; std::vector<Entry> entries; - Core::Timing::EventType* event; + std::shared_ptr<Core::Timing::EventType> event; Core::Timing::CoreTiming& core_timing; + Memory::Memory& memory; }; } // namespace Tools diff --git a/src/tests/core/arm/arm_test_common.cpp b/src/tests/core/arm/arm_test_common.cpp index ac7ae3e52..17043346b 100644 --- a/src/tests/core/arm/arm_test_common.cpp +++ b/src/tests/core/arm/arm_test_common.cpp @@ -8,7 +8,6 @@ #include "core/core.h" #include "core/hle/kernel/process.h" #include "core/memory.h" -#include "core/memory_setup.h" #include "tests/core/arm/arm_test_common.h" namespace ArmTests { @@ -16,8 +15,9 @@ namespace ArmTests { TestEnvironment::TestEnvironment(bool mutable_memory_) : mutable_memory(mutable_memory_), test_memory(std::make_shared<TestMemory>(this)), kernel{Core::System::GetInstance()} { - auto process = Kernel::Process::Create(Core::System::GetInstance(), "", - Kernel::Process::ProcessType::Userland); + auto& system = Core::System::GetInstance(); + + auto process = Kernel::Process::Create(system, "", Kernel::Process::ProcessType::Userland); page_table = &process->VMManager().page_table; std::fill(page_table->pointers.begin(), page_table->pointers.end(), nullptr); @@ -25,15 +25,16 @@ TestEnvironment::TestEnvironment(bool mutable_memory_) std::fill(page_table->attributes.begin(), page_table->attributes.end(), Common::PageType::Unmapped); - Memory::MapIoRegion(*page_table, 0x00000000, 0x80000000, test_memory); - Memory::MapIoRegion(*page_table, 0x80000000, 0x80000000, test_memory); + system.Memory().MapIoRegion(*page_table, 0x00000000, 0x80000000, test_memory); + system.Memory().MapIoRegion(*page_table, 0x80000000, 0x80000000, test_memory); kernel.MakeCurrentProcess(process.get()); } TestEnvironment::~TestEnvironment() { - Memory::UnmapRegion(*page_table, 0x80000000, 0x80000000); - Memory::UnmapRegion(*page_table, 0x00000000, 0x80000000); + auto& system = Core::System::GetInstance(); + system.Memory().UnmapRegion(*page_table, 0x80000000, 0x80000000); + system.Memory().UnmapRegion(*page_table, 0x00000000, 0x80000000); } void TestEnvironment::SetMemory64(VAddr vaddr, u64 value) { diff --git a/src/tests/core/core_timing.cpp b/src/tests/core/core_timing.cpp index 3443bf05e..1e3940801 100644 --- a/src/tests/core/core_timing.cpp +++ b/src/tests/core/core_timing.cpp @@ -7,7 +7,9 @@ #include <array> #include <bitset> #include <cstdlib> +#include <memory> #include <string> + #include "common/file_util.h" #include "core/core.h" #include "core/core_timing.h" @@ -65,11 +67,16 @@ TEST_CASE("CoreTiming[BasicOrder]", "[core]") { ScopeInit guard; auto& core_timing = guard.core_timing; - Core::Timing::EventType* cb_a = core_timing.RegisterEvent("callbackA", CallbackTemplate<0>); - Core::Timing::EventType* cb_b = core_timing.RegisterEvent("callbackB", CallbackTemplate<1>); - Core::Timing::EventType* cb_c = core_timing.RegisterEvent("callbackC", CallbackTemplate<2>); - Core::Timing::EventType* cb_d = core_timing.RegisterEvent("callbackD", CallbackTemplate<3>); - Core::Timing::EventType* cb_e = core_timing.RegisterEvent("callbackE", CallbackTemplate<4>); + std::shared_ptr<Core::Timing::EventType> cb_a = + Core::Timing::CreateEvent("callbackA", CallbackTemplate<0>); + std::shared_ptr<Core::Timing::EventType> cb_b = + Core::Timing::CreateEvent("callbackB", CallbackTemplate<1>); + std::shared_ptr<Core::Timing::EventType> cb_c = + Core::Timing::CreateEvent("callbackC", CallbackTemplate<2>); + std::shared_ptr<Core::Timing::EventType> cb_d = + Core::Timing::CreateEvent("callbackD", CallbackTemplate<3>); + std::shared_ptr<Core::Timing::EventType> cb_e = + Core::Timing::CreateEvent("callbackE", CallbackTemplate<4>); // Enter slice 0 core_timing.ResetRun(); @@ -99,8 +106,8 @@ TEST_CASE("CoreTiming[FairSharing]", "[core]") { ScopeInit guard; auto& core_timing = guard.core_timing; - Core::Timing::EventType* empty_callback = - core_timing.RegisterEvent("empty_callback", EmptyCallback); + std::shared_ptr<Core::Timing::EventType> empty_callback = + Core::Timing::CreateEvent("empty_callback", EmptyCallback); callbacks_done = 0; u64 MAX_CALLBACKS = 10; @@ -133,8 +140,10 @@ TEST_CASE("Core::Timing[PredictableLateness]", "[core]") { ScopeInit guard; auto& core_timing = guard.core_timing; - Core::Timing::EventType* cb_a = core_timing.RegisterEvent("callbackA", CallbackTemplate<0>); - Core::Timing::EventType* cb_b = core_timing.RegisterEvent("callbackB", CallbackTemplate<1>); + std::shared_ptr<Core::Timing::EventType> cb_a = + Core::Timing::CreateEvent("callbackA", CallbackTemplate<0>); + std::shared_ptr<Core::Timing::EventType> cb_b = + Core::Timing::CreateEvent("callbackB", CallbackTemplate<1>); // Enter slice 0 core_timing.ResetRun(); @@ -145,60 +154,3 @@ TEST_CASE("Core::Timing[PredictableLateness]", "[core]") { AdvanceAndCheck(core_timing, 0, 0, 10, -10); // (100 - 10) AdvanceAndCheck(core_timing, 1, 1, 50, -50); } - -namespace ChainSchedulingTest { -static int reschedules = 0; - -static void RescheduleCallback(Core::Timing::CoreTiming& core_timing, u64 userdata, - s64 cycles_late) { - --reschedules; - REQUIRE(reschedules >= 0); - REQUIRE(lateness == cycles_late); - - if (reschedules > 0) { - core_timing.ScheduleEvent(1000, reinterpret_cast<Core::Timing::EventType*>(userdata), - userdata); - } -} -} // namespace ChainSchedulingTest - -TEST_CASE("CoreTiming[ChainScheduling]", "[core]") { - using namespace ChainSchedulingTest; - - ScopeInit guard; - auto& core_timing = guard.core_timing; - - Core::Timing::EventType* cb_a = core_timing.RegisterEvent("callbackA", CallbackTemplate<0>); - Core::Timing::EventType* cb_b = core_timing.RegisterEvent("callbackB", CallbackTemplate<1>); - Core::Timing::EventType* cb_c = core_timing.RegisterEvent("callbackC", CallbackTemplate<2>); - Core::Timing::EventType* cb_rs = core_timing.RegisterEvent( - "callbackReschedule", [&core_timing](u64 userdata, s64 cycles_late) { - RescheduleCallback(core_timing, userdata, cycles_late); - }); - - // Enter slice 0 - core_timing.ResetRun(); - - core_timing.ScheduleEvent(800, cb_a, CB_IDS[0]); - core_timing.ScheduleEvent(1000, cb_b, CB_IDS[1]); - core_timing.ScheduleEvent(2200, cb_c, CB_IDS[2]); - core_timing.ScheduleEvent(1000, cb_rs, reinterpret_cast<u64>(cb_rs)); - REQUIRE(800 == core_timing.GetDowncount()); - - reschedules = 3; - AdvanceAndCheck(core_timing, 0, 0); // cb_a - AdvanceAndCheck(core_timing, 1, 1); // cb_b, cb_rs - REQUIRE(2 == reschedules); - - core_timing.AddTicks(core_timing.GetDowncount()); - core_timing.Advance(); // cb_rs - core_timing.SwitchContext(3); - REQUIRE(1 == reschedules); - REQUIRE(200 == core_timing.GetDowncount()); - - AdvanceAndCheck(core_timing, 2, 3); // cb_c - - core_timing.AddTicks(core_timing.GetDowncount()); - core_timing.Advance(); // cb_rs - REQUIRE(0 == reschedules); -} diff --git a/src/video_core/CMakeLists.txt b/src/video_core/CMakeLists.txt index 45d8eaf23..3b20c7d34 100644 --- a/src/video_core/CMakeLists.txt +++ b/src/video_core/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(video_core STATIC engines/maxwell_dma.h engines/shader_bytecode.h engines/shader_header.h + engines/shader_type.h gpu.cpp gpu.h gpu_asynch.cpp @@ -127,6 +128,8 @@ add_library(video_core STATIC shader/track.cpp surface.cpp surface.h + texture_cache/format_lookup_table.cpp + texture_cache/format_lookup_table.h texture_cache/surface_base.cpp texture_cache/surface_base.h texture_cache/surface_params.cpp diff --git a/src/video_core/buffer_cache/buffer_cache.h b/src/video_core/buffer_cache/buffer_cache.h index 4408b5001..0510ed777 100644 --- a/src/video_core/buffer_cache/buffer_cache.h +++ b/src/video_core/buffer_cache/buffer_cache.h @@ -427,8 +427,8 @@ private: VideoCore::RasterizerInterface& rasterizer; Core::System& system; - std::unique_ptr<StreamBuffer> stream_buffer; + std::unique_ptr<StreamBuffer> stream_buffer; TBufferType stream_buffer_handle{}; bool invalidated = false; @@ -440,18 +440,18 @@ private: using IntervalSet = boost::icl::interval_set<CacheAddr>; using IntervalCache = boost::icl::interval_map<CacheAddr, MapInterval>; using IntervalType = typename IntervalCache::interval_type; - IntervalCache mapped_addresses{}; + IntervalCache mapped_addresses; - static constexpr u64 write_page_bit{11}; - std::unordered_map<u64, u32> written_pages{}; + static constexpr u64 write_page_bit = 11; + std::unordered_map<u64, u32> written_pages; - static constexpr u64 block_page_bits{21}; - static constexpr u64 block_page_size{1 << block_page_bits}; - std::unordered_map<u64, TBuffer> blocks{}; + static constexpr u64 block_page_bits = 21; + static constexpr u64 block_page_size = 1ULL << block_page_bits; + std::unordered_map<u64, TBuffer> blocks; - std::list<TBuffer> pending_destruction{}; - u64 epoch{}; - u64 modified_ticks{}; + std::list<TBuffer> pending_destruction; + u64 epoch = 0; + u64 modified_ticks = 0; std::recursive_mutex mutex; }; diff --git a/src/video_core/engines/const_buffer_engine_interface.h b/src/video_core/engines/const_buffer_engine_interface.h index ac27b6cbe..44b8b8d22 100644 --- a/src/video_core/engines/const_buffer_engine_interface.h +++ b/src/video_core/engines/const_buffer_engine_interface.h @@ -8,19 +8,11 @@ #include "common/bit_field.h" #include "common/common_types.h" #include "video_core/engines/shader_bytecode.h" +#include "video_core/engines/shader_type.h" #include "video_core/textures/texture.h" namespace Tegra::Engines { -enum class ShaderType : u32 { - Vertex = 0, - TesselationControl = 1, - TesselationEval = 2, - Geometry = 3, - Fragment = 4, - Compute = 5, -}; - struct SamplerDescriptor { union { BitField<0, 20, Tegra::Shader::TextureType> texture_type; diff --git a/src/video_core/engines/kepler_compute.cpp b/src/video_core/engines/kepler_compute.cpp index 3a39aeabe..110406f2f 100644 --- a/src/video_core/engines/kepler_compute.cpp +++ b/src/video_core/engines/kepler_compute.cpp @@ -8,6 +8,7 @@ #include "core/core.h" #include "video_core/engines/kepler_compute.h" #include "video_core/engines/maxwell_3d.h" +#include "video_core/engines/shader_type.h" #include "video_core/memory_manager.h" #include "video_core/rasterizer_interface.h" #include "video_core/renderer_base.h" diff --git a/src/video_core/engines/kepler_compute.h b/src/video_core/engines/kepler_compute.h index 5259d92bd..4ef3e0613 100644 --- a/src/video_core/engines/kepler_compute.h +++ b/src/video_core/engines/kepler_compute.h @@ -12,6 +12,7 @@ #include "common/common_types.h" #include "video_core/engines/const_buffer_engine_interface.h" #include "video_core/engines/engine_upload.h" +#include "video_core/engines/shader_type.h" #include "video_core/gpu.h" #include "video_core/textures/texture.h" @@ -140,7 +141,7 @@ public: INSERT_PADDING_WORDS(0x3); - BitField<0, 16, u32> shared_alloc; + BitField<0, 18, u32> shared_alloc; BitField<16, 16, u32> block_dim_x; union { @@ -178,7 +179,12 @@ public: BitField<24, 5, u32> gpr_alloc; }; - INSERT_PADDING_WORDS(0x11); + union { + BitField<0, 20, u32> local_crs_alloc; + BitField<24, 5, u32> sass_version; + }; + + INSERT_PADDING_WORDS(0x10); } launch_description{}; struct { diff --git a/src/video_core/engines/maxwell_3d.cpp b/src/video_core/engines/maxwell_3d.cpp index 42ce49a4d..15a7a9d6a 100644 --- a/src/video_core/engines/maxwell_3d.cpp +++ b/src/video_core/engines/maxwell_3d.cpp @@ -9,6 +9,7 @@ #include "core/core_timing.h" #include "video_core/debug_utils/debug_utils.h" #include "video_core/engines/maxwell_3d.h" +#include "video_core/engines/shader_type.h" #include "video_core/memory_manager.h" #include "video_core/rasterizer_interface.h" #include "video_core/textures/texture.h" @@ -368,24 +369,24 @@ void Maxwell3D::CallMethod(const GPU::MethodCall& method_call) { StartCBData(method); break; } - case MAXWELL3D_REG_INDEX(cb_bind[0].raw_config): { - ProcessCBBind(Regs::ShaderStage::Vertex); + case MAXWELL3D_REG_INDEX(cb_bind[0]): { + ProcessCBBind(0); break; } - case MAXWELL3D_REG_INDEX(cb_bind[1].raw_config): { - ProcessCBBind(Regs::ShaderStage::TesselationControl); + case MAXWELL3D_REG_INDEX(cb_bind[1]): { + ProcessCBBind(1); break; } - case MAXWELL3D_REG_INDEX(cb_bind[2].raw_config): { - ProcessCBBind(Regs::ShaderStage::TesselationEval); + case MAXWELL3D_REG_INDEX(cb_bind[2]): { + ProcessCBBind(2); break; } - case MAXWELL3D_REG_INDEX(cb_bind[3].raw_config): { - ProcessCBBind(Regs::ShaderStage::Geometry); + case MAXWELL3D_REG_INDEX(cb_bind[3]): { + ProcessCBBind(3); break; } - case MAXWELL3D_REG_INDEX(cb_bind[4].raw_config): { - ProcessCBBind(Regs::ShaderStage::Fragment); + case MAXWELL3D_REG_INDEX(cb_bind[4]): { + ProcessCBBind(4); break; } case MAXWELL3D_REG_INDEX(draw.vertex_end_gl): { @@ -687,10 +688,10 @@ void Maxwell3D::DrawArrays() { } } -void Maxwell3D::ProcessCBBind(Regs::ShaderStage stage) { +void Maxwell3D::ProcessCBBind(std::size_t stage_index) { // Bind the buffer currently in CB_ADDRESS to the specified index in the desired shader stage. - auto& shader = state.shader_stages[static_cast<std::size_t>(stage)]; - auto& bind_data = regs.cb_bind[static_cast<std::size_t>(stage)]; + auto& shader = state.shader_stages[stage_index]; + auto& bind_data = regs.cb_bind[stage_index]; ASSERT(bind_data.index < Regs::MaxConstBuffers); auto& buffer = shader.const_buffers[bind_data.index]; @@ -742,14 +743,6 @@ Texture::TICEntry Maxwell3D::GetTICEntry(u32 tic_index) const { Texture::TICEntry tic_entry; memory_manager.ReadBlockUnsafe(tic_address_gpu, &tic_entry, sizeof(Texture::TICEntry)); - [[maybe_unused]] const auto r_type{tic_entry.r_type.Value()}; - [[maybe_unused]] const auto g_type{tic_entry.g_type.Value()}; - [[maybe_unused]] const auto b_type{tic_entry.b_type.Value()}; - [[maybe_unused]] const auto a_type{tic_entry.a_type.Value()}; - - // TODO(Subv): Different data types for separate components are not supported - DEBUG_ASSERT(r_type == g_type && r_type == b_type && r_type == a_type); - return tic_entry; } @@ -765,9 +758,9 @@ Texture::FullTextureInfo Maxwell3D::GetTextureInfo(Texture::TextureHandle tex_ha return Texture::FullTextureInfo{GetTICEntry(tex_handle.tic_id), GetTSCEntry(tex_handle.tsc_id)}; } -Texture::FullTextureInfo Maxwell3D::GetStageTexture(Regs::ShaderStage stage, - std::size_t offset) const { - const auto& shader = state.shader_stages[static_cast<std::size_t>(stage)]; +Texture::FullTextureInfo Maxwell3D::GetStageTexture(ShaderType stage, std::size_t offset) const { + const auto stage_index = static_cast<std::size_t>(stage); + const auto& shader = state.shader_stages[stage_index]; const auto& tex_info_buffer = shader.const_buffers[regs.tex_cb_index]; ASSERT(tex_info_buffer.enabled && tex_info_buffer.address != 0); diff --git a/src/video_core/engines/maxwell_3d.h b/src/video_core/engines/maxwell_3d.h index 1aa7c274f..4cb7339b5 100644 --- a/src/video_core/engines/maxwell_3d.h +++ b/src/video_core/engines/maxwell_3d.h @@ -18,6 +18,7 @@ #include "video_core/engines/const_buffer_engine_interface.h" #include "video_core/engines/const_buffer_info.h" #include "video_core/engines/engine_upload.h" +#include "video_core/engines/shader_type.h" #include "video_core/gpu.h" #include "video_core/macro_interpreter.h" #include "video_core/textures/texture.h" @@ -62,7 +63,6 @@ public: static constexpr std::size_t NumVertexArrays = 32; static constexpr std::size_t NumVertexAttributes = 32; static constexpr std::size_t NumVaryings = 31; - static constexpr std::size_t NumTextureSamplers = 32; static constexpr std::size_t NumImages = 8; // TODO(Rodrigo): Investigate this number static constexpr std::size_t NumClipDistances = 8; static constexpr std::size_t MaxShaderProgram = 6; @@ -130,14 +130,6 @@ public: Fragment = 5, }; - enum class ShaderStage : u32 { - Vertex = 0, - TesselationControl = 1, - TesselationEval = 2, - Geometry = 3, - Fragment = 4, - }; - struct VertexAttribute { enum class Size : u32 { Invalid = 0x0, @@ -677,8 +669,8 @@ public: INSERT_UNION_PADDING_WORDS(0x15); s32 stencil_back_func_ref; - u32 stencil_back_mask; u32 stencil_back_func_mask; + u32 stencil_back_mask; INSERT_UNION_PADDING_WORDS(0xC); @@ -1254,7 +1246,7 @@ public: Texture::FullTextureInfo GetTextureInfo(Texture::TextureHandle tex_handle) const; /// Returns the texture information for a specific texture in a specific shader stage. - Texture::FullTextureInfo GetStageTexture(Regs::ShaderStage stage, std::size_t offset) const; + Texture::FullTextureInfo GetStageTexture(ShaderType stage, std::size_t offset) const; u32 AccessConstBuffer32(ShaderType stage, u64 const_buffer, u64 offset) const override; @@ -1376,7 +1368,7 @@ private: void FinishCBData(); /// Handles a write to the CB_BIND register. - void ProcessCBBind(Regs::ShaderStage stage); + void ProcessCBBind(std::size_t stage_index); /// Handles a write to the VERTEX_END_GL register, triggering a draw. void DrawArrays(); @@ -1407,8 +1399,8 @@ ASSERT_REG_POSITION(polygon_offset_line_enable, 0x371); ASSERT_REG_POSITION(polygon_offset_fill_enable, 0x372); ASSERT_REG_POSITION(scissor_test, 0x380); ASSERT_REG_POSITION(stencil_back_func_ref, 0x3D5); -ASSERT_REG_POSITION(stencil_back_mask, 0x3D6); -ASSERT_REG_POSITION(stencil_back_func_mask, 0x3D7); +ASSERT_REG_POSITION(stencil_back_func_mask, 0x3D6); +ASSERT_REG_POSITION(stencil_back_mask, 0x3D7); ASSERT_REG_POSITION(color_mask_common, 0x3E4); ASSERT_REG_POSITION(rt_separate_frag_data, 0x3EB); ASSERT_REG_POSITION(depth_bounds, 0x3EC); diff --git a/src/video_core/engines/shader_type.h b/src/video_core/engines/shader_type.h new file mode 100644 index 000000000..49ce5cde5 --- /dev/null +++ b/src/video_core/engines/shader_type.h @@ -0,0 +1,21 @@ +// Copyright 2019 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include "common/common_types.h" + +namespace Tegra::Engines { + +enum class ShaderType : u32 { + Vertex = 0, + TesselationControl = 1, + TesselationEval = 2, + Geometry = 3, + Fragment = 4, + Compute = 5, +}; +static constexpr std::size_t MaxShaderTypes = 6; + +} // namespace Tegra::Engines diff --git a/src/video_core/gpu_thread.cpp b/src/video_core/gpu_thread.cpp index 758a37f14..2cdf1aa7f 100644 --- a/src/video_core/gpu_thread.cpp +++ b/src/video_core/gpu_thread.cpp @@ -31,24 +31,22 @@ static void RunThread(VideoCore::RendererBase& renderer, Tegra::DmaPusher& dma_p CommandDataContainer next; while (state.is_running) { - while (!state.queue.Empty()) { - state.queue.Pop(next); - if (const auto submit_list = std::get_if<SubmitListCommand>(&next.data)) { - dma_pusher.Push(std::move(submit_list->entries)); - dma_pusher.DispatchCalls(); - } else if (const auto data = std::get_if<SwapBuffersCommand>(&next.data)) { - renderer.SwapBuffers(data->framebuffer ? &*data->framebuffer : nullptr); - } else if (const auto data = std::get_if<FlushRegionCommand>(&next.data)) { - renderer.Rasterizer().FlushRegion(data->addr, data->size); - } else if (const auto data = std::get_if<InvalidateRegionCommand>(&next.data)) { - renderer.Rasterizer().InvalidateRegion(data->addr, data->size); - } else if (std::holds_alternative<EndProcessingCommand>(next.data)) { - return; - } else { - UNREACHABLE(); - } - state.signaled_fence.store(next.fence); + next = state.queue.PopWait(); + if (const auto submit_list = std::get_if<SubmitListCommand>(&next.data)) { + dma_pusher.Push(std::move(submit_list->entries)); + dma_pusher.DispatchCalls(); + } else if (const auto data = std::get_if<SwapBuffersCommand>(&next.data)) { + renderer.SwapBuffers(data->framebuffer ? &*data->framebuffer : nullptr); + } else if (const auto data = std::get_if<FlushRegionCommand>(&next.data)) { + renderer.Rasterizer().FlushRegion(data->addr, data->size); + } else if (const auto data = std::get_if<InvalidateRegionCommand>(&next.data)) { + renderer.Rasterizer().InvalidateRegion(data->addr, data->size); + } else if (std::holds_alternative<EndProcessingCommand>(next.data)) { + return; + } else { + UNREACHABLE(); } + state.signaled_fence.store(next.fence); } } @@ -73,8 +71,7 @@ void ThreadManager::SubmitList(Tegra::CommandList&& entries) { } void ThreadManager::SwapBuffers(const Tegra::FramebufferConfig* framebuffer) { - PushCommand(SwapBuffersCommand(framebuffer ? *framebuffer - : std::optional<const Tegra::FramebufferConfig>{})); + PushCommand(SwapBuffersCommand(framebuffer ? std::make_optional(*framebuffer) : std::nullopt)); } void ThreadManager::FlushRegion(CacheAddr addr, u64 size) { diff --git a/src/video_core/memory_manager.cpp b/src/video_core/memory_manager.cpp index bffae940c..11848fbce 100644 --- a/src/video_core/memory_manager.cpp +++ b/src/video_core/memory_manager.cpp @@ -52,7 +52,7 @@ GPUVAddr MemoryManager::MapBufferEx(VAddr cpu_addr, u64 size) { const u64 aligned_size{Common::AlignUp(size, page_size)}; const GPUVAddr gpu_addr{FindFreeRegion(address_space_base, aligned_size)}; - MapBackingMemory(gpu_addr, Memory::GetPointer(cpu_addr), aligned_size, cpu_addr); + MapBackingMemory(gpu_addr, system.Memory().GetPointer(cpu_addr), aligned_size, cpu_addr); ASSERT(system.CurrentProcess() ->VMManager() .SetMemoryAttribute(cpu_addr, size, Kernel::MemoryAttribute::DeviceMapped, @@ -67,7 +67,7 @@ GPUVAddr MemoryManager::MapBufferEx(VAddr cpu_addr, GPUVAddr gpu_addr, u64 size) const u64 aligned_size{Common::AlignUp(size, page_size)}; - MapBackingMemory(gpu_addr, Memory::GetPointer(cpu_addr), aligned_size, cpu_addr); + MapBackingMemory(gpu_addr, system.Memory().GetPointer(cpu_addr), aligned_size, cpu_addr); ASSERT(system.CurrentProcess() ->VMManager() .SetMemoryAttribute(cpu_addr, size, Kernel::MemoryAttribute::DeviceMapped, diff --git a/src/video_core/rasterizer_accelerated.cpp b/src/video_core/rasterizer_accelerated.cpp index b230dcc18..fc6ecb899 100644 --- a/src/video_core/rasterizer_accelerated.cpp +++ b/src/video_core/rasterizer_accelerated.cpp @@ -22,7 +22,8 @@ constexpr auto RangeFromInterval(Map& map, const Interval& interval) { } // Anonymous namespace -RasterizerAccelerated::RasterizerAccelerated() = default; +RasterizerAccelerated::RasterizerAccelerated(Memory::Memory& cpu_memory_) + : cpu_memory{cpu_memory_} {} RasterizerAccelerated::~RasterizerAccelerated() = default; @@ -47,9 +48,9 @@ void RasterizerAccelerated::UpdatePagesCachedCount(VAddr addr, u64 size, int del const u64 interval_size = interval_end_addr - interval_start_addr; if (delta > 0 && count == delta) { - Memory::RasterizerMarkRegionCached(interval_start_addr, interval_size, true); + cpu_memory.RasterizerMarkRegionCached(interval_start_addr, interval_size, true); } else if (delta < 0 && count == -delta) { - Memory::RasterizerMarkRegionCached(interval_start_addr, interval_size, false); + cpu_memory.RasterizerMarkRegionCached(interval_start_addr, interval_size, false); } else { ASSERT(count >= 0); } diff --git a/src/video_core/rasterizer_accelerated.h b/src/video_core/rasterizer_accelerated.h index 8f7e3547e..315798e7c 100644 --- a/src/video_core/rasterizer_accelerated.h +++ b/src/video_core/rasterizer_accelerated.h @@ -11,12 +11,16 @@ #include "common/common_types.h" #include "video_core/rasterizer_interface.h" +namespace Memory { +class Memory; +} + namespace VideoCore { /// Implements the shared part in GPU accelerated rasterizers in RasterizerInterface. class RasterizerAccelerated : public RasterizerInterface { public: - explicit RasterizerAccelerated(); + explicit RasterizerAccelerated(Memory::Memory& cpu_memory_); ~RasterizerAccelerated() override; void UpdatePagesCachedCount(VAddr addr, u64 size, int delta) override; @@ -24,8 +28,9 @@ public: private: using CachedPageMap = boost::icl::interval_map<u64, int>; CachedPageMap cached_pages; - std::mutex pages_mutex; + + Memory::Memory& cpu_memory; }; } // namespace VideoCore diff --git a/src/video_core/renderer_opengl/gl_device.cpp b/src/video_core/renderer_opengl/gl_device.cpp index b30d5be74..413d8546b 100644 --- a/src/video_core/renderer_opengl/gl_device.cpp +++ b/src/video_core/renderer_opengl/gl_device.cpp @@ -5,7 +5,9 @@ #include <algorithm> #include <array> #include <cstddef> +#include <optional> #include <vector> + #include <glad/glad.h> #include "common/logging/log.h" @@ -17,6 +19,30 @@ namespace OpenGL { namespace { +// One uniform block is reserved for emulation purposes +constexpr u32 ReservedUniformBlocks = 1; + +constexpr u32 NumStages = 5; + +constexpr std::array LimitUBOs = {GL_MAX_VERTEX_UNIFORM_BLOCKS, GL_MAX_TESS_CONTROL_UNIFORM_BLOCKS, + GL_MAX_TESS_EVALUATION_UNIFORM_BLOCKS, + GL_MAX_GEOMETRY_UNIFORM_BLOCKS, GL_MAX_FRAGMENT_UNIFORM_BLOCKS}; + +constexpr std::array LimitSSBOs = { + GL_MAX_VERTEX_SHADER_STORAGE_BLOCKS, GL_MAX_TESS_CONTROL_SHADER_STORAGE_BLOCKS, + GL_MAX_TESS_EVALUATION_SHADER_STORAGE_BLOCKS, GL_MAX_GEOMETRY_SHADER_STORAGE_BLOCKS, + GL_MAX_FRAGMENT_SHADER_STORAGE_BLOCKS}; + +constexpr std::array LimitSamplers = { + GL_MAX_VERTEX_TEXTURE_IMAGE_UNITS, GL_MAX_TESS_CONTROL_TEXTURE_IMAGE_UNITS, + GL_MAX_TESS_EVALUATION_TEXTURE_IMAGE_UNITS, GL_MAX_GEOMETRY_TEXTURE_IMAGE_UNITS, + GL_MAX_TEXTURE_IMAGE_UNITS}; + +constexpr std::array LimitImages = {GL_MAX_VERTEX_IMAGE_UNIFORMS, + GL_MAX_TESS_CONTROL_IMAGE_UNIFORMS, + GL_MAX_TESS_EVALUATION_IMAGE_UNIFORMS, + GL_MAX_GEOMETRY_IMAGE_UNIFORMS, GL_MAX_FRAGMENT_IMAGE_UNIFORMS}; + template <typename T> T GetInteger(GLenum pname) { GLint temporary; @@ -48,13 +74,71 @@ bool HasExtension(const std::vector<std::string_view>& images, std::string_view return std::find(images.begin(), images.end(), extension) != images.end(); } +u32 Extract(u32& base, u32& num, u32 amount, std::optional<GLenum> limit = {}) { + ASSERT(num >= amount); + if (limit) { + amount = std::min(amount, GetInteger<u32>(*limit)); + } + num -= amount; + return std::exchange(base, base + amount); +} + +std::array<Device::BaseBindings, Tegra::Engines::MaxShaderTypes> BuildBaseBindings() noexcept { + std::array<Device::BaseBindings, Tegra::Engines::MaxShaderTypes> bindings; + + static std::array<std::size_t, 5> stage_swizzle = {0, 1, 2, 3, 4}; + const u32 total_ubos = GetInteger<u32>(GL_MAX_UNIFORM_BUFFER_BINDINGS); + const u32 total_ssbos = GetInteger<u32>(GL_MAX_SHADER_STORAGE_BUFFER_BINDINGS); + const u32 total_samplers = GetInteger<u32>(GL_MAX_COMBINED_TEXTURE_IMAGE_UNITS); + + u32 num_ubos = total_ubos - ReservedUniformBlocks; + u32 num_ssbos = total_ssbos; + u32 num_samplers = total_samplers; + + u32 base_ubo = ReservedUniformBlocks; + u32 base_ssbo = 0; + u32 base_samplers = 0; + + for (std::size_t i = 0; i < NumStages; ++i) { + const std::size_t stage = stage_swizzle[i]; + bindings[stage] = { + Extract(base_ubo, num_ubos, total_ubos / NumStages, LimitUBOs[stage]), + Extract(base_ssbo, num_ssbos, total_ssbos / NumStages, LimitSSBOs[stage]), + Extract(base_samplers, num_samplers, total_samplers / NumStages, LimitSamplers[stage])}; + } + + u32 num_images = GetInteger<u32>(GL_MAX_IMAGE_UNITS); + u32 base_images = 0; + + // Reserve more image bindings on fragment and vertex stages. + bindings[4].image = + Extract(base_images, num_images, num_images / NumStages + 2, LimitImages[4]); + bindings[0].image = + Extract(base_images, num_images, num_images / NumStages + 1, LimitImages[0]); + + // Reserve the other image bindings. + const u32 total_extracted_images = num_images / (NumStages - 2); + for (std::size_t i = 2; i < NumStages; ++i) { + const std::size_t stage = stage_swizzle[i]; + bindings[stage].image = + Extract(base_images, num_images, total_extracted_images, LimitImages[stage]); + } + + // Compute doesn't care about any of this. + bindings[5] = {0, 0, 0, 0}; + + return bindings; +} + } // Anonymous namespace -Device::Device() { +Device::Device() : base_bindings{BuildBaseBindings()} { const std::string_view vendor = reinterpret_cast<const char*>(glGetString(GL_VENDOR)); const std::vector extensions = GetExtensions(); const bool is_nvidia = vendor == "NVIDIA Corporation"; + const bool is_amd = vendor == "ATI Technologies Inc."; + const bool is_intel = vendor == "Intel"; uniform_buffer_alignment = GetInteger<std::size_t>(GL_UNIFORM_BUFFER_OFFSET_ALIGNMENT); shader_storage_alignment = GetInteger<std::size_t>(GL_SHADER_STORAGE_BUFFER_OFFSET_ALIGNMENT); @@ -66,8 +150,9 @@ Device::Device() { has_vertex_viewport_layer = GLAD_GL_ARB_shader_viewport_layer_array; has_image_load_formatted = HasExtension(extensions, "GL_EXT_shader_image_load_formatted"); has_variable_aoffi = TestVariableAoffi(); - has_component_indexing_bug = TestComponentIndexingBug(); + has_component_indexing_bug = is_amd; has_precise_bug = TestPreciseBug(); + has_broken_compute = is_intel; has_fast_buffer_sub_data = is_nvidia; LOG_INFO(Render_OpenGL, "Renderer_VariableAOFFI: {}", has_variable_aoffi); @@ -85,6 +170,7 @@ Device::Device(std::nullptr_t) { has_image_load_formatted = true; has_variable_aoffi = true; has_component_indexing_bug = false; + has_broken_compute = false; has_precise_bug = false; } @@ -99,52 +185,6 @@ void main() { })"); } -bool Device::TestComponentIndexingBug() { - const GLchar* COMPONENT_TEST = R"(#version 430 core -layout (std430, binding = 0) buffer OutputBuffer { - uint output_value; -}; -layout (std140, binding = 0) uniform InputBuffer { - uvec4 input_value[4096]; -}; -layout (location = 0) uniform uint idx; -void main() { - output_value = input_value[idx >> 2][idx & 3]; -})"; - const GLuint shader{glCreateShaderProgramv(GL_VERTEX_SHADER, 1, &COMPONENT_TEST)}; - SCOPE_EXIT({ glDeleteProgram(shader); }); - glUseProgram(shader); - - OGLVertexArray vao; - vao.Create(); - glBindVertexArray(vao.handle); - - constexpr std::array<GLuint, 8> values{0, 0, 0, 0, 0x1236327, 0x985482, 0x872753, 0x2378432}; - OGLBuffer ubo; - ubo.Create(); - glNamedBufferData(ubo.handle, sizeof(values), values.data(), GL_STATIC_DRAW); - glBindBufferBase(GL_UNIFORM_BUFFER, 0, ubo.handle); - - OGLBuffer ssbo; - ssbo.Create(); - glNamedBufferStorage(ssbo.handle, sizeof(GLuint), nullptr, GL_CLIENT_STORAGE_BIT); - - for (GLuint index = 4; index < 8; ++index) { - glInvalidateBufferData(ssbo.handle); - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, ssbo.handle); - - glProgramUniform1ui(shader, 0, index); - glDrawArrays(GL_POINTS, 0, 1); - - GLuint result; - glGetNamedBufferSubData(ssbo.handle, 0, sizeof(result), &result); - if (result != values.at(index)) { - return true; - } - } - return false; -} - bool Device::TestPreciseBug() { return !TestProgram(R"(#version 430 core in vec3 coords; diff --git a/src/video_core/renderer_opengl/gl_device.h b/src/video_core/renderer_opengl/gl_device.h index 6c86fe207..d73b099d0 100644 --- a/src/video_core/renderer_opengl/gl_device.h +++ b/src/video_core/renderer_opengl/gl_device.h @@ -6,14 +6,32 @@ #include <cstddef> #include "common/common_types.h" +#include "video_core/engines/shader_type.h" namespace OpenGL { -class Device { +static constexpr u32 EmulationUniformBlockBinding = 0; + +class Device final { public: + struct BaseBindings final { + u32 uniform_buffer{}; + u32 shader_storage_buffer{}; + u32 sampler{}; + u32 image{}; + }; + explicit Device(); explicit Device(std::nullptr_t); + const BaseBindings& GetBaseBindings(std::size_t stage_index) const noexcept { + return base_bindings[stage_index]; + } + + const BaseBindings& GetBaseBindings(Tegra::Engines::ShaderType shader_type) const noexcept { + return GetBaseBindings(static_cast<std::size_t>(shader_type)); + } + std::size_t GetUniformBufferAlignment() const { return uniform_buffer_alignment; } @@ -58,15 +76,19 @@ public: return has_precise_bug; } + bool HasBrokenCompute() const { + return has_broken_compute; + } + bool HasFastBufferSubData() const { return has_fast_buffer_sub_data; } private: static bool TestVariableAoffi(); - static bool TestComponentIndexingBug(); static bool TestPreciseBug(); + std::array<BaseBindings, Tegra::Engines::MaxShaderTypes> base_bindings; std::size_t uniform_buffer_alignment{}; std::size_t shader_storage_alignment{}; u32 max_vertex_attributes{}; @@ -78,6 +100,7 @@ private: bool has_variable_aoffi{}; bool has_component_indexing_bug{}; bool has_precise_bug{}; + bool has_broken_compute{}; bool has_fast_buffer_sub_data{}; }; diff --git a/src/video_core/renderer_opengl/gl_framebuffer_cache.cpp b/src/video_core/renderer_opengl/gl_framebuffer_cache.cpp index a5d69d78d..874ed3c6e 100644 --- a/src/video_core/renderer_opengl/gl_framebuffer_cache.cpp +++ b/src/video_core/renderer_opengl/gl_framebuffer_cache.cpp @@ -3,9 +3,12 @@ // Refer to the license.txt file included. #include <tuple> +#include <unordered_map> +#include <utility> -#include "common/cityhash.h" -#include "common/scope_exit.h" +#include <glad/glad.h> + +#include "common/common_types.h" #include "video_core/engines/maxwell_3d.h" #include "video_core/renderer_opengl/gl_framebuffer_cache.h" #include "video_core/renderer_opengl/gl_state.h" @@ -13,6 +16,7 @@ namespace OpenGL { using Maxwell = Tegra::Engines::Maxwell3D::Regs; +using VideoCore::Surface::SurfaceType; FramebufferCacheOpenGL::FramebufferCacheOpenGL() = default; @@ -35,36 +39,49 @@ OGLFramebuffer FramebufferCacheOpenGL::CreateFramebuffer(const FramebufferCacheK local_state.draw.draw_framebuffer = framebuffer.handle; local_state.ApplyFramebufferState(); + if (key.zeta) { + const bool stencil = key.zeta->GetSurfaceParams().type == SurfaceType::DepthStencil; + const GLenum attach_target = stencil ? GL_DEPTH_STENCIL_ATTACHMENT : GL_DEPTH_ATTACHMENT; + key.zeta->Attach(attach_target, GL_DRAW_FRAMEBUFFER); + } + + std::size_t num_buffers = 0; + std::array<GLenum, Maxwell::NumRenderTargets> targets; + for (std::size_t index = 0; index < Maxwell::NumRenderTargets; ++index) { - if (key.colors[index]) { - key.colors[index]->Attach(GL_COLOR_ATTACHMENT0 + static_cast<GLenum>(index), - GL_DRAW_FRAMEBUFFER); + if (!key.colors[index]) { + targets[index] = GL_NONE; + continue; } + const GLenum attach_target = GL_COLOR_ATTACHMENT0 + static_cast<GLenum>(index); + key.colors[index]->Attach(attach_target, GL_DRAW_FRAMEBUFFER); + + const u32 attachment = (key.color_attachments >> (BitsPerAttachment * index)) & 0b1111; + targets[index] = GL_COLOR_ATTACHMENT0 + attachment; + num_buffers = index + 1; } - if (key.colors_count) { - glDrawBuffers(key.colors_count, key.color_attachments.data()); + + if (num_buffers > 0) { + glDrawBuffers(static_cast<GLsizei>(num_buffers), std::data(targets)); } else { glDrawBuffer(GL_NONE); } - if (key.zeta) { - key.zeta->Attach(key.stencil_enable ? GL_DEPTH_STENCIL_ATTACHMENT : GL_DEPTH_ATTACHMENT, - GL_DRAW_FRAMEBUFFER); - } - return framebuffer; } -std::size_t FramebufferCacheKey::Hash() const { - static_assert(sizeof(*this) % sizeof(u64) == 0, "Unaligned struct"); - return static_cast<std::size_t>( - Common::CityHash64(reinterpret_cast<const char*>(this), sizeof(*this))); +std::size_t FramebufferCacheKey::Hash() const noexcept { + std::size_t hash = std::hash<View>{}(zeta); + for (const auto& color : colors) { + hash ^= std::hash<View>{}(color); + } + hash ^= static_cast<std::size_t>(color_attachments) << 16; + return hash; } -bool FramebufferCacheKey::operator==(const FramebufferCacheKey& rhs) const { - return std::tie(stencil_enable, colors_count, color_attachments, colors, zeta) == - std::tie(rhs.stencil_enable, rhs.colors_count, rhs.color_attachments, rhs.colors, - rhs.zeta); +bool FramebufferCacheKey::operator==(const FramebufferCacheKey& rhs) const noexcept { + return std::tie(colors, zeta, color_attachments) == + std::tie(rhs.colors, rhs.zeta, rhs.color_attachments); } } // namespace OpenGL diff --git a/src/video_core/renderer_opengl/gl_framebuffer_cache.h b/src/video_core/renderer_opengl/gl_framebuffer_cache.h index 424344c48..02ec80ae9 100644 --- a/src/video_core/renderer_opengl/gl_framebuffer_cache.h +++ b/src/video_core/renderer_opengl/gl_framebuffer_cache.h @@ -18,21 +18,24 @@ namespace OpenGL { -struct alignas(sizeof(u64)) FramebufferCacheKey { - bool stencil_enable = false; - u16 colors_count = 0; +constexpr std::size_t BitsPerAttachment = 4; - std::array<GLenum, Tegra::Engines::Maxwell3D::Regs::NumRenderTargets> color_attachments{}; - std::array<View, Tegra::Engines::Maxwell3D::Regs::NumRenderTargets> colors; +struct FramebufferCacheKey { View zeta; + std::array<View, Tegra::Engines::Maxwell3D::Regs::NumRenderTargets> colors; + u32 color_attachments = 0; - std::size_t Hash() const; + std::size_t Hash() const noexcept; - bool operator==(const FramebufferCacheKey& rhs) const; + bool operator==(const FramebufferCacheKey& rhs) const noexcept; - bool operator!=(const FramebufferCacheKey& rhs) const { + bool operator!=(const FramebufferCacheKey& rhs) const noexcept { return !operator==(rhs); } + + void SetAttachment(std::size_t index, u32 attachment) { + color_attachments |= attachment << (BitsPerAttachment * index); + } }; } // namespace OpenGL diff --git a/src/video_core/renderer_opengl/gl_rasterizer.cpp b/src/video_core/renderer_opengl/gl_rasterizer.cpp index e43ba9d6b..9eef7fcd2 100644 --- a/src/video_core/renderer_opengl/gl_rasterizer.cpp +++ b/src/video_core/renderer_opengl/gl_rasterizer.cpp @@ -19,9 +19,11 @@ #include "common/scope_exit.h" #include "core/core.h" #include "core/hle/kernel/process.h" +#include "core/memory.h" #include "core/settings.h" #include "video_core/engines/kepler_compute.h" #include "video_core/engines/maxwell_3d.h" +#include "video_core/engines/shader_type.h" #include "video_core/memory_manager.h" #include "video_core/renderer_opengl/gl_rasterizer.h" #include "video_core/renderer_opengl/gl_shader_cache.h" @@ -49,8 +51,25 @@ MICROPROFILE_DEFINE(OpenGL_Blits, "OpenGL", "Blits", MP_RGB(128, 128, 192)); MICROPROFILE_DEFINE(OpenGL_CacheManagement, "OpenGL", "Cache Mgmt", MP_RGB(100, 255, 100)); MICROPROFILE_DEFINE(OpenGL_PrimitiveAssembly, "OpenGL", "Prim Asmbl", MP_RGB(255, 100, 100)); -static std::size_t GetConstBufferSize(const Tegra::Engines::ConstBufferInfo& buffer, - const GLShader::ConstBufferEntry& entry) { +namespace { + +template <typename Engine, typename Entry> +Tegra::Texture::FullTextureInfo GetTextureInfo(const Engine& engine, const Entry& entry, + Tegra::Engines::ShaderType shader_type) { + if (entry.IsBindless()) { + const Tegra::Texture::TextureHandle tex_handle = + engine.AccessConstBuffer32(shader_type, entry.GetBuffer(), entry.GetOffset()); + return engine.GetTextureInfo(tex_handle); + } + if constexpr (std::is_same_v<Engine, Tegra::Engines::Maxwell3D>) { + return engine.GetStageTexture(shader_type, entry.GetOffset()); + } else { + return engine.GetTexture(entry.GetOffset()); + } +} + +std::size_t GetConstBufferSize(const Tegra::Engines::ConstBufferInfo& buffer, + const GLShader::ConstBufferEntry& entry) { if (!entry.IsIndirect()) { return entry.GetSize(); } @@ -64,14 +83,16 @@ static std::size_t GetConstBufferSize(const Tegra::Engines::ConstBufferInfo& buf return buffer.size; } +} // Anonymous namespace + RasterizerOpenGL::RasterizerOpenGL(Core::System& system, Core::Frontend::EmuWindow& emu_window, ScreenInfo& info) - : texture_cache{system, *this, device}, shader_cache{*this, system, emu_window, device}, - system{system}, screen_info{info}, buffer_cache{*this, system, device, STREAM_BUFFER_SIZE} { + : RasterizerAccelerated{system.Memory()}, texture_cache{system, *this, device}, + shader_cache{*this, system, emu_window, device}, system{system}, screen_info{info}, + buffer_cache{*this, system, device, STREAM_BUFFER_SIZE} { shader_program_manager = std::make_unique<GLShader::ProgramManager>(); state.draw.shader_program = 0; state.Apply(); - clear_framebuffer.Create(); LOG_DEBUG(Render_OpenGL, "Sync fixed function OpenGL state here"); CheckExtensions(); @@ -238,12 +259,11 @@ void RasterizerOpenGL::SetupShaders(GLenum primitive_mode) { MICROPROFILE_SCOPE(OpenGL_Shader); auto& gpu = system.GPU().Maxwell3D(); - BaseBindings base_bindings; std::array<bool, Maxwell::NumClipDistances> clip_distances{}; for (std::size_t index = 0; index < Maxwell::MaxShaderProgram; ++index) { const auto& shader_config = gpu.regs.shader_config[index]; - const Maxwell::ShaderProgram program{static_cast<Maxwell::ShaderProgram>(index)}; + const auto program{static_cast<Maxwell::ShaderProgram>(index)}; // Skip stages that are not enabled if (!gpu.regs.IsShaderConfigEnabled(index)) { @@ -257,25 +277,17 @@ void RasterizerOpenGL::SetupShaders(GLenum primitive_mode) { continue; } - const std::size_t stage{index == 0 ? 0 : index - 1}; // Stage indices are 0 - 5 - - GLShader::MaxwellUniformData ubo{}; - ubo.SetFromRegs(gpu, stage); - const auto [buffer, offset] = - buffer_cache.UploadHostMemory(&ubo, sizeof(ubo), device.GetUniformBufferAlignment()); - - // Bind the emulation info buffer - bind_ubo_pushbuffer.Push(buffer, offset, static_cast<GLsizeiptr>(sizeof(ubo))); - Shader shader{shader_cache.GetStageProgram(program)}; - const auto stage_enum = static_cast<Maxwell::ShaderStage>(stage); - SetupDrawConstBuffers(stage_enum, shader); - SetupDrawGlobalMemory(stage_enum, shader); - const auto texture_buffer_usage{SetupDrawTextures(stage_enum, shader, base_bindings)}; + // Stage indices are 0 - 5 + const std::size_t stage = index == 0 ? 0 : index - 1; + SetupDrawConstBuffers(stage, shader); + SetupDrawGlobalMemory(stage, shader); + SetupDrawTextures(stage, shader); + SetupDrawImages(stage, shader); - const ProgramVariant variant{base_bindings, primitive_mode, texture_buffer_usage}; - const auto [program_handle, next_bindings] = shader->GetProgramHandle(variant); + const ProgramVariant variant(primitive_mode); + const auto program_handle = shader->GetHandle(variant); switch (program) { case Maxwell::ShaderProgram::VertexA: @@ -304,10 +316,8 @@ void RasterizerOpenGL::SetupShaders(GLenum primitive_mode) { // When VertexA is enabled, we have dual vertex shaders if (program == Maxwell::ShaderProgram::VertexA) { // VertexB was combined with VertexA, so we skip the VertexB iteration - index++; + ++index; } - - base_bindings = next_bindings; } SyncClipEnabled(clip_distances); @@ -362,78 +372,58 @@ void RasterizerOpenGL::ConfigureFramebuffers() { UNIMPLEMENTED_IF(regs.rt_separate_frag_data == 0); // Bind the framebuffer surfaces - FramebufferCacheKey fbkey; - for (std::size_t index = 0; index < Maxwell::NumRenderTargets; ++index) { + FramebufferCacheKey key; + const auto colors_count = static_cast<std::size_t>(regs.rt_control.count); + for (std::size_t index = 0; index < colors_count; ++index) { View color_surface{texture_cache.GetColorBufferSurface(index, true)}; - - if (color_surface) { - // Assume that a surface will be written to if it is used as a framebuffer, even - // if the shader doesn't actually write to it. - texture_cache.MarkColorBufferInUse(index); + if (!color_surface) { + continue; } + // Assume that a surface will be written to if it is used as a framebuffer, even + // if the shader doesn't actually write to it. + texture_cache.MarkColorBufferInUse(index); - fbkey.color_attachments[index] = GL_COLOR_ATTACHMENT0 + regs.rt_control.GetMap(index); - fbkey.colors[index] = std::move(color_surface); + key.SetAttachment(index, regs.rt_control.GetMap(index)); + key.colors[index] = std::move(color_surface); } - fbkey.colors_count = static_cast<u16>(regs.rt_control.count); if (depth_surface) { // Assume that a surface will be written to if it is used as a framebuffer, even if // the shader doesn't actually write to it. texture_cache.MarkDepthBufferInUse(); - - fbkey.stencil_enable = depth_surface->GetSurfaceParams().type == SurfaceType::DepthStencil; - fbkey.zeta = std::move(depth_surface); + key.zeta = std::move(depth_surface); } texture_cache.GuardRenderTargets(false); - state.draw.draw_framebuffer = framebuffer_cache.GetFramebuffer(fbkey); + state.draw.draw_framebuffer = framebuffer_cache.GetFramebuffer(key); SyncViewport(state); } void RasterizerOpenGL::ConfigureClearFramebuffer(OpenGLState& current_state, bool using_color_fb, bool using_depth_fb, bool using_stencil_fb) { + using VideoCore::Surface::SurfaceType; + auto& gpu = system.GPU().Maxwell3D(); const auto& regs = gpu.regs; texture_cache.GuardRenderTargets(true); - View color_surface{}; + View color_surface; if (using_color_fb) { color_surface = texture_cache.GetColorBufferSurface(regs.clear_buffers.RT, false); } - View depth_surface{}; + View depth_surface; if (using_depth_fb || using_stencil_fb) { depth_surface = texture_cache.GetDepthBufferSurface(false); } texture_cache.GuardRenderTargets(false); - current_state.draw.draw_framebuffer = clear_framebuffer.handle; - current_state.ApplyFramebufferState(); + FramebufferCacheKey key; + key.colors[0] = color_surface; + key.zeta = depth_surface; - if (color_surface) { - color_surface->Attach(GL_COLOR_ATTACHMENT0, GL_DRAW_FRAMEBUFFER); - } else { - glFramebufferTexture2D(GL_DRAW_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, 0, 0); - } - - if (depth_surface) { - const auto& params = depth_surface->GetSurfaceParams(); - switch (params.type) { - case VideoCore::Surface::SurfaceType::Depth: - depth_surface->Attach(GL_DEPTH_ATTACHMENT, GL_DRAW_FRAMEBUFFER); - glFramebufferTexture2D(GL_DRAW_FRAMEBUFFER, GL_STENCIL_ATTACHMENT, GL_TEXTURE_2D, 0, 0); - break; - case VideoCore::Surface::SurfaceType::DepthStencil: - depth_surface->Attach(GL_DEPTH_STENCIL_ATTACHMENT, GL_DRAW_FRAMEBUFFER); - break; - default: - UNIMPLEMENTED(); - } - } else { - glFramebufferTexture2D(GL_DRAW_FRAMEBUFFER, GL_DEPTH_STENCIL_ATTACHMENT, GL_TEXTURE_2D, 0, - 0); - } + current_state.draw.draw_framebuffer = framebuffer_cache.GetFramebuffer(key); + current_state.ApplyFramebufferState(); } void RasterizerOpenGL::Clear() { @@ -592,8 +582,16 @@ void RasterizerOpenGL::DrawPrelude() { index_buffer_offset = SetupIndexBuffer(); // Prepare packed bindings. - bind_ubo_pushbuffer.Setup(0); - bind_ssbo_pushbuffer.Setup(0); + bind_ubo_pushbuffer.Setup(); + bind_ssbo_pushbuffer.Setup(); + + // Setup emulation uniform buffer. + GLShader::MaxwellUniformData ubo; + ubo.SetFromRegs(gpu); + const auto [buffer, offset] = + buffer_cache.UploadHostMemory(&ubo, sizeof(ubo), device.GetUniformBufferAlignment()); + bind_ubo_pushbuffer.Push(EmulationUniformBlockBinding, buffer, offset, + static_cast<GLsizeiptr>(sizeof(ubo))); // Setup shaders and their used resources. texture_cache.GuardSamplers(true); @@ -726,19 +724,21 @@ bool RasterizerOpenGL::DrawMultiBatch(bool is_indexed) { } void RasterizerOpenGL::DispatchCompute(GPUVAddr code_addr) { - if (!GLAD_GL_ARB_compute_variable_group_size) { - LOG_ERROR(Render_OpenGL, "Compute is currently not supported on this device due to the " - "lack of GL_ARB_compute_variable_group_size"); + if (device.HasBrokenCompute()) { return; } + buffer_cache.Acquire(); + auto kernel = shader_cache.GetComputeKernel(code_addr); - ProgramVariant variant; - variant.texture_buffer_usage = SetupComputeTextures(kernel); + SetupComputeTextures(kernel); SetupComputeImages(kernel); - const auto [program, next_bindings] = kernel->GetProgramHandle(variant); - state.draw.shader_program = program; + const auto& launch_desc = system.GPU().KeplerCompute().launch_description; + const ProgramVariant variant(launch_desc.block_dim_x, launch_desc.block_dim_y, + launch_desc.block_dim_z, launch_desc.shared_alloc, + launch_desc.local_pos_alloc); + state.draw.shader_program = kernel->GetHandle(variant); state.draw.program_pipeline = 0; const std::size_t buffer_size = @@ -746,8 +746,8 @@ void RasterizerOpenGL::DispatchCompute(GPUVAddr code_addr) { (Maxwell::MaxConstBufferSize + device.GetUniformBufferAlignment()); buffer_cache.Map(buffer_size); - bind_ubo_pushbuffer.Setup(0); - bind_ssbo_pushbuffer.Setup(0); + bind_ubo_pushbuffer.Setup(); + bind_ssbo_pushbuffer.Setup(); SetupComputeConstBuffers(kernel); SetupComputeGlobalMemory(kernel); @@ -762,10 +762,7 @@ void RasterizerOpenGL::DispatchCompute(GPUVAddr code_addr) { state.ApplyShaderProgram(); state.ApplyProgramPipeline(); - const auto& launch_desc = system.GPU().KeplerCompute().launch_description; - glDispatchComputeGroupSizeARB(launch_desc.grid_dim_x, launch_desc.grid_dim_y, - launch_desc.grid_dim_z, launch_desc.block_dim_x, - launch_desc.block_dim_y, launch_desc.block_dim_z); + glDispatchCompute(launch_desc.grid_dim_x, launch_desc.grid_dim_y, launch_desc.grid_dim_z); } void RasterizerOpenGL::FlushAll() {} @@ -821,7 +818,7 @@ bool RasterizerOpenGL::AccelerateDisplay(const Tegra::FramebufferConfig& config, MICROPROFILE_SCOPE(OpenGL_CacheManagement); const auto surface{ - texture_cache.TryFindFramebufferSurface(Memory::GetPointer(framebuffer_addr))}; + texture_cache.TryFindFramebufferSurface(system.Memory().GetPointer(framebuffer_addr))}; if (!surface) { return {}; } @@ -834,7 +831,7 @@ bool RasterizerOpenGL::AccelerateDisplay(const Tegra::FramebufferConfig& config, ASSERT_MSG(params.height == config.height, "Framebuffer height is different"); if (params.pixel_format != pixel_format) { - LOG_WARNING(Render_OpenGL, "Framebuffer pixel_format is different"); + LOG_DEBUG(Render_OpenGL, "Framebuffer pixel_format is different"); } screen_info.display_texture = surface->GetTexture(); @@ -843,20 +840,23 @@ bool RasterizerOpenGL::AccelerateDisplay(const Tegra::FramebufferConfig& config, return true; } -void RasterizerOpenGL::SetupDrawConstBuffers(Tegra::Engines::Maxwell3D::Regs::ShaderStage stage, - const Shader& shader) { +void RasterizerOpenGL::SetupDrawConstBuffers(std::size_t stage_index, const Shader& shader) { MICROPROFILE_SCOPE(OpenGL_UBO); const auto& stages = system.GPU().Maxwell3D().state.shader_stages; - const auto& shader_stage = stages[static_cast<std::size_t>(stage)]; + const auto& shader_stage = stages[stage_index]; + + u32 binding = device.GetBaseBindings(stage_index).uniform_buffer; for (const auto& entry : shader->GetShaderEntries().const_buffers) { const auto& buffer = shader_stage.const_buffers[entry.GetIndex()]; - SetupConstBuffer(buffer, entry); + SetupConstBuffer(binding++, buffer, entry); } } void RasterizerOpenGL::SetupComputeConstBuffers(const Shader& kernel) { MICROPROFILE_SCOPE(OpenGL_UBO); const auto& launch_desc = system.GPU().KeplerCompute().launch_description; + + u32 binding = 0; for (const auto& entry : kernel->GetShaderEntries().const_buffers) { const auto& config = launch_desc.const_buffer_config[entry.GetIndex()]; const std::bitset<8> mask = launch_desc.const_buffer_enable_mask.Value(); @@ -864,15 +864,16 @@ void RasterizerOpenGL::SetupComputeConstBuffers(const Shader& kernel) { buffer.address = config.Address(); buffer.size = config.size; buffer.enabled = mask[entry.GetIndex()]; - SetupConstBuffer(buffer, entry); + SetupConstBuffer(binding++, buffer, entry); } } -void RasterizerOpenGL::SetupConstBuffer(const Tegra::Engines::ConstBufferInfo& buffer, +void RasterizerOpenGL::SetupConstBuffer(u32 binding, const Tegra::Engines::ConstBufferInfo& buffer, const GLShader::ConstBufferEntry& entry) { if (!buffer.enabled) { // Set values to zero to unbind buffers - bind_ubo_pushbuffer.Push(buffer_cache.GetEmptyBuffer(sizeof(float)), 0, sizeof(float)); + bind_ubo_pushbuffer.Push(binding, buffer_cache.GetEmptyBuffer(sizeof(float)), 0, + sizeof(float)); return; } @@ -883,19 +884,20 @@ void RasterizerOpenGL::SetupConstBuffer(const Tegra::Engines::ConstBufferInfo& b const auto alignment = device.GetUniformBufferAlignment(); const auto [cbuf, offset] = buffer_cache.UploadMemory(buffer.address, size, alignment, false, device.HasFastBufferSubData()); - bind_ubo_pushbuffer.Push(cbuf, offset, size); + bind_ubo_pushbuffer.Push(binding, cbuf, offset, size); } -void RasterizerOpenGL::SetupDrawGlobalMemory(Tegra::Engines::Maxwell3D::Regs::ShaderStage stage, - const Shader& shader) { +void RasterizerOpenGL::SetupDrawGlobalMemory(std::size_t stage_index, const Shader& shader) { auto& gpu{system.GPU()}; auto& memory_manager{gpu.MemoryManager()}; - const auto cbufs{gpu.Maxwell3D().state.shader_stages[static_cast<std::size_t>(stage)]}; + const auto cbufs{gpu.Maxwell3D().state.shader_stages[stage_index]}; + + u32 binding = device.GetBaseBindings(stage_index).shader_storage_buffer; for (const auto& entry : shader->GetShaderEntries().global_memory_entries) { const auto addr{cbufs.const_buffers[entry.GetCbufIndex()].address + entry.GetCbufOffset()}; const auto gpu_addr{memory_manager.Read<u64>(addr)}; const auto size{memory_manager.Read<u32>(addr + 8)}; - SetupGlobalMemory(entry, gpu_addr, size); + SetupGlobalMemory(binding++, entry, gpu_addr, size); } } @@ -903,120 +905,82 @@ void RasterizerOpenGL::SetupComputeGlobalMemory(const Shader& kernel) { auto& gpu{system.GPU()}; auto& memory_manager{gpu.MemoryManager()}; const auto cbufs{gpu.KeplerCompute().launch_description.const_buffer_config}; + + u32 binding = 0; for (const auto& entry : kernel->GetShaderEntries().global_memory_entries) { const auto addr{cbufs[entry.GetCbufIndex()].Address() + entry.GetCbufOffset()}; const auto gpu_addr{memory_manager.Read<u64>(addr)}; const auto size{memory_manager.Read<u32>(addr + 8)}; - SetupGlobalMemory(entry, gpu_addr, size); + SetupGlobalMemory(binding++, entry, gpu_addr, size); } } -void RasterizerOpenGL::SetupGlobalMemory(const GLShader::GlobalMemoryEntry& entry, +void RasterizerOpenGL::SetupGlobalMemory(u32 binding, const GLShader::GlobalMemoryEntry& entry, GPUVAddr gpu_addr, std::size_t size) { const auto alignment{device.GetShaderStorageBufferAlignment()}; const auto [ssbo, buffer_offset] = buffer_cache.UploadMemory(gpu_addr, size, alignment, entry.IsWritten()); - bind_ssbo_pushbuffer.Push(ssbo, buffer_offset, static_cast<GLsizeiptr>(size)); + bind_ssbo_pushbuffer.Push(binding, ssbo, buffer_offset, static_cast<GLsizeiptr>(size)); } -TextureBufferUsage RasterizerOpenGL::SetupDrawTextures(Maxwell::ShaderStage stage, - const Shader& shader, - BaseBindings base_bindings) { +void RasterizerOpenGL::SetupDrawTextures(std::size_t stage_index, const Shader& shader) { MICROPROFILE_SCOPE(OpenGL_Texture); - const auto& gpu = system.GPU(); - const auto& maxwell3d = gpu.Maxwell3D(); - const auto& entries = shader->GetShaderEntries().samplers; - - ASSERT_MSG(base_bindings.sampler + entries.size() <= std::size(state.textures), - "Exceeded the number of active textures."); - - TextureBufferUsage texture_buffer_usage{0}; - - for (u32 bindpoint = 0; bindpoint < entries.size(); ++bindpoint) { - const auto& entry = entries[bindpoint]; - const auto texture = [&] { - if (!entry.IsBindless()) { - return maxwell3d.GetStageTexture(stage, entry.GetOffset()); - } - const auto shader_type = static_cast<Tegra::Engines::ShaderType>(stage); - const Tegra::Texture::TextureHandle tex_handle = - maxwell3d.AccessConstBuffer32(shader_type, entry.GetBuffer(), entry.GetOffset()); - return maxwell3d.GetTextureInfo(tex_handle); - }(); - - if (SetupTexture(base_bindings.sampler + bindpoint, texture, entry)) { - texture_buffer_usage.set(bindpoint); - } + const auto& maxwell3d = system.GPU().Maxwell3D(); + u32 binding = device.GetBaseBindings(stage_index).sampler; + for (const auto& entry : shader->GetShaderEntries().samplers) { + const auto shader_type = static_cast<Tegra::Engines::ShaderType>(stage_index); + const auto texture = GetTextureInfo(maxwell3d, entry, shader_type); + SetupTexture(binding++, texture, entry); } - - return texture_buffer_usage; } -TextureBufferUsage RasterizerOpenGL::SetupComputeTextures(const Shader& kernel) { +void RasterizerOpenGL::SetupComputeTextures(const Shader& kernel) { MICROPROFILE_SCOPE(OpenGL_Texture); const auto& compute = system.GPU().KeplerCompute(); - const auto& entries = kernel->GetShaderEntries().samplers; - - ASSERT_MSG(entries.size() <= std::size(state.textures), - "Exceeded the number of active textures."); - - TextureBufferUsage texture_buffer_usage{0}; - - for (u32 bindpoint = 0; bindpoint < entries.size(); ++bindpoint) { - const auto& entry = entries[bindpoint]; - const auto texture = [&] { - if (!entry.IsBindless()) { - return compute.GetTexture(entry.GetOffset()); - } - const Tegra::Texture::TextureHandle tex_handle = compute.AccessConstBuffer32( - Tegra::Engines::ShaderType::Compute, entry.GetBuffer(), entry.GetOffset()); - return compute.GetTextureInfo(tex_handle); - }(); - - if (SetupTexture(bindpoint, texture, entry)) { - texture_buffer_usage.set(bindpoint); - } + u32 binding = 0; + for (const auto& entry : kernel->GetShaderEntries().samplers) { + const auto texture = GetTextureInfo(compute, entry, Tegra::Engines::ShaderType::Compute); + SetupTexture(binding++, texture, entry); } - - return texture_buffer_usage; } -bool RasterizerOpenGL::SetupTexture(u32 binding, const Tegra::Texture::FullTextureInfo& texture, +void RasterizerOpenGL::SetupTexture(u32 binding, const Tegra::Texture::FullTextureInfo& texture, const GLShader::SamplerEntry& entry) { - state.samplers[binding] = sampler_cache.GetSampler(texture.tsc); - const auto view = texture_cache.GetTextureSurface(texture.tic, entry); if (!view) { // Can occur when texture addr is null or its memory is unmapped/invalid + state.samplers[binding] = 0; state.textures[binding] = 0; - return false; + return; } state.textures[binding] = view->GetTexture(); if (view->GetSurfaceParams().IsBuffer()) { - return true; + return; } + state.samplers[binding] = sampler_cache.GetSampler(texture.tsc); // Apply swizzle to textures that are not buffers. view->ApplySwizzle(texture.tic.x_source, texture.tic.y_source, texture.tic.z_source, texture.tic.w_source); - return false; +} + +void RasterizerOpenGL::SetupDrawImages(std::size_t stage_index, const Shader& shader) { + const auto& maxwell3d = system.GPU().Maxwell3D(); + u32 binding = device.GetBaseBindings(stage_index).image; + for (const auto& entry : shader->GetShaderEntries().images) { + const auto shader_type = static_cast<Tegra::Engines::ShaderType>(stage_index); + const auto tic = GetTextureInfo(maxwell3d, entry, shader_type).tic; + SetupImage(binding++, tic, entry); + } } void RasterizerOpenGL::SetupComputeImages(const Shader& shader) { const auto& compute = system.GPU().KeplerCompute(); - const auto& entries = shader->GetShaderEntries().images; - for (u32 bindpoint = 0; bindpoint < entries.size(); ++bindpoint) { - const auto& entry = entries[bindpoint]; - const auto tic = [&] { - if (!entry.IsBindless()) { - return compute.GetTexture(entry.GetOffset()).tic; - } - const Tegra::Texture::TextureHandle tex_handle = compute.AccessConstBuffer32( - Tegra::Engines::ShaderType::Compute, entry.GetBuffer(), entry.GetOffset()); - return compute.GetTextureInfo(tex_handle).tic; - }(); - SetupImage(bindpoint, tic, entry); + u32 binding = 0; + for (const auto& entry : shader->GetShaderEntries().images) { + const auto tic = GetTextureInfo(compute, entry, Tegra::Engines::ShaderType::Compute).tic; + SetupImage(binding++, tic, entry); } } @@ -1055,6 +1019,15 @@ void RasterizerOpenGL::SyncViewport(OpenGLState& current_state) { } state.depth_clamp.far_plane = regs.view_volume_clip_control.depth_clamp_far != 0; state.depth_clamp.near_plane = regs.view_volume_clip_control.depth_clamp_near != 0; + + bool flip_y = false; + if (regs.viewport_transform[0].scale_y < 0.0) { + flip_y = !flip_y; + } + if (regs.screen_y_control.y_negate != 0) { + flip_y = !flip_y; + } + state.clip_control.origin = flip_y ? GL_UPPER_LEFT : GL_LOWER_LEFT; } void RasterizerOpenGL::SyncClipEnabled( @@ -1077,28 +1050,14 @@ void RasterizerOpenGL::SyncClipCoef() { } void RasterizerOpenGL::SyncCullMode() { - auto& maxwell3d = system.GPU().Maxwell3D(); - - const auto& regs = maxwell3d.regs; + const auto& regs = system.GPU().Maxwell3D().regs; state.cull.enabled = regs.cull.enabled != 0; if (state.cull.enabled) { - state.cull.front_face = MaxwellToGL::FrontFace(regs.cull.front_face); state.cull.mode = MaxwellToGL::CullFace(regs.cull.cull_face); - - const bool flip_triangles{regs.screen_y_control.triangle_rast_flip == 0 || - regs.viewport_transform[0].scale_y < 0.0f}; - - // If the GPU is configured to flip the rasterized triangles, then we need to flip the - // notion of front and back. Note: We flip the triangles when the value of the register is 0 - // because OpenGL already does it for us. - if (flip_triangles) { - if (state.cull.front_face == GL_CCW) - state.cull.front_face = GL_CW; - else if (state.cull.front_face == GL_CW) - state.cull.front_face = GL_CCW; - } } + + state.cull.front_face = MaxwellToGL::FrontFace(regs.cull.front_face); } void RasterizerOpenGL::SyncPrimitiveRestart() { diff --git a/src/video_core/renderer_opengl/gl_rasterizer.h b/src/video_core/renderer_opengl/gl_rasterizer.h index bd6fe5c3a..04c1ca551 100644 --- a/src/video_core/renderer_opengl/gl_rasterizer.h +++ b/src/video_core/renderer_opengl/gl_rasterizer.h @@ -83,42 +83,41 @@ private: bool using_depth_fb, bool using_stencil_fb); /// Configures the current constbuffers to use for the draw command. - void SetupDrawConstBuffers(Tegra::Engines::Maxwell3D::Regs::ShaderStage stage, - const Shader& shader); + void SetupDrawConstBuffers(std::size_t stage_index, const Shader& shader); /// Configures the current constbuffers to use for the kernel invocation. void SetupComputeConstBuffers(const Shader& kernel); /// Configures a constant buffer. - void SetupConstBuffer(const Tegra::Engines::ConstBufferInfo& buffer, + void SetupConstBuffer(u32 binding, const Tegra::Engines::ConstBufferInfo& buffer, const GLShader::ConstBufferEntry& entry); /// Configures the current global memory entries to use for the draw command. - void SetupDrawGlobalMemory(Tegra::Engines::Maxwell3D::Regs::ShaderStage stage, - const Shader& shader); + void SetupDrawGlobalMemory(std::size_t stage_index, const Shader& shader); /// Configures the current global memory entries to use for the kernel invocation. void SetupComputeGlobalMemory(const Shader& kernel); /// Configures a constant buffer. - void SetupGlobalMemory(const GLShader::GlobalMemoryEntry& entry, GPUVAddr gpu_addr, + void SetupGlobalMemory(u32 binding, const GLShader::GlobalMemoryEntry& entry, GPUVAddr gpu_addr, std::size_t size); /// Syncs all the state, shaders, render targets and textures setting before a draw call. void DrawPrelude(); - /// Configures the current textures to use for the draw command. Returns shaders texture buffer - /// usage. - TextureBufferUsage SetupDrawTextures(Tegra::Engines::Maxwell3D::Regs::ShaderStage stage, - const Shader& shader, BaseBindings base_bindings); + /// Configures the current textures to use for the draw command. + void SetupDrawTextures(std::size_t stage_index, const Shader& shader); - /// Configures the textures used in a compute shader. Returns texture buffer usage. - TextureBufferUsage SetupComputeTextures(const Shader& kernel); + /// Configures the textures used in a compute shader. + void SetupComputeTextures(const Shader& kernel); - /// Configures a texture. Returns true when the texture is a texture buffer. - bool SetupTexture(u32 binding, const Tegra::Texture::FullTextureInfo& texture, + /// Configures a texture. + void SetupTexture(u32 binding, const Tegra::Texture::FullTextureInfo& texture, const GLShader::SamplerEntry& entry); + /// Configures images in a graphics shader. + void SetupDrawImages(std::size_t stage_index, const Shader& shader); + /// Configures images in a compute shader. void SetupComputeImages(const Shader& shader); @@ -224,8 +223,6 @@ private: enum class AccelDraw { Disabled, Arrays, Indexed }; AccelDraw accelerate_draw = AccelDraw::Disabled; - - OGLFramebuffer clear_framebuffer; }; } // namespace OpenGL diff --git a/src/video_core/renderer_opengl/gl_shader_cache.cpp b/src/video_core/renderer_opengl/gl_shader_cache.cpp index 04a239a39..370bdf052 100644 --- a/src/video_core/renderer_opengl/gl_shader_cache.cpp +++ b/src/video_core/renderer_opengl/gl_shader_cache.cpp @@ -8,12 +8,15 @@ #include <thread> #include <unordered_set> #include <boost/functional/hash.hpp> +#include "common/alignment.h" #include "common/assert.h" +#include "common/logging/log.h" #include "common/scope_exit.h" #include "core/core.h" #include "core/frontend/emu_window.h" #include "video_core/engines/kepler_compute.h" #include "video_core/engines/maxwell_3d.h" +#include "video_core/engines/shader_type.h" #include "video_core/memory_manager.h" #include "video_core/renderer_opengl/gl_rasterizer.h" #include "video_core/renderer_opengl/gl_shader_cache.h" @@ -82,28 +85,26 @@ std::size_t CalculateProgramSize(const GLShader::ProgramCode& program) { /// Gets the shader program code from memory for the specified address ProgramCode GetShaderCode(Tegra::MemoryManager& memory_manager, const GPUVAddr gpu_addr, const u8* host_ptr) { - ProgramCode program_code(VideoCommon::Shader::MAX_PROGRAM_LENGTH); + ProgramCode code(VideoCommon::Shader::MAX_PROGRAM_LENGTH); ASSERT_OR_EXECUTE(host_ptr != nullptr, { - std::fill(program_code.begin(), program_code.end(), 0); - return program_code; + std::fill(code.begin(), code.end(), 0); + return code; }); - memory_manager.ReadBlockUnsafe(gpu_addr, program_code.data(), - program_code.size() * sizeof(u64)); - program_code.resize(CalculateProgramSize(program_code)); - return program_code; + memory_manager.ReadBlockUnsafe(gpu_addr, code.data(), code.size() * sizeof(u64)); + code.resize(CalculateProgramSize(code)); + return code; } /// Gets the shader type from a Maxwell program type -constexpr GLenum GetShaderType(ProgramType program_type) { - switch (program_type) { - case ProgramType::VertexA: - case ProgramType::VertexB: +constexpr GLenum GetGLShaderType(ShaderType shader_type) { + switch (shader_type) { + case ShaderType::Vertex: return GL_VERTEX_SHADER; - case ProgramType::Geometry: + case ShaderType::Geometry: return GL_GEOMETRY_SHADER; - case ProgramType::Fragment: + case ShaderType::Fragment: return GL_FRAGMENT_SHADER; - case ProgramType::Compute: + case ShaderType::Compute: return GL_COMPUTE_SHADER; default: return GL_NONE; @@ -133,30 +134,11 @@ constexpr std::tuple<const char*, const char*, u32> GetPrimitiveDescription(GLen } } -ProgramType GetProgramType(Maxwell::ShaderProgram program) { - switch (program) { - case Maxwell::ShaderProgram::VertexA: - return ProgramType::VertexA; - case Maxwell::ShaderProgram::VertexB: - return ProgramType::VertexB; - case Maxwell::ShaderProgram::TesselationControl: - return ProgramType::TessellationControl; - case Maxwell::ShaderProgram::TesselationEval: - return ProgramType::TessellationEval; - case Maxwell::ShaderProgram::Geometry: - return ProgramType::Geometry; - case Maxwell::ShaderProgram::Fragment: - return ProgramType::Fragment; - } - UNREACHABLE(); - return {}; -} - /// Hashes one (or two) program streams -u64 GetUniqueIdentifier(ProgramType program_type, const ProgramCode& code, +u64 GetUniqueIdentifier(ShaderType shader_type, bool is_a, const ProgramCode& code, const ProgramCode& code_b) { u64 unique_identifier = boost::hash_value(code); - if (program_type == ProgramType::VertexA) { + if (is_a) { // VertexA programs include two programs boost::hash_combine(unique_identifier, boost::hash_value(code_b)); } @@ -164,79 +146,74 @@ u64 GetUniqueIdentifier(ProgramType program_type, const ProgramCode& code, } /// Creates an unspecialized program from code streams -std::string GenerateGLSL(const Device& device, ProgramType program_type, const ShaderIR& ir, +std::string GenerateGLSL(const Device& device, ShaderType shader_type, const ShaderIR& ir, const std::optional<ShaderIR>& ir_b) { - switch (program_type) { - case ProgramType::VertexA: - case ProgramType::VertexB: + switch (shader_type) { + case ShaderType::Vertex: return GLShader::GenerateVertexShader(device, ir, ir_b ? &*ir_b : nullptr); - case ProgramType::Geometry: + case ShaderType::Geometry: return GLShader::GenerateGeometryShader(device, ir); - case ProgramType::Fragment: + case ShaderType::Fragment: return GLShader::GenerateFragmentShader(device, ir); - case ProgramType::Compute: + case ShaderType::Compute: return GLShader::GenerateComputeShader(device, ir); default: - UNIMPLEMENTED_MSG("Unimplemented program_type={}", static_cast<u32>(program_type)); + UNIMPLEMENTED_MSG("Unimplemented shader_type={}", static_cast<u32>(shader_type)); return {}; } } -constexpr const char* GetProgramTypeName(ProgramType program_type) { - switch (program_type) { - case ProgramType::VertexA: - case ProgramType::VertexB: +constexpr const char* GetShaderTypeName(ShaderType shader_type) { + switch (shader_type) { + case ShaderType::Vertex: return "VS"; - case ProgramType::TessellationControl: - return "TCS"; - case ProgramType::TessellationEval: - return "TES"; - case ProgramType::Geometry: + case ShaderType::TesselationControl: + return "HS"; + case ShaderType::TesselationEval: + return "DS"; + case ShaderType::Geometry: return "GS"; - case ProgramType::Fragment: + case ShaderType::Fragment: return "FS"; - case ProgramType::Compute: + case ShaderType::Compute: return "CS"; } return "UNK"; } -Tegra::Engines::ShaderType GetEnginesShaderType(ProgramType program_type) { +constexpr ShaderType GetShaderType(Maxwell::ShaderProgram program_type) { switch (program_type) { - case ProgramType::VertexA: - case ProgramType::VertexB: - return Tegra::Engines::ShaderType::Vertex; - case ProgramType::TessellationControl: - return Tegra::Engines::ShaderType::TesselationControl; - case ProgramType::TessellationEval: - return Tegra::Engines::ShaderType::TesselationEval; - case ProgramType::Geometry: - return Tegra::Engines::ShaderType::Geometry; - case ProgramType::Fragment: - return Tegra::Engines::ShaderType::Fragment; - case ProgramType::Compute: - return Tegra::Engines::ShaderType::Compute; - } - UNREACHABLE(); + case Maxwell::ShaderProgram::VertexA: + case Maxwell::ShaderProgram::VertexB: + return ShaderType::Vertex; + case Maxwell::ShaderProgram::TesselationControl: + return ShaderType::TesselationControl; + case Maxwell::ShaderProgram::TesselationEval: + return ShaderType::TesselationEval; + case Maxwell::ShaderProgram::Geometry: + return ShaderType::Geometry; + case Maxwell::ShaderProgram::Fragment: + return ShaderType::Fragment; + } return {}; } -std::string GetShaderId(u64 unique_identifier, ProgramType program_type) { - return fmt::format("{}{:016X}", GetProgramTypeName(program_type), unique_identifier); +std::string GetShaderId(u64 unique_identifier, ShaderType shader_type) { + return fmt::format("{}{:016X}", GetShaderTypeName(shader_type), unique_identifier); } -Tegra::Engines::ConstBufferEngineInterface& GetConstBufferEngineInterface( - Core::System& system, ProgramType program_type) { - if (program_type == ProgramType::Compute) { +Tegra::Engines::ConstBufferEngineInterface& GetConstBufferEngineInterface(Core::System& system, + ShaderType shader_type) { + if (shader_type == ShaderType::Compute) { return system.GPU().KeplerCompute(); } else { return system.GPU().Maxwell3D(); } } -std::unique_ptr<ConstBufferLocker> MakeLocker(Core::System& system, ProgramType program_type) { - return std::make_unique<ConstBufferLocker>(GetEnginesShaderType(program_type), - GetConstBufferEngineInterface(system, program_type)); +std::unique_ptr<ConstBufferLocker> MakeLocker(Core::System& system, ShaderType shader_type) { + return std::make_unique<ConstBufferLocker>(shader_type, + GetConstBufferEngineInterface(system, shader_type)); } void FillLocker(ConstBufferLocker& locker, const ShaderDiskCacheUsage& usage) { @@ -253,33 +230,26 @@ void FillLocker(ConstBufferLocker& locker, const ShaderDiskCacheUsage& usage) { } } -CachedProgram BuildShader(const Device& device, u64 unique_identifier, ProgramType program_type, - const ProgramCode& program_code, const ProgramCode& program_code_b, - const ProgramVariant& variant, ConstBufferLocker& locker, +CachedProgram BuildShader(const Device& device, u64 unique_identifier, ShaderType shader_type, + const ProgramCode& code, const ProgramCode& code_b, + ConstBufferLocker& locker, const ProgramVariant& variant, bool hint_retrievable = false) { - LOG_INFO(Render_OpenGL, "called. {}", GetShaderId(unique_identifier, program_type)); + LOG_INFO(Render_OpenGL, "called. {}", GetShaderId(unique_identifier, shader_type)); - const bool is_compute = program_type == ProgramType::Compute; + const bool is_compute = shader_type == ShaderType::Compute; const u32 main_offset = is_compute ? KERNEL_MAIN_OFFSET : STAGE_MAIN_OFFSET; - const ShaderIR ir(program_code, main_offset, COMPILER_SETTINGS, locker); + const ShaderIR ir(code, main_offset, COMPILER_SETTINGS, locker); std::optional<ShaderIR> ir_b; - if (!program_code_b.empty()) { - ir_b.emplace(program_code_b, main_offset, COMPILER_SETTINGS, locker); + if (!code_b.empty()) { + ir_b.emplace(code_b, main_offset, COMPILER_SETTINGS, locker); } const auto entries = GLShader::GetEntries(ir); - auto base_bindings{variant.base_bindings}; - const auto primitive_mode{variant.primitive_mode}; - const auto texture_buffer_usage{variant.texture_buffer_usage}; - std::string source = fmt::format(R"(// {} #version 430 core #extension GL_ARB_separate_shader_objects : enable )", - GetShaderId(unique_identifier, program_type)); - if (is_compute) { - source += "#extension GL_ARB_compute_variable_group_size : require\n"; - } + GetShaderId(unique_identifier, shader_type)); if (device.HasShaderBallot()) { source += "#extension GL_ARB_shader_ballot : require\n"; } @@ -296,54 +266,35 @@ CachedProgram BuildShader(const Device& device, u64 unique_identifier, ProgramTy } source += '\n'; - if (!is_compute) { - source += fmt::format("#define EMULATION_UBO_BINDING {}\n", base_bindings.cbuf++); - } + if (shader_type == ShaderType::Geometry) { + const auto [glsl_topology, debug_name, max_vertices] = + GetPrimitiveDescription(variant.primitive_mode); - for (const auto& cbuf : entries.const_buffers) { - source += - fmt::format("#define CBUF_BINDING_{} {}\n", cbuf.GetIndex(), base_bindings.cbuf++); + source += fmt::format("layout ({}) in;\n\n", glsl_topology); + source += fmt::format("#define MAX_VERTEX_INPUT {}\n", max_vertices); } - for (const auto& gmem : entries.global_memory_entries) { - source += fmt::format("#define GMEM_BINDING_{}_{} {}\n", gmem.GetCbufIndex(), - gmem.GetCbufOffset(), base_bindings.gmem++); - } - for (const auto& sampler : entries.samplers) { - source += fmt::format("#define SAMPLER_BINDING_{} {}\n", sampler.GetIndex(), - base_bindings.sampler++); - } - for (const auto& image : entries.images) { + if (shader_type == ShaderType::Compute) { source += - fmt::format("#define IMAGE_BINDING_{} {}\n", image.GetIndex(), base_bindings.image++); - } + fmt::format("layout (local_size_x = {}, local_size_y = {}, local_size_z = {}) in;\n", + variant.block_x, variant.block_y, variant.block_z); - // Transform 1D textures to texture samplers by declaring its preprocessor macros. - for (std::size_t i = 0; i < texture_buffer_usage.size(); ++i) { - if (!texture_buffer_usage.test(i)) { - continue; + if (variant.shared_memory_size > 0) { + // TODO(Rodrigo): We should divide by four here, but having a larger shared memory pool + // avoids out of bound stores. Find out why shared memory size is being invalid. + source += fmt::format("shared uint smem[{}];", variant.shared_memory_size); } - source += fmt::format("#define SAMPLER_{}_IS_BUFFER\n", i); - } - if (texture_buffer_usage.any()) { - source += '\n'; - } - if (program_type == ProgramType::Geometry) { - const auto [glsl_topology, debug_name, max_vertices] = - GetPrimitiveDescription(primitive_mode); - - source += "layout (" + std::string(glsl_topology) + ") in;\n\n"; - source += "#define MAX_VERTEX_INPUT " + std::to_string(max_vertices) + '\n'; - } - if (program_type == ProgramType::Compute) { - source += "layout (local_size_variable) in;\n"; + if (variant.local_memory_size > 0) { + source += fmt::format("#define LOCAL_MEMORY_SIZE {}", + Common::AlignUp(variant.local_memory_size, 4) / 4); + } } source += '\n'; - source += GenerateGLSL(device, program_type, ir, ir_b); + source += GenerateGLSL(device, shader_type, ir, ir_b); OGLShader shader; - shader.Create(source.c_str(), GetShaderType(program_type)); + shader.Create(source.c_str(), GetGLShaderType(shader_type)); auto program = std::make_shared<OGLProgram>(); program->Create(true, hint_retrievable, shader.handle); @@ -366,18 +317,16 @@ std::unordered_set<GLenum> GetSupportedFormats() { } // Anonymous namespace -CachedShader::CachedShader(const ShaderParameters& params, ProgramType program_type, - GLShader::ShaderEntries entries, ProgramCode program_code, - ProgramCode program_code_b) - : RasterizerCacheObject{params.host_ptr}, system{params.system}, - disk_cache{params.disk_cache}, device{params.device}, cpu_addr{params.cpu_addr}, - unique_identifier{params.unique_identifier}, program_type{program_type}, entries{entries}, - program_code{std::move(program_code)}, program_code_b{std::move(program_code_b)} { +CachedShader::CachedShader(const ShaderParameters& params, ShaderType shader_type, + GLShader::ShaderEntries entries, ProgramCode code, ProgramCode code_b) + : RasterizerCacheObject{params.host_ptr}, system{params.system}, disk_cache{params.disk_cache}, + device{params.device}, cpu_addr{params.cpu_addr}, unique_identifier{params.unique_identifier}, + shader_type{shader_type}, entries{entries}, code{std::move(code)}, code_b{std::move(code_b)} { if (!params.precompiled_variants) { return; } for (const auto& pair : *params.precompiled_variants) { - auto locker = MakeLocker(system, program_type); + auto locker = MakeLocker(system, shader_type); const auto& usage = pair->first; FillLocker(*locker, usage); @@ -398,94 +347,83 @@ CachedShader::CachedShader(const ShaderParameters& params, ProgramType program_t } Shader CachedShader::CreateStageFromMemory(const ShaderParameters& params, - Maxwell::ShaderProgram program_type, - ProgramCode program_code, ProgramCode program_code_b) { - params.disk_cache.SaveRaw(ShaderDiskCacheRaw( - params.unique_identifier, GetProgramType(program_type), program_code, program_code_b)); - - ConstBufferLocker locker(GetEnginesShaderType(GetProgramType(program_type)), - params.system.GPU().Maxwell3D()); - const ShaderIR ir(program_code, STAGE_MAIN_OFFSET, COMPILER_SETTINGS, locker); + Maxwell::ShaderProgram program_type, ProgramCode code, + ProgramCode code_b) { + const auto shader_type = GetShaderType(program_type); + params.disk_cache.SaveRaw( + ShaderDiskCacheRaw(params.unique_identifier, shader_type, code, code_b)); + + ConstBufferLocker locker(shader_type, params.system.GPU().Maxwell3D()); + const ShaderIR ir(code, STAGE_MAIN_OFFSET, COMPILER_SETTINGS, locker); // TODO(Rodrigo): Handle VertexA shaders // std::optional<ShaderIR> ir_b; - // if (!program_code_b.empty()) { - // ir_b.emplace(program_code_b, STAGE_MAIN_OFFSET); + // if (!code_b.empty()) { + // ir_b.emplace(code_b, STAGE_MAIN_OFFSET); // } - return std::shared_ptr<CachedShader>( - new CachedShader(params, GetProgramType(program_type), GLShader::GetEntries(ir), - std::move(program_code), std::move(program_code_b))); + return std::shared_ptr<CachedShader>(new CachedShader( + params, shader_type, GLShader::GetEntries(ir), std::move(code), std::move(code_b))); } Shader CachedShader::CreateKernelFromMemory(const ShaderParameters& params, ProgramCode code) { params.disk_cache.SaveRaw( - ShaderDiskCacheRaw(params.unique_identifier, ProgramType::Compute, code)); + ShaderDiskCacheRaw(params.unique_identifier, ShaderType::Compute, code)); ConstBufferLocker locker(Tegra::Engines::ShaderType::Compute, params.system.GPU().KeplerCompute()); const ShaderIR ir(code, KERNEL_MAIN_OFFSET, COMPILER_SETTINGS, locker); return std::shared_ptr<CachedShader>(new CachedShader( - params, ProgramType::Compute, GLShader::GetEntries(ir), std::move(code), {})); + params, ShaderType::Compute, GLShader::GetEntries(ir), std::move(code), {})); } Shader CachedShader::CreateFromCache(const ShaderParameters& params, const UnspecializedShader& unspecialized) { - return std::shared_ptr<CachedShader>(new CachedShader(params, unspecialized.program_type, + return std::shared_ptr<CachedShader>(new CachedShader(params, unspecialized.type, unspecialized.entries, unspecialized.code, unspecialized.code_b)); } -std::tuple<GLuint, BaseBindings> CachedShader::GetProgramHandle(const ProgramVariant& variant) { - UpdateVariant(); +GLuint CachedShader::GetHandle(const ProgramVariant& variant) { + EnsureValidLockerVariant(); - const auto [entry, is_cache_miss] = curr_variant->programs.try_emplace(variant); + const auto [entry, is_cache_miss] = curr_locker_variant->programs.try_emplace(variant); auto& program = entry->second; - if (is_cache_miss) { - program = BuildShader(device, unique_identifier, program_type, program_code, program_code_b, - variant, *curr_variant->locker); - disk_cache.SaveUsage(GetUsage(variant, *curr_variant->locker)); - - LabelGLObject(GL_PROGRAM, program->handle, cpu_addr); + if (!is_cache_miss) { + return program->handle; } - auto base_bindings = variant.base_bindings; - base_bindings.cbuf += static_cast<u32>(entries.const_buffers.size()); - if (program_type != ProgramType::Compute) { - base_bindings.cbuf += STAGE_RESERVED_UBOS; - } - base_bindings.gmem += static_cast<u32>(entries.global_memory_entries.size()); - base_bindings.sampler += static_cast<u32>(entries.samplers.size()); + program = BuildShader(device, unique_identifier, shader_type, code, code_b, + *curr_locker_variant->locker, variant); + disk_cache.SaveUsage(GetUsage(variant, *curr_locker_variant->locker)); - return {program->handle, base_bindings}; + LabelGLObject(GL_PROGRAM, program->handle, cpu_addr); + return program->handle; } -void CachedShader::UpdateVariant() { - if (curr_variant && !curr_variant->locker->IsConsistent()) { - curr_variant = nullptr; +bool CachedShader::EnsureValidLockerVariant() { + const auto previous_variant = curr_locker_variant; + if (curr_locker_variant && !curr_locker_variant->locker->IsConsistent()) { + curr_locker_variant = nullptr; } - if (!curr_variant) { + if (!curr_locker_variant) { for (auto& variant : locker_variants) { if (variant->locker->IsConsistent()) { - curr_variant = variant.get(); + curr_locker_variant = variant.get(); } } } - if (!curr_variant) { + if (!curr_locker_variant) { auto& new_variant = locker_variants.emplace_back(); new_variant = std::make_unique<LockerVariant>(); - new_variant->locker = MakeLocker(system, program_type); - curr_variant = new_variant.get(); + new_variant->locker = MakeLocker(system, shader_type); + curr_locker_variant = new_variant.get(); } + return previous_variant == curr_locker_variant; } ShaderDiskCacheUsage CachedShader::GetUsage(const ProgramVariant& variant, const ConstBufferLocker& locker) const { - ShaderDiskCacheUsage usage; - usage.unique_identifier = unique_identifier; - usage.variant = variant; - usage.keys = locker.GetKeys(); - usage.bound_samplers = locker.GetBoundSamplers(); - usage.bindless_samplers = locker.GetBindlessSamplers(); - return usage; + return ShaderDiskCacheUsage{unique_identifier, variant, locker.GetKeys(), + locker.GetBoundSamplers(), locker.GetBindlessSamplers()}; } ShaderCacheOpenGL::ShaderCacheOpenGL(RasterizerOpenGL& rasterizer, Core::System& system, @@ -544,11 +482,12 @@ void ShaderCacheOpenGL::LoadDiskCache(const std::atomic_bool& stop_loading, } } if (!shader) { - auto locker{MakeLocker(system, unspecialized.program_type)}; + auto locker{MakeLocker(system, unspecialized.type)}; FillLocker(*locker, usage); - shader = BuildShader(device, usage.unique_identifier, unspecialized.program_type, - unspecialized.code, unspecialized.code_b, usage.variant, - *locker, true); + + shader = BuildShader(device, usage.unique_identifier, unspecialized.type, + unspecialized.code, unspecialized.code_b, *locker, + usage.variant, true); } std::scoped_lock lock{mutex}; @@ -651,7 +590,7 @@ bool ShaderCacheOpenGL::GenerateUnspecializedShaders( const auto& raw{raws[i]}; const u64 unique_identifier{raw.GetUniqueIdentifier()}; const u64 calculated_hash{ - GetUniqueIdentifier(raw.GetProgramType(), raw.GetProgramCode(), raw.GetProgramCodeB())}; + GetUniqueIdentifier(raw.GetType(), raw.HasProgramA(), raw.GetCode(), raw.GetCodeB())}; if (unique_identifier != calculated_hash) { LOG_ERROR(Render_OpenGL, "Invalid hash in entry={:016x} (obtained hash={:016x}) - " @@ -662,9 +601,9 @@ bool ShaderCacheOpenGL::GenerateUnspecializedShaders( } const u32 main_offset = - raw.GetProgramType() == ProgramType::Compute ? KERNEL_MAIN_OFFSET : STAGE_MAIN_OFFSET; - ConstBufferLocker locker(GetEnginesShaderType(raw.GetProgramType())); - const ShaderIR ir(raw.GetProgramCode(), main_offset, COMPILER_SETTINGS, locker); + raw.GetType() == ShaderType::Compute ? KERNEL_MAIN_OFFSET : STAGE_MAIN_OFFSET; + ConstBufferLocker locker(raw.GetType()); + const ShaderIR ir(raw.GetCode(), main_offset, COMPILER_SETTINGS, locker); // TODO(Rodrigo): Handle VertexA shaders // std::optional<ShaderIR> ir_b; // if (raw.HasProgramA()) { @@ -673,9 +612,9 @@ bool ShaderCacheOpenGL::GenerateUnspecializedShaders( UnspecializedShader unspecialized; unspecialized.entries = GLShader::GetEntries(ir); - unspecialized.program_type = raw.GetProgramType(); - unspecialized.code = raw.GetProgramCode(); - unspecialized.code_b = raw.GetProgramCodeB(); + unspecialized.type = raw.GetType(); + unspecialized.code = raw.GetCode(); + unspecialized.code_b = raw.GetCodeB(); unspecialized_shaders.emplace(raw.GetUniqueIdentifier(), unspecialized); if (callback) { @@ -708,7 +647,8 @@ Shader ShaderCacheOpenGL::GetStageProgram(Maxwell::ShaderProgram program) { code_b = GetShaderCode(memory_manager, address_b, memory_manager.GetPointer(address_b)); } - const auto unique_identifier = GetUniqueIdentifier(GetProgramType(program), code, code_b); + const auto unique_identifier = GetUniqueIdentifier( + GetShaderType(program), program == Maxwell::ShaderProgram::VertexA, code, code_b); const auto precompiled_variants = GetPrecompiledVariants(unique_identifier); const auto cpu_addr{*memory_manager.GpuToCpuAddress(address)}; const ShaderParameters params{system, disk_cache, precompiled_variants, device, @@ -736,7 +676,7 @@ Shader ShaderCacheOpenGL::GetComputeKernel(GPUVAddr code_addr) { // No kernel found - create a new one auto code{GetShaderCode(memory_manager, code_addr, host_ptr)}; - const auto unique_identifier{GetUniqueIdentifier(ProgramType::Compute, code, {})}; + const auto unique_identifier{GetUniqueIdentifier(ShaderType::Compute, false, code, {})}; const auto precompiled_variants = GetPrecompiledVariants(unique_identifier); const auto cpu_addr{*memory_manager.GpuToCpuAddress(code_addr)}; const ShaderParameters params{system, disk_cache, precompiled_variants, device, diff --git a/src/video_core/renderer_opengl/gl_shader_cache.h b/src/video_core/renderer_opengl/gl_shader_cache.h index 6bd7c9cf1..7b1470db3 100644 --- a/src/video_core/renderer_opengl/gl_shader_cache.h +++ b/src/video_core/renderer_opengl/gl_shader_cache.h @@ -17,6 +17,7 @@ #include <glad/glad.h> #include "common/common_types.h" +#include "video_core/engines/shader_type.h" #include "video_core/rasterizer_cache.h" #include "video_core/renderer_opengl/gl_resource_manager.h" #include "video_core/renderer_opengl/gl_shader_decompiler.h" @@ -47,7 +48,7 @@ using PrecompiledVariants = std::vector<PrecompiledPrograms::iterator>; struct UnspecializedShader { GLShader::ShaderEntries entries; - ProgramType program_type; + Tegra::Engines::ShaderType type; ProgramCode code; ProgramCode code_b; }; @@ -77,7 +78,7 @@ public: } std::size_t GetSizeInBytes() const override { - return program_code.size() * sizeof(u64); + return code.size() * sizeof(u64); } /// Gets the shader entries for the shader @@ -86,7 +87,7 @@ public: } /// Gets the GL program handle for the shader - std::tuple<GLuint, BaseBindings> GetProgramHandle(const ProgramVariant& variant); + GLuint GetHandle(const ProgramVariant& variant); private: struct LockerVariant { @@ -94,11 +95,11 @@ private: std::unordered_map<ProgramVariant, CachedProgram> programs; }; - explicit CachedShader(const ShaderParameters& params, ProgramType program_type, + explicit CachedShader(const ShaderParameters& params, Tegra::Engines::ShaderType shader_type, GLShader::ShaderEntries entries, ProgramCode program_code, ProgramCode program_code_b); - void UpdateVariant(); + bool EnsureValidLockerVariant(); ShaderDiskCacheUsage GetUsage(const ProgramVariant& variant, const VideoCommon::Shader::ConstBufferLocker& locker) const; @@ -110,14 +111,14 @@ private: VAddr cpu_addr{}; u64 unique_identifier{}; - ProgramType program_type{}; + Tegra::Engines::ShaderType shader_type{}; GLShader::ShaderEntries entries; - ProgramCode program_code; - ProgramCode program_code_b; + ProgramCode code; + ProgramCode code_b; - LockerVariant* curr_variant = nullptr; + LockerVariant* curr_locker_variant = nullptr; std::vector<std::unique_ptr<LockerVariant>> locker_variants; }; diff --git a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp index f7b77711a..3d3cd21f3 100644 --- a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp +++ b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp @@ -16,6 +16,7 @@ #include "common/common_types.h" #include "common/logging/log.h" #include "video_core/engines/maxwell_3d.h" +#include "video_core/engines/shader_type.h" #include "video_core/renderer_opengl/gl_device.h" #include "video_core/renderer_opengl/gl_rasterizer.h" #include "video_core/renderer_opengl/gl_shader_decompiler.h" @@ -27,6 +28,7 @@ namespace OpenGL::GLShader { namespace { +using Tegra::Engines::ShaderType; using Tegra::Shader::Attribute; using Tegra::Shader::AttributeUse; using Tegra::Shader::Header; @@ -41,6 +43,9 @@ using namespace VideoCommon::Shader; using Maxwell = Tegra::Engines::Maxwell3D::Regs; using Operation = const OperationNode&; +class ASTDecompiler; +class ExprDecompiler; + enum class Type { Void, Bool, Bool2, Float, Int, Uint, HalfFloat }; struct TextureAoffi {}; @@ -224,7 +229,7 @@ private: Type type{}; }; -constexpr const char* GetTypeString(Type type) { +const char* GetTypeString(Type type) { switch (type) { case Type::Bool: return "bool"; @@ -244,7 +249,7 @@ constexpr const char* GetTypeString(Type type) { } } -constexpr const char* GetImageTypeDeclaration(Tegra::Shader::ImageType image_type) { +const char* GetImageTypeDeclaration(Tegra::Shader::ImageType image_type) { switch (image_type) { case Tegra::Shader::ImageType::Texture1D: return "1D"; @@ -332,16 +337,13 @@ std::string FlowStackTopName(MetaStackClass stack) { return fmt::format("{}_flow_stack_top", GetFlowStackPrefix(stack)); } -constexpr bool IsVertexShader(ProgramType stage) { - return stage == ProgramType::VertexA || stage == ProgramType::VertexB; +[[deprecated]] constexpr bool IsVertexShader(ShaderType stage) { + return stage == ShaderType::Vertex; } -class ASTDecompiler; -class ExprDecompiler; - class GLSLDecompiler final { public: - explicit GLSLDecompiler(const Device& device, const ShaderIR& ir, ProgramType stage, + explicit GLSLDecompiler(const Device& device, const ShaderIR& ir, ShaderType stage, std::string suffix) : device{device}, ir{ir}, stage{stage}, suffix{suffix}, header{ir.GetHeader()} {} @@ -428,7 +430,7 @@ private: } void DeclareGeometry() { - if (stage != ProgramType::Geometry) { + if (stage != ShaderType::Geometry) { return; } @@ -511,10 +513,14 @@ private: } void DeclareLocalMemory() { - // TODO(Rodrigo): Unstub kernel local memory size and pass it from a register at - // specialization time. - const u64 local_memory_size = - stage == ProgramType::Compute ? 0x400 : header.GetLocalMemorySize(); + if (stage == ShaderType::Compute) { + code.AddLine("#ifdef LOCAL_MEMORY_SIZE"); + code.AddLine("uint {}[LOCAL_MEMORY_SIZE];", GetLocalMemory()); + code.AddLine("#endif"); + return; + } + + const u64 local_memory_size = header.GetLocalMemorySize(); if (local_memory_size == 0) { return; } @@ -523,13 +529,6 @@ private: code.AddNewLine(); } - void DeclareSharedMemory() { - if (stage != ProgramType::Compute) { - return; - } - code.AddLine("shared uint {}[];", GetSharedMemory()); - } - void DeclareInternalFlags() { for (u32 flag = 0; flag < static_cast<u32>(InternalFlag::Amount); flag++) { const auto flag_code = static_cast<InternalFlag>(flag); @@ -579,12 +578,12 @@ private: const u32 location{GetGenericAttributeIndex(index)}; std::string name{GetInputAttribute(index)}; - if (stage == ProgramType::Geometry) { + if (stage == ShaderType::Geometry) { name = "gs_" + name + "[]"; } std::string suffix; - if (stage == ProgramType::Fragment) { + if (stage == ShaderType::Fragment) { const auto input_mode{header.ps.GetAttributeUse(location)}; if (skip_unused && input_mode == AttributeUse::Unused) { return; @@ -596,7 +595,7 @@ private: } void DeclareOutputAttributes() { - if (ir.HasPhysicalAttributes() && stage != ProgramType::Fragment) { + if (ir.HasPhysicalAttributes() && stage != ShaderType::Fragment) { for (u32 i = 0; i < GetNumPhysicalVaryings(); ++i) { DeclareOutputAttribute(ToGenericAttribute(i)); } @@ -621,9 +620,9 @@ private: } void DeclareConstantBuffers() { - for (const auto& entry : ir.GetConstantBuffers()) { - const auto [index, size] = entry; - code.AddLine("layout (std140, binding = CBUF_BINDING_{}) uniform {} {{", index, + u32 binding = device.GetBaseBindings(stage).uniform_buffer; + for (const auto& [index, cbuf] : ir.GetConstantBuffers()) { + code.AddLine("layout (std140, binding = {}) uniform {} {{", binding++, GetConstBufferBlock(index)); code.AddLine(" uvec4 {}[{}];", GetConstBuffer(index), MAX_CONSTBUFFER_ELEMENTS); code.AddLine("}};"); @@ -632,9 +631,8 @@ private: } void DeclareGlobalMemory() { - for (const auto& gmem : ir.GetGlobalMemory()) { - const auto& [base, usage] = gmem; - + u32 binding = device.GetBaseBindings(stage).shader_storage_buffer; + for (const auto& [base, usage] : ir.GetGlobalMemory()) { // Since we don't know how the shader will use the shader, hint the driver to disable as // much optimizations as possible std::string qualifier = "coherent volatile"; @@ -644,8 +642,8 @@ private: qualifier += " writeonly"; } - code.AddLine("layout (std430, binding = GMEM_BINDING_{}_{}) {} buffer {} {{", - base.cbuf_index, base.cbuf_offset, qualifier, GetGlobalMemoryBlock(base)); + code.AddLine("layout (std430, binding = {}) {} buffer {} {{", binding++, qualifier, + GetGlobalMemoryBlock(base)); code.AddLine(" uint {}[];", GetGlobalMemory(base)); code.AddLine("}};"); code.AddNewLine(); @@ -653,15 +651,17 @@ private: } void DeclareSamplers() { - const auto& samplers = ir.GetSamplers(); - for (const auto& sampler : samplers) { - const std::string name{GetSampler(sampler)}; - const std::string description{"layout (binding = SAMPLER_BINDING_" + - std::to_string(sampler.GetIndex()) + ") uniform"}; + u32 binding = device.GetBaseBindings(stage).sampler; + for (const auto& sampler : ir.GetSamplers()) { + const std::string name = GetSampler(sampler); + const std::string description = fmt::format("layout (binding = {}) uniform", binding++); + std::string sampler_type = [&]() { + if (sampler.IsBuffer()) { + return "samplerBuffer"; + } switch (sampler.GetType()) { case Tegra::Shader::TextureType::Texture1D: - // Special cased, read below. return "sampler1D"; case Tegra::Shader::TextureType::Texture2D: return "sampler2D"; @@ -681,21 +681,9 @@ private: sampler_type += "Shadow"; } - if (sampler.GetType() == Tegra::Shader::TextureType::Texture1D) { - // 1D textures can be aliased to texture buffers, hide the declarations behind a - // preprocessor flag and use one or the other from the GPU state. This has to be - // done because shaders don't have enough information to determine the texture type. - EmitIfdefIsBuffer(sampler); - code.AddLine("{} samplerBuffer {};", description, name); - code.AddLine("#else"); - code.AddLine("{} {} {};", description, sampler_type, name); - code.AddLine("#endif"); - } else { - // The other texture types (2D, 3D and cubes) don't have this issue. - code.AddLine("{} {} {};", description, sampler_type, name); - } + code.AddLine("{} {} {};", description, sampler_type, name); } - if (!samplers.empty()) { + if (!ir.GetSamplers().empty()) { code.AddNewLine(); } } @@ -718,7 +706,7 @@ private: constexpr u32 element_stride = 4; const u32 address{generic_base + index * generic_stride + element * element_stride}; - const bool declared = stage != ProgramType::Fragment || + const bool declared = stage != ShaderType::Fragment || header.ps.GetAttributeUse(index) != AttributeUse::Unused; const std::string value = declared ? ReadAttribute(attribute, element).AsFloat() : "0.0f"; @@ -735,8 +723,8 @@ private: } void DeclareImages() { - const auto& images{ir.GetImages()}; - for (const auto& image : images) { + u32 binding = device.GetBaseBindings(stage).image; + for (const auto& image : ir.GetImages()) { std::string qualifier = "coherent volatile"; if (image.IsRead() && !image.IsWritten()) { qualifier += " readonly"; @@ -746,10 +734,10 @@ private: const char* format = image.IsAtomic() ? "r32ui, " : ""; const char* type_declaration = GetImageTypeDeclaration(image.GetType()); - code.AddLine("layout ({}binding = IMAGE_BINDING_{}) {} uniform uimage{} {};", format, - image.GetIndex(), qualifier, type_declaration, GetImage(image)); + code.AddLine("layout ({}binding = {}) {} uniform uimage{} {};", format, binding++, + qualifier, type_declaration, GetImage(image)); } - if (!images.empty()) { + if (!ir.GetImages().empty()) { code.AddNewLine(); } } @@ -810,7 +798,7 @@ private: } if (const auto abuf = std::get_if<AbufNode>(&*node)) { - UNIMPLEMENTED_IF_MSG(abuf->IsPhysicalBuffer() && stage == ProgramType::Geometry, + UNIMPLEMENTED_IF_MSG(abuf->IsPhysicalBuffer() && stage == ShaderType::Geometry, "Physical attributes in geometry shaders are not implemented"); if (abuf->IsPhysicalBuffer()) { return {fmt::format("ReadPhysicalAttribute({})", @@ -869,18 +857,13 @@ private: } if (const auto lmem = std::get_if<LmemNode>(&*node)) { - if (stage == ProgramType::Compute) { - LOG_WARNING(Render_OpenGL, "Local memory is stubbed on compute shaders"); - } return { fmt::format("{}[{} >> 2]", GetLocalMemory(), Visit(lmem->GetAddress()).AsUint()), Type::Uint}; } if (const auto smem = std::get_if<SmemNode>(&*node)) { - return { - fmt::format("{}[{} >> 2]", GetSharedMemory(), Visit(smem->GetAddress()).AsUint()), - Type::Uint}; + return {fmt::format("smem[{} >> 2]", Visit(smem->GetAddress()).AsUint()), Type::Uint}; } if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) { @@ -910,7 +893,7 @@ private: Expression ReadAttribute(Attribute::Index attribute, u32 element, const Node& buffer = {}) { const auto GeometryPass = [&](std::string_view name) { - if (stage == ProgramType::Geometry && buffer) { + if (stage == ShaderType::Geometry && buffer) { // TODO(Rodrigo): Guard geometry inputs against out of bound reads. Some games // set an 0x80000000 index for those and the shader fails to build. Find out why // this happens and what's its intent. @@ -922,11 +905,11 @@ private: switch (attribute) { case Attribute::Index::Position: switch (stage) { - case ProgramType::Geometry: + case ShaderType::Geometry: return {fmt::format("gl_in[{}].gl_Position{}", Visit(buffer).AsUint(), GetSwizzle(element)), Type::Float}; - case ProgramType::Fragment: + case ShaderType::Fragment: return {element == 3 ? "1.0f" : ("gl_FragCoord"s + GetSwizzle(element)), Type::Float}; default: @@ -960,7 +943,7 @@ private: return {"0", Type::Int}; case Attribute::Index::FrontFacing: // TODO(Subv): Find out what the values are for the other elements. - ASSERT(stage == ProgramType::Fragment); + ASSERT(stage == ShaderType::Fragment); switch (element) { case 3: return {"(gl_FrontFacing ? -1 : 0)", Type::Int}; @@ -986,7 +969,7 @@ private: // be found in fragment shaders, so we disable precise there. There are vertex shaders that // also fail to build but nobody seems to care about those. // Note: Only bugged drivers will skip precise. - const bool disable_precise = device.HasPreciseBug() && stage == ProgramType::Fragment; + const bool disable_precise = device.HasPreciseBug() && stage == ShaderType::Fragment; std::string temporary = code.GenerateTemporary(); code.AddLine("{}{} {} = {};", disable_precise ? "" : "precise ", GetTypeString(type), @@ -1280,17 +1263,12 @@ private: } target = std::move(*output); } else if (const auto lmem = std::get_if<LmemNode>(&*dest)) { - if (stage == ProgramType::Compute) { - LOG_WARNING(Render_OpenGL, "Local memory is stubbed on compute shaders"); - } target = { fmt::format("{}[{} >> 2]", GetLocalMemory(), Visit(lmem->GetAddress()).AsUint()), Type::Uint}; } else if (const auto smem = std::get_if<SmemNode>(&*dest)) { - ASSERT(stage == ProgramType::Compute); - target = { - fmt::format("{}[{} >> 2]", GetSharedMemory(), Visit(smem->GetAddress()).AsUint()), - Type::Uint}; + ASSERT(stage == ShaderType::Compute); + target = {fmt::format("smem[{} >> 2]", Visit(smem->GetAddress()).AsUint()), Type::Uint}; } else if (const auto gmem = std::get_if<GmemNode>(&*dest)) { const std::string real = Visit(gmem->GetRealAddress()).AsUint(); const std::string base = Visit(gmem->GetBaseAddress()).AsUint(); @@ -1529,7 +1507,8 @@ private: } Expression HCastFloat(Operation operation) { - return {fmt::format("vec2({})", VisitOperand(operation, 0).AsFloat()), Type::HalfFloat}; + return {fmt::format("vec2({}, 0.0f)", VisitOperand(operation, 0).AsFloat()), + Type::HalfFloat}; } Expression HUnpack(Operation operation) { @@ -1787,27 +1766,14 @@ private: expr += ", "; } - // Store a copy of the expression without the lod to be used with texture buffers - std::string expr_buffer = expr; - - if (meta->lod) { + if (meta->lod && !meta->sampler.IsBuffer()) { expr += ", "; expr += Visit(meta->lod).AsInt(); } expr += ')'; expr += GetSwizzle(meta->element); - expr_buffer += ')'; - expr_buffer += GetSwizzle(meta->element); - - const std::string tmp{code.GenerateTemporary()}; - EmitIfdefIsBuffer(meta->sampler); - code.AddLine("float {} = {};", tmp, expr_buffer); - code.AddLine("#else"); - code.AddLine("float {} = {};", tmp, expr); - code.AddLine("#endif"); - - return {tmp, Type::Float}; + return {std::move(expr), Type::Float}; } Expression TextureGradient(Operation operation) { @@ -1883,7 +1849,7 @@ private: } void PreExit() { - if (stage != ProgramType::Fragment) { + if (stage != ShaderType::Fragment) { return; } const auto& used_registers = ir.GetRegisters(); @@ -1936,27 +1902,21 @@ private: } Expression EmitVertex(Operation operation) { - ASSERT_MSG(stage == ProgramType::Geometry, + ASSERT_MSG(stage == ShaderType::Geometry, "EmitVertex is expected to be used in a geometry shader."); - - // If a geometry shader is attached, it will always flip (it's the last stage before - // fragment). For more info about flipping, refer to gl_shader_gen.cpp. - code.AddLine("gl_Position.xy *= viewport_flip.xy;"); code.AddLine("EmitVertex();"); return {}; } Expression EndPrimitive(Operation operation) { - ASSERT_MSG(stage == ProgramType::Geometry, + ASSERT_MSG(stage == ShaderType::Geometry, "EndPrimitive is expected to be used in a geometry shader."); - code.AddLine("EndPrimitive();"); return {}; } Expression YNegate(Operation operation) { - // Config pack's third value is Y_NEGATE's state. - return {"config_pack[2]", Type::Uint}; + return {"y_direction", Type::Float}; } template <u32 element> @@ -2248,10 +2208,6 @@ private: return "lmem_" + suffix; } - std::string GetSharedMemory() const { - return fmt::format("smem_{}", suffix); - } - std::string GetInternalFlag(InternalFlag flag) const { constexpr std::array InternalFlagNames = {"zero_flag", "sign_flag", "carry_flag", "overflow_flag"}; @@ -2269,10 +2225,6 @@ private: return GetDeclarationWithSuffix(static_cast<u32>(image.GetIndex()), "image"); } - void EmitIfdefIsBuffer(const Sampler& sampler) { - code.AddLine("#ifdef SAMPLER_{}_IS_BUFFER", sampler.GetIndex()); - } - std::string GetDeclarationWithSuffix(u32 index, std::string_view name) const { return fmt::format("{}_{}_{}", name, index, suffix); } @@ -2291,7 +2243,7 @@ private: const Device& device; const ShaderIR& ir; - const ProgramType stage; + const ShaderType stage; const std::string suffix; const Header header; @@ -2546,7 +2498,7 @@ const float fswzadd_modifiers_b[] = float[4](-1.0f, -1.0f, 1.0f, -1.0f ); )"; } -std::string Decompile(const Device& device, const ShaderIR& ir, ProgramType stage, +std::string Decompile(const Device& device, const ShaderIR& ir, ShaderType stage, const std::string& suffix) { GLSLDecompiler decompiler(device, ir, stage, suffix); decompiler.Decompile(); diff --git a/src/video_core/renderer_opengl/gl_shader_decompiler.h b/src/video_core/renderer_opengl/gl_shader_decompiler.h index b1e75e6cc..7876f48d6 100644 --- a/src/video_core/renderer_opengl/gl_shader_decompiler.h +++ b/src/video_core/renderer_opengl/gl_shader_decompiler.h @@ -10,6 +10,7 @@ #include <vector> #include "common/common_types.h" #include "video_core/engines/maxwell_3d.h" +#include "video_core/engines/shader_type.h" #include "video_core/shader/shader_ir.h" namespace VideoCommon::Shader { @@ -17,20 +18,8 @@ class ShaderIR; } namespace OpenGL { - class Device; - -enum class ProgramType : u32 { - VertexA = 0, - VertexB = 1, - TessellationControl = 2, - TessellationEval = 3, - Geometry = 4, - Fragment = 5, - Compute = 6 -}; - -} // namespace OpenGL +} namespace OpenGL::GLShader { @@ -94,6 +83,6 @@ ShaderEntries GetEntries(const VideoCommon::Shader::ShaderIR& ir); std::string GetCommonDeclarations(); std::string Decompile(const Device& device, const VideoCommon::Shader::ShaderIR& ir, - ProgramType stage, const std::string& suffix); + Tegra::Engines::ShaderType stage, const std::string& suffix); } // namespace OpenGL::GLShader diff --git a/src/video_core/renderer_opengl/gl_shader_disk_cache.cpp b/src/video_core/renderer_opengl/gl_shader_disk_cache.cpp index 184a565e6..cf874a09a 100644 --- a/src/video_core/renderer_opengl/gl_shader_disk_cache.cpp +++ b/src/video_core/renderer_opengl/gl_shader_disk_cache.cpp @@ -3,6 +3,7 @@ // Refer to the license.txt file included. #include <cstring> + #include <fmt/format.h> #include "common/assert.h" @@ -12,50 +13,50 @@ #include "common/logging/log.h" #include "common/scm_rev.h" #include "common/zstd_compression.h" - #include "core/core.h" #include "core/hle/kernel/process.h" #include "core/settings.h" - +#include "video_core/engines/shader_type.h" #include "video_core/renderer_opengl/gl_shader_cache.h" #include "video_core/renderer_opengl/gl_shader_disk_cache.h" namespace OpenGL { +using Tegra::Engines::ShaderType; using VideoCommon::Shader::BindlessSamplerMap; using VideoCommon::Shader::BoundSamplerMap; using VideoCommon::Shader::KeyMap; namespace { +using ShaderCacheVersionHash = std::array<u8, 64>; + +enum class TransferableEntryKind : u32 { + Raw, + Usage, +}; + struct ConstBufferKey { - u32 cbuf; - u32 offset; - u32 value; + u32 cbuf{}; + u32 offset{}; + u32 value{}; }; struct BoundSamplerKey { - u32 offset; - Tegra::Engines::SamplerDescriptor sampler; + u32 offset{}; + Tegra::Engines::SamplerDescriptor sampler{}; }; struct BindlessSamplerKey { - u32 cbuf; - u32 offset; - Tegra::Engines::SamplerDescriptor sampler; -}; - -using ShaderCacheVersionHash = std::array<u8, 64>; - -enum class TransferableEntryKind : u32 { - Raw, - Usage, + u32 cbuf{}; + u32 offset{}; + Tegra::Engines::SamplerDescriptor sampler{}; }; -constexpr u32 NativeVersion = 5; +constexpr u32 NativeVersion = 11; // Making sure sizes doesn't change by accident -static_assert(sizeof(BaseBindings) == 16); +static_assert(sizeof(ProgramVariant) == 20); ShaderCacheVersionHash GetShaderCacheVersionHash() { ShaderCacheVersionHash hash{}; @@ -66,10 +67,10 @@ ShaderCacheVersionHash GetShaderCacheVersionHash() { } // Anonymous namespace -ShaderDiskCacheRaw::ShaderDiskCacheRaw(u64 unique_identifier, ProgramType program_type, - ProgramCode program_code, ProgramCode program_code_b) - : unique_identifier{unique_identifier}, program_type{program_type}, - program_code{std::move(program_code)}, program_code_b{std::move(program_code_b)} {} +ShaderDiskCacheRaw::ShaderDiskCacheRaw(u64 unique_identifier, ShaderType type, ProgramCode code, + ProgramCode code_b) + : unique_identifier{unique_identifier}, type{type}, code{std::move(code)}, code_b{std::move( + code_b)} {} ShaderDiskCacheRaw::ShaderDiskCacheRaw() = default; @@ -77,42 +78,39 @@ ShaderDiskCacheRaw::~ShaderDiskCacheRaw() = default; bool ShaderDiskCacheRaw::Load(FileUtil::IOFile& file) { if (file.ReadBytes(&unique_identifier, sizeof(u64)) != sizeof(u64) || - file.ReadBytes(&program_type, sizeof(u32)) != sizeof(u32)) { + file.ReadBytes(&type, sizeof(u32)) != sizeof(u32)) { return false; } - u32 program_code_size{}; - u32 program_code_size_b{}; - if (file.ReadBytes(&program_code_size, sizeof(u32)) != sizeof(u32) || - file.ReadBytes(&program_code_size_b, sizeof(u32)) != sizeof(u32)) { + u32 code_size{}; + u32 code_size_b{}; + if (file.ReadBytes(&code_size, sizeof(u32)) != sizeof(u32) || + file.ReadBytes(&code_size_b, sizeof(u32)) != sizeof(u32)) { return false; } - program_code.resize(program_code_size); - program_code_b.resize(program_code_size_b); + code.resize(code_size); + code_b.resize(code_size_b); - if (file.ReadArray(program_code.data(), program_code_size) != program_code_size) + if (file.ReadArray(code.data(), code_size) != code_size) return false; - if (HasProgramA() && - file.ReadArray(program_code_b.data(), program_code_size_b) != program_code_size_b) { + if (HasProgramA() && file.ReadArray(code_b.data(), code_size_b) != code_size_b) { return false; } return true; } bool ShaderDiskCacheRaw::Save(FileUtil::IOFile& file) const { - if (file.WriteObject(unique_identifier) != 1 || - file.WriteObject(static_cast<u32>(program_type)) != 1 || - file.WriteObject(static_cast<u32>(program_code.size())) != 1 || - file.WriteObject(static_cast<u32>(program_code_b.size())) != 1) { + if (file.WriteObject(unique_identifier) != 1 || file.WriteObject(static_cast<u32>(type)) != 1 || + file.WriteObject(static_cast<u32>(code.size())) != 1 || + file.WriteObject(static_cast<u32>(code_b.size())) != 1) { return false; } - if (file.WriteArray(program_code.data(), program_code.size()) != program_code.size()) + if (file.WriteArray(code.data(), code.size()) != code.size()) return false; - if (HasProgramA() && - file.WriteArray(program_code_b.data(), program_code_b.size()) != program_code_b.size()) { + if (HasProgramA() && file.WriteArray(code_b.data(), code_b.size()) != code_b.size()) { return false; } return true; diff --git a/src/video_core/renderer_opengl/gl_shader_disk_cache.h b/src/video_core/renderer_opengl/gl_shader_disk_cache.h index db23ada93..69a2fbdda 100644 --- a/src/video_core/renderer_opengl/gl_shader_disk_cache.h +++ b/src/video_core/renderer_opengl/gl_shader_disk_cache.h @@ -4,7 +4,6 @@ #pragma once -#include <bitset> #include <optional> #include <string> #include <tuple> @@ -19,6 +18,7 @@ #include "common/assert.h" #include "common/common_types.h" #include "core/file_sys/vfs_vector.h" +#include "video_core/engines/shader_type.h" #include "video_core/renderer_opengl/gl_shader_gen.h" #include "video_core/shader/const_buffer_locker.h" @@ -37,42 +37,42 @@ struct ShaderDiskCacheDump; using ProgramCode = std::vector<u64>; using ShaderDumpsMap = std::unordered_map<ShaderDiskCacheUsage, ShaderDiskCacheDump>; -using TextureBufferUsage = std::bitset<64>; - -/// Allocated bindings used by an OpenGL shader program -struct BaseBindings { - u32 cbuf{}; - u32 gmem{}; - u32 sampler{}; - u32 image{}; - - bool operator==(const BaseBindings& rhs) const { - return std::tie(cbuf, gmem, sampler, image) == - std::tie(rhs.cbuf, rhs.gmem, rhs.sampler, rhs.image); - } - bool operator!=(const BaseBindings& rhs) const { - return !operator==(rhs); - } -}; -static_assert(std::is_trivially_copyable_v<BaseBindings>); +/// Describes the different variants a program can be compiled with. +struct ProgramVariant final { + ProgramVariant() = default; + + /// Graphics constructor. + explicit constexpr ProgramVariant(GLenum primitive_mode) noexcept + : primitive_mode{primitive_mode} {} + + /// Compute constructor. + explicit constexpr ProgramVariant(u32 block_x, u32 block_y, u32 block_z, u32 shared_memory_size, + u32 local_memory_size) noexcept + : block_x{block_x}, block_y{static_cast<u16>(block_y)}, block_z{static_cast<u16>(block_z)}, + shared_memory_size{shared_memory_size}, local_memory_size{local_memory_size} {} -/// Describes the different variants a single program can be compiled. -struct ProgramVariant { - BaseBindings base_bindings; + // Graphics specific parameters. GLenum primitive_mode{}; - TextureBufferUsage texture_buffer_usage{}; - bool operator==(const ProgramVariant& rhs) const { - return std::tie(base_bindings, primitive_mode, texture_buffer_usage) == - std::tie(rhs.base_bindings, rhs.primitive_mode, rhs.texture_buffer_usage); + // Compute specific parameters. + u32 block_x{}; + u16 block_y{}; + u16 block_z{}; + u32 shared_memory_size{}; + u32 local_memory_size{}; + + bool operator==(const ProgramVariant& rhs) const noexcept { + return std::tie(primitive_mode, block_x, block_y, block_z, shared_memory_size, + local_memory_size) == std::tie(rhs.primitive_mode, rhs.block_x, rhs.block_y, + rhs.block_z, rhs.shared_memory_size, + rhs.local_memory_size); } - bool operator!=(const ProgramVariant& rhs) const { + bool operator!=(const ProgramVariant& rhs) const noexcept { return !operator==(rhs); } }; - static_assert(std::is_trivially_copyable_v<ProgramVariant>); /// Describes how a shader is used. @@ -99,21 +99,14 @@ struct ShaderDiskCacheUsage { namespace std { template <> -struct hash<OpenGL::BaseBindings> { - std::size_t operator()(const OpenGL::BaseBindings& bindings) const noexcept { - return static_cast<std::size_t>(bindings.cbuf) ^ - (static_cast<std::size_t>(bindings.gmem) << 8) ^ - (static_cast<std::size_t>(bindings.sampler) << 16) ^ - (static_cast<std::size_t>(bindings.image) << 24); - } -}; - -template <> struct hash<OpenGL::ProgramVariant> { std::size_t operator()(const OpenGL::ProgramVariant& variant) const noexcept { - return std::hash<OpenGL::BaseBindings>()(variant.base_bindings) ^ - std::hash<OpenGL::TextureBufferUsage>()(variant.texture_buffer_usage) ^ - (static_cast<std::size_t>(variant.primitive_mode) << 6); + return (static_cast<std::size_t>(variant.primitive_mode) << 6) ^ + static_cast<std::size_t>(variant.block_x) ^ + (static_cast<std::size_t>(variant.block_y) << 32) ^ + (static_cast<std::size_t>(variant.block_z) << 48) ^ + (static_cast<std::size_t>(variant.shared_memory_size) << 16) ^ + (static_cast<std::size_t>(variant.local_memory_size) << 36); } }; @@ -121,7 +114,7 @@ template <> struct hash<OpenGL::ShaderDiskCacheUsage> { std::size_t operator()(const OpenGL::ShaderDiskCacheUsage& usage) const noexcept { return static_cast<std::size_t>(usage.unique_identifier) ^ - std::hash<OpenGL::ProgramVariant>()(usage.variant); + std::hash<OpenGL::ProgramVariant>{}(usage.variant); } }; @@ -132,8 +125,8 @@ namespace OpenGL { /// Describes a shader how it's used by the guest GPU class ShaderDiskCacheRaw { public: - explicit ShaderDiskCacheRaw(u64 unique_identifier, ProgramType program_type, - ProgramCode program_code, ProgramCode program_code_b = {}); + explicit ShaderDiskCacheRaw(u64 unique_identifier, Tegra::Engines::ShaderType type, + ProgramCode code, ProgramCode code_b = {}); ShaderDiskCacheRaw(); ~ShaderDiskCacheRaw(); @@ -146,27 +139,26 @@ public: } bool HasProgramA() const { - return program_type == ProgramType::VertexA; + return !code.empty() && !code_b.empty(); } - ProgramType GetProgramType() const { - return program_type; + Tegra::Engines::ShaderType GetType() const { + return type; } - const ProgramCode& GetProgramCode() const { - return program_code; + const ProgramCode& GetCode() const { + return code; } - const ProgramCode& GetProgramCodeB() const { - return program_code_b; + const ProgramCode& GetCodeB() const { + return code_b; } private: u64 unique_identifier{}; - ProgramType program_type{}; - - ProgramCode program_code; - ProgramCode program_code_b; + Tegra::Engines::ShaderType type{}; + ProgramCode code; + ProgramCode code_b; }; /// Contains an OpenGL dumped binary program diff --git a/src/video_core/renderer_opengl/gl_shader_gen.cpp b/src/video_core/renderer_opengl/gl_shader_gen.cpp index 0e22eede9..34946fb47 100644 --- a/src/video_core/renderer_opengl/gl_shader_gen.cpp +++ b/src/video_core/renderer_opengl/gl_shader_gen.cpp @@ -2,8 +2,13 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. +#include <string> + #include <fmt/format.h> + #include "video_core/engines/maxwell_3d.h" +#include "video_core/engines/shader_type.h" +#include "video_core/renderer_opengl/gl_device.h" #include "video_core/renderer_opengl/gl_shader_decompiler.h" #include "video_core/renderer_opengl/gl_shader_gen.h" #include "video_core/shader/shader_ir.h" @@ -11,6 +16,7 @@ namespace OpenGL::GLShader { using Tegra::Engines::Maxwell3D; +using Tegra::Engines::ShaderType; using VideoCommon::Shader::CompileDepth; using VideoCommon::Shader::CompilerSettings; using VideoCommon::Shader::ProgramCode; @@ -18,53 +24,40 @@ using VideoCommon::Shader::ShaderIR; std::string GenerateVertexShader(const Device& device, const ShaderIR& ir, const ShaderIR* ir_b) { std::string out = GetCommonDeclarations(); - out += R"( -layout (std140, binding = EMULATION_UBO_BINDING) uniform vs_config { - vec4 viewport_flip; - uvec4 config_pack; // instance_id, flip_stage, y_direction, padding -}; - -)"; - const auto stage = ir_b ? ProgramType::VertexA : ProgramType::VertexB; - out += Decompile(device, ir, stage, "vertex"); + out += fmt::format(R"( +layout (std140, binding = {}) uniform vs_config {{ + float y_direction; +}}; + +)", + EmulationUniformBlockBinding); + out += Decompile(device, ir, ShaderType::Vertex, "vertex"); if (ir_b) { - out += Decompile(device, *ir_b, ProgramType::VertexB, "vertex_b"); + out += Decompile(device, *ir_b, ShaderType::Vertex, "vertex_b"); } out += R"( void main() { + gl_Position = vec4(0.0f, 0.0f, 0.0f, 1.0f); execute_vertex(); )"; - if (ir_b) { out += " execute_vertex_b();"; } - - out += R"( - - // Set Position Y direction - gl_Position.y *= utof(config_pack[2]); - // Check if the flip stage is VertexB - // Config pack's second value is flip_stage - if (config_pack[1] == 1) { - // Viewport can be flipped, which is unsupported by glViewport - gl_Position.xy *= viewport_flip.xy; - } -} -)"; + out += "}\n"; return out; } std::string GenerateGeometryShader(const Device& device, const ShaderIR& ir) { std::string out = GetCommonDeclarations(); - out += R"( -layout (std140, binding = EMULATION_UBO_BINDING) uniform gs_config { - vec4 viewport_flip; - uvec4 config_pack; // instance_id, flip_stage, y_direction, padding -}; + out += fmt::format(R"( +layout (std140, binding = {}) uniform gs_config {{ + float y_direction; +}}; -)"; - out += Decompile(device, ir, ProgramType::Geometry, "geometry"); +)", + EmulationUniformBlockBinding); + out += Decompile(device, ir, ShaderType::Geometry, "geometry"); out += R"( void main() { @@ -76,7 +69,7 @@ void main() { std::string GenerateFragmentShader(const Device& device, const ShaderIR& ir) { std::string out = GetCommonDeclarations(); - out += R"( + out += fmt::format(R"( layout (location = 0) out vec4 FragColor0; layout (location = 1) out vec4 FragColor1; layout (location = 2) out vec4 FragColor2; @@ -86,13 +79,13 @@ layout (location = 5) out vec4 FragColor5; layout (location = 6) out vec4 FragColor6; layout (location = 7) out vec4 FragColor7; -layout (std140, binding = EMULATION_UBO_BINDING) uniform fs_config { - vec4 viewport_flip; - uvec4 config_pack; // instance_id, flip_stage, y_direction, padding -}; +layout (std140, binding = {}) uniform fs_config {{ + float y_direction; +}}; -)"; - out += Decompile(device, ir, ProgramType::Fragment, "fragment"); +)", + EmulationUniformBlockBinding); + out += Decompile(device, ir, ShaderType::Fragment, "fragment"); out += R"( void main() { @@ -104,7 +97,7 @@ void main() { std::string GenerateComputeShader(const Device& device, const ShaderIR& ir) { std::string out = GetCommonDeclarations(); - out += Decompile(device, ir, ProgramType::Compute, "compute"); + out += Decompile(device, ir, ShaderType::Compute, "compute"); out += R"( void main() { execute_compute(); diff --git a/src/video_core/renderer_opengl/gl_shader_manager.cpp b/src/video_core/renderer_opengl/gl_shader_manager.cpp index b05f90f20..75d3fac04 100644 --- a/src/video_core/renderer_opengl/gl_shader_manager.cpp +++ b/src/video_core/renderer_opengl/gl_shader_manager.cpp @@ -40,27 +40,11 @@ void ProgramManager::UpdatePipeline() { old_state = current_state; } -void MaxwellUniformData::SetFromRegs(const Maxwell3D& maxwell, std::size_t shader_stage) { +void MaxwellUniformData::SetFromRegs(const Maxwell3D& maxwell) { const auto& regs = maxwell.regs; - const auto& state = maxwell.state; - - // TODO(bunnei): Support more than one viewport - viewport_flip[0] = regs.viewport_transform[0].scale_x < 0.0 ? -1.0f : 1.0f; - viewport_flip[1] = regs.viewport_transform[0].scale_y < 0.0 ? -1.0f : 1.0f; - - instance_id = state.current_instance; - - // Assign in which stage the position has to be flipped - // (the last stage before the fragment shader). - constexpr u32 geometry_index = static_cast<u32>(Maxwell3D::Regs::ShaderProgram::Geometry); - if (maxwell.regs.shader_config[geometry_index].enable) { - flip_stage = geometry_index; - } else { - flip_stage = static_cast<u32>(Maxwell3D::Regs::ShaderProgram::VertexB); - } // Y_NEGATE controls what value S2R returns for the Y_DIRECTION system value. - y_direction = regs.screen_y_control.y_negate == 0 ? 1.f : -1.f; + y_direction = regs.screen_y_control.y_negate == 0 ? 1.0f : -1.0f; } } // namespace OpenGL::GLShader diff --git a/src/video_core/renderer_opengl/gl_shader_manager.h b/src/video_core/renderer_opengl/gl_shader_manager.h index 6961e702a..3703e7018 100644 --- a/src/video_core/renderer_opengl/gl_shader_manager.h +++ b/src/video_core/renderer_opengl/gl_shader_manager.h @@ -18,17 +18,12 @@ namespace OpenGL::GLShader { /// @note Always keep a vec4 at the end. The GL spec is not clear whether the alignment at /// the end of a uniform block is included in UNIFORM_BLOCK_DATA_SIZE or not. /// Not following that rule will cause problems on some AMD drivers. -struct MaxwellUniformData { - void SetFromRegs(const Tegra::Engines::Maxwell3D& maxwell, std::size_t shader_stage); - - alignas(16) GLvec4 viewport_flip; - struct alignas(16) { - GLuint instance_id; - GLuint flip_stage; - GLfloat y_direction; - }; +struct alignas(16) MaxwellUniformData { + void SetFromRegs(const Tegra::Engines::Maxwell3D& maxwell); + + GLfloat y_direction; }; -static_assert(sizeof(MaxwellUniformData) == 32, "MaxwellUniformData structure size is incorrect"); +static_assert(sizeof(MaxwellUniformData) == 16, "MaxwellUniformData structure size is incorrect"); static_assert(sizeof(MaxwellUniformData) < 16384, "MaxwellUniformData structure must be less than 16kb as per the OpenGL spec"); diff --git a/src/video_core/renderer_opengl/gl_state.cpp b/src/video_core/renderer_opengl/gl_state.cpp index f25148362..39b3986d3 100644 --- a/src/video_core/renderer_opengl/gl_state.cpp +++ b/src/video_core/renderer_opengl/gl_state.cpp @@ -410,15 +410,31 @@ void OpenGLState::ApplyAlphaTest() { } } +void OpenGLState::ApplyClipControl() { + if (UpdateValue(cur_state.clip_control.origin, clip_control.origin)) { + glClipControl(clip_control.origin, GL_NEGATIVE_ONE_TO_ONE); + } +} + void OpenGLState::ApplyTextures() { - if (const auto update = UpdateArray(cur_state.textures, textures)) { - glBindTextures(update->first, update->second, textures.data() + update->first); + const std::size_t size = std::size(textures); + for (std::size_t i = 0; i < size; ++i) { + if (UpdateValue(cur_state.textures[i], textures[i])) { + // BindTextureUnit doesn't support binding null textures, skip those binds. + // TODO(Rodrigo): Stop using null textures + if (textures[i] != 0) { + glBindTextureUnit(static_cast<GLuint>(i), textures[i]); + } + } } } void OpenGLState::ApplySamplers() { - if (const auto update = UpdateArray(cur_state.samplers, samplers)) { - glBindSamplers(update->first, update->second, samplers.data() + update->first); + const std::size_t size = std::size(samplers); + for (std::size_t i = 0; i < size; ++i) { + if (UpdateValue(cur_state.samplers[i], samplers[i])) { + glBindSampler(static_cast<GLuint>(i), samplers[i]); + } } } @@ -453,6 +469,7 @@ void OpenGLState::Apply() { ApplyImages(); ApplyPolygonOffset(); ApplyAlphaTest(); + ApplyClipControl(); } void OpenGLState::EmulateViewportWithScissor() { diff --git a/src/video_core/renderer_opengl/gl_state.h b/src/video_core/renderer_opengl/gl_state.h index cca25206b..e53c2c5f2 100644 --- a/src/video_core/renderer_opengl/gl_state.h +++ b/src/video_core/renderer_opengl/gl_state.h @@ -96,9 +96,11 @@ public: GLenum operation = GL_COPY; } logic_op; - std::array<GLuint, Tegra::Engines::Maxwell3D::Regs::NumTextureSamplers> textures = {}; - std::array<GLuint, Tegra::Engines::Maxwell3D::Regs::NumTextureSamplers> samplers = {}; - std::array<GLuint, Tegra::Engines::Maxwell3D::Regs::NumImages> images = {}; + static constexpr std::size_t NumSamplers = 32 * 5; + static constexpr std::size_t NumImages = 8 * 5; + std::array<GLuint, NumSamplers> textures = {}; + std::array<GLuint, NumSamplers> samplers = {}; + std::array<GLuint, NumImages> images = {}; struct { GLuint read_framebuffer = 0; // GL_READ_FRAMEBUFFER_BINDING @@ -146,6 +148,10 @@ public: std::array<bool, 8> clip_distance = {}; // GL_CLIP_DISTANCE + struct { + GLenum origin = GL_LOWER_LEFT; + } clip_control; + OpenGLState(); /// Get the currently active OpenGL state @@ -182,6 +188,7 @@ public: void ApplyDepthClamp(); void ApplyPolygonOffset(); void ApplyAlphaTest(); + void ApplyClipControl(); /// Resets any references to the given resource OpenGLState& UnbindTexture(GLuint handle); diff --git a/src/video_core/renderer_opengl/gl_texture_cache.cpp b/src/video_core/renderer_opengl/gl_texture_cache.cpp index 55b3e58b2..b790b0ef4 100644 --- a/src/video_core/renderer_opengl/gl_texture_cache.cpp +++ b/src/video_core/renderer_opengl/gl_texture_cache.cpp @@ -23,7 +23,6 @@ namespace OpenGL { using Tegra::Texture::SwizzleSource; using VideoCore::MortonSwizzleMode; -using VideoCore::Surface::ComponentType; using VideoCore::Surface::PixelFormat; using VideoCore::Surface::SurfaceCompression; using VideoCore::Surface::SurfaceTarget; @@ -40,114 +39,95 @@ struct FormatTuple { GLint internal_format; GLenum format; GLenum type; - ComponentType component_type; bool compressed; }; constexpr std::array<FormatTuple, VideoCore::Surface::MaxPixelFormat> tex_format_tuples = {{ - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8_REV, ComponentType::UNorm, false}, // ABGR8U - {GL_RGBA8, GL_RGBA, GL_BYTE, ComponentType::SNorm, false}, // ABGR8S - {GL_RGBA8UI, GL_RGBA_INTEGER, GL_UNSIGNED_BYTE, ComponentType::UInt, false}, // ABGR8UI - {GL_RGB565, GL_RGB, GL_UNSIGNED_SHORT_5_6_5_REV, ComponentType::UNorm, false}, // B5G6R5U - {GL_RGB10_A2, GL_RGBA, GL_UNSIGNED_INT_2_10_10_10_REV, ComponentType::UNorm, - false}, // A2B10G10R10U - {GL_RGB5_A1, GL_RGBA, GL_UNSIGNED_SHORT_1_5_5_5_REV, ComponentType::UNorm, false}, // A1B5G5R5U - {GL_R8, GL_RED, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // R8U - {GL_R8UI, GL_RED_INTEGER, GL_UNSIGNED_BYTE, ComponentType::UInt, false}, // R8UI - {GL_RGBA16F, GL_RGBA, GL_HALF_FLOAT, ComponentType::Float, false}, // RGBA16F - {GL_RGBA16, GL_RGBA, GL_UNSIGNED_SHORT, ComponentType::UNorm, false}, // RGBA16U - {GL_RGBA16UI, GL_RGBA_INTEGER, GL_UNSIGNED_SHORT, ComponentType::UInt, false}, // RGBA16UI - {GL_R11F_G11F_B10F, GL_RGB, GL_UNSIGNED_INT_10F_11F_11F_REV, ComponentType::Float, - false}, // R11FG11FB10F - {GL_RGBA32UI, GL_RGBA_INTEGER, GL_UNSIGNED_INT, ComponentType::UInt, false}, // RGBA32UI - {GL_COMPRESSED_RGBA_S3TC_DXT1_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, ComponentType::UNorm, - true}, // DXT1 - {GL_COMPRESSED_RGBA_S3TC_DXT3_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, ComponentType::UNorm, - true}, // DXT23 - {GL_COMPRESSED_RGBA_S3TC_DXT5_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, ComponentType::UNorm, - true}, // DXT45 - {GL_COMPRESSED_RED_RGTC1, GL_RED, GL_UNSIGNED_INT_8_8_8_8, ComponentType::UNorm, true}, // DXN1 - {GL_COMPRESSED_RG_RGTC2, GL_RG, GL_UNSIGNED_INT_8_8_8_8, ComponentType::UNorm, - true}, // DXN2UNORM - {GL_COMPRESSED_SIGNED_RG_RGTC2, GL_RG, GL_INT, ComponentType::SNorm, true}, // DXN2SNORM - {GL_COMPRESSED_RGBA_BPTC_UNORM, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, ComponentType::UNorm, - true}, // BC7U - {GL_COMPRESSED_RGB_BPTC_UNSIGNED_FLOAT, GL_RGB, GL_UNSIGNED_INT_8_8_8_8, ComponentType::Float, - true}, // BC6H_UF16 - {GL_COMPRESSED_RGB_BPTC_SIGNED_FLOAT, GL_RGB, GL_UNSIGNED_INT_8_8_8_8, ComponentType::Float, - true}, // BC6H_SF16 - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_4X4 - {GL_RGBA8, GL_BGRA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // BGRA8 - {GL_RGBA32F, GL_RGBA, GL_FLOAT, ComponentType::Float, false}, // RGBA32F - {GL_RG32F, GL_RG, GL_FLOAT, ComponentType::Float, false}, // RG32F - {GL_R32F, GL_RED, GL_FLOAT, ComponentType::Float, false}, // R32F - {GL_R16F, GL_RED, GL_HALF_FLOAT, ComponentType::Float, false}, // R16F - {GL_R16, GL_RED, GL_UNSIGNED_SHORT, ComponentType::UNorm, false}, // R16U - {GL_R16_SNORM, GL_RED, GL_SHORT, ComponentType::SNorm, false}, // R16S - {GL_R16UI, GL_RED_INTEGER, GL_UNSIGNED_SHORT, ComponentType::UInt, false}, // R16UI - {GL_R16I, GL_RED_INTEGER, GL_SHORT, ComponentType::SInt, false}, // R16I - {GL_RG16, GL_RG, GL_UNSIGNED_SHORT, ComponentType::UNorm, false}, // RG16 - {GL_RG16F, GL_RG, GL_HALF_FLOAT, ComponentType::Float, false}, // RG16F - {GL_RG16UI, GL_RG_INTEGER, GL_UNSIGNED_SHORT, ComponentType::UInt, false}, // RG16UI - {GL_RG16I, GL_RG_INTEGER, GL_SHORT, ComponentType::SInt, false}, // RG16I - {GL_RG16_SNORM, GL_RG, GL_SHORT, ComponentType::SNorm, false}, // RG16S - {GL_RGB32F, GL_RGB, GL_FLOAT, ComponentType::Float, false}, // RGB32F - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8_REV, ComponentType::UNorm, - false}, // RGBA8_SRGB - {GL_RG8, GL_RG, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // RG8U - {GL_RG8, GL_RG, GL_BYTE, ComponentType::SNorm, false}, // RG8S - {GL_RG32UI, GL_RG_INTEGER, GL_UNSIGNED_INT, ComponentType::UInt, false}, // RG32UI - {GL_RGB16F, GL_RGBA16, GL_HALF_FLOAT, ComponentType::Float, false}, // RGBX16F - {GL_R32UI, GL_RED_INTEGER, GL_UNSIGNED_INT, ComponentType::UInt, false}, // R32UI - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_8X8 - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_8X5 - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_5X4 - {GL_SRGB8_ALPHA8, GL_BGRA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // BGRA8 + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8_REV, false}, // ABGR8U + {GL_RGBA8, GL_RGBA, GL_BYTE, false}, // ABGR8S + {GL_RGBA8UI, GL_RGBA_INTEGER, GL_UNSIGNED_BYTE, false}, // ABGR8UI + {GL_RGB565, GL_RGB, GL_UNSIGNED_SHORT_5_6_5_REV, false}, // B5G6R5U + {GL_RGB10_A2, GL_RGBA, GL_UNSIGNED_INT_2_10_10_10_REV, false}, // A2B10G10R10U + {GL_RGB5_A1, GL_RGBA, GL_UNSIGNED_SHORT_1_5_5_5_REV, false}, // A1B5G5R5U + {GL_R8, GL_RED, GL_UNSIGNED_BYTE, false}, // R8U + {GL_R8UI, GL_RED_INTEGER, GL_UNSIGNED_BYTE, false}, // R8UI + {GL_RGBA16F, GL_RGBA, GL_HALF_FLOAT, false}, // RGBA16F + {GL_RGBA16, GL_RGBA, GL_UNSIGNED_SHORT, false}, // RGBA16U + {GL_RGBA16UI, GL_RGBA_INTEGER, GL_UNSIGNED_SHORT, false}, // RGBA16UI + {GL_R11F_G11F_B10F, GL_RGB, GL_UNSIGNED_INT_10F_11F_11F_REV, false}, // R11FG11FB10F + {GL_RGBA32UI, GL_RGBA_INTEGER, GL_UNSIGNED_INT, false}, // RGBA32UI + {GL_COMPRESSED_RGBA_S3TC_DXT1_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, true}, // DXT1 + {GL_COMPRESSED_RGBA_S3TC_DXT3_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, true}, // DXT23 + {GL_COMPRESSED_RGBA_S3TC_DXT5_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, true}, // DXT45 + {GL_COMPRESSED_RED_RGTC1, GL_RED, GL_UNSIGNED_INT_8_8_8_8, true}, // DXN1 + {GL_COMPRESSED_RG_RGTC2, GL_RG, GL_UNSIGNED_INT_8_8_8_8, true}, // DXN2UNORM + {GL_COMPRESSED_SIGNED_RG_RGTC2, GL_RG, GL_INT, true}, // DXN2SNORM + {GL_COMPRESSED_RGBA_BPTC_UNORM, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, true}, // BC7U + {GL_COMPRESSED_RGB_BPTC_UNSIGNED_FLOAT, GL_RGB, GL_UNSIGNED_INT_8_8_8_8, true}, // BC6H_UF16 + {GL_COMPRESSED_RGB_BPTC_SIGNED_FLOAT, GL_RGB, GL_UNSIGNED_INT_8_8_8_8, true}, // BC6H_SF16 + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_4X4 + {GL_RGBA8, GL_BGRA, GL_UNSIGNED_BYTE, false}, // BGRA8 + {GL_RGBA32F, GL_RGBA, GL_FLOAT, false}, // RGBA32F + {GL_RG32F, GL_RG, GL_FLOAT, false}, // RG32F + {GL_R32F, GL_RED, GL_FLOAT, false}, // R32F + {GL_R16F, GL_RED, GL_HALF_FLOAT, false}, // R16F + {GL_R16, GL_RED, GL_UNSIGNED_SHORT, false}, // R16U + {GL_R16_SNORM, GL_RED, GL_SHORT, false}, // R16S + {GL_R16UI, GL_RED_INTEGER, GL_UNSIGNED_SHORT, false}, // R16UI + {GL_R16I, GL_RED_INTEGER, GL_SHORT, false}, // R16I + {GL_RG16, GL_RG, GL_UNSIGNED_SHORT, false}, // RG16 + {GL_RG16F, GL_RG, GL_HALF_FLOAT, false}, // RG16F + {GL_RG16UI, GL_RG_INTEGER, GL_UNSIGNED_SHORT, false}, // RG16UI + {GL_RG16I, GL_RG_INTEGER, GL_SHORT, false}, // RG16I + {GL_RG16_SNORM, GL_RG, GL_SHORT, false}, // RG16S + {GL_RGB32F, GL_RGB, GL_FLOAT, false}, // RGB32F + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8_REV, false}, // RGBA8_SRGB + {GL_RG8, GL_RG, GL_UNSIGNED_BYTE, false}, // RG8U + {GL_RG8, GL_RG, GL_BYTE, false}, // RG8S + {GL_RG32UI, GL_RG_INTEGER, GL_UNSIGNED_INT, false}, // RG32UI + {GL_RGB16F, GL_RGBA16, GL_HALF_FLOAT, false}, // RGBX16F + {GL_R32UI, GL_RED_INTEGER, GL_UNSIGNED_INT, false}, // R32UI + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_8X8 + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_8X5 + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_5X4 + {GL_SRGB8_ALPHA8, GL_BGRA, GL_UNSIGNED_BYTE, false}, // BGRA8 // Compressed sRGB formats - {GL_COMPRESSED_SRGB_ALPHA_S3TC_DXT1_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, ComponentType::UNorm, - true}, // DXT1_SRGB - {GL_COMPRESSED_SRGB_ALPHA_S3TC_DXT3_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, ComponentType::UNorm, - true}, // DXT23_SRGB - {GL_COMPRESSED_SRGB_ALPHA_S3TC_DXT5_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, ComponentType::UNorm, - true}, // DXT45_SRGB - {GL_COMPRESSED_SRGB_ALPHA_BPTC_UNORM, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, ComponentType::UNorm, - true}, // BC7U_SRGB - {GL_RGBA4, GL_RGBA, GL_UNSIGNED_SHORT_4_4_4_4_REV, ComponentType::UNorm, false}, // R4G4B4A4U - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_4X4_SRGB - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_8X8_SRGB - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_8X5_SRGB - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_5X4_SRGB - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_5X5 - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_5X5_SRGB - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_10X8 - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_10X8_SRGB - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_6X6 - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_6X6_SRGB - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_10X10 - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_10X10_SRGB - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_12X12 - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_12X12_SRGB - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_8X6 - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_8X6_SRGB - {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_6X5 - {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, ComponentType::UNorm, false}, // ASTC_2D_6X5_SRGB - {GL_RGB9_E5, GL_RGB, GL_UNSIGNED_INT_5_9_9_9_REV, ComponentType::Float, false}, // E5B9G9R9F + {GL_COMPRESSED_SRGB_ALPHA_S3TC_DXT1_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, true}, // DXT1_SRGB + {GL_COMPRESSED_SRGB_ALPHA_S3TC_DXT3_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, true}, // DXT23_SRGB + {GL_COMPRESSED_SRGB_ALPHA_S3TC_DXT5_EXT, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, true}, // DXT45_SRGB + {GL_COMPRESSED_SRGB_ALPHA_BPTC_UNORM, GL_RGBA, GL_UNSIGNED_INT_8_8_8_8, true}, // BC7U_SRGB + {GL_RGBA4, GL_RGBA, GL_UNSIGNED_SHORT_4_4_4_4_REV, false}, // R4G4B4A4U + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_4X4_SRGB + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_8X8_SRGB + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_8X5_SRGB + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_5X4_SRGB + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_5X5 + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_5X5_SRGB + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_10X8 + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_10X8_SRGB + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_6X6 + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_6X6_SRGB + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_10X10 + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_10X10_SRGB + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_12X12 + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_12X12_SRGB + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_8X6 + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_8X6_SRGB + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_6X5 + {GL_SRGB8_ALPHA8, GL_RGBA, GL_UNSIGNED_BYTE, false}, // ASTC_2D_6X5_SRGB + {GL_RGB9_E5, GL_RGB, GL_UNSIGNED_INT_5_9_9_9_REV, false}, // E5B9G9R9F // Depth formats - {GL_DEPTH_COMPONENT32F, GL_DEPTH_COMPONENT, GL_FLOAT, ComponentType::Float, false}, // Z32F - {GL_DEPTH_COMPONENT16, GL_DEPTH_COMPONENT, GL_UNSIGNED_SHORT, ComponentType::UNorm, - false}, // Z16 + {GL_DEPTH_COMPONENT32F, GL_DEPTH_COMPONENT, GL_FLOAT, false}, // Z32F + {GL_DEPTH_COMPONENT16, GL_DEPTH_COMPONENT, GL_UNSIGNED_SHORT, false}, // Z16 // DepthStencil formats - {GL_DEPTH24_STENCIL8, GL_DEPTH_STENCIL, GL_UNSIGNED_INT_24_8, ComponentType::UNorm, - false}, // Z24S8 - {GL_DEPTH24_STENCIL8, GL_DEPTH_STENCIL, GL_UNSIGNED_INT_24_8, ComponentType::UNorm, - false}, // S8Z24 - {GL_DEPTH32F_STENCIL8, GL_DEPTH_STENCIL, GL_FLOAT_32_UNSIGNED_INT_24_8_REV, - ComponentType::Float, false}, // Z32FS8 + {GL_DEPTH24_STENCIL8, GL_DEPTH_STENCIL, GL_UNSIGNED_INT_24_8, false}, // Z24S8 + {GL_DEPTH24_STENCIL8, GL_DEPTH_STENCIL, GL_UNSIGNED_INT_24_8, false}, // S8Z24 + {GL_DEPTH32F_STENCIL8, GL_DEPTH_STENCIL, GL_FLOAT_32_UNSIGNED_INT_24_8_REV, false}, // Z32FS8 }}; -const FormatTuple& GetFormatTuple(PixelFormat pixel_format, ComponentType component_type) { +const FormatTuple& GetFormatTuple(PixelFormat pixel_format) { ASSERT(static_cast<std::size_t>(pixel_format) < tex_format_tuples.size()); const auto& format{tex_format_tuples[static_cast<std::size_t>(pixel_format)]}; return format; @@ -249,7 +229,7 @@ OGLTexture CreateTexture(const SurfaceParams& params, GLenum target, GLenum inte CachedSurface::CachedSurface(const GPUVAddr gpu_addr, const SurfaceParams& params) : VideoCommon::SurfaceBase<View>(gpu_addr, params) { - const auto& tuple{GetFormatTuple(params.pixel_format, params.component_type)}; + const auto& tuple{GetFormatTuple(params.pixel_format)}; internal_format = tuple.internal_format; format = tuple.format; type = tuple.type; @@ -451,8 +431,7 @@ OGLTextureView CachedSurfaceView::CreateTextureView() const { texture_view.Create(); const GLuint handle{texture_view.handle}; - const FormatTuple& tuple{ - GetFormatTuple(owner_params.pixel_format, owner_params.component_type)}; + const FormatTuple& tuple{GetFormatTuple(owner_params.pixel_format)}; glTextureView(handle, target, surface.texture.handle, tuple.internal_format, params.base_level, params.num_levels, params.base_layer, params.num_layers); @@ -509,6 +488,7 @@ void TextureCacheOpenGL::ImageBlit(View& src_view, View& dst_view, OpenGLState state; state.draw.read_framebuffer = src_framebuffer.handle; state.draw.draw_framebuffer = dst_framebuffer.handle; + state.framebuffer_srgb.enabled = dst_params.srgb_conversion; state.AllDirty(); state.Apply(); @@ -562,8 +542,8 @@ void TextureCacheOpenGL::BufferCopy(Surface& src_surface, Surface& dst_surface) const auto& dst_params = dst_surface->GetSurfaceParams(); UNIMPLEMENTED_IF(src_params.num_levels > 1 || dst_params.num_levels > 1); - const auto source_format = GetFormatTuple(src_params.pixel_format, src_params.component_type); - const auto dest_format = GetFormatTuple(dst_params.pixel_format, dst_params.component_type); + const auto source_format = GetFormatTuple(src_params.pixel_format); + const auto dest_format = GetFormatTuple(dst_params.pixel_format); const std::size_t source_size = src_surface->GetHostSizeInBytes(); const std::size_t dest_size = dst_surface->GetHostSizeInBytes(); diff --git a/src/video_core/renderer_opengl/renderer_opengl.cpp b/src/video_core/renderer_opengl/renderer_opengl.cpp index 7646cbb0e..a57a564f7 100644 --- a/src/video_core/renderer_opengl/renderer_opengl.cpp +++ b/src/video_core/renderer_opengl/renderer_opengl.cpp @@ -158,7 +158,7 @@ void RendererOpenGL::LoadFBToScreenInfo(const Tegra::FramebufferConfig& framebuf VideoCore::Surface::PixelFormatFromGPUPixelFormat(framebuffer.pixel_format)}; const u32 bytes_per_pixel{VideoCore::Surface::GetBytesPerPixel(pixel_format)}; const u64 size_in_bytes{framebuffer.stride * framebuffer.height * bytes_per_pixel}; - const auto host_ptr{Memory::GetPointer(framebuffer_addr)}; + u8* const host_ptr{system.Memory().GetPointer(framebuffer_addr)}; rasterizer->FlushRegion(ToCacheAddr(host_ptr), size_in_bytes); // TODO(Rodrigo): Read this from HLE diff --git a/src/video_core/renderer_opengl/utils.cpp b/src/video_core/renderer_opengl/utils.cpp index c504a2c1a..9770dda1c 100644 --- a/src/video_core/renderer_opengl/utils.cpp +++ b/src/video_core/renderer_opengl/utils.cpp @@ -3,7 +3,10 @@ // Refer to the license.txt file included. #include <string> +#include <vector> + #include <fmt/format.h> + #include <glad/glad.h> #include "common/assert.h" @@ -48,34 +51,19 @@ BindBuffersRangePushBuffer::BindBuffersRangePushBuffer(GLenum target) : target{t BindBuffersRangePushBuffer::~BindBuffersRangePushBuffer() = default; -void BindBuffersRangePushBuffer::Setup(GLuint first_) { - first = first_; - buffer_pointers.clear(); - offsets.clear(); - sizes.clear(); +void BindBuffersRangePushBuffer::Setup() { + entries.clear(); } -void BindBuffersRangePushBuffer::Push(const GLuint* buffer, GLintptr offset, GLsizeiptr size) { - buffer_pointers.push_back(buffer); - offsets.push_back(offset); - sizes.push_back(size); +void BindBuffersRangePushBuffer::Push(GLuint binding, const GLuint* buffer, GLintptr offset, + GLsizeiptr size) { + entries.push_back(Entry{binding, buffer, offset, size}); } void BindBuffersRangePushBuffer::Bind() { - // Ensure sizes are valid. - const std::size_t count{buffer_pointers.size()}; - DEBUG_ASSERT(count == offsets.size() && count == sizes.size()); - if (count == 0) { - return; + for (const Entry& entry : entries) { + glBindBufferRange(target, entry.binding, *entry.buffer, entry.offset, entry.size); } - - // Dereference buffers. - buffers.resize(count); - std::transform(buffer_pointers.begin(), buffer_pointers.end(), buffers.begin(), - [](const GLuint* pointer) { return *pointer; }); - - glBindBuffersRange(target, first, static_cast<GLsizei>(count), buffers.data(), offsets.data(), - sizes.data()); } void LabelGLObject(GLenum identifier, GLuint handle, VAddr addr, std::string_view extra_info) { diff --git a/src/video_core/renderer_opengl/utils.h b/src/video_core/renderer_opengl/utils.h index 6c2b45546..d56153fe7 100644 --- a/src/video_core/renderer_opengl/utils.h +++ b/src/video_core/renderer_opengl/utils.h @@ -43,20 +43,22 @@ public: explicit BindBuffersRangePushBuffer(GLenum target); ~BindBuffersRangePushBuffer(); - void Setup(GLuint first_); + void Setup(); - void Push(const GLuint* buffer, GLintptr offset, GLsizeiptr size); + void Push(GLuint binding, const GLuint* buffer, GLintptr offset, GLsizeiptr size); void Bind(); private: - GLenum target{}; - GLuint first{}; - std::vector<const GLuint*> buffer_pointers; + struct Entry { + GLuint binding; + const GLuint* buffer; + GLintptr offset; + GLsizeiptr size; + }; - std::vector<GLuint> buffers; - std::vector<GLintptr> offsets; - std::vector<GLsizeiptr> sizes; + GLenum target; + std::vector<Entry> entries; }; void LabelGLObject(GLenum identifier, GLuint handle, VAddr addr, std::string_view extra_info = {}); diff --git a/src/video_core/renderer_vulkan/maxwell_to_vk.cpp b/src/video_core/renderer_vulkan/maxwell_to_vk.cpp index 3c5acda3e..7f0eb6b74 100644 --- a/src/video_core/renderer_vulkan/maxwell_to_vk.cpp +++ b/src/video_core/renderer_vulkan/maxwell_to_vk.cpp @@ -13,6 +13,8 @@ namespace Vulkan::MaxwellToVK { +using Maxwell = Tegra::Engines::Maxwell3D::Regs; + namespace Sampler { vk::Filter Filter(Tegra::Texture::TextureFilter filter) { @@ -95,83 +97,82 @@ vk::CompareOp DepthCompareFunction(Tegra::Texture::DepthCompareFunc depth_compar } // namespace Sampler struct FormatTuple { - vk::Format format; ///< Vulkan format - ComponentType component_type; ///< Abstracted component type - bool attachable; ///< True when this format can be used as an attachment + vk::Format format; ///< Vulkan format + bool attachable; ///< True when this format can be used as an attachment }; static constexpr std::array<FormatTuple, VideoCore::Surface::MaxPixelFormat> tex_format_tuples = {{ - {vk::Format::eA8B8G8R8UnormPack32, ComponentType::UNorm, true}, // ABGR8U - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ABGR8S - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ABGR8UI - {vk::Format::eB5G6R5UnormPack16, ComponentType::UNorm, false}, // B5G6R5U - {vk::Format::eA2B10G10R10UnormPack32, ComponentType::UNorm, true}, // A2B10G10R10U - {vk::Format::eUndefined, ComponentType::Invalid, false}, // A1B5G5R5U - {vk::Format::eR8Unorm, ComponentType::UNorm, true}, // R8U - {vk::Format::eUndefined, ComponentType::Invalid, false}, // R8UI - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RGBA16F - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RGBA16U - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RGBA16UI - {vk::Format::eUndefined, ComponentType::Invalid, false}, // R11FG11FB10F - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RGBA32UI - {vk::Format::eBc1RgbaUnormBlock, ComponentType::UNorm, false}, // DXT1 - {vk::Format::eBc2UnormBlock, ComponentType::UNorm, false}, // DXT23 - {vk::Format::eBc3UnormBlock, ComponentType::UNorm, false}, // DXT45 - {vk::Format::eBc4UnormBlock, ComponentType::UNorm, false}, // DXN1 - {vk::Format::eUndefined, ComponentType::Invalid, false}, // DXN2UNORM - {vk::Format::eUndefined, ComponentType::Invalid, false}, // DXN2SNORM - {vk::Format::eUndefined, ComponentType::Invalid, false}, // BC7U - {vk::Format::eUndefined, ComponentType::Invalid, false}, // BC6H_UF16 - {vk::Format::eUndefined, ComponentType::Invalid, false}, // BC6H_SF16 - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_4X4 - {vk::Format::eUndefined, ComponentType::Invalid, false}, // BGRA8 - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RGBA32F - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RG32F - {vk::Format::eUndefined, ComponentType::Invalid, false}, // R32F - {vk::Format::eUndefined, ComponentType::Invalid, false}, // R16F - {vk::Format::eUndefined, ComponentType::Invalid, false}, // R16U - {vk::Format::eUndefined, ComponentType::Invalid, false}, // R16S - {vk::Format::eUndefined, ComponentType::Invalid, false}, // R16UI - {vk::Format::eUndefined, ComponentType::Invalid, false}, // R16I - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RG16 - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RG16F - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RG16UI - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RG16I - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RG16S - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RGB32F - {vk::Format::eA8B8G8R8SrgbPack32, ComponentType::UNorm, true}, // RGBA8_SRGB - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RG8U - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RG8S - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RG32UI - {vk::Format::eUndefined, ComponentType::Invalid, false}, // RGBX16F - {vk::Format::eUndefined, ComponentType::Invalid, false}, // R32UI - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_8X8 - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_8X5 - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_5X4 + {vk::Format::eA8B8G8R8UnormPack32, true}, // ABGR8U + {vk::Format::eUndefined, false}, // ABGR8S + {vk::Format::eUndefined, false}, // ABGR8UI + {vk::Format::eB5G6R5UnormPack16, false}, // B5G6R5U + {vk::Format::eA2B10G10R10UnormPack32, true}, // A2B10G10R10U + {vk::Format::eUndefined, false}, // A1B5G5R5U + {vk::Format::eR8Unorm, true}, // R8U + {vk::Format::eUndefined, false}, // R8UI + {vk::Format::eUndefined, false}, // RGBA16F + {vk::Format::eUndefined, false}, // RGBA16U + {vk::Format::eUndefined, false}, // RGBA16UI + {vk::Format::eUndefined, false}, // R11FG11FB10F + {vk::Format::eUndefined, false}, // RGBA32UI + {vk::Format::eBc1RgbaUnormBlock, false}, // DXT1 + {vk::Format::eBc2UnormBlock, false}, // DXT23 + {vk::Format::eBc3UnormBlock, false}, // DXT45 + {vk::Format::eBc4UnormBlock, false}, // DXN1 + {vk::Format::eUndefined, false}, // DXN2UNORM + {vk::Format::eUndefined, false}, // DXN2SNORM + {vk::Format::eUndefined, false}, // BC7U + {vk::Format::eUndefined, false}, // BC6H_UF16 + {vk::Format::eUndefined, false}, // BC6H_SF16 + {vk::Format::eUndefined, false}, // ASTC_2D_4X4 + {vk::Format::eUndefined, false}, // BGRA8 + {vk::Format::eUndefined, false}, // RGBA32F + {vk::Format::eUndefined, false}, // RG32F + {vk::Format::eUndefined, false}, // R32F + {vk::Format::eUndefined, false}, // R16F + {vk::Format::eUndefined, false}, // R16U + {vk::Format::eUndefined, false}, // R16S + {vk::Format::eUndefined, false}, // R16UI + {vk::Format::eUndefined, false}, // R16I + {vk::Format::eUndefined, false}, // RG16 + {vk::Format::eUndefined, false}, // RG16F + {vk::Format::eUndefined, false}, // RG16UI + {vk::Format::eUndefined, false}, // RG16I + {vk::Format::eUndefined, false}, // RG16S + {vk::Format::eUndefined, false}, // RGB32F + {vk::Format::eA8B8G8R8SrgbPack32, true}, // RGBA8_SRGB + {vk::Format::eUndefined, false}, // RG8U + {vk::Format::eUndefined, false}, // RG8S + {vk::Format::eUndefined, false}, // RG32UI + {vk::Format::eUndefined, false}, // RGBX16F + {vk::Format::eUndefined, false}, // R32UI + {vk::Format::eUndefined, false}, // ASTC_2D_8X8 + {vk::Format::eUndefined, false}, // ASTC_2D_8X5 + {vk::Format::eUndefined, false}, // ASTC_2D_5X4 // Compressed sRGB formats - {vk::Format::eUndefined, ComponentType::Invalid, false}, // BGRA8_SRGB - {vk::Format::eUndefined, ComponentType::Invalid, false}, // DXT1_SRGB - {vk::Format::eUndefined, ComponentType::Invalid, false}, // DXT23_SRGB - {vk::Format::eUndefined, ComponentType::Invalid, false}, // DXT45_SRGB - {vk::Format::eUndefined, ComponentType::Invalid, false}, // BC7U_SRGB - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_4X4_SRGB - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_8X8_SRGB - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_8X5_SRGB - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_5X4_SRGB - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_5X5 - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_5X5_SRGB - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_10X8 - {vk::Format::eUndefined, ComponentType::Invalid, false}, // ASTC_2D_10X8_SRGB + {vk::Format::eUndefined, false}, // BGRA8_SRGB + {vk::Format::eUndefined, false}, // DXT1_SRGB + {vk::Format::eUndefined, false}, // DXT23_SRGB + {vk::Format::eUndefined, false}, // DXT45_SRGB + {vk::Format::eUndefined, false}, // BC7U_SRGB + {vk::Format::eUndefined, false}, // ASTC_2D_4X4_SRGB + {vk::Format::eUndefined, false}, // ASTC_2D_8X8_SRGB + {vk::Format::eUndefined, false}, // ASTC_2D_8X5_SRGB + {vk::Format::eUndefined, false}, // ASTC_2D_5X4_SRGB + {vk::Format::eUndefined, false}, // ASTC_2D_5X5 + {vk::Format::eUndefined, false}, // ASTC_2D_5X5_SRGB + {vk::Format::eUndefined, false}, // ASTC_2D_10X8 + {vk::Format::eUndefined, false}, // ASTC_2D_10X8_SRGB // Depth formats - {vk::Format::eD32Sfloat, ComponentType::Float, true}, // Z32F - {vk::Format::eD16Unorm, ComponentType::UNorm, true}, // Z16 + {vk::Format::eD32Sfloat, true}, // Z32F + {vk::Format::eD16Unorm, true}, // Z16 // DepthStencil formats - {vk::Format::eD24UnormS8Uint, ComponentType::UNorm, true}, // Z24S8 - {vk::Format::eD24UnormS8Uint, ComponentType::UNorm, true}, // S8Z24 (emulated) - {vk::Format::eUndefined, ComponentType::Invalid, false}, // Z32FS8 + {vk::Format::eD24UnormS8Uint, true}, // Z24S8 + {vk::Format::eD24UnormS8Uint, true}, // S8Z24 (emulated) + {vk::Format::eUndefined, false}, // Z32FS8 }}; static constexpr bool IsZetaFormat(PixelFormat pixel_format) { @@ -180,14 +181,13 @@ static constexpr bool IsZetaFormat(PixelFormat pixel_format) { } std::pair<vk::Format, bool> SurfaceFormat(const VKDevice& device, FormatType format_type, - PixelFormat pixel_format, ComponentType component_type) { + PixelFormat pixel_format) { ASSERT(static_cast<std::size_t>(pixel_format) < tex_format_tuples.size()); const auto tuple = tex_format_tuples[static_cast<u32>(pixel_format)]; UNIMPLEMENTED_IF_MSG(tuple.format == vk::Format::eUndefined, - "Unimplemented texture format with pixel format={} and component type={}", - static_cast<u32>(pixel_format), static_cast<u32>(component_type)); - ASSERT_MSG(component_type == tuple.component_type, "Component type mismatch"); + "Unimplemented texture format with pixel format={}", + static_cast<u32>(pixel_format)); auto usage = vk::FormatFeatureFlagBits::eSampledImage | vk::FormatFeatureFlagBits::eTransferDst | vk::FormatFeatureFlagBits::eTransferSrc; @@ -198,17 +198,17 @@ std::pair<vk::Format, bool> SurfaceFormat(const VKDevice& device, FormatType for return {device.GetSupportedFormat(tuple.format, usage, format_type), tuple.attachable}; } -vk::ShaderStageFlagBits ShaderStage(Maxwell::ShaderStage stage) { +vk::ShaderStageFlagBits ShaderStage(Tegra::Engines::ShaderType stage) { switch (stage) { - case Maxwell::ShaderStage::Vertex: + case Tegra::Engines::ShaderType::Vertex: return vk::ShaderStageFlagBits::eVertex; - case Maxwell::ShaderStage::TesselationControl: + case Tegra::Engines::ShaderType::TesselationControl: return vk::ShaderStageFlagBits::eTessellationControl; - case Maxwell::ShaderStage::TesselationEval: + case Tegra::Engines::ShaderType::TesselationEval: return vk::ShaderStageFlagBits::eTessellationEvaluation; - case Maxwell::ShaderStage::Geometry: + case Tegra::Engines::ShaderType::Geometry: return vk::ShaderStageFlagBits::eGeometry; - case Maxwell::ShaderStage::Fragment: + case Tegra::Engines::ShaderType::Fragment: return vk::ShaderStageFlagBits::eFragment; } UNIMPLEMENTED_MSG("Unimplemented shader stage={}", static_cast<u32>(stage)); diff --git a/src/video_core/renderer_vulkan/maxwell_to_vk.h b/src/video_core/renderer_vulkan/maxwell_to_vk.h index 4cadc0721..904a32e01 100644 --- a/src/video_core/renderer_vulkan/maxwell_to_vk.h +++ b/src/video_core/renderer_vulkan/maxwell_to_vk.h @@ -16,7 +16,6 @@ namespace Vulkan::MaxwellToVK { using Maxwell = Tegra::Engines::Maxwell3D::Regs; using PixelFormat = VideoCore::Surface::PixelFormat; -using ComponentType = VideoCore::Surface::ComponentType; namespace Sampler { @@ -31,9 +30,9 @@ vk::CompareOp DepthCompareFunction(Tegra::Texture::DepthCompareFunc depth_compar } // namespace Sampler std::pair<vk::Format, bool> SurfaceFormat(const VKDevice& device, FormatType format_type, - PixelFormat pixel_format, ComponentType component_type); + PixelFormat pixel_format); -vk::ShaderStageFlagBits ShaderStage(Maxwell::ShaderStage stage); +vk::ShaderStageFlagBits ShaderStage(Tegra::Engines::ShaderType stage); vk::PrimitiveTopology PrimitiveTopology(Maxwell::PrimitiveTopology topology); diff --git a/src/video_core/renderer_vulkan/vk_buffer_cache.cpp b/src/video_core/renderer_vulkan/vk_buffer_cache.cpp index d2e9f4031..46da81aaa 100644 --- a/src/video_core/renderer_vulkan/vk_buffer_cache.cpp +++ b/src/video_core/renderer_vulkan/vk_buffer_cache.cpp @@ -24,9 +24,11 @@ CachedBufferEntry::CachedBufferEntry(VAddr cpu_addr, std::size_t size, u64 offse alignment{alignment} {} VKBufferCache::VKBufferCache(Tegra::MemoryManager& tegra_memory_manager, + Memory::Memory& cpu_memory_, VideoCore::RasterizerInterface& rasterizer, const VKDevice& device, VKMemoryManager& memory_manager, VKScheduler& scheduler, u64 size) - : RasterizerCache{rasterizer}, tegra_memory_manager{tegra_memory_manager} { + : RasterizerCache{rasterizer}, tegra_memory_manager{tegra_memory_manager}, cpu_memory{ + cpu_memory_} { const auto usage = vk::BufferUsageFlagBits::eVertexBuffer | vk::BufferUsageFlagBits::eIndexBuffer | vk::BufferUsageFlagBits::eUniformBuffer; @@ -48,9 +50,9 @@ u64 VKBufferCache::UploadMemory(GPUVAddr gpu_addr, std::size_t size, u64 alignme // TODO: Figure out which size is the best for given games. cache &= size >= 2048; - const auto& host_ptr{Memory::GetPointer(*cpu_addr)}; + u8* const host_ptr{cpu_memory.GetPointer(*cpu_addr)}; if (cache) { - auto entry = TryGet(host_ptr); + const auto entry = TryGet(host_ptr); if (entry) { if (entry->GetSize() >= size && entry->GetAlignment() == alignment) { return entry->GetOffset(); @@ -62,7 +64,7 @@ u64 VKBufferCache::UploadMemory(GPUVAddr gpu_addr, std::size_t size, u64 alignme AlignBuffer(alignment); const u64 uploaded_offset = buffer_offset; - if (!host_ptr) { + if (host_ptr == nullptr) { return uploaded_offset; } diff --git a/src/video_core/renderer_vulkan/vk_buffer_cache.h b/src/video_core/renderer_vulkan/vk_buffer_cache.h index 49f13bcdc..daa8ccf66 100644 --- a/src/video_core/renderer_vulkan/vk_buffer_cache.h +++ b/src/video_core/renderer_vulkan/vk_buffer_cache.h @@ -13,6 +13,10 @@ #include "video_core/renderer_vulkan/declarations.h" #include "video_core/renderer_vulkan/vk_scheduler.h" +namespace Memory { +class Memory; +} + namespace Tegra { class MemoryManager; } @@ -58,7 +62,7 @@ private: class VKBufferCache final : public RasterizerCache<std::shared_ptr<CachedBufferEntry>> { public: - explicit VKBufferCache(Tegra::MemoryManager& tegra_memory_manager, + explicit VKBufferCache(Tegra::MemoryManager& tegra_memory_manager, Memory::Memory& cpu_memory_, VideoCore::RasterizerInterface& rasterizer, const VKDevice& device, VKMemoryManager& memory_manager, VKScheduler& scheduler, u64 size); ~VKBufferCache(); @@ -92,6 +96,7 @@ private: void AlignBuffer(std::size_t alignment); Tegra::MemoryManager& tegra_memory_manager; + Memory::Memory& cpu_memory; std::unique_ptr<VKStreamBuffer> stream_buffer; vk::Buffer buffer_handle; diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp index 30a525e5d..76894275b 100644 --- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp +++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp @@ -17,6 +17,7 @@ #include "video_core/engines/maxwell_3d.h" #include "video_core/engines/shader_bytecode.h" #include "video_core/engines/shader_header.h" +#include "video_core/engines/shader_type.h" #include "video_core/renderer_vulkan/vk_device.h" #include "video_core/renderer_vulkan/vk_shader_decompiler.h" #include "video_core/shader/node.h" @@ -25,13 +26,13 @@ namespace Vulkan::VKShader { using Sirit::Id; +using Tegra::Engines::ShaderType; using Tegra::Shader::Attribute; using Tegra::Shader::AttributeUse; using Tegra::Shader::Register; using namespace VideoCommon::Shader; using Maxwell = Tegra::Engines::Maxwell3D::Regs; -using ShaderStage = Tegra::Engines::Maxwell3D::Regs::ShaderStage; using Operation = const OperationNode&; // TODO(Rodrigo): Use rasterizer's value @@ -93,7 +94,7 @@ class ExprDecompiler; class SPIRVDecompiler : public Sirit::Module { public: - explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage) + explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage) : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()} { AddCapability(spv::Capability::Shader); AddExtension("SPV_KHR_storage_buffer_storage_class"); @@ -256,21 +257,21 @@ private: } void DeclareVertex() { - if (stage != ShaderStage::Vertex) + if (stage != ShaderType::Vertex) return; DeclareVertexRedeclarations(); } void DeclareGeometry() { - if (stage != ShaderStage::Geometry) + if (stage != ShaderType::Geometry) return; UNIMPLEMENTED(); } void DeclareFragment() { - if (stage != ShaderStage::Fragment) + if (stage != ShaderType::Fragment) return; for (u32 rt = 0; rt < static_cast<u32>(frag_colors.size()); ++rt) { @@ -354,7 +355,7 @@ private: continue; } - UNIMPLEMENTED_IF(stage == ShaderStage::Geometry); + UNIMPLEMENTED_IF(stage == ShaderType::Geometry); const u32 location = GetGenericAttributeLocation(index); const Id id = OpVariable(t_in_float4, spv::StorageClass::Input); @@ -364,7 +365,7 @@ private: Decorate(id, spv::Decoration::Location, location); - if (stage != ShaderStage::Fragment) { + if (stage != ShaderType::Fragment) { continue; } switch (header.ps.GetAttributeUse(location)) { @@ -548,7 +549,7 @@ private: switch (attribute) { case Attribute::Index::Position: - if (stage != ShaderStage::Fragment) { + if (stage != ShaderType::Fragment) { UNIMPLEMENTED(); break; } else { @@ -561,7 +562,7 @@ private: // TODO(Subv): Find out what the values are for the first two elements when inside a // vertex shader, and what's the value of the fourth element when inside a Tess Eval // shader. - ASSERT(stage == ShaderStage::Vertex); + ASSERT(stage == ShaderType::Vertex); switch (element) { case 2: return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, instance_index))); @@ -572,7 +573,7 @@ private: return Constant(t_float, 0); case Attribute::Index::FrontFacing: // TODO(Subv): Find out what the values are for the other elements. - ASSERT(stage == ShaderStage::Fragment); + ASSERT(stage == ShaderType::Fragment); if (element == 3) { const Id is_front_facing = Emit(OpLoad(t_bool, front_facing)); const Id true_value = @@ -1080,7 +1081,7 @@ private: Id PreExit() { switch (stage) { - case ShaderStage::Vertex: { + case ShaderType::Vertex: { // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't // seem to be working on Nvidia's drivers and Intel (mesa and blob) doesn't support it. const Id z_pointer = AccessElement(t_out_float, per_vertex, position_index, 2u); @@ -1090,7 +1091,7 @@ private: Emit(OpStore(z_pointer, depth)); break; } - case ShaderStage::Fragment: { + case ShaderType::Fragment: { const auto SafeGetRegister = [&](u32 reg) { // TODO(Rodrigo): Replace with contains once C++20 releases if (const auto it = registers.find(reg); it != registers.end()) { @@ -1519,7 +1520,7 @@ private: const VKDevice& device; const ShaderIR& ir; - const ShaderStage stage; + const ShaderType stage; const Tegra::Shader::Header header; u64 conditional_nest_count{}; u64 inside_branch{}; @@ -1851,7 +1852,7 @@ void SPIRVDecompiler::DecompileAST() { } DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, - Maxwell::ShaderStage stage) { + ShaderType stage) { auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage); decompiler->Decompile(); return {std::move(decompiler), decompiler->GetShaderEntries()}; diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.h b/src/video_core/renderer_vulkan/vk_shader_decompiler.h index f90541cc1..203fc00d0 100644 --- a/src/video_core/renderer_vulkan/vk_shader_decompiler.h +++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.h @@ -79,6 +79,6 @@ struct ShaderEntries { using DecompilerResult = std::pair<std::unique_ptr<Sirit::Module>, ShaderEntries>; DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, - Maxwell::ShaderStage stage); + Tegra::Engines::ShaderType stage); } // namespace Vulkan::VKShader diff --git a/src/video_core/shader/const_buffer_locker.cpp b/src/video_core/shader/const_buffer_locker.cpp index fe467608e..a4a0319eb 100644 --- a/src/video_core/shader/const_buffer_locker.cpp +++ b/src/video_core/shader/const_buffer_locker.cpp @@ -2,13 +2,12 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. -#pragma once - #include <algorithm> -#include <memory> -#include "common/assert.h" +#include <tuple> + #include "common/common_types.h" #include "video_core/engines/maxwell_3d.h" +#include "video_core/engines/shader_type.h" #include "video_core/shader/const_buffer_locker.h" namespace VideoCommon::Shader { @@ -103,8 +102,8 @@ bool ConstBufferLocker::IsConsistent() const { } bool ConstBufferLocker::HasEqualKeys(const ConstBufferLocker& rhs) const { - return keys == rhs.keys && bound_samplers == rhs.bound_samplers && - bindless_samplers == rhs.bindless_samplers; + return std::tie(keys, bound_samplers, bindless_samplers) == + std::tie(rhs.keys, rhs.bound_samplers, rhs.bindless_samplers); } } // namespace VideoCommon::Shader diff --git a/src/video_core/shader/const_buffer_locker.h b/src/video_core/shader/const_buffer_locker.h index 600e2f3c3..d32e2d657 100644 --- a/src/video_core/shader/const_buffer_locker.h +++ b/src/video_core/shader/const_buffer_locker.h @@ -4,10 +4,12 @@ #pragma once +#include <optional> #include <unordered_map> #include "common/common_types.h" #include "common/hash.h" #include "video_core/engines/const_buffer_engine_interface.h" +#include "video_core/engines/shader_type.h" namespace VideoCommon::Shader { @@ -20,7 +22,7 @@ using BindlessSamplerMap = * The ConstBufferLocker is a class use to interface the 3D and compute engines with the shader * compiler. with it, the shader can obtain required data from GPU state and store it for disk * shader compilation. - **/ + */ class ConstBufferLocker { public: explicit ConstBufferLocker(Tegra::Engines::ShaderType shader_stage); diff --git a/src/video_core/shader/decode/other.cpp b/src/video_core/shader/decode/other.cpp index 116b95f76..17cd45d3c 100644 --- a/src/video_core/shader/decode/other.cpp +++ b/src/video_core/shader/decode/other.cpp @@ -256,7 +256,7 @@ u32 ShaderIR::DecodeOther(NodeBlock& bb, u32 pc) { break; } case OpCode::Id::DEPBAR: { - LOG_WARNING(HW_GPU, "DEPBAR instruction is stubbed"); + LOG_DEBUG(HW_GPU, "DEPBAR instruction is stubbed"); break; } default: diff --git a/src/video_core/shader/decode/texture.cpp b/src/video_core/shader/decode/texture.cpp index 9afba2495..da8e886df 100644 --- a/src/video_core/shader/decode/texture.cpp +++ b/src/video_core/shader/decode/texture.cpp @@ -128,8 +128,8 @@ u32 ShaderIR::DecodeTexture(NodeBlock& bb, u32 pc) { } const Node component = Immediate(static_cast<u32>(instr.tld4s.component)); - const auto& sampler = - GetSampler(instr.sampler, {{TextureType::Texture2D, false, depth_compare}}); + const SamplerInfo info{TextureType::Texture2D, false, depth_compare}; + const auto& sampler = GetSampler(instr.sampler, info); Node4 values; for (u32 element = 0; element < values.size(); ++element) { @@ -188,7 +188,7 @@ u32 ShaderIR::DecodeTexture(NodeBlock& bb, u32 pc) { // Sadly, not all texture instructions specify the type of texture their sampler // uses. This must be fixed at a later instance. const auto& sampler = - is_bindless ? GetBindlessSampler(instr.gpr8, {}) : GetSampler(instr.sampler, {}); + is_bindless ? GetBindlessSampler(instr.gpr8) : GetSampler(instr.sampler); u32 indexer = 0; switch (instr.txq.query_type) { @@ -224,8 +224,7 @@ u32 ShaderIR::DecodeTexture(NodeBlock& bb, u32 pc) { auto texture_type = instr.tmml.texture_type.Value(); const bool is_array = instr.tmml.array != 0; const auto& sampler = - is_bindless ? GetBindlessSampler(instr.gpr20, {{texture_type, is_array, false}}) - : GetSampler(instr.sampler, {{texture_type, is_array, false}}); + is_bindless ? GetBindlessSampler(instr.gpr20) : GetSampler(instr.sampler); std::vector<Node> coords; @@ -293,67 +292,50 @@ u32 ShaderIR::DecodeTexture(NodeBlock& bb, u32 pc) { return pc; } -const Sampler& ShaderIR::GetSampler(const Tegra::Shader::Sampler& sampler, - std::optional<SamplerInfo> sampler_info) { - const auto offset = static_cast<u32>(sampler.index.Value()); - - TextureType type; - bool is_array; - bool is_shadow; +ShaderIR::SamplerInfo ShaderIR::GetSamplerInfo(std::optional<SamplerInfo> sampler_info, u32 offset, + std::optional<u32> buffer) { if (sampler_info) { - type = sampler_info->type; - is_array = sampler_info->is_array; - is_shadow = sampler_info->is_shadow; - } else if (const auto sampler = locker.ObtainBoundSampler(offset)) { - type = sampler->texture_type.Value(); - is_array = sampler->is_array.Value() != 0; - is_shadow = sampler->is_shadow.Value() != 0; - } else { + return *sampler_info; + } + const auto sampler = + buffer ? locker.ObtainBindlessSampler(*buffer, offset) : locker.ObtainBoundSampler(offset); + if (!sampler) { LOG_WARNING(HW_GPU, "Unknown sampler info"); - type = TextureType::Texture2D; - is_array = false; - is_shadow = false; + return SamplerInfo{TextureType::Texture2D, false, false, false}; } + return SamplerInfo{sampler->texture_type, sampler->is_array != 0, sampler->is_shadow != 0, + sampler->is_buffer != 0}; +} + +const Sampler& ShaderIR::GetSampler(const Tegra::Shader::Sampler& sampler, + std::optional<SamplerInfo> sampler_info) { + const auto offset = static_cast<u32>(sampler.index.Value()); + const auto info = GetSamplerInfo(sampler_info, offset); // If this sampler has already been used, return the existing mapping. const auto it = std::find_if(used_samplers.begin(), used_samplers.end(), [offset](const Sampler& entry) { return entry.GetOffset() == offset; }); if (it != used_samplers.end()) { - ASSERT(!it->IsBindless() && it->GetType() == type && it->IsArray() == is_array && - it->IsShadow() == is_shadow); + ASSERT(!it->IsBindless() && it->GetType() == info.type && it->IsArray() == info.is_array && + it->IsShadow() == info.is_shadow && it->IsBuffer() == info.is_buffer); return *it; } // Otherwise create a new mapping for this sampler const auto next_index = static_cast<u32>(used_samplers.size()); - return used_samplers.emplace_back(Sampler(next_index, offset, type, is_array, is_shadow)); + return used_samplers.emplace_back(next_index, offset, info.type, info.is_array, info.is_shadow, + info.is_buffer); } -const Sampler& ShaderIR::GetBindlessSampler(const Tegra::Shader::Register& reg, +const Sampler& ShaderIR::GetBindlessSampler(Tegra::Shader::Register reg, std::optional<SamplerInfo> sampler_info) { const Node sampler_register = GetRegister(reg); const auto [base_sampler, buffer, offset] = TrackCbuf(sampler_register, global_code, static_cast<s64>(global_code.size())); ASSERT(base_sampler != nullptr); - TextureType type; - bool is_array; - bool is_shadow; - if (sampler_info) { - type = sampler_info->type; - is_array = sampler_info->is_array; - is_shadow = sampler_info->is_shadow; - } else if (const auto sampler = locker.ObtainBindlessSampler(buffer, offset)) { - type = sampler->texture_type.Value(); - is_array = sampler->is_array.Value() != 0; - is_shadow = sampler->is_shadow.Value() != 0; - } else { - LOG_WARNING(HW_GPU, "Unknown sampler info"); - type = TextureType::Texture2D; - is_array = false; - is_shadow = false; - } + const auto info = GetSamplerInfo(sampler_info, offset, buffer); // If this sampler has already been used, return the existing mapping. const auto it = @@ -362,15 +344,15 @@ const Sampler& ShaderIR::GetBindlessSampler(const Tegra::Shader::Register& reg, return entry.GetBuffer() == buffer && entry.GetOffset() == offset; }); if (it != used_samplers.end()) { - ASSERT(it->IsBindless() && it->GetType() == type && it->IsArray() == is_array && - it->IsShadow() == is_shadow); + ASSERT(it->IsBindless() && it->GetType() == info.type && it->IsArray() == info.is_array && + it->IsShadow() == info.is_shadow); return *it; } // Otherwise create a new mapping for this sampler const auto next_index = static_cast<u32>(used_samplers.size()); - return used_samplers.emplace_back( - Sampler(next_index, offset, buffer, type, is_array, is_shadow)); + return used_samplers.emplace_back(next_index, offset, buffer, info.type, info.is_array, + info.is_shadow, info.is_buffer); } void ShaderIR::WriteTexInstructionFloat(NodeBlock& bb, Instruction instr, const Node4& components) { @@ -455,17 +437,16 @@ Node4 ShaderIR::GetTextureCode(Instruction instr, TextureType texture_type, (texture_type == TextureType::TextureCube && is_array && is_shadow), "This method is not supported."); + const SamplerInfo info{texture_type, is_array, is_shadow, false}; const auto& sampler = - is_bindless ? GetBindlessSampler(*bindless_reg, {{texture_type, is_array, is_shadow}}) - : GetSampler(instr.sampler, {{texture_type, is_array, is_shadow}}); + is_bindless ? GetBindlessSampler(*bindless_reg, info) : GetSampler(instr.sampler, info); const bool lod_needed = process_mode == TextureProcessMode::LZ || process_mode == TextureProcessMode::LL || process_mode == TextureProcessMode::LLA; - // LOD selection (either via bias or explicit textureLod) not - // supported in GL for sampler2DArrayShadow and - // samplerCubeArrayShadow. + // LOD selection (either via bias or explicit textureLod) not supported in GL for + // sampler2DArrayShadow and samplerCubeArrayShadow. const bool gl_lod_supported = !((texture_type == Tegra::Shader::TextureType::Texture2D && is_array && is_shadow) || (texture_type == Tegra::Shader::TextureType::TextureCube && is_array && is_shadow)); @@ -475,8 +456,8 @@ Node4 ShaderIR::GetTextureCode(Instruction instr, TextureType texture_type, UNIMPLEMENTED_IF(process_mode != TextureProcessMode::None && !gl_lod_supported); - Node bias = {}; - Node lod = {}; + Node bias; + Node lod; if (process_mode != TextureProcessMode::None && gl_lod_supported) { switch (process_mode) { case TextureProcessMode::LZ: @@ -612,10 +593,9 @@ Node4 ShaderIR::GetTld4Code(Instruction instr, TextureType texture_type, bool de u64 parameter_register = instr.gpr20.Value(); - const auto& sampler = - is_bindless - ? GetBindlessSampler(parameter_register++, {{texture_type, is_array, depth_compare}}) - : GetSampler(instr.sampler, {{texture_type, is_array, depth_compare}}); + const SamplerInfo info{texture_type, is_array, depth_compare, false}; + const auto& sampler = is_bindless ? GetBindlessSampler(parameter_register++, info) + : GetSampler(instr.sampler, info); std::vector<Node> aoffi; if (is_aoffi) { @@ -662,7 +642,7 @@ Node4 ShaderIR::GetTldCode(Tegra::Shader::Instruction instr) { // const Node aoffi_register{is_aoffi ? GetRegister(gpr20_cursor++) : nullptr}; // const Node multisample{is_multisample ? GetRegister(gpr20_cursor++) : nullptr}; - const auto& sampler = GetSampler(instr.sampler, {{texture_type, is_array, false}}); + const auto& sampler = GetSampler(instr.sampler); Node4 values; for (u32 element = 0; element < values.size(); ++element) { @@ -675,6 +655,8 @@ Node4 ShaderIR::GetTldCode(Tegra::Shader::Instruction instr) { } Node4 ShaderIR::GetTldsCode(Instruction instr, TextureType texture_type, bool is_array) { + const auto& sampler = GetSampler(instr.sampler); + const std::size_t type_coord_count = GetCoordCount(texture_type); const bool lod_enabled = instr.tlds.GetTextureProcessMode() == TextureProcessMode::LL; @@ -698,7 +680,14 @@ Node4 ShaderIR::GetTldsCode(Instruction instr, TextureType texture_type, bool is // When lod is used always is in gpr20 const Node lod = lod_enabled ? GetRegister(instr.gpr20) : Immediate(0); - const auto& sampler = GetSampler(instr.sampler, {{texture_type, is_array, false}}); + // Fill empty entries from the guest sampler. + const std::size_t entry_coord_count = GetCoordCount(sampler.GetType()); + if (type_coord_count != entry_coord_count) { + LOG_WARNING(HW_GPU, "Bound and built texture types mismatch"); + } + for (std::size_t i = type_coord_count; i < entry_coord_count; ++i) { + coords.push_back(GetRegister(Register::ZeroIndex)); + } Node4 values; for (u32 element = 0; element < values.size(); ++element) { diff --git a/src/video_core/shader/node.h b/src/video_core/shader/node.h index 6c5046d3b..b2576bdd6 100644 --- a/src/video_core/shader/node.h +++ b/src/video_core/shader/node.h @@ -228,14 +228,15 @@ class Sampler { public: /// This constructor is for bound samplers constexpr explicit Sampler(u32 index, u32 offset, Tegra::Shader::TextureType type, - bool is_array, bool is_shadow) - : index{index}, offset{offset}, type{type}, is_array{is_array}, is_shadow{is_shadow} {} + bool is_array, bool is_shadow, bool is_buffer) + : index{index}, offset{offset}, type{type}, is_array{is_array}, is_shadow{is_shadow}, + is_buffer{is_buffer} {} /// This constructor is for bindless samplers constexpr explicit Sampler(u32 index, u32 offset, u32 buffer, Tegra::Shader::TextureType type, - bool is_array, bool is_shadow) + bool is_array, bool is_shadow, bool is_buffer) : index{index}, offset{offset}, buffer{buffer}, type{type}, is_array{is_array}, - is_shadow{is_shadow}, is_bindless{true} {} + is_shadow{is_shadow}, is_buffer{is_buffer}, is_bindless{true} {} constexpr u32 GetIndex() const { return index; @@ -261,6 +262,10 @@ public: return is_shadow; } + constexpr bool IsBuffer() const { + return is_buffer; + } + constexpr bool IsBindless() const { return is_bindless; } @@ -273,6 +278,7 @@ private: Tegra::Shader::TextureType type{}; ///< The type used to sample this texture (Texture2D, etc) bool is_array{}; ///< Whether the texture is being sampled as an array texture or not. bool is_shadow{}; ///< Whether the texture is being sampled as a depth texture or not. + bool is_buffer{}; ///< Whether the texture is a texture buffer without sampler. bool is_bindless{}; ///< Whether this sampler belongs to a bindless texture or not. }; diff --git a/src/video_core/shader/shader_ir.h b/src/video_core/shader/shader_ir.h index 76a849818..2f71a50d2 100644 --- a/src/video_core/shader/shader_ir.h +++ b/src/video_core/shader/shader_ir.h @@ -179,6 +179,7 @@ private: Tegra::Shader::TextureType type; bool is_array; bool is_shadow; + bool is_buffer; }; void Decode(); @@ -303,13 +304,17 @@ private: /// Returns a predicate combiner operation OperationCode GetPredicateCombiner(Tegra::Shader::PredOperation operation); + /// Queries the missing sampler info from the execution context. + SamplerInfo GetSamplerInfo(std::optional<SamplerInfo> sampler_info, u32 offset, + std::optional<u32> buffer = std::nullopt); + /// Accesses a texture sampler const Sampler& GetSampler(const Tegra::Shader::Sampler& sampler, - std::optional<SamplerInfo> sampler_info); + std::optional<SamplerInfo> sampler_info = std::nullopt); - // Accesses a texture sampler for a bindless texture. - const Sampler& GetBindlessSampler(const Tegra::Shader::Register& reg, - std::optional<SamplerInfo> sampler_info); + /// Accesses a texture sampler for a bindless texture. + const Sampler& GetBindlessSampler(Tegra::Shader::Register reg, + std::optional<SamplerInfo> sampler_info = std::nullopt); /// Accesses an image. Image& GetImage(Tegra::Shader::Image image, Tegra::Shader::ImageType type); diff --git a/src/video_core/surface.cpp b/src/video_core/surface.cpp index 621136b6e..1655ccf16 100644 --- a/src/video_core/surface.cpp +++ b/src/video_core/surface.cpp @@ -168,309 +168,6 @@ PixelFormat PixelFormatFromRenderTargetFormat(Tegra::RenderTargetFormat format) } } -PixelFormat PixelFormatFromTextureFormat(Tegra::Texture::TextureFormat format, - Tegra::Texture::ComponentType component_type, - bool is_srgb) { - // TODO(Subv): Properly implement this - switch (format) { - case Tegra::Texture::TextureFormat::A8R8G8B8: - if (is_srgb) { - return PixelFormat::RGBA8_SRGB; - } - switch (component_type) { - case Tegra::Texture::ComponentType::UNORM: - return PixelFormat::ABGR8U; - case Tegra::Texture::ComponentType::SNORM: - return PixelFormat::ABGR8S; - case Tegra::Texture::ComponentType::UINT: - return PixelFormat::ABGR8UI; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::B5G6R5: - switch (component_type) { - case Tegra::Texture::ComponentType::UNORM: - return PixelFormat::B5G6R5U; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::A2B10G10R10: - switch (component_type) { - case Tegra::Texture::ComponentType::UNORM: - return PixelFormat::A2B10G10R10U; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::A1B5G5R5: - switch (component_type) { - case Tegra::Texture::ComponentType::UNORM: - return PixelFormat::A1B5G5R5U; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::A4B4G4R4: - switch (component_type) { - case Tegra::Texture::ComponentType::UNORM: - return PixelFormat::R4G4B4A4U; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::R8: - switch (component_type) { - case Tegra::Texture::ComponentType::UNORM: - return PixelFormat::R8U; - case Tegra::Texture::ComponentType::UINT: - return PixelFormat::R8UI; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::G8R8: - // TextureFormat::G8R8 is actually ordered red then green, as such we can use - // PixelFormat::RG8U and PixelFormat::RG8S. This was tested with The Legend of Zelda: Breath - // of the Wild, which uses this format to render the hearts on the UI. - switch (component_type) { - case Tegra::Texture::ComponentType::UNORM: - return PixelFormat::RG8U; - case Tegra::Texture::ComponentType::SNORM: - return PixelFormat::RG8S; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::R16_G16_B16_A16: - switch (component_type) { - case Tegra::Texture::ComponentType::UNORM: - return PixelFormat::RGBA16U; - case Tegra::Texture::ComponentType::FLOAT: - return PixelFormat::RGBA16F; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::BF10GF11RF11: - switch (component_type) { - case Tegra::Texture::ComponentType::FLOAT: - return PixelFormat::R11FG11FB10F; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::R32_G32_B32_A32: - switch (component_type) { - case Tegra::Texture::ComponentType::FLOAT: - return PixelFormat::RGBA32F; - case Tegra::Texture::ComponentType::UINT: - return PixelFormat::RGBA32UI; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::R32_G32: - switch (component_type) { - case Tegra::Texture::ComponentType::FLOAT: - return PixelFormat::RG32F; - case Tegra::Texture::ComponentType::UINT: - return PixelFormat::RG32UI; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::R32_G32_B32: - switch (component_type) { - case Tegra::Texture::ComponentType::FLOAT: - return PixelFormat::RGB32F; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::R16: - switch (component_type) { - case Tegra::Texture::ComponentType::FLOAT: - return PixelFormat::R16F; - case Tegra::Texture::ComponentType::UNORM: - return PixelFormat::R16U; - case Tegra::Texture::ComponentType::SNORM: - return PixelFormat::R16S; - case Tegra::Texture::ComponentType::UINT: - return PixelFormat::R16UI; - case Tegra::Texture::ComponentType::SINT: - return PixelFormat::R16I; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::R32: - switch (component_type) { - case Tegra::Texture::ComponentType::FLOAT: - return PixelFormat::R32F; - case Tegra::Texture::ComponentType::UINT: - return PixelFormat::R32UI; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::E5B9G9R9_SHAREDEXP: - switch (component_type) { - case Tegra::Texture::ComponentType::FLOAT: - return PixelFormat::E5B9G9R9F; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::ZF32: - return PixelFormat::Z32F; - case Tegra::Texture::TextureFormat::Z16: - return PixelFormat::Z16; - case Tegra::Texture::TextureFormat::S8Z24: - return PixelFormat::S8Z24; - case Tegra::Texture::TextureFormat::ZF32_X24S8: - return PixelFormat::Z32FS8; - case Tegra::Texture::TextureFormat::DXT1: - return is_srgb ? PixelFormat::DXT1_SRGB : PixelFormat::DXT1; - case Tegra::Texture::TextureFormat::DXT23: - return is_srgb ? PixelFormat::DXT23_SRGB : PixelFormat::DXT23; - case Tegra::Texture::TextureFormat::DXT45: - return is_srgb ? PixelFormat::DXT45_SRGB : PixelFormat::DXT45; - case Tegra::Texture::TextureFormat::DXN1: - return PixelFormat::DXN1; - case Tegra::Texture::TextureFormat::DXN2: - switch (component_type) { - case Tegra::Texture::ComponentType::UNORM: - return PixelFormat::DXN2UNORM; - case Tegra::Texture::ComponentType::SNORM: - return PixelFormat::DXN2SNORM; - default: - break; - } - break; - case Tegra::Texture::TextureFormat::BC7U: - return is_srgb ? PixelFormat::BC7U_SRGB : PixelFormat::BC7U; - case Tegra::Texture::TextureFormat::BC6H_UF16: - return PixelFormat::BC6H_UF16; - case Tegra::Texture::TextureFormat::BC6H_SF16: - return PixelFormat::BC6H_SF16; - case Tegra::Texture::TextureFormat::ASTC_2D_4X4: - return is_srgb ? PixelFormat::ASTC_2D_4X4_SRGB : PixelFormat::ASTC_2D_4X4; - case Tegra::Texture::TextureFormat::ASTC_2D_5X4: - return is_srgb ? PixelFormat::ASTC_2D_5X4_SRGB : PixelFormat::ASTC_2D_5X4; - case Tegra::Texture::TextureFormat::ASTC_2D_5X5: - return is_srgb ? PixelFormat::ASTC_2D_5X5_SRGB : PixelFormat::ASTC_2D_5X5; - case Tegra::Texture::TextureFormat::ASTC_2D_8X8: - return is_srgb ? PixelFormat::ASTC_2D_8X8_SRGB : PixelFormat::ASTC_2D_8X8; - case Tegra::Texture::TextureFormat::ASTC_2D_8X5: - return is_srgb ? PixelFormat::ASTC_2D_8X5_SRGB : PixelFormat::ASTC_2D_8X5; - case Tegra::Texture::TextureFormat::ASTC_2D_10X8: - return is_srgb ? PixelFormat::ASTC_2D_10X8_SRGB : PixelFormat::ASTC_2D_10X8; - case Tegra::Texture::TextureFormat::ASTC_2D_6X6: - return is_srgb ? PixelFormat::ASTC_2D_6X6_SRGB : PixelFormat::ASTC_2D_6X6; - case Tegra::Texture::TextureFormat::ASTC_2D_10X10: - return is_srgb ? PixelFormat::ASTC_2D_10X10_SRGB : PixelFormat::ASTC_2D_10X10; - case Tegra::Texture::TextureFormat::ASTC_2D_12X12: - return is_srgb ? PixelFormat::ASTC_2D_12X12_SRGB : PixelFormat::ASTC_2D_12X12; - case Tegra::Texture::TextureFormat::ASTC_2D_8X6: - return is_srgb ? PixelFormat::ASTC_2D_8X6_SRGB : PixelFormat::ASTC_2D_8X6; - case Tegra::Texture::TextureFormat::ASTC_2D_6X5: - return is_srgb ? PixelFormat::ASTC_2D_6X5_SRGB : PixelFormat::ASTC_2D_6X5; - case Tegra::Texture::TextureFormat::R16_G16: - switch (component_type) { - case Tegra::Texture::ComponentType::FLOAT: - return PixelFormat::RG16F; - case Tegra::Texture::ComponentType::UNORM: - return PixelFormat::RG16; - case Tegra::Texture::ComponentType::SNORM: - return PixelFormat::RG16S; - case Tegra::Texture::ComponentType::UINT: - return PixelFormat::RG16UI; - case Tegra::Texture::ComponentType::SINT: - return PixelFormat::RG16I; - default: - break; - } - break; - default: - break; - } - LOG_CRITICAL(HW_GPU, "Unimplemented format={}, component_type={}", static_cast<u32>(format), - static_cast<u32>(component_type)); - UNREACHABLE(); - return PixelFormat::ABGR8U; -} - -ComponentType ComponentTypeFromTexture(Tegra::Texture::ComponentType type) { - // TODO(Subv): Implement more component types - switch (type) { - case Tegra::Texture::ComponentType::UNORM: - return ComponentType::UNorm; - case Tegra::Texture::ComponentType::FLOAT: - return ComponentType::Float; - case Tegra::Texture::ComponentType::SNORM: - return ComponentType::SNorm; - case Tegra::Texture::ComponentType::UINT: - return ComponentType::UInt; - case Tegra::Texture::ComponentType::SINT: - return ComponentType::SInt; - default: - LOG_CRITICAL(HW_GPU, "Unimplemented component type={}", static_cast<u32>(type)); - UNREACHABLE(); - return ComponentType::UNorm; - } -} - -ComponentType ComponentTypeFromRenderTarget(Tegra::RenderTargetFormat format) { - // TODO(Subv): Implement more render targets - switch (format) { - case Tegra::RenderTargetFormat::RGBA8_UNORM: - case Tegra::RenderTargetFormat::RGBA8_SRGB: - case Tegra::RenderTargetFormat::BGRA8_UNORM: - case Tegra::RenderTargetFormat::BGRA8_SRGB: - case Tegra::RenderTargetFormat::RGB10_A2_UNORM: - case Tegra::RenderTargetFormat::R8_UNORM: - case Tegra::RenderTargetFormat::RG16_UNORM: - case Tegra::RenderTargetFormat::R16_UNORM: - case Tegra::RenderTargetFormat::B5G6R5_UNORM: - case Tegra::RenderTargetFormat::BGR5A1_UNORM: - case Tegra::RenderTargetFormat::RG8_UNORM: - case Tegra::RenderTargetFormat::RGBA16_UNORM: - return ComponentType::UNorm; - case Tegra::RenderTargetFormat::RGBA8_SNORM: - case Tegra::RenderTargetFormat::RG16_SNORM: - case Tegra::RenderTargetFormat::R16_SNORM: - case Tegra::RenderTargetFormat::RG8_SNORM: - return ComponentType::SNorm; - case Tegra::RenderTargetFormat::RGBA16_FLOAT: - case Tegra::RenderTargetFormat::RGBX16_FLOAT: - case Tegra::RenderTargetFormat::R11G11B10_FLOAT: - case Tegra::RenderTargetFormat::RGBA32_FLOAT: - case Tegra::RenderTargetFormat::RG32_FLOAT: - case Tegra::RenderTargetFormat::RG16_FLOAT: - case Tegra::RenderTargetFormat::R16_FLOAT: - case Tegra::RenderTargetFormat::R32_FLOAT: - return ComponentType::Float; - case Tegra::RenderTargetFormat::RGBA32_UINT: - case Tegra::RenderTargetFormat::RGBA16_UINT: - case Tegra::RenderTargetFormat::RG16_UINT: - case Tegra::RenderTargetFormat::R8_UINT: - case Tegra::RenderTargetFormat::R16_UINT: - case Tegra::RenderTargetFormat::RG32_UINT: - case Tegra::RenderTargetFormat::R32_UINT: - case Tegra::RenderTargetFormat::RGBA8_UINT: - return ComponentType::UInt; - case Tegra::RenderTargetFormat::RG16_SINT: - case Tegra::RenderTargetFormat::R16_SINT: - return ComponentType::SInt; - default: - LOG_CRITICAL(HW_GPU, "Unimplemented format={}", static_cast<u32>(format)); - UNREACHABLE(); - return ComponentType::UNorm; - } -} - PixelFormat PixelFormatFromGPUPixelFormat(Tegra::FramebufferConfig::PixelFormat format) { switch (format) { case Tegra::FramebufferConfig::PixelFormat::ABGR8: @@ -485,22 +182,6 @@ PixelFormat PixelFormatFromGPUPixelFormat(Tegra::FramebufferConfig::PixelFormat } } -ComponentType ComponentTypeFromDepthFormat(Tegra::DepthFormat format) { - switch (format) { - case Tegra::DepthFormat::Z16_UNORM: - case Tegra::DepthFormat::S8_Z24_UNORM: - case Tegra::DepthFormat::Z24_S8_UNORM: - return ComponentType::UNorm; - case Tegra::DepthFormat::Z32_FLOAT: - case Tegra::DepthFormat::Z32_S8_X24_FLOAT: - return ComponentType::Float; - default: - LOG_CRITICAL(HW_GPU, "Unimplemented format={}", static_cast<u32>(format)); - UNREACHABLE(); - return ComponentType::UNorm; - } -} - SurfaceType GetFormatType(PixelFormat pixel_format) { if (static_cast<std::size_t>(pixel_format) < static_cast<std::size_t>(PixelFormat::MaxColorFormat)) { diff --git a/src/video_core/surface.h b/src/video_core/surface.h index d3bcd38c5..0d17a93ed 100644 --- a/src/video_core/surface.h +++ b/src/video_core/surface.h @@ -106,18 +106,8 @@ enum class PixelFormat { Max = MaxDepthStencilFormat, Invalid = 255, }; - static constexpr std::size_t MaxPixelFormat = static_cast<std::size_t>(PixelFormat::Max); -enum class ComponentType { - Invalid = 0, - SNorm = 1, - UNorm = 2, - SInt = 3, - UInt = 4, - Float = 5, -}; - enum class SurfaceType { ColorTexture = 0, Depth = 1, @@ -609,18 +599,8 @@ PixelFormat PixelFormatFromDepthFormat(Tegra::DepthFormat format); PixelFormat PixelFormatFromRenderTargetFormat(Tegra::RenderTargetFormat format); -PixelFormat PixelFormatFromTextureFormat(Tegra::Texture::TextureFormat format, - Tegra::Texture::ComponentType component_type, - bool is_srgb); - -ComponentType ComponentTypeFromTexture(Tegra::Texture::ComponentType type); - -ComponentType ComponentTypeFromRenderTarget(Tegra::RenderTargetFormat format); - PixelFormat PixelFormatFromGPUPixelFormat(Tegra::FramebufferConfig::PixelFormat format); -ComponentType ComponentTypeFromDepthFormat(Tegra::DepthFormat format); - SurfaceType GetFormatType(PixelFormat pixel_format); bool IsPixelFormatASTC(PixelFormat format); diff --git a/src/video_core/texture_cache/format_lookup_table.cpp b/src/video_core/texture_cache/format_lookup_table.cpp new file mode 100644 index 000000000..271e67533 --- /dev/null +++ b/src/video_core/texture_cache/format_lookup_table.cpp @@ -0,0 +1,208 @@ +// Copyright 2019 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include <array> +#include "common/common_types.h" +#include "common/logging/log.h" +#include "video_core/texture_cache/format_lookup_table.h" + +namespace VideoCommon { + +using Tegra::Texture::ComponentType; +using Tegra::Texture::TextureFormat; +using VideoCore::Surface::PixelFormat; + +namespace { + +constexpr auto SNORM = ComponentType::SNORM; +constexpr auto UNORM = ComponentType::UNORM; +constexpr auto SINT = ComponentType::SINT; +constexpr auto UINT = ComponentType::UINT; +constexpr auto SNORM_FORCE_FP16 = ComponentType::SNORM_FORCE_FP16; +constexpr auto UNORM_FORCE_FP16 = ComponentType::UNORM_FORCE_FP16; +constexpr auto FLOAT = ComponentType::FLOAT; +constexpr bool C = false; // Normal color +constexpr bool S = true; // Srgb + +struct Table { + constexpr Table(TextureFormat texture_format, bool is_srgb, ComponentType red_component, + ComponentType green_component, ComponentType blue_component, + ComponentType alpha_component, PixelFormat pixel_format) + : texture_format{texture_format}, pixel_format{pixel_format}, red_component{red_component}, + green_component{green_component}, blue_component{blue_component}, + alpha_component{alpha_component}, is_srgb{is_srgb} {} + + TextureFormat texture_format; + PixelFormat pixel_format; + ComponentType red_component; + ComponentType green_component; + ComponentType blue_component; + ComponentType alpha_component; + bool is_srgb; +}; +constexpr std::array<Table, 74> DefinitionTable = {{ + {TextureFormat::A8R8G8B8, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ABGR8U}, + {TextureFormat::A8R8G8B8, C, SNORM, SNORM, SNORM, SNORM, PixelFormat::ABGR8S}, + {TextureFormat::A8R8G8B8, C, UINT, UINT, UINT, UINT, PixelFormat::ABGR8UI}, + {TextureFormat::A8R8G8B8, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::RGBA8_SRGB}, + + {TextureFormat::B5G6R5, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::B5G6R5U}, + + {TextureFormat::A2B10G10R10, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::A2B10G10R10U}, + + {TextureFormat::A1B5G5R5, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::A1B5G5R5U}, + + {TextureFormat::A4B4G4R4, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::R4G4B4A4U}, + + {TextureFormat::R8, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::R8U}, + {TextureFormat::R8, C, UINT, UINT, UINT, UINT, PixelFormat::R8UI}, + + {TextureFormat::G8R8, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::RG8U}, + {TextureFormat::G8R8, C, SNORM, SNORM, SNORM, SNORM, PixelFormat::RG8S}, + + {TextureFormat::R16_G16_B16_A16, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::RGBA16U}, + {TextureFormat::R16_G16_B16_A16, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::RGBA16F}, + {TextureFormat::R16_G16_B16_A16, C, UINT, UINT, UINT, UINT, PixelFormat::RGBA16UI}, + + {TextureFormat::R16_G16, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::RG16F}, + {TextureFormat::R16_G16, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::RG16}, + {TextureFormat::R16_G16, C, SNORM, SNORM, SNORM, SNORM, PixelFormat::RG16S}, + {TextureFormat::R16_G16, C, UINT, UINT, UINT, UINT, PixelFormat::RG16UI}, + {TextureFormat::R16_G16, C, SINT, SINT, SINT, SINT, PixelFormat::RG16I}, + + {TextureFormat::R16, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::R16F}, + {TextureFormat::R16, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::R16U}, + {TextureFormat::R16, C, SNORM, SNORM, SNORM, SNORM, PixelFormat::R16S}, + {TextureFormat::R16, C, UINT, UINT, UINT, UINT, PixelFormat::R16UI}, + {TextureFormat::R16, C, SINT, SINT, SINT, SINT, PixelFormat::R16I}, + + {TextureFormat::BF10GF11RF11, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::R11FG11FB10F}, + + {TextureFormat::R32_G32_B32_A32, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::RGBA32F}, + {TextureFormat::R32_G32_B32_A32, C, UINT, UINT, UINT, UINT, PixelFormat::RGBA32UI}, + + {TextureFormat::R32_G32_B32, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::RGB32F}, + + {TextureFormat::R32_G32, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::RG32F}, + {TextureFormat::R32_G32, C, UINT, UINT, UINT, UINT, PixelFormat::RG32UI}, + + {TextureFormat::R32, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::R32F}, + {TextureFormat::R32, C, UINT, UINT, UINT, UINT, PixelFormat::R32UI}, + + {TextureFormat::E5B9G9R9_SHAREDEXP, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::E5B9G9R9F}, + + {TextureFormat::ZF32, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::Z32F}, + {TextureFormat::Z16, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::Z16}, + {TextureFormat::S8Z24, C, UINT, UNORM, UNORM, UNORM, PixelFormat::S8Z24}, + {TextureFormat::ZF32_X24S8, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::Z32FS8}, + + {TextureFormat::DXT1, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::DXT1}, + {TextureFormat::DXT1, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::DXT1_SRGB}, + + {TextureFormat::DXT23, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::DXT23}, + {TextureFormat::DXT23, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::DXT23_SRGB}, + + {TextureFormat::DXT45, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::DXT45}, + {TextureFormat::DXT45, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::DXT45_SRGB}, + + // TODO: Use a different pixel format for SNORM + {TextureFormat::DXN1, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::DXN1}, + {TextureFormat::DXN1, C, SNORM, SNORM, SNORM, SNORM, PixelFormat::DXN1}, + + {TextureFormat::DXN2, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::DXN2UNORM}, + {TextureFormat::DXN2, C, SNORM, SNORM, SNORM, SNORM, PixelFormat::DXN2SNORM}, + + {TextureFormat::BC7U, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::BC7U}, + {TextureFormat::BC7U, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::BC7U_SRGB}, + + {TextureFormat::BC6H_SF16, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::BC6H_SF16}, + {TextureFormat::BC6H_UF16, C, FLOAT, FLOAT, FLOAT, FLOAT, PixelFormat::BC6H_UF16}, + + {TextureFormat::ASTC_2D_4X4, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_4X4}, + {TextureFormat::ASTC_2D_4X4, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_4X4_SRGB}, + + {TextureFormat::ASTC_2D_5X4, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_5X4}, + {TextureFormat::ASTC_2D_5X4, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_5X4_SRGB}, + + {TextureFormat::ASTC_2D_5X5, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_5X5}, + {TextureFormat::ASTC_2D_5X5, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_5X5_SRGB}, + + {TextureFormat::ASTC_2D_8X8, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_8X8}, + {TextureFormat::ASTC_2D_8X8, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_8X8_SRGB}, + + {TextureFormat::ASTC_2D_8X5, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_8X5}, + {TextureFormat::ASTC_2D_8X5, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_8X5_SRGB}, + + {TextureFormat::ASTC_2D_10X8, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_10X8}, + {TextureFormat::ASTC_2D_10X8, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_10X8_SRGB}, + + {TextureFormat::ASTC_2D_6X6, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_6X6}, + {TextureFormat::ASTC_2D_6X6, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_6X6_SRGB}, + + {TextureFormat::ASTC_2D_10X10, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_10X10}, + {TextureFormat::ASTC_2D_10X10, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_10X10_SRGB}, + + {TextureFormat::ASTC_2D_12X12, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_12X12}, + {TextureFormat::ASTC_2D_12X12, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_12X12_SRGB}, + + {TextureFormat::ASTC_2D_8X6, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_8X6}, + {TextureFormat::ASTC_2D_8X6, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_8X6_SRGB}, + + {TextureFormat::ASTC_2D_6X5, C, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_6X5}, + {TextureFormat::ASTC_2D_6X5, S, UNORM, UNORM, UNORM, UNORM, PixelFormat::ASTC_2D_6X5_SRGB}, +}}; + +} // Anonymous namespace + +FormatLookupTable::FormatLookupTable() { + table.fill(static_cast<u8>(PixelFormat::Invalid)); + + for (const auto& entry : DefinitionTable) { + table[CalculateIndex(entry.texture_format, entry.is_srgb != 0, entry.red_component, + entry.green_component, entry.blue_component, entry.alpha_component)] = + static_cast<u8>(entry.pixel_format); + } +} + +PixelFormat FormatLookupTable::GetPixelFormat(TextureFormat format, bool is_srgb, + ComponentType red_component, + ComponentType green_component, + ComponentType blue_component, + ComponentType alpha_component) const noexcept { + const auto pixel_format = static_cast<PixelFormat>(table[CalculateIndex( + format, is_srgb, red_component, green_component, blue_component, alpha_component)]); + // [[likely]] + if (pixel_format != PixelFormat::Invalid) { + return pixel_format; + } + UNIMPLEMENTED_MSG("texture format={} srgb={} components={{{} {} {} {}}}", + static_cast<int>(format), is_srgb, static_cast<int>(red_component), + static_cast<int>(green_component), static_cast<int>(blue_component), + static_cast<int>(alpha_component)); + return PixelFormat::ABGR8U; +} + +void FormatLookupTable::Set(TextureFormat format, bool is_srgb, ComponentType red_component, + ComponentType green_component, ComponentType blue_component, + ComponentType alpha_component, PixelFormat pixel_format) {} + +std::size_t FormatLookupTable::CalculateIndex(TextureFormat format, bool is_srgb, + ComponentType red_component, + ComponentType green_component, + ComponentType blue_component, + ComponentType alpha_component) noexcept { + const auto format_index = static_cast<std::size_t>(format); + const auto red_index = static_cast<std::size_t>(red_component); + const auto green_index = static_cast<std::size_t>(red_component); + const auto blue_index = static_cast<std::size_t>(red_component); + const auto alpha_index = static_cast<std::size_t>(red_component); + const std::size_t srgb_index = is_srgb ? 1 : 0; + + return format_index * PerFormat + + srgb_index * PerComponent * PerComponent * PerComponent * PerComponent + + alpha_index * PerComponent * PerComponent * PerComponent + + blue_index * PerComponent * PerComponent + green_index * PerComponent + red_index; +} + +} // namespace VideoCommon diff --git a/src/video_core/texture_cache/format_lookup_table.h b/src/video_core/texture_cache/format_lookup_table.h new file mode 100644 index 000000000..aa77e0a5a --- /dev/null +++ b/src/video_core/texture_cache/format_lookup_table.h @@ -0,0 +1,51 @@ +// Copyright 2019 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include <array> +#include <limits> +#include "video_core/surface.h" +#include "video_core/textures/texture.h" + +namespace VideoCommon { + +class FormatLookupTable { +public: + explicit FormatLookupTable(); + + VideoCore::Surface::PixelFormat GetPixelFormat( + Tegra::Texture::TextureFormat format, bool is_srgb, + Tegra::Texture::ComponentType red_component, Tegra::Texture::ComponentType green_component, + Tegra::Texture::ComponentType blue_component, + Tegra::Texture::ComponentType alpha_component) const noexcept; + +private: + static_assert(VideoCore::Surface::MaxPixelFormat <= std::numeric_limits<u8>::max()); + + static constexpr std::size_t NumTextureFormats = 128; + + static constexpr std::size_t PerComponent = 8; + static constexpr std::size_t PerComponents2 = PerComponent * PerComponent; + static constexpr std::size_t PerComponents3 = PerComponents2 * PerComponent; + static constexpr std::size_t PerComponents4 = PerComponents3 * PerComponent; + static constexpr std::size_t PerFormat = PerComponents4 * 2; + + static std::size_t CalculateIndex(Tegra::Texture::TextureFormat format, bool is_srgb, + Tegra::Texture::ComponentType red_component, + Tegra::Texture::ComponentType green_component, + Tegra::Texture::ComponentType blue_component, + Tegra::Texture::ComponentType alpha_component) noexcept; + + void Set(Tegra::Texture::TextureFormat format, bool is_srgb, + Tegra::Texture::ComponentType red_component, + Tegra::Texture::ComponentType green_component, + Tegra::Texture::ComponentType blue_component, + Tegra::Texture::ComponentType alpha_component, + VideoCore::Surface::PixelFormat pixel_format); + + std::array<u8, NumTextureFormats * PerFormat> table; +}; + +} // namespace VideoCommon diff --git a/src/video_core/texture_cache/surface_base.h b/src/video_core/texture_cache/surface_base.h index 1bed82898..5f79bb0aa 100644 --- a/src/video_core/texture_cache/surface_base.h +++ b/src/video_core/texture_cache/surface_base.h @@ -254,16 +254,14 @@ public: if (!layer_mipmap) { return {}; } - const u32 end_layer{layer_mipmap->first}; - const u32 end_mipmap{layer_mipmap->second}; + const auto [end_layer, end_mipmap] = *layer_mipmap; if (layer != end_layer) { if (mipmap == 0 && end_mipmap == 0) { - return GetView(ViewParams(view_params.target, layer, end_layer - layer + 1, 0, 1)); + return GetView(ViewParams(view_params.target, layer, end_layer - layer, 0, 1)); } return {}; } else { - return GetView( - ViewParams(view_params.target, layer, 1, mipmap, end_mipmap - mipmap + 1)); + return GetView(ViewParams(view_params.target, layer, 1, mipmap, end_mipmap - mipmap)); } } @@ -278,8 +276,7 @@ public: if (!layer_mipmap) { return {}; } - const u32 layer{layer_mipmap->first}; - const u32 mipmap{layer_mipmap->second}; + const auto [layer, mipmap] = *layer_mipmap; if (GetMipmapSize(mipmap) != candidate_size) { return EmplaceIrregularView(view_params, view_addr, candidate_size, mipmap, layer); } diff --git a/src/video_core/texture_cache/surface_params.cpp b/src/video_core/texture_cache/surface_params.cpp index 1e4d3fb79..a4f1edd9a 100644 --- a/src/video_core/texture_cache/surface_params.cpp +++ b/src/video_core/texture_cache/surface_params.cpp @@ -2,24 +2,23 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. -#include <map> +#include <algorithm> +#include <string> +#include <tuple> #include "common/alignment.h" #include "common/bit_util.h" #include "core/core.h" #include "video_core/engines/shader_bytecode.h" #include "video_core/surface.h" +#include "video_core/texture_cache/format_lookup_table.h" #include "video_core/texture_cache/surface_params.h" namespace VideoCommon { -using VideoCore::Surface::ComponentTypeFromDepthFormat; -using VideoCore::Surface::ComponentTypeFromRenderTarget; -using VideoCore::Surface::ComponentTypeFromTexture; using VideoCore::Surface::PixelFormat; using VideoCore::Surface::PixelFormatFromDepthFormat; using VideoCore::Surface::PixelFormatFromRenderTargetFormat; -using VideoCore::Surface::PixelFormatFromTextureFormat; using VideoCore::Surface::SurfaceTarget; using VideoCore::Surface::SurfaceTargetFromTextureType; using VideoCore::Surface::SurfaceType; @@ -69,7 +68,8 @@ constexpr u32 GetMipmapSize(bool uncompressed, u32 mip_size, u32 tile) { } // Anonymous namespace -SurfaceParams SurfaceParams::CreateForTexture(const Tegra::Texture::TICEntry& tic, +SurfaceParams SurfaceParams::CreateForTexture(const FormatLookupTable& lookup_table, + const Tegra::Texture::TICEntry& tic, const VideoCommon::Shader::Sampler& entry) { SurfaceParams params; params.is_tiled = tic.IsTiled(); @@ -78,8 +78,8 @@ SurfaceParams SurfaceParams::CreateForTexture(const Tegra::Texture::TICEntry& ti params.block_height = params.is_tiled ? tic.BlockHeight() : 0, params.block_depth = params.is_tiled ? tic.BlockDepth() : 0, params.tile_width_spacing = params.is_tiled ? (1 << tic.tile_width_spacing.Value()) : 1; - params.pixel_format = - PixelFormatFromTextureFormat(tic.format, tic.r_type.Value(), params.srgb_conversion); + params.pixel_format = lookup_table.GetPixelFormat( + tic.format, params.srgb_conversion, tic.r_type, tic.g_type, tic.b_type, tic.a_type); params.type = GetFormatType(params.pixel_format); if (entry.IsShadow() && params.type == SurfaceType::ColorTexture) { switch (params.pixel_format) { @@ -99,7 +99,6 @@ SurfaceParams SurfaceParams::CreateForTexture(const Tegra::Texture::TICEntry& ti } params.type = GetFormatType(params.pixel_format); } - params.component_type = ComponentTypeFromTexture(tic.r_type.Value()); params.type = GetFormatType(params.pixel_format); // TODO: on 1DBuffer we should use the tic info. if (tic.IsBuffer()) { @@ -128,7 +127,8 @@ SurfaceParams SurfaceParams::CreateForTexture(const Tegra::Texture::TICEntry& ti return params; } -SurfaceParams SurfaceParams::CreateForImage(const Tegra::Texture::TICEntry& tic, +SurfaceParams SurfaceParams::CreateForImage(const FormatLookupTable& lookup_table, + const Tegra::Texture::TICEntry& tic, const VideoCommon::Shader::Image& entry) { SurfaceParams params; params.is_tiled = tic.IsTiled(); @@ -137,10 +137,9 @@ SurfaceParams SurfaceParams::CreateForImage(const Tegra::Texture::TICEntry& tic, params.block_height = params.is_tiled ? tic.BlockHeight() : 0, params.block_depth = params.is_tiled ? tic.BlockDepth() : 0, params.tile_width_spacing = params.is_tiled ? (1 << tic.tile_width_spacing.Value()) : 1; - params.pixel_format = - PixelFormatFromTextureFormat(tic.format, tic.r_type.Value(), params.srgb_conversion); + params.pixel_format = lookup_table.GetPixelFormat( + tic.format, params.srgb_conversion, tic.r_type, tic.g_type, tic.b_type, tic.a_type); params.type = GetFormatType(params.pixel_format); - params.component_type = ComponentTypeFromTexture(tic.r_type.Value()); params.type = GetFormatType(params.pixel_format); params.target = ImageTypeToSurfaceTarget(entry.GetType()); // TODO: on 1DBuffer we should use the tic info. @@ -181,7 +180,6 @@ SurfaceParams SurfaceParams::CreateForDepthBuffer( params.block_depth = std::min(block_depth, 5U); params.tile_width_spacing = 1; params.pixel_format = PixelFormatFromDepthFormat(format); - params.component_type = ComponentTypeFromDepthFormat(format); params.type = GetFormatType(params.pixel_format); params.width = zeta_width; params.height = zeta_height; @@ -206,7 +204,6 @@ SurfaceParams SurfaceParams::CreateForFramebuffer(Core::System& system, std::siz params.block_depth = config.memory_layout.block_depth; params.tile_width_spacing = 1; params.pixel_format = PixelFormatFromRenderTargetFormat(config.format); - params.component_type = ComponentTypeFromRenderTarget(config.format); params.type = GetFormatType(params.pixel_format); if (params.is_tiled) { params.pitch = 0; @@ -236,7 +233,6 @@ SurfaceParams SurfaceParams::CreateForFermiCopySurface( params.block_depth = params.is_tiled ? std::min(config.BlockDepth(), 5U) : 0, params.tile_width_spacing = 1; params.pixel_format = PixelFormatFromRenderTargetFormat(config.format); - params.component_type = ComponentTypeFromRenderTarget(config.format); params.type = GetFormatType(params.pixel_format); params.width = config.width; params.height = config.height; @@ -250,6 +246,16 @@ SurfaceParams SurfaceParams::CreateForFermiCopySurface( return params; } +VideoCore::Surface::SurfaceTarget SurfaceParams::ExpectedTarget( + const VideoCommon::Shader::Sampler& entry) { + return TextureTypeToSurfaceTarget(entry.GetType(), entry.IsArray()); +} + +VideoCore::Surface::SurfaceTarget SurfaceParams::ExpectedTarget( + const VideoCommon::Shader::Image& entry) { + return ImageTypeToSurfaceTarget(entry.GetType()); +} + bool SurfaceParams::IsLayered() const { switch (target) { case SurfaceTarget::Texture1DArray: @@ -355,10 +361,10 @@ std::size_t SurfaceParams::GetInnerMipmapMemorySize(u32 level, bool as_host_size bool SurfaceParams::operator==(const SurfaceParams& rhs) const { return std::tie(is_tiled, block_width, block_height, block_depth, tile_width_spacing, width, - height, depth, pitch, num_levels, pixel_format, component_type, type, target) == + height, depth, pitch, num_levels, pixel_format, type, target) == std::tie(rhs.is_tiled, rhs.block_width, rhs.block_height, rhs.block_depth, rhs.tile_width_spacing, rhs.width, rhs.height, rhs.depth, rhs.pitch, - rhs.num_levels, rhs.pixel_format, rhs.component_type, rhs.type, rhs.target); + rhs.num_levels, rhs.pixel_format, rhs.type, rhs.target); } std::string SurfaceParams::TargetName() const { diff --git a/src/video_core/texture_cache/surface_params.h b/src/video_core/texture_cache/surface_params.h index c58e7f8a4..129817ad3 100644 --- a/src/video_core/texture_cache/surface_params.h +++ b/src/video_core/texture_cache/surface_params.h @@ -16,16 +16,20 @@ namespace VideoCommon { +class FormatLookupTable; + using VideoCore::Surface::SurfaceCompression; class SurfaceParams { public: /// Creates SurfaceCachedParams from a texture configuration. - static SurfaceParams CreateForTexture(const Tegra::Texture::TICEntry& tic, + static SurfaceParams CreateForTexture(const FormatLookupTable& lookup_table, + const Tegra::Texture::TICEntry& tic, const VideoCommon::Shader::Sampler& entry); /// Creates SurfaceCachedParams from an image configuration. - static SurfaceParams CreateForImage(const Tegra::Texture::TICEntry& tic, + static SurfaceParams CreateForImage(const FormatLookupTable& lookup_table, + const Tegra::Texture::TICEntry& tic, const VideoCommon::Shader::Image& entry); /// Creates SurfaceCachedParams for a depth buffer configuration. @@ -41,6 +45,14 @@ public: static SurfaceParams CreateForFermiCopySurface( const Tegra::Engines::Fermi2D::Regs::Surface& config); + /// Obtains the texture target from a shader's sampler entry. + static VideoCore::Surface::SurfaceTarget ExpectedTarget( + const VideoCommon::Shader::Sampler& entry); + + /// Obtains the texture target from a shader's sampler entry. + static VideoCore::Surface::SurfaceTarget ExpectedTarget( + const VideoCommon::Shader::Image& entry); + std::size_t Hash() const { return static_cast<std::size_t>( Common::CityHash64(reinterpret_cast<const char*>(this), sizeof(*this))); @@ -248,7 +260,6 @@ public: u32 num_levels; u32 emulated_levels; VideoCore::Surface::PixelFormat pixel_format; - VideoCore::Surface::ComponentType component_type; VideoCore::Surface::SurfaceType type; VideoCore::Surface::SurfaceTarget target; diff --git a/src/video_core/texture_cache/texture_cache.h b/src/video_core/texture_cache/texture_cache.h index 6a92b22d3..02d2e9136 100644 --- a/src/video_core/texture_cache/texture_cache.h +++ b/src/video_core/texture_cache/texture_cache.h @@ -29,6 +29,7 @@ #include "video_core/rasterizer_interface.h" #include "video_core/surface.h" #include "video_core/texture_cache/copy_params.h" +#include "video_core/texture_cache/format_lookup_table.h" #include "video_core/texture_cache/surface_base.h" #include "video_core/texture_cache/surface_params.h" #include "video_core/texture_cache/surface_view.h" @@ -94,10 +95,16 @@ public: std::lock_guard lock{mutex}; const auto gpu_addr{tic.Address()}; if (!gpu_addr) { - return {}; + return GetNullSurface(SurfaceParams::ExpectedTarget(entry)); + } + + const auto host_ptr{system.GPU().MemoryManager().GetPointer(gpu_addr)}; + const auto cache_addr{ToCacheAddr(host_ptr)}; + if (!cache_addr) { + return GetNullSurface(SurfaceParams::ExpectedTarget(entry)); } - const auto params{SurfaceParams::CreateForTexture(tic, entry)}; - const auto [surface, view] = GetSurface(gpu_addr, params, true, false); + const auto params{SurfaceParams::CreateForTexture(format_lookup_table, tic, entry)}; + const auto [surface, view] = GetSurface(gpu_addr, cache_addr, params, true, false); if (guard_samplers) { sampled_textures.push_back(surface); } @@ -109,10 +116,15 @@ public: std::lock_guard lock{mutex}; const auto gpu_addr{tic.Address()}; if (!gpu_addr) { - return {}; + return GetNullSurface(SurfaceParams::ExpectedTarget(entry)); + } + const auto host_ptr{system.GPU().MemoryManager().GetPointer(gpu_addr)}; + const auto cache_addr{ToCacheAddr(host_ptr)}; + if (!cache_addr) { + return GetNullSurface(SurfaceParams::ExpectedTarget(entry)); } - const auto params{SurfaceParams::CreateForImage(tic, entry)}; - const auto [surface, view] = GetSurface(gpu_addr, params, true, false); + const auto params{SurfaceParams::CreateForImage(format_lookup_table, tic, entry)}; + const auto [surface, view] = GetSurface(gpu_addr, cache_addr, params, true, false); if (guard_samplers) { sampled_textures.push_back(surface); } @@ -142,11 +154,17 @@ public: SetEmptyDepthBuffer(); return {}; } + const auto host_ptr{system.GPU().MemoryManager().GetPointer(gpu_addr)}; + const auto cache_addr{ToCacheAddr(host_ptr)}; + if (!cache_addr) { + SetEmptyDepthBuffer(); + return {}; + } const auto depth_params{SurfaceParams::CreateForDepthBuffer( system, regs.zeta_width, regs.zeta_height, regs.zeta.format, regs.zeta.memory_layout.block_width, regs.zeta.memory_layout.block_height, regs.zeta.memory_layout.block_depth, regs.zeta.memory_layout.type)}; - auto surface_view = GetSurface(gpu_addr, depth_params, preserve_contents, true); + auto surface_view = GetSurface(gpu_addr, cache_addr, depth_params, preserve_contents, true); if (depth_buffer.target) depth_buffer.target->MarkAsRenderTarget(false, NO_RT); depth_buffer.target = surface_view.first; @@ -179,8 +197,16 @@ public: return {}; } - auto surface_view = GetSurface(gpu_addr, SurfaceParams::CreateForFramebuffer(system, index), - preserve_contents, true); + const auto host_ptr{system.GPU().MemoryManager().GetPointer(gpu_addr)}; + const auto cache_addr{ToCacheAddr(host_ptr)}; + if (!cache_addr) { + SetEmptyColorBuffer(index); + return {}; + } + + auto surface_view = + GetSurface(gpu_addr, cache_addr, SurfaceParams::CreateForFramebuffer(system, index), + preserve_contents, true); if (render_targets[index].target) render_targets[index].target->MarkAsRenderTarget(false, NO_RT); render_targets[index].target = surface_view.first; @@ -229,8 +255,14 @@ public: const GPUVAddr src_gpu_addr = src_config.Address(); const GPUVAddr dst_gpu_addr = dst_config.Address(); DeduceBestBlit(src_params, dst_params, src_gpu_addr, dst_gpu_addr); - std::pair<TSurface, TView> dst_surface = GetSurface(dst_gpu_addr, dst_params, true, false); - std::pair<TSurface, TView> src_surface = GetSurface(src_gpu_addr, src_params, true, false); + const auto dst_host_ptr{system.GPU().MemoryManager().GetPointer(dst_gpu_addr)}; + const auto dst_cache_addr{ToCacheAddr(dst_host_ptr)}; + const auto src_host_ptr{system.GPU().MemoryManager().GetPointer(src_gpu_addr)}; + const auto src_cache_addr{ToCacheAddr(src_host_ptr)}; + std::pair<TSurface, TView> dst_surface = + GetSurface(dst_gpu_addr, dst_cache_addr, dst_params, true, false); + std::pair<TSurface, TView> src_surface = + GetSurface(src_gpu_addr, src_cache_addr, src_params, true, false); ImageBlit(src_surface.second, dst_surface.second, copy_config); dst_surface.first->MarkAsModified(true, Tick()); } @@ -346,13 +378,6 @@ protected: return new_surface; } - std::pair<TSurface, TView> GetFermiSurface( - const Tegra::Engines::Fermi2D::Regs::Surface& config) { - SurfaceParams params = SurfaceParams::CreateForFermiCopySurface(config); - const GPUVAddr gpu_addr = config.Address(); - return GetSurface(gpu_addr, params, true, false); - } - Core::System& system; private: @@ -485,15 +510,13 @@ private: GetSiblingFormat(cr_params.pixel_format) == params.pixel_format) { SurfaceParams new_params = params; new_params.pixel_format = cr_params.pixel_format; - new_params.component_type = cr_params.component_type; new_params.type = cr_params.type; new_surface = GetUncachedSurface(gpu_addr, new_params); } else { new_surface = GetUncachedSurface(gpu_addr, params); } const auto& final_params = new_surface->GetSurfaceParams(); - if (cr_params.type != final_params.type || - (cr_params.component_type != final_params.component_type)) { + if (cr_params.type != final_params.type) { BufferCopy(current_surface, new_surface); } else { std::vector<CopyParams> bricks = current_surface->BreakDown(final_params); @@ -615,22 +638,9 @@ private: * left blank. * @param is_render Whether or not the surface is a render target. **/ - std::pair<TSurface, TView> GetSurface(const GPUVAddr gpu_addr, const SurfaceParams& params, - bool preserve_contents, bool is_render) { - const auto host_ptr{system.GPU().MemoryManager().GetPointer(gpu_addr)}; - const auto cache_addr{ToCacheAddr(host_ptr)}; - - // Step 0: guarantee a valid surface - if (!cache_addr) { - // Return a null surface if it's invalid - SurfaceParams new_params = params; - new_params.width = 1; - new_params.height = 1; - new_params.depth = 1; - new_params.block_height = 0; - new_params.block_depth = 0; - return InitializeSurface(gpu_addr, new_params, false); - } + std::pair<TSurface, TView> GetSurface(const GPUVAddr gpu_addr, const CacheAddr cache_addr, + const SurfaceParams& params, bool preserve_contents, + bool is_render) { // Step 1 // Check Level 1 Cache for a fast structural match. If candidate surface @@ -795,6 +805,41 @@ private: } /** + * Gets a null surface based on a target texture. + * @param target The target of the null surface. + */ + TView GetNullSurface(SurfaceTarget target) { + const u32 i_target = static_cast<u32>(target); + if (const auto it = invalid_cache.find(i_target); it != invalid_cache.end()) { + return it->second->GetMainView(); + } + SurfaceParams params{}; + params.target = target; + params.is_tiled = false; + params.srgb_conversion = false; + params.is_layered = false; + params.block_width = 0; + params.block_height = 0; + params.block_depth = 0; + params.tile_width_spacing = 1; + params.width = 1; + params.height = 1; + params.depth = 1; + params.pitch = 4; + params.num_levels = 1; + params.emulated_levels = 1; + params.pixel_format = VideoCore::Surface::PixelFormat::RGBA16F; + params.type = VideoCore::Surface::SurfaceType::ColorTexture; + auto surface = CreateSurface(0ULL, params); + invalid_memory.clear(); + invalid_memory.resize(surface->GetHostSizeInBytes(), 0U); + surface->UploadTexture(invalid_memory); + surface->MarkAsModified(false, Tick()); + invalid_cache.emplace(i_target, surface); + return surface->GetMainView(); + } + + /** * Gets the a source and destination starting address and parameters, * and tries to deduce if they are supposed to be depth textures. If so, their * parameters are modified and fixed into so. @@ -835,12 +880,11 @@ private: } } - const auto inherit_format = ([](SurfaceParams& to, TSurface from) { + const auto inherit_format = [](SurfaceParams& to, TSurface from) { const SurfaceParams& params = from->GetSurfaceParams(); to.pixel_format = params.pixel_format; - to.component_type = params.component_type; to.type = params.type; - }); + }; // Now we got the cases where one or both is Depth and the other is not known if (!incomplete_src) { inherit_format(src_params, deduced_src.surface); @@ -956,6 +1000,8 @@ private: VideoCore::RasterizerInterface& rasterizer; + FormatLookupTable format_lookup_table; + u64 ticks{}; // Guards the cache for protection conflicts. @@ -991,6 +1037,11 @@ private: std::vector<TSurface> sampled_textures; + /// This cache stores null surfaces in order to be used as a placeholder + /// for invalid texture calls. + std::unordered_map<u32, TSurface> invalid_cache; + std::vector<u8> invalid_memory; + StagingCache staging_cache; std::recursive_mutex mutex; }; diff --git a/src/yuzu/debugger/wait_tree.cpp b/src/yuzu/debugger/wait_tree.cpp index 188f798c0..727bd8a94 100644 --- a/src/yuzu/debugger/wait_tree.cpp +++ b/src/yuzu/debugger/wait_tree.cpp @@ -57,7 +57,7 @@ std::size_t WaitTreeItem::Row() const { std::vector<std::unique_ptr<WaitTreeThread>> WaitTreeItem::MakeThreadItemList() { std::vector<std::unique_ptr<WaitTreeThread>> item_list; std::size_t row = 0; - auto add_threads = [&](const std::vector<Kernel::SharedPtr<Kernel::Thread>>& threads) { + auto add_threads = [&](const std::vector<std::shared_ptr<Kernel::Thread>>& threads) { for (std::size_t i = 0; i < threads.size(); ++i) { item_list.push_back(std::make_unique<WaitTreeThread>(*threads[i])); item_list.back()->row = row; @@ -80,7 +80,7 @@ QString WaitTreeText::GetText() const { WaitTreeMutexInfo::WaitTreeMutexInfo(VAddr mutex_address, const Kernel::HandleTable& handle_table) : mutex_address(mutex_address) { - mutex_value = Memory::Read32(mutex_address); + mutex_value = Core::System::GetInstance().Memory().Read32(mutex_address); owner_handle = static_cast<Kernel::Handle>(mutex_value & Kernel::Mutex::MutexOwnerMask); owner = handle_table.Get<Kernel::Thread>(owner_handle); } @@ -115,10 +115,11 @@ std::vector<std::unique_ptr<WaitTreeItem>> WaitTreeCallstack::GetChildren() cons std::vector<std::unique_ptr<WaitTreeItem>> list; constexpr std::size_t BaseRegister = 29; + auto& memory = Core::System::GetInstance().Memory(); u64 base_pointer = thread.GetContext().cpu_registers[BaseRegister]; while (base_pointer != 0) { - const u64 lr = Memory::Read64(base_pointer + sizeof(u64)); + const u64 lr = memory.Read64(base_pointer + sizeof(u64)); if (lr == 0) { break; } @@ -126,7 +127,7 @@ std::vector<std::unique_ptr<WaitTreeItem>> WaitTreeCallstack::GetChildren() cons list.push_back(std::make_unique<WaitTreeText>( tr("0x%1").arg(lr - sizeof(u32), 16, 16, QLatin1Char{'0'}))); - base_pointer = Memory::Read64(base_pointer); + base_pointer = memory.Read64(base_pointer); } return list; @@ -172,8 +173,8 @@ std::vector<std::unique_ptr<WaitTreeItem>> WaitTreeWaitObject::GetChildren() con return list; } -WaitTreeObjectList::WaitTreeObjectList( - const std::vector<Kernel::SharedPtr<Kernel::WaitObject>>& list, bool w_all) +WaitTreeObjectList::WaitTreeObjectList(const std::vector<std::shared_ptr<Kernel::WaitObject>>& list, + bool w_all) : object_list(list), wait_all(w_all) {} WaitTreeObjectList::~WaitTreeObjectList() = default; @@ -325,7 +326,7 @@ std::vector<std::unique_ptr<WaitTreeItem>> WaitTreeThread::GetChildren() const { WaitTreeEvent::WaitTreeEvent(const Kernel::ReadableEvent& object) : WaitTreeWaitObject(object) {} WaitTreeEvent::~WaitTreeEvent() = default; -WaitTreeThreadList::WaitTreeThreadList(const std::vector<Kernel::SharedPtr<Kernel::Thread>>& list) +WaitTreeThreadList::WaitTreeThreadList(const std::vector<std::shared_ptr<Kernel::Thread>>& list) : thread_list(list) {} WaitTreeThreadList::~WaitTreeThreadList() = default; diff --git a/src/yuzu/debugger/wait_tree.h b/src/yuzu/debugger/wait_tree.h index f2b13be24..631274a5f 100644 --- a/src/yuzu/debugger/wait_tree.h +++ b/src/yuzu/debugger/wait_tree.h @@ -83,7 +83,7 @@ private: VAddr mutex_address; u32 mutex_value; Kernel::Handle owner_handle; - Kernel::SharedPtr<Kernel::Thread> owner; + std::shared_ptr<Kernel::Thread> owner; }; class WaitTreeCallstack : public WaitTreeExpandableItem { @@ -116,15 +116,14 @@ protected: class WaitTreeObjectList : public WaitTreeExpandableItem { Q_OBJECT public: - WaitTreeObjectList(const std::vector<Kernel::SharedPtr<Kernel::WaitObject>>& list, - bool wait_all); + WaitTreeObjectList(const std::vector<std::shared_ptr<Kernel::WaitObject>>& list, bool wait_all); ~WaitTreeObjectList() override; QString GetText() const override; std::vector<std::unique_ptr<WaitTreeItem>> GetChildren() const override; private: - const std::vector<Kernel::SharedPtr<Kernel::WaitObject>>& object_list; + const std::vector<std::shared_ptr<Kernel::WaitObject>>& object_list; bool wait_all; }; @@ -149,14 +148,14 @@ public: class WaitTreeThreadList : public WaitTreeExpandableItem { Q_OBJECT public: - explicit WaitTreeThreadList(const std::vector<Kernel::SharedPtr<Kernel::Thread>>& list); + explicit WaitTreeThreadList(const std::vector<std::shared_ptr<Kernel::Thread>>& list); ~WaitTreeThreadList() override; QString GetText() const override; std::vector<std::unique_ptr<WaitTreeItem>> GetChildren() const override; private: - const std::vector<Kernel::SharedPtr<Kernel::Thread>>& thread_list; + const std::vector<std::shared_ptr<Kernel::Thread>>& thread_list; }; class WaitTreeModel : public QAbstractItemModel { diff --git a/src/yuzu/main.cpp b/src/yuzu/main.cpp index 160613ee1..867f8e913 100644 --- a/src/yuzu/main.cpp +++ b/src/yuzu/main.cpp @@ -817,6 +817,9 @@ QStringList GMainWindow::GetUnsupportedGLExtensions() { if (!GLAD_GL_ARB_multi_bind) { unsupported_ext.append(QStringLiteral("ARB_multi_bind")); } + if (!GLAD_GL_ARB_clip_control) { + unsupported_ext.append(QStringLiteral("ARB_clip_control")); + } // Extensions required to support some texture formats. if (!GLAD_GL_EXT_texture_compression_s3tc) { diff --git a/src/yuzu/main.ui b/src/yuzu/main.ui index a1ce3c0c3..21f422500 100644 --- a/src/yuzu/main.ui +++ b/src/yuzu/main.ui @@ -140,11 +140,6 @@ <string>Load Folder...</string> </property> </action> - <action name="action_Load_Symbol_Map"> - <property name="text"> - <string>Load Symbol Map...</string> - </property> - </action> <action name="action_Exit"> <property name="text"> <string>E&xit</string> @@ -221,14 +216,6 @@ <string>Show Status Bar</string> </property> </action> - <action name="action_Select_Game_List_Root"> - <property name="text"> - <string>Select Game Directory...</string> - </property> - <property name="toolTip"> - <string>Selects a folder to display in the game list</string> - </property> - </action> <action name="action_Select_NAND_Directory"> <property name="text"> <string>Select NAND Directory...</string> diff --git a/src/yuzu_cmd/emu_window/emu_window_sdl2_gl.cpp b/src/yuzu_cmd/emu_window/emu_window_sdl2_gl.cpp index f91b071bf..6fde694a2 100644 --- a/src/yuzu_cmd/emu_window/emu_window_sdl2_gl.cpp +++ b/src/yuzu_cmd/emu_window/emu_window_sdl2_gl.cpp @@ -50,7 +50,7 @@ private: }; bool EmuWindow_SDL2_GL::SupportsRequiredGLExtensions() { - std::vector<std::string> unsupported_ext; + std::vector<std::string_view> unsupported_ext; if (!GLAD_GL_ARB_buffer_storage) unsupported_ext.push_back("ARB_buffer_storage"); @@ -62,6 +62,8 @@ bool EmuWindow_SDL2_GL::SupportsRequiredGLExtensions() { unsupported_ext.push_back("ARB_texture_mirror_clamp_to_edge"); if (!GLAD_GL_ARB_multi_bind) unsupported_ext.push_back("ARB_multi_bind"); + if (!GLAD_GL_ARB_clip_control) + unsupported_ext.push_back("ARB_clip_control"); // Extensions required to support some texture formats. if (!GLAD_GL_EXT_texture_compression_s3tc) @@ -71,8 +73,8 @@ bool EmuWindow_SDL2_GL::SupportsRequiredGLExtensions() { if (!GLAD_GL_ARB_depth_buffer_float) unsupported_ext.push_back("ARB_depth_buffer_float"); - for (const std::string& ext : unsupported_ext) - LOG_CRITICAL(Frontend, "Unsupported GL extension: {}", ext); + for (const auto& extension : unsupported_ext) + LOG_CRITICAL(Frontend, "Unsupported GL extension: {}", extension); return unsupported_ext.empty(); } |