From ad374ef5e2cc01a7ac7841325e1999282c2097da Mon Sep 17 00:00:00 2001 From: Stenzek Date: Sun, 21 Jul 2024 15:35:56 +1000 Subject: [PATCH] Sockets: Use epoll on Linux --- src/util/sockets.cpp | 134 +++++++++++++++++++++++++++++++++++++------ src/util/sockets.h | 7 +++ 2 files changed, 123 insertions(+), 18 deletions(-) diff --git a/src/util/sockets.cpp b/src/util/sockets.cpp index 0c5e82b33..15f9ae6f4 100644 --- a/src/util/sockets.cpp +++ b/src/util/sockets.cpp @@ -42,6 +42,10 @@ using nfds_t = ULONG; #include #include +#ifdef __linux__ +#include +#endif + #define ioctlsocket ioctl #define closesocket close #define WSAEWOULDBLOCK EAGAIN @@ -227,16 +231,42 @@ SocketMultiplexer::~SocketMultiplexer() { CloseAll(); +#ifdef __linux__ + if (m_epoll_fd >= 0) + close(m_epoll_fd); +#else if (m_poll_array) std::free(m_poll_array); +#endif } std::unique_ptr SocketMultiplexer::Create(Error* error) { - if (!PlatformMisc::InitializeSocketSupport(error)) - return {}; + std::unique_ptr ret; + if (PlatformMisc::InitializeSocketSupport(error)) + { + ret = std::unique_ptr(new SocketMultiplexer()); + if (!ret->Initialize(error)) + ret.reset(); + } - return std::unique_ptr(new SocketMultiplexer()); + return ret; +} + +bool SocketMultiplexer::Initialize(Error* error) +{ +#ifdef __linux__ + m_epoll_fd = epoll_create1(0); + if (m_epoll_fd < 0) + { + Error::SetErrno(error, "epoll_create1() failed: ", errno); + return false; + } + + return true; +#else + return true; +#endif } std::shared_ptr SocketMultiplexer::InternalCreateListenSocket(const SocketAddress& address, @@ -325,8 +355,13 @@ std::shared_ptr SocketMultiplexer::InternalConnectStreamSocket(con void SocketMultiplexer::AddOpenSocket(std::shared_ptr socket) { - std::unique_lock lock(m_open_sockets_lock); +#ifdef __linux__ + struct epoll_event ev = {.events = 0u, .data = {.fd = socket->GetDescriptor()}}; + if (epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, socket->GetDescriptor(), &ev) != 0) [[unlikely]] + ERROR_LOG("epoll_ctl() to add socket failed: {}", Error::CreateErrno(errno).GetDescription()); +#endif + std::unique_lock lock(m_open_sockets_lock); DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end()); m_open_sockets.emplace(socket->GetDescriptor(), std::move(socket)); } @@ -339,27 +374,29 @@ void SocketMultiplexer::AddClientSocket(std::shared_ptr socket) void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket) { -#ifdef _DEBUG - { - std::unique_lock lock(m_poll_array_lock); - for (size_t i = 0; i < m_poll_array_active_size; i++) - { - pollfd& pfd = m_poll_array[i]; - DebugAssert(pfd.fd != socket->GetDescriptor()); - } - } -#endif - std::unique_lock lock(m_open_sockets_lock); const auto iter = m_open_sockets.find(socket->GetDescriptor()); Assert(iter != m_open_sockets.end()); m_open_sockets.erase(iter); +#ifdef __linux__ + if (epoll_ctl(m_epoll_fd, EPOLL_CTL_DEL, socket->GetDescriptor(), nullptr) != 0) [[unlikely]] + ERROR_LOG("epoll_ctl() to remove socket failed: {}", Error::CreateErrno(errno).GetDescription()); +#else +#ifdef _DEBUG + for (size_t i = 0; i < m_poll_array_active_size; i++) + { + pollfd& pfd = m_poll_array[i]; + DebugAssert(pfd.fd != socket->GetDescriptor()); + } +#endif + // Update size. size_t new_active_size = 0; for (size_t i = 0; i < m_poll_array_active_size; i++) new_active_size = (m_poll_array[i].fd != INVALID_SOCKET) ? (i + 1) : new_active_size; m_poll_array_active_size = new_active_size; +#endif } void SocketMultiplexer::RemoveClientSocket(BaseSocket* socket) @@ -400,6 +437,11 @@ void SocketMultiplexer::CloseAll() void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events) { +#ifdef __linux__ + struct epoll_event ev = {.events = events, .data = {.fd = descriptor}}; + if (epoll_ctl(m_epoll_fd, EPOLL_CTL_MOD, descriptor, &ev) != 0) [[unlikely]] + ERROR_LOG("epoll_ctl() for events 0x{:x} failed: {}", events, Error::CreateErrno(errno).GetDescription()); +#else std::unique_lock lock(m_poll_array_lock); size_t free_slot = m_poll_array_active_size; for (size_t i = 0; i < m_poll_array_active_size; i++) @@ -440,10 +482,64 @@ void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor m_poll_array[free_slot] = {.fd = descriptor, .events = static_cast(events), .revents = 0}; m_poll_array_active_size = free_slot + 1; +#endif } bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds) { +#ifdef __linux__ + constexpr int MAX_EVENTS = 128; + struct epoll_event events[MAX_EVENTS]; + + const int nevents = epoll_wait(m_epoll_fd, events, MAX_EVENTS, static_cast(milliseconds)); + if (nevents <= 0) + return false; + + // find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects + using PendingSocketPair = std::pair, u32>; + PendingSocketPair* triggered_sockets = + reinterpret_cast(alloca(sizeof(PendingSocketPair) * static_cast(nevents))); + size_t num_triggered_sockets = 0; + { + std::unique_lock open_lock(m_open_sockets_lock); + for (int i = 0; i < nevents; i++) + { + const epoll_event& ev = events[i]; + const auto iter = m_open_sockets.find(ev.data.fd); + if (iter == m_open_sockets.end()) [[unlikely]] + { + ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", ev.data.fd); + continue; + } + + // we add a reference here in case the read kills it with a write pending, or something like that + new (&triggered_sockets[num_triggered_sockets++]) PendingSocketPair(iter->second->shared_from_this(), ev.events); + } + } + + // fire events + for (size_t i = 0; i < num_triggered_sockets; i++) + { + PendingSocketPair& psp = triggered_sockets[i]; + + // fire events + if (psp.second & (EPOLLRDHUP | EPOLLHUP | EPOLLERR)) + { + psp.first->OnHangupEvent(); + } + else + { + if (psp.second & EPOLLIN) + psp.first->OnReadEvent(); + if (psp.second & EPOLLOUT) + psp.first->OnWriteEvent(); + } + + psp.first.~shared_ptr(); + } + + return true; +#else std::unique_lock lock(m_poll_array_lock); if (m_poll_array_active_size == 0) return false; @@ -454,7 +550,8 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds) // find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects using PendingSocketPair = std::pair, u32>; - PendingSocketPair* triggered_sockets = reinterpret_cast(alloca(sizeof(PendingSocketPair) * res)); + PendingSocketPair* triggered_sockets = + reinterpret_cast(alloca(sizeof(PendingSocketPair) * static_cast(res))); size_t num_triggered_sockets = 0; { std::unique_lock open_lock(m_open_sockets_lock); @@ -467,7 +564,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds) const auto iter = m_open_sockets.find(pfd.fd); if (iter == m_open_sockets.end()) [[unlikely]] { - ERROR_LOG("Attempting to look up known socket {}, this should never happen.", pfd.fd); + ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", pfd.fd); continue; } @@ -481,7 +578,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds) lock.unlock(); // fire events - for (u32 i = 0; i < num_triggered_sockets; i++) + for (size_t i = 0; i < num_triggered_sockets; i++) { PendingSocketPair& psp = triggered_sockets[i]; @@ -502,6 +599,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds) } return true; +#endif } ListenSocket::ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, diff --git a/src/util/sockets.h b/src/util/sockets.h index 2a4375059..ac0e58a34 100644 --- a/src/util/sockets.h +++ b/src/util/sockets.h @@ -135,6 +135,9 @@ private: // Hide the constructor. SocketMultiplexer(); + // Initialization. + bool Initialize(Error* error); + // Tracking of open sockets. void AddOpenSocket(std::shared_ptr socket); void AddClientSocket(std::shared_ptr socket); @@ -148,10 +151,14 @@ private: // We store the fd in the struct to avoid the cache miss reading the object. using SocketMap = std::unordered_map>; +#ifdef __linux__ + int m_epoll_fd = -1; +#else std::mutex m_poll_array_lock; pollfd* m_poll_array = nullptr; size_t m_poll_array_active_size = 0; size_t m_poll_array_max_size = 0; +#endif std::mutex m_open_sockets_lock; SocketMap m_open_sockets;