diff options
Diffstat (limited to 'src/core/internal_network/network.cpp')
-rw-r--r-- | src/core/internal_network/network.cpp | 274 |
1 files changed, 220 insertions, 54 deletions
diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp index 75ac10a9c..39381e06e 100644 --- a/src/core/internal_network/network.cpp +++ b/src/core/internal_network/network.cpp @@ -121,6 +121,8 @@ Errno TranslateNativeError(int e) { return Errno::MSGSIZE; case WSAETIMEDOUT: return Errno::TIMEDOUT; + case WSAEINPROGRESS: + return Errno::INPROGRESS; default: UNIMPLEMENTED_MSG("Unimplemented errno={}", e); return Errno::OTHER; @@ -195,6 +197,8 @@ bool EnableNonBlock(int fd, bool enable) { Errno TranslateNativeError(int e) { switch (e) { + case 0: + return Errno::SUCCESS; case EBADF: return Errno::BADF; case EINVAL: @@ -219,8 +223,10 @@ Errno TranslateNativeError(int e) { return Errno::MSGSIZE; case ETIMEDOUT: return Errno::TIMEDOUT; + case EINPROGRESS: + return Errno::INPROGRESS; default: - UNIMPLEMENTED_MSG("Unimplemented errno={}", e); + UNIMPLEMENTED_MSG("Unimplemented errno={} ({})", e, strerror(e)); return Errno::OTHER; } } @@ -234,15 +240,84 @@ Errno GetAndLogLastError() { int e = errno; #endif const Errno err = TranslateNativeError(e); - if (err == Errno::AGAIN || err == Errno::TIMEDOUT) { + if (err == Errno::AGAIN || err == Errno::TIMEDOUT || err == Errno::INPROGRESS) { + // These happen during normal operation, so only log them at debug level. + LOG_DEBUG(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); return err; } LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); return err; } -int TranslateDomain(Domain domain) { +GetAddrInfoError TranslateGetAddrInfoErrorFromNative(int gai_err) { + switch (gai_err) { + case 0: + return GetAddrInfoError::SUCCESS; +#ifdef EAI_ADDRFAMILY + case EAI_ADDRFAMILY: + return GetAddrInfoError::ADDRFAMILY; +#endif + case EAI_AGAIN: + return GetAddrInfoError::AGAIN; + case EAI_BADFLAGS: + return GetAddrInfoError::BADFLAGS; + case EAI_FAIL: + return GetAddrInfoError::FAIL; + case EAI_FAMILY: + return GetAddrInfoError::FAMILY; + case EAI_MEMORY: + return GetAddrInfoError::MEMORY; + case EAI_NONAME: + return GetAddrInfoError::NONAME; + case EAI_SERVICE: + return GetAddrInfoError::SERVICE; + case EAI_SOCKTYPE: + return GetAddrInfoError::SOCKTYPE; + // These codes may not be defined on all systems: +#ifdef EAI_SYSTEM + case EAI_SYSTEM: + return GetAddrInfoError::SYSTEM; +#endif +#ifdef EAI_BADHINTS + case EAI_BADHINTS: + return GetAddrInfoError::BADHINTS; +#endif +#ifdef EAI_PROTOCOL + case EAI_PROTOCOL: + return GetAddrInfoError::PROTOCOL; +#endif +#ifdef EAI_OVERFLOW + case EAI_OVERFLOW: + return GetAddrInfoError::OVERFLOW_; +#endif + default: +#ifdef EAI_NODATA + // This can't be a case statement because it would create a duplicate + // case on Windows where EAI_NODATA is an alias for EAI_NONAME. + if (gai_err == EAI_NODATA) { + return GetAddrInfoError::NODATA; + } +#endif + return GetAddrInfoError::OTHER; + } +} + +Domain TranslateDomainFromNative(int domain) { switch (domain) { + case 0: + return Domain::Unspecified; + case AF_INET: + return Domain::INET; + default: + UNIMPLEMENTED_MSG("Unhandled domain={}", domain); + return Domain::INET; + } +} + +int TranslateDomainToNative(Domain domain) { + switch (domain) { + case Domain::Unspecified: + return 0; case Domain::INET: return AF_INET; default: @@ -251,20 +326,58 @@ int TranslateDomain(Domain domain) { } } -int TranslateType(Type type) { +Type TranslateTypeFromNative(int type) { + switch (type) { + case 0: + return Type::Unspecified; + case SOCK_STREAM: + return Type::STREAM; + case SOCK_DGRAM: + return Type::DGRAM; + case SOCK_RAW: + return Type::RAW; + case SOCK_SEQPACKET: + return Type::SEQPACKET; + default: + UNIMPLEMENTED_MSG("Unimplemented type={}", type); + return Type::STREAM; + } +} + +int TranslateTypeToNative(Type type) { switch (type) { + case Type::Unspecified: + return 0; case Type::STREAM: return SOCK_STREAM; case Type::DGRAM: return SOCK_DGRAM; + case Type::RAW: + return SOCK_RAW; default: UNIMPLEMENTED_MSG("Unimplemented type={}", type); return 0; } } -int TranslateProtocol(Protocol protocol) { +Protocol TranslateProtocolFromNative(int protocol) { + switch (protocol) { + case 0: + return Protocol::Unspecified; + case IPPROTO_TCP: + return Protocol::TCP; + case IPPROTO_UDP: + return Protocol::UDP; + default: + UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); + return Protocol::Unspecified; + } +} + +int TranslateProtocolToNative(Protocol protocol) { switch (protocol) { + case Protocol::Unspecified: + return 0; case Protocol::TCP: return IPPROTO_TCP; case Protocol::UDP: @@ -275,21 +388,10 @@ int TranslateProtocol(Protocol protocol) { } } -SockAddrIn TranslateToSockAddrIn(sockaddr input_) { - sockaddr_in input; - std::memcpy(&input, &input_, sizeof(input)); - +SockAddrIn TranslateToSockAddrIn(sockaddr_in input, size_t input_len) { SockAddrIn result; - switch (input.sin_family) { - case AF_INET: - result.family = Domain::INET; - break; - default: - UNIMPLEMENTED_MSG("Unhandled sockaddr family={}", input.sin_family); - result.family = Domain::INET; - break; - } + result.family = TranslateDomainFromNative(input.sin_family); result.portno = ntohs(input.sin_port); @@ -301,22 +403,28 @@ SockAddrIn TranslateToSockAddrIn(sockaddr input_) { short TranslatePollEvents(PollEvents events) { short result = 0; - if (True(events & PollEvents::In)) { - events &= ~PollEvents::In; - result |= POLLIN; - } - if (True(events & PollEvents::Pri)) { - events &= ~PollEvents::Pri; + const auto translate = [&result, &events](PollEvents guest, short host) { + if (True(events & guest)) { + events &= ~guest; + result |= host; + } + }; + + translate(PollEvents::In, POLLIN); + translate(PollEvents::Pri, POLLPRI); + translate(PollEvents::Out, POLLOUT); + translate(PollEvents::Err, POLLERR); + translate(PollEvents::Hup, POLLHUP); + translate(PollEvents::Nval, POLLNVAL); + translate(PollEvents::RdNorm, POLLRDNORM); + translate(PollEvents::RdBand, POLLRDBAND); + translate(PollEvents::WrBand, POLLWRBAND); + #ifdef _WIN32 + if (True(events & PollEvents::Pri)) { LOG_WARNING(Service, "Winsock doesn't support POLLPRI"); -#else - result |= POLLPRI; -#endif - } - if (True(events & PollEvents::Out)) { - events &= ~PollEvents::Out; - result |= POLLOUT; } +#endif UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events); @@ -337,6 +445,10 @@ PollEvents TranslatePollRevents(short revents) { translate(POLLOUT, PollEvents::Out); translate(POLLERR, PollEvents::Err); translate(POLLHUP, PollEvents::Hup); + translate(POLLNVAL, PollEvents::Nval); + translate(POLLRDNORM, PollEvents::RdNorm); + translate(POLLRDBAND, PollEvents::RdBand); + translate(POLLWRBAND, PollEvents::WrBand); UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents); @@ -360,12 +472,53 @@ std::optional<IPv4Address> GetHostIPv4Address() { return {}; } - std::array<char, 16> ip_addr = {}; - ASSERT(inet_ntop(AF_INET, &network_interface->ip_address, ip_addr.data(), sizeof(ip_addr)) != - nullptr); return TranslateIPv4(network_interface->ip_address); } +std::string IPv4AddressToString(IPv4Address ip_addr) { + std::array<char, INET_ADDRSTRLEN> buf = {}; + ASSERT(inet_ntop(AF_INET, &ip_addr, buf.data(), sizeof(buf)) == buf.data()); + return std::string(buf.data()); +} + +u32 IPv4AddressToInteger(IPv4Address ip_addr) { + return static_cast<u32>(ip_addr[0]) << 24 | static_cast<u32>(ip_addr[1]) << 16 | + static_cast<u32>(ip_addr[2]) << 8 | static_cast<u32>(ip_addr[3]); +} + +#undef GetAddrInfo // Windows defines it as a macro + +Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddrInfo( + const std::string& host, const std::optional<std::string>& service) { + addrinfo hints{}; + hints.ai_family = AF_INET; // Switch only supports IPv4. + addrinfo* addrinfo; + s32 gai_err = getaddrinfo(host.c_str(), service.has_value() ? service->c_str() : nullptr, + &hints, &addrinfo); + if (gai_err != 0) { + return Common::Unexpected(TranslateGetAddrInfoErrorFromNative(gai_err)); + } + std::vector<AddrInfo> ret; + for (auto* current = addrinfo; current; current = current->ai_next) { + // We should only get AF_INET results due to the hints value. + ASSERT_OR_EXECUTE(addrinfo->ai_family == AF_INET && + addrinfo->ai_addrlen == sizeof(sockaddr_in), + continue;); + + AddrInfo& out = ret.emplace_back(); + out.family = TranslateDomainFromNative(current->ai_family); + out.socket_type = TranslateTypeFromNative(current->ai_socktype); + out.protocol = TranslateProtocolFromNative(current->ai_protocol); + out.addr = TranslateToSockAddrIn(*reinterpret_cast<sockaddr_in*>(current->ai_addr), + current->ai_addrlen); + if (current->ai_canonname != nullptr) { + out.canon_name = current->ai_canonname; + } + } + freeaddrinfo(addrinfo); + return ret; +} + std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) { const size_t num = pollfds.size(); @@ -411,6 +564,18 @@ Socket::Socket(Socket&& rhs) noexcept { } template <typename T> +std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_, int option) { + T value{}; + socklen_t len = sizeof(value); + const int result = getsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len); + if (result != SOCKET_ERROR) { + ASSERT(len == sizeof(value)); + return {value, Errno::SUCCESS}; + } + return {value, GetAndLogLastError()}; +} + +template <typename T> Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { const int result = setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value)); @@ -421,7 +586,8 @@ Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { } Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { - fd = socket(TranslateDomain(domain), TranslateType(type), TranslateProtocol(protocol)); + fd = socket(TranslateDomainToNative(domain), TranslateTypeToNative(type), + TranslateProtocolToNative(protocol)); if (fd != INVALID_SOCKET) { return Errno::SUCCESS; } @@ -430,19 +596,17 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { } std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { - sockaddr addr; + sockaddr_in addr; socklen_t addrlen = sizeof(addr); - const SOCKET new_socket = accept(fd, &addr, &addrlen); + const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen); if (new_socket == INVALID_SOCKET) { return {AcceptResult{}, GetAndLogLastError()}; } - ASSERT(addrlen == sizeof(sockaddr_in)); - AcceptResult result{ .socket = std::make_unique<Socket>(new_socket), - .sockaddr_in = TranslateToSockAddrIn(addr), + .sockaddr_in = TranslateToSockAddrIn(addr, addrlen), }; return {std::move(result), Errno::SUCCESS}; @@ -458,25 +622,23 @@ Errno Socket::Connect(SockAddrIn addr_in) { } std::pair<SockAddrIn, Errno> Socket::GetPeerName() { - sockaddr addr; + sockaddr_in addr; socklen_t addrlen = sizeof(addr); - if (getpeername(fd, &addr, &addrlen) == SOCKET_ERROR) { + if (getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) { return {SockAddrIn{}, GetAndLogLastError()}; } - ASSERT(addrlen == sizeof(sockaddr_in)); - return {TranslateToSockAddrIn(addr), Errno::SUCCESS}; + return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS}; } std::pair<SockAddrIn, Errno> Socket::GetSockName() { - sockaddr addr; + sockaddr_in addr; socklen_t addrlen = sizeof(addr); - if (getsockname(fd, &addr, &addrlen) == SOCKET_ERROR) { + if (getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) { return {SockAddrIn{}, GetAndLogLastError()}; } - ASSERT(addrlen == sizeof(sockaddr_in)); - return {TranslateToSockAddrIn(addr), Errno::SUCCESS}; + return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS}; } Errno Socket::Bind(SockAddrIn addr) { @@ -519,7 +681,7 @@ Errno Socket::Shutdown(ShutdownHow how) { return GetAndLogLastError(); } -std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) { +std::pair<s32, Errno> Socket::Recv(int flags, std::span<u8> message) { ASSERT(flags == 0); ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); @@ -532,21 +694,20 @@ std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) { return {-1, GetAndLogLastError()}; } -std::pair<s32, Errno> Socket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { +std::pair<s32, Errno> Socket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) { ASSERT(flags == 0); ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); - sockaddr addr_in{}; + sockaddr_in addr_in{}; socklen_t addrlen = sizeof(addr_in); socklen_t* const p_addrlen = addr ? &addrlen : nullptr; - sockaddr* const p_addr_in = addr ? &addr_in : nullptr; + sockaddr* const p_addr_in = addr ? reinterpret_cast<sockaddr*>(&addr_in) : nullptr; const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()), static_cast<int>(message.size()), 0, p_addr_in, p_addrlen); if (result != SOCKET_ERROR) { if (addr) { - ASSERT(addrlen == sizeof(addr_in)); - *addr = TranslateToSockAddrIn(addr_in); + *addr = TranslateToSockAddrIn(addr_in, addrlen); } return {static_cast<s32>(result), Errno::SUCCESS}; } @@ -597,6 +758,11 @@ Errno Socket::Close() { return Errno::SUCCESS; } +std::pair<Errno, Errno> Socket::GetPendingError() { + auto [pending_err, getsockopt_err] = GetSockOpt<int>(fd, SO_ERROR); + return {TranslateNativeError(pending_err), getsockopt_err}; +} + Errno Socket::SetLinger(bool enable, u32 linger) { return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger)); } |