HTTPDownloader: Move to common

This commit is contained in:
Connor McLaughlin
2022-03-26 23:05:02 +10:00
parent cb51ab7197
commit d5128a5ea9
15 changed files with 54 additions and 51 deletions

View File

@ -113,25 +113,6 @@ if(ENABLE_DISCORD_PRESENCE)
endif()
if(ENABLE_CHEEVOS)
target_sources(frontend-common PRIVATE
http_downloader.cpp
http_downloader.h
)
if(WIN32)
target_sources(frontend-common PRIVATE
http_downloader_winhttp.cpp
http_downloader_winhttp.h
)
elseif(NOT ANDROID)
target_sources(frontend-common PRIVATE
http_downloader_curl.cpp
http_downloader_curl.h
)
target_link_libraries(frontend-common PRIVATE
CURL::libcurl
)
endif()
target_sources(frontend-common PRIVATE
cheevos.cpp
cheevos.h

View File

@ -1,6 +1,7 @@
#include "cheevos.h"
#include "common/cd_image.h"
#include "common/file_system.h"
#include "common/http_downloader.h"
#include "common/log.h"
#include "common/md5_digest.h"
#include "common/platform.h"
@ -13,7 +14,6 @@
#include "core/host_display.h"
#include "core/system.h"
#include "fullscreen_ui.h"
#include "http_downloader.h"
#include "imgui_fullscreen.h"
#include "rapidjson/document.h"
#include "rc_url.h"

View File

@ -17,13 +17,6 @@
<ClCompile Include="d3d12_host_display.cpp" />
<ClCompile Include="game_list.cpp" />
<ClCompile Include="game_settings.cpp" />
<ClCompile Include="http_downloader.cpp" />
<ClCompile Include="http_downloader_uwp.cpp">
<ExcludedFromBuild Condition="'$(BuildingForUWP)'!='true'">true</ExcludedFromBuild>
</ClCompile>
<ClCompile Include="http_downloader_winhttp.cpp">
<ExcludedFromBuild Condition="'$(BuildingForUWP)'=='true'">true</ExcludedFromBuild>
</ClCompile>
<ClCompile Include="icon.cpp" />
<ClCompile Include="imgui_fullscreen.cpp" />
<ClCompile Include="imgui_impl_dx11.cpp" />
@ -69,13 +62,6 @@
<ClInclude Include="d3d12_host_display.h" />
<ClInclude Include="game_list.h" />
<ClInclude Include="game_settings.h" />
<ClInclude Include="http_downloader.h" />
<ClInclude Include="http_downloader_uwp.h">
<ExcludedFromBuild Condition="'$(BuildingForUWP)'!='true'">true</ExcludedFromBuild>
</ClInclude>
<ClInclude Include="http_downloader_winhttp.h">
<ExcludedFromBuild Condition="'$(BuildingForUWP)'=='true'">true</ExcludedFromBuild>
</ClInclude>
<ClInclude Include="icon.h" />
<ClInclude Include="imgui_fullscreen.h" />
<ClInclude Include="imgui_impl_dx11.h" />

View File

@ -28,15 +28,12 @@
<ClCompile Include="fullscreen_ui.cpp" />
<ClCompile Include="fullscreen_ui_progress_callback.cpp" />
<ClCompile Include="cheevos.cpp" />
<ClCompile Include="http_downloader.cpp" />
<ClCompile Include="http_downloader_winhttp.cpp" />
<ClCompile Include="input_overlay_ui.cpp" />
<ClCompile Include="game_database.cpp" />
<ClCompile Include="inhibit_screensaver.cpp" />
<ClCompile Include="xaudio2_audio_stream.cpp" />
<ClCompile Include="d3d12_host_display.cpp" />
<ClCompile Include="imgui_impl_dx12.cpp" />
<ClCompile Include="http_downloader_uwp.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="icon.h" />
@ -66,15 +63,12 @@
<ClInclude Include="fullscreen_ui.h" />
<ClInclude Include="fullscreen_ui_progress_callback.h" />
<ClInclude Include="cheevos.h" />
<ClInclude Include="http_downloader.h" />
<ClInclude Include="http_downloader_winhttp.h" />
<ClInclude Include="input_overlay_ui.h" />
<ClInclude Include="game_database.h" />
<ClInclude Include="inhibit_screensaver.h" />
<ClInclude Include="xaudio2_audio_stream.h" />
<ClInclude Include="d3d12_host_display.h" />
<ClInclude Include="imgui_impl_dx12.h" />
<ClInclude Include="http_downloader_uwp.h" />
</ItemGroup>
<ItemGroup>
<None Include="font_roboto_regular.inl" />

View File

@ -1,186 +0,0 @@
#include "http_downloader.h"
#include "common/assert.h"
#include "common/log.h"
#include "common/timer.h"
Log_SetChannel(HTTPDownloader);
static constexpr float DEFAULT_TIMEOUT_IN_SECONDS = 30;
static constexpr u32 DEFAULT_MAX_ACTIVE_REQUESTS = 4;
namespace FrontendCommon {
const char HTTPDownloader::DEFAULT_USER_AGENT[] =
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:85.0) Gecko/20100101 Firefox/85.0";
HTTPDownloader::HTTPDownloader()
: m_timeout(DEFAULT_TIMEOUT_IN_SECONDS), m_max_active_requests(DEFAULT_MAX_ACTIVE_REQUESTS)
{
}
HTTPDownloader::~HTTPDownloader() = default;
void HTTPDownloader::SetTimeout(float timeout)
{
m_timeout = timeout;
}
void HTTPDownloader::SetMaxActiveRequests(u32 max_active_requests)
{
Assert(max_active_requests > 0);
m_max_active_requests = max_active_requests;
}
void HTTPDownloader::CreateRequest(std::string url, Request::Callback callback)
{
Request* req = InternalCreateRequest();
req->parent = this;
req->type = Request::Type::Get;
req->url = std::move(url);
req->callback = std::move(callback);
req->start_time = Common::Timer::GetValue();
std::unique_lock<std::mutex> lock(m_pending_http_request_lock);
if (LockedGetActiveRequestCount() < m_max_active_requests)
{
if (!StartRequest(req))
return;
}
LockedAddRequest(req);
}
void HTTPDownloader::CreatePostRequest(std::string url, std::string post_data, Request::Callback callback)
{
Request* req = InternalCreateRequest();
req->parent = this;
req->type = Request::Type::Post;
req->url = std::move(url);
req->post_data = std::move(post_data);
req->callback = std::move(callback);
req->start_time = Common::Timer::GetValue();
std::unique_lock<std::mutex> lock(m_pending_http_request_lock);
if (LockedGetActiveRequestCount() < m_max_active_requests)
{
if (!StartRequest(req))
return;
}
LockedAddRequest(req);
}
void HTTPDownloader::LockedPollRequests(std::unique_lock<std::mutex>& lock)
{
if (m_pending_http_requests.empty())
return;
InternalPollRequests();
const Common::Timer::Value current_time = Common::Timer::GetValue();
u32 active_requests = 0;
u32 unstarted_requests = 0;
for (size_t index = 0; index < m_pending_http_requests.size();)
{
Request* req = m_pending_http_requests[index];
if (req->state == Request::State::Pending)
{
unstarted_requests++;
index++;
continue;
}
if (req->state == Request::State::Started && current_time >= req->start_time &&
Common::Timer::ConvertValueToSeconds(current_time - req->start_time) >= m_timeout)
{
// request timed out
Log_ErrorPrintf("Request for '%s' timed out", req->url.c_str());
req->state.store(Request::State::Cancelled);
m_pending_http_requests.erase(m_pending_http_requests.begin() + index);
lock.unlock();
req->callback(-1, Request::Data());
CloseRequest(req);
lock.lock();
continue;
}
if (req->state != Request::State::Complete)
{
active_requests++;
index++;
continue;
}
// request complete
Log_VerbosePrintf("Request for '%s' complete, returned status code %u and %zu bytes", req->url.c_str(),
req->status_code, req->data.size());
m_pending_http_requests.erase(m_pending_http_requests.begin() + index);
// run callback with lock unheld
lock.unlock();
req->callback(req->status_code, req->data);
CloseRequest(req);
lock.lock();
}
// start new requests when we finished some
if (unstarted_requests > 0 && active_requests < m_max_active_requests)
{
for (size_t index = 0; index < m_pending_http_requests.size();)
{
Request* req = m_pending_http_requests[index];
if (req->state != Request::State::Pending)
{
index++;
continue;
}
if (!StartRequest(req))
{
m_pending_http_requests.erase(m_pending_http_requests.begin() + index);
continue;
}
active_requests++;
index++;
if (active_requests >= m_max_active_requests)
break;
}
}
}
void HTTPDownloader::PollRequests()
{
std::unique_lock<std::mutex> lock(m_pending_http_request_lock);
LockedPollRequests(lock);
}
void HTTPDownloader::WaitForAllRequests()
{
std::unique_lock<std::mutex> lock(m_pending_http_request_lock);
while (!m_pending_http_requests.empty())
LockedPollRequests(lock);
}
void HTTPDownloader::LockedAddRequest(Request* request)
{
m_pending_http_requests.push_back(request);
}
u32 HTTPDownloader::LockedGetActiveRequestCount()
{
u32 count = 0;
for (Request* req : m_pending_http_requests)
{
if (req->state == Request::State::Started || req->state == Request::State::Receiving)
count++;
}
return count;
}
} // namespace FrontendCommon

View File

@ -1,85 +0,0 @@
#pragma once
#include "common/types.h"
#include <atomic>
#include <functional>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
namespace FrontendCommon {
class HTTPDownloader
{
public:
enum : s32
{
HTTP_OK = 200
};
struct Request
{
using Data = std::vector<u8>;
using Callback = std::function<void(s32 status_code, const Data& data)>;
enum class Type
{
Get,
Post,
};
enum class State
{
Pending,
Cancelled,
Started,
Receiving,
Complete,
};
HTTPDownloader* parent;
Callback callback;
std::string url;
std::string post_data;
Data data;
u64 start_time;
s32 status_code = 0;
u32 content_length = 0;
Type type = Type::Get;
std::atomic<State> state{State::Pending};
};
HTTPDownloader();
virtual ~HTTPDownloader();
static std::unique_ptr<HTTPDownloader> Create(const char* user_agent = DEFAULT_USER_AGENT);
void SetTimeout(float timeout);
void SetMaxActiveRequests(u32 max_active_requests);
void CreateRequest(std::string url, Request::Callback callback);
void CreatePostRequest(std::string url, std::string post_data, Request::Callback callback);
void PollRequests();
void WaitForAllRequests();
static const char DEFAULT_USER_AGENT[];
protected:
virtual Request* InternalCreateRequest() = 0;
virtual void InternalPollRequests() = 0;
virtual bool StartRequest(Request* request) = 0;
virtual void CloseRequest(Request* request) = 0;
void LockedAddRequest(Request* request);
u32 LockedGetActiveRequestCount();
void LockedPollRequests(std::unique_lock<std::mutex>& lock);
float m_timeout;
u32 m_max_active_requests;
std::mutex m_pending_http_request_lock;
std::vector<Request*> m_pending_http_requests;
};
} // namespace FrontendCommon

View File

@ -1,162 +0,0 @@
#include "http_downloader_curl.h"
#include "common/assert.h"
#include "common/log.h"
#include "common/string_util.h"
#include "common/timer.h"
#include <algorithm>
#include <functional>
#include <pthread.h>
#include <signal.h>
Log_SetChannel(HTTPDownloaderCurl);
namespace FrontendCommon {
HTTPDownloaderCurl::HTTPDownloaderCurl() : HTTPDownloader() {}
HTTPDownloaderCurl::~HTTPDownloaderCurl() = default;
std::unique_ptr<HTTPDownloader> HTTPDownloader::Create(const char* user_agent)
{
std::unique_ptr<HTTPDownloaderCurl> instance(std::make_unique<HTTPDownloaderCurl>());
if (!instance->Initialize(user_agent))
return {};
return instance;
}
static bool s_curl_initialized = false;
static std::once_flag s_curl_initialized_once_flag;
bool HTTPDownloaderCurl::Initialize(const char* user_agent)
{
if (!s_curl_initialized)
{
std::call_once(s_curl_initialized_once_flag, []() {
s_curl_initialized = curl_global_init(CURL_GLOBAL_ALL) == CURLE_OK;
if (s_curl_initialized)
{
std::atexit([]() {
curl_global_cleanup();
s_curl_initialized = false;
});
}
});
if (!s_curl_initialized)
{
Log_ErrorPrint("curl_global_init() failed");
return false;
}
}
m_user_agent = user_agent;
m_thread_pool = std::make_unique<cb::ThreadPool>(m_max_active_requests);
return true;
}
size_t HTTPDownloaderCurl::WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata)
{
Request* req = static_cast<Request*>(userdata);
const size_t current_size = req->data.size();
const size_t transfer_size = size * nmemb;
const size_t new_size = current_size + transfer_size;
req->data.resize(new_size);
std::memcpy(&req->data[current_size], ptr, transfer_size);
return nmemb;
}
void HTTPDownloaderCurl::ProcessRequest(Request* req)
{
std::unique_lock<std::mutex> cancel_lock(m_cancel_mutex);
if (req->closed.load())
return;
cancel_lock.unlock();
// Apparently OpenSSL can fire SIGPIPE...
sigset_t old_block_mask = {};
sigset_t new_block_mask = {};
sigemptyset(&old_block_mask);
sigemptyset(&new_block_mask);
sigaddset(&new_block_mask, SIGPIPE);
if (pthread_sigmask(SIG_BLOCK, &new_block_mask, &old_block_mask) != 0)
Log_WarningPrint("Failed to block SIGPIPE");
req->start_time = Common::Timer::GetValue();
int ret = curl_easy_perform(req->handle);
if (ret == CURLE_OK)
{
long response_code = 0;
curl_easy_getinfo(req->handle, CURLINFO_RESPONSE_CODE, &response_code);
req->status_code = static_cast<s32>(response_code);
Log_DevPrintf("Request for '%s' returned status code %d and %zu bytes", req->url.c_str(), req->status_code,
req->data.size());
}
else
{
Log_ErrorPrintf("Request for '%s' returned %d", req->url.c_str(), ret);
}
curl_easy_cleanup(req->handle);
if (pthread_sigmask(SIG_UNBLOCK, &new_block_mask, &old_block_mask) != 0)
Log_WarningPrint("Failed to unblock SIGPIPE");
cancel_lock.lock();
req->state = Request::State::Complete;
if (req->closed.load())
delete req;
else
req->closed.store(true);
}
HTTPDownloader::Request* HTTPDownloaderCurl::InternalCreateRequest()
{
Request* req = new Request();
req->handle = curl_easy_init();
if (!req->handle)
{
delete req;
return nullptr;
}
return req;
}
void HTTPDownloaderCurl::InternalPollRequests()
{
// noop - uses thread pool
}
bool HTTPDownloaderCurl::StartRequest(HTTPDownloader::Request* request)
{
Request* req = static_cast<Request*>(request);
curl_easy_setopt(req->handle, CURLOPT_URL, request->url.c_str());
curl_easy_setopt(req->handle, CURLOPT_USERAGENT, m_user_agent.c_str());
curl_easy_setopt(req->handle, CURLOPT_WRITEFUNCTION, &HTTPDownloaderCurl::WriteCallback);
curl_easy_setopt(req->handle, CURLOPT_WRITEDATA, req);
curl_easy_setopt(req->handle, CURLOPT_NOSIGNAL, 1);
if (request->type == Request::Type::Post)
{
curl_easy_setopt(req->handle, CURLOPT_POST, 1L);
curl_easy_setopt(req->handle, CURLOPT_POSTFIELDS, request->post_data.c_str());
}
Log_DevPrintf("Started HTTP request for '%s'", req->url.c_str());
req->state = Request::State::Started;
req->start_time = Common::Timer::GetValue();
m_thread_pool->Schedule(std::bind(&HTTPDownloaderCurl::ProcessRequest, this, req));
return true;
}
void HTTPDownloaderCurl::CloseRequest(HTTPDownloader::Request* request)
{
std::unique_lock<std::mutex> cancel_lock(m_cancel_mutex);
Request* req = static_cast<Request*>(request);
if (req->closed.load())
delete req;
else
req->closed.store(true);
}
} // namespace FrontendCommon

View File

@ -1,40 +0,0 @@
#pragma once
#include "http_downloader.h"
#include "common/thirdparty/thread_pool.h"
#include <atomic>
#include <memory>
#include <mutex>
#include <curl/curl.h>
namespace FrontendCommon {
class HTTPDownloaderCurl final : public HTTPDownloader
{
public:
HTTPDownloaderCurl();
~HTTPDownloaderCurl() override;
bool Initialize(const char* user_agent);
protected:
Request* InternalCreateRequest() override;
void InternalPollRequests() override;
bool StartRequest(HTTPDownloader::Request* request) override;
void CloseRequest(HTTPDownloader::Request* request) override;
private:
struct Request : HTTPDownloader::Request
{
CURL* handle = nullptr;
std::atomic_bool closed{false};
};
static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata);
void ProcessRequest(Request* req);
std::string m_user_agent;
std::unique_ptr<cb::ThreadPool> m_thread_pool;
std::mutex m_cancel_mutex;
};
} // namespace FrontendCommon

View File

@ -1,165 +0,0 @@
#include "http_downloader_uwp.h"
#include "common/assert.h"
#include "common/log.h"
#include "common/string_util.h"
#include "common/timer.h"
#include <algorithm>
Log_SetChannel(HTTPDownloaderWinHttp);
#include <winrt/Windows.Foundation.h>
#include <winrt/Windows.Storage.Streams.h>
#include <winrt/Windows.Web.Http.Headers.h>
using namespace winrt::Windows::Foundation;
using namespace winrt::Windows::Web::Http;
namespace FrontendCommon {
HTTPDownloaderUWP::HTTPDownloaderUWP(std::string user_agent) : HTTPDownloader(), m_user_agent(std::move(user_agent)) {}
HTTPDownloaderUWP::~HTTPDownloaderUWP() = default;
std::unique_ptr<HTTPDownloader> HTTPDownloader::Create(const char* user_agent)
{
std::string user_agent_str;
if (user_agent)
user_agent_str = user_agent;
return std::make_unique<HTTPDownloaderUWP>(user_agent ? std::string(user_agent) : std::string());
}
HTTPDownloader::Request* HTTPDownloaderUWP::InternalCreateRequest()
{
Request* req = new Request();
return req;
}
void HTTPDownloaderUWP::InternalPollRequests()
{
// noop - uses async
}
bool HTTPDownloaderUWP::StartRequest(HTTPDownloader::Request* request)
{
Request* req = static_cast<Request*>(request);
try
{
const std::wstring url_wide(StringUtil::UTF8StringToWideString(req->url));
const Uri uri(url_wide);
if (!m_user_agent.empty() &&
!req->client.DefaultRequestHeaders().UserAgent().TryParseAdd(StringUtil::UTF8StringToWideString(m_user_agent)))
{
Log_WarningPrintf("Failed to set user agent to '%s'", m_user_agent.c_str());
}
if (req->type == Request::Type::Post)
{
const HttpStringContent post_content(StringUtil::UTF8StringToWideString(req->post_data),
winrt::Windows::Storage::Streams::UnicodeEncoding::Utf8,
L"application/x-www-form-urlencoded");
req->request_async = req->client.PostAsync(uri, post_content);
}
else
{
req->request_async = req->client.GetAsync(uri);
}
req->request_async.Completed(
[req](const IAsyncOperationWithProgress<HttpResponseMessage, HttpProgress>& operation, AsyncStatus status) {
if (status == AsyncStatus::Completed)
{
Log_DevPrintf("Request for '%s' completed start portion", req->url.c_str());
try
{
req->state.store(Request::State::Receiving);
req->start_time = Common::Timer::GetValue();
const HttpResponseMessage response(req->request_async.get());
req->status_code = static_cast<s32>(response.StatusCode());
const IHttpContent content(response.Content());
req->receive_async = content.ReadAsBufferAsync();
req->receive_async.Completed(
[req](
const IAsyncOperationWithProgress<winrt::Windows::Storage::Streams::IBuffer, uint64_t>& inner_operation,
AsyncStatus inner_status) {
if (inner_status == AsyncStatus::Completed)
{
const winrt::Windows::Storage::Streams::IBuffer buffer(inner_operation.get());
if (buffer && buffer.Length() > 0)
{
req->data.resize(buffer.Length());
std::memcpy(req->data.data(), buffer.data(), req->data.size());
}
Log_DevPrintf("End of request '%s', %zu bytes received", req->url.c_str(), req->data.size());
req->state.store(Request::State::Complete);
}
else if (inner_status == AsyncStatus::Canceled)
{
// don't do anything, the request has been freed
}
else
{
Log_ErrorPrintf("Request for '%s' failed during recieve phase: %08X", req->url.c_str(),
inner_operation.ErrorCode().value);
req->status_code = -1;
req->state.store(Request::State::Complete);
}
});
}
catch (const winrt::hresult_error& err)
{
Log_ErrorPrintf("Failed to receive HTTP request for '%s': %08X %s", req->url.c_str(), err.code(),
StringUtil::WideStringToUTF8String(err.message()).c_str());
req->status_code = -1;
req->state.store(Request::State::Complete);
}
req->receive_async = nullptr;
}
else if (status == AsyncStatus::Canceled)
{
// don't do anything, the request has been freed
}
else
{
Log_ErrorPrintf("Request for '%s' failed during start phase: %08X", req->url.c_str(),
operation.ErrorCode().value);
req->status_code = -1;
req->state.store(Request::State::Complete);
}
req->request_async = nullptr;
});
}
catch (const winrt::hresult_error& err)
{
Log_ErrorPrintf("Failed to start HTTP request for '%s': %08X %s", req->url.c_str(), err.code(),
StringUtil::WideStringToUTF8String(err.message()).c_str());
req->callback(-1, req->data);
delete req;
return false;
}
Log_DevPrintf("Started HTTP request for '%s'", req->url.c_str());
req->state = Request::State::Started;
req->start_time = Common::Timer::GetValue();
return true;
}
void HTTPDownloaderUWP::CloseRequest(HTTPDownloader::Request* request)
{
Request* req = static_cast<Request*>(request);
if (req->request_async)
req->request_async.Cancel();
if (req->receive_async)
req->receive_async.Cancel();
req->client.Close();
delete req;
}
} // namespace FrontendCommon

View File

@ -1,37 +0,0 @@
#pragma once
#include "http_downloader.h"
#include "common/windows_headers.h"
#include <winrt/windows.Web.Http.h>
namespace FrontendCommon {
class HTTPDownloaderUWP final : public HTTPDownloader
{
public:
HTTPDownloaderUWP(std::string user_agent);
~HTTPDownloaderUWP() override;
protected:
Request* InternalCreateRequest() override;
void InternalPollRequests() override;
bool StartRequest(HTTPDownloader::Request* request) override;
void CloseRequest(HTTPDownloader::Request* request) override;
private:
struct Request : HTTPDownloader::Request
{
std::wstring object_name;
winrt::Windows::Web::Http::HttpClient client;
winrt::Windows::Foundation::IAsyncOperationWithProgress<winrt::Windows::Web::Http::HttpResponseMessage,
winrt::Windows::Web::Http::HttpProgress>
request_async{nullptr};
winrt::Windows::Foundation::IAsyncOperationWithProgress<winrt::Windows::Storage::Streams::IBuffer, uint64_t>
receive_async{};
};
std::string m_user_agent;
};
} // namespace FrontendCommon

View File

@ -1,300 +0,0 @@
#include "http_downloader_winhttp.h"
#include "common/assert.h"
#include "common/log.h"
#include "common/string_util.h"
#include "common/timer.h"
#include <VersionHelpers.h>
#include <algorithm>
Log_SetChannel(HTTPDownloaderWinHttp);
#pragma comment(lib, "winhttp.lib")
namespace FrontendCommon {
HTTPDownloaderWinHttp::HTTPDownloaderWinHttp() : HTTPDownloader() {}
HTTPDownloaderWinHttp::~HTTPDownloaderWinHttp()
{
if (m_hSession)
{
WinHttpSetStatusCallback(m_hSession, nullptr, WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS, NULL);
WinHttpCloseHandle(m_hSession);
}
}
std::unique_ptr<HTTPDownloader> HTTPDownloader::Create(const char* user_agent)
{
std::unique_ptr<HTTPDownloaderWinHttp> instance(std::make_unique<HTTPDownloaderWinHttp>());
if (!instance->Initialize(user_agent))
return {};
return instance;
}
bool HTTPDownloaderWinHttp::Initialize(const char* user_agent)
{
// WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY is not supported before Win8.1.
const DWORD dwAccessType =
IsWindows8Point1OrGreater() ? WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY : WINHTTP_ACCESS_TYPE_DEFAULT_PROXY;
m_hSession = WinHttpOpen(StringUtil::UTF8StringToWideString(user_agent).c_str(), dwAccessType, nullptr, nullptr,
WINHTTP_FLAG_ASYNC);
if (m_hSession == NULL)
{
Log_ErrorPrintf("WinHttpOpen() failed: %u", GetLastError());
return false;
}
const DWORD notification_flags = WINHTTP_CALLBACK_FLAG_ALL_COMPLETIONS | WINHTTP_CALLBACK_FLAG_REQUEST_ERROR |
WINHTTP_CALLBACK_FLAG_HANDLES | WINHTTP_CALLBACK_FLAG_SECURE_FAILURE;
if (WinHttpSetStatusCallback(m_hSession, HTTPStatusCallback, notification_flags, NULL) ==
WINHTTP_INVALID_STATUS_CALLBACK)
{
Log_ErrorPrintf("WinHttpSetStatusCallback() failed: %u", GetLastError());
return false;
}
return true;
}
void CALLBACK HTTPDownloaderWinHttp::HTTPStatusCallback(HINTERNET hRequest, DWORD_PTR dwContext, DWORD dwInternetStatus,
LPVOID lpvStatusInformation, DWORD dwStatusInformationLength)
{
Request* req = reinterpret_cast<Request*>(dwContext);
switch (dwInternetStatus)
{
case WINHTTP_CALLBACK_STATUS_HANDLE_CREATED:
return;
case WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING:
{
if (!req)
return;
DebugAssert(hRequest == req->hRequest);
HTTPDownloaderWinHttp* parent = static_cast<HTTPDownloaderWinHttp*>(req->parent);
std::unique_lock<std::mutex> lock(parent->m_pending_http_request_lock);
Assert(std::none_of(parent->m_pending_http_requests.begin(), parent->m_pending_http_requests.end(),
[req](HTTPDownloader::Request* it) { return it == req; }));
// we can clean up the connection as well
DebugAssert(req->hConnection != NULL);
WinHttpCloseHandle(req->hConnection);
delete req;
return;
}
case WINHTTP_CALLBACK_STATUS_REQUEST_ERROR:
{
const WINHTTP_ASYNC_RESULT* res = reinterpret_cast<const WINHTTP_ASYNC_RESULT*>(lpvStatusInformation);
Log_ErrorPrintf("WinHttp async function %p returned error %u", res->dwResult, res->dwError);
req->status_code = -1;
req->state.store(Request::State::Complete);
return;
}
case WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE:
{
Log_DevPrintf("SendRequest complete");
if (!WinHttpReceiveResponse(hRequest, nullptr))
{
Log_ErrorPrintf("WinHttpReceiveResponse() failed: %u", GetLastError());
req->status_code = -1;
req->state.store(Request::State::Complete);
}
return;
}
case WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE:
{
Log_DevPrintf("Headers available");
DWORD buffer_size = sizeof(req->status_code);
if (!WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER,
WINHTTP_HEADER_NAME_BY_INDEX, &req->status_code, &buffer_size, WINHTTP_NO_HEADER_INDEX))
{
Log_ErrorPrintf("WinHttpQueryHeaders() for status code failed: %u", GetLastError());
req->status_code = -1;
req->state.store(Request::State::Complete);
return;
}
buffer_size = sizeof(req->content_length);
if (!WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_CONTENT_LENGTH | WINHTTP_QUERY_FLAG_NUMBER,
WINHTTP_HEADER_NAME_BY_INDEX, &req->content_length, &buffer_size,
WINHTTP_NO_HEADER_INDEX))
{
if (GetLastError() != ERROR_WINHTTP_HEADER_NOT_FOUND)
Log_WarningPrintf("WinHttpQueryHeaders() for content length failed: %u", GetLastError());
req->content_length = 0;
}
Log_DevPrintf("Status code %d, content-length is %u", req->status_code, req->content_length);
req->data.reserve(req->content_length);
req->state = Request::State::Receiving;
// start reading
if (!WinHttpQueryDataAvailable(hRequest, nullptr) && GetLastError() != ERROR_IO_PENDING)
{
Log_ErrorPrintf("WinHttpQueryDataAvailable() failed: %u", GetLastError());
req->status_code = -1;
req->state.store(Request::State::Complete);
}
return;
}
case WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE:
{
DWORD bytes_available;
std::memcpy(&bytes_available, lpvStatusInformation, sizeof(bytes_available));
if (bytes_available == 0)
{
// end of request
Log_DevPrintf("End of request '%s', %zu bytes received", req->url.c_str(), req->data.size());
req->state.store(Request::State::Complete);
return;
}
// start the transfer
Log_DevPrintf("%u bytes available", bytes_available);
req->io_position = static_cast<u32>(req->data.size());
req->data.resize(req->io_position + bytes_available);
if (!WinHttpReadData(hRequest, req->data.data() + req->io_position, bytes_available, nullptr) &&
GetLastError() != ERROR_IO_PENDING)
{
Log_ErrorPrintf("WinHttpReadData() failed: %u", GetLastError());
req->status_code = -1;
req->state.store(Request::State::Complete);
}
return;
}
case WINHTTP_CALLBACK_STATUS_READ_COMPLETE:
{
Log_DevPrintf("Read of %u complete", dwStatusInformationLength);
const u32 new_size = req->io_position + dwStatusInformationLength;
Assert(new_size <= req->data.size());
req->data.resize(new_size);
req->start_time = Common::Timer::GetValue();
if (!WinHttpQueryDataAvailable(hRequest, nullptr) && GetLastError() != ERROR_IO_PENDING)
{
Log_ErrorPrintf("WinHttpQueryDataAvailable() failed: %u", GetLastError());
req->status_code = -1;
req->state.store(Request::State::Complete);
}
return;
}
default:
// unhandled, ignore
return;
}
}
HTTPDownloader::Request* HTTPDownloaderWinHttp::InternalCreateRequest()
{
Request* req = new Request();
return req;
}
void HTTPDownloaderWinHttp::InternalPollRequests()
{
// noop - it uses windows's worker threads
}
bool HTTPDownloaderWinHttp::StartRequest(HTTPDownloader::Request* request)
{
Request* req = static_cast<Request*>(request);
std::wstring host_name;
host_name.resize(req->url.size());
req->object_name.resize(req->url.size());
URL_COMPONENTSW uc = {};
uc.dwStructSize = sizeof(uc);
uc.lpszHostName = host_name.data();
uc.dwHostNameLength = static_cast<DWORD>(host_name.size());
uc.lpszUrlPath = req->object_name.data();
uc.dwUrlPathLength = static_cast<DWORD>(req->object_name.size());
const std::wstring url_wide(StringUtil::UTF8StringToWideString(req->url));
if (!WinHttpCrackUrl(url_wide.c_str(), static_cast<DWORD>(url_wide.size()), 0, &uc))
{
Log_ErrorPrintf("WinHttpCrackUrl() failed: %u", GetLastError());
req->callback(-1, req->data);
delete req;
return false;
}
host_name.resize(uc.dwHostNameLength);
req->object_name.resize(uc.dwUrlPathLength);
req->hConnection = WinHttpConnect(m_hSession, host_name.c_str(), uc.nPort, 0);
if (!req->hConnection)
{
Log_ErrorPrintf("Failed to start HTTP request for '%s': %u", req->url.c_str(), GetLastError());
req->callback(-1, req->data);
delete req;
return false;
}
const DWORD request_flags = uc.nScheme == INTERNET_SCHEME_HTTPS ? WINHTTP_FLAG_SECURE : 0;
req->hRequest =
WinHttpOpenRequest(req->hConnection, (req->type == HTTPDownloader::Request::Type::Post) ? L"POST" : L"GET",
req->object_name.c_str(), NULL, NULL, NULL, request_flags);
if (!req->hRequest)
{
Log_ErrorPrintf("WinHttpOpenRequest() failed: %u", GetLastError());
WinHttpCloseHandle(req->hConnection);
return false;
}
BOOL result;
if (req->type == HTTPDownloader::Request::Type::Post)
{
const std::wstring_view additional_headers(L"Content-Type: application/x-www-form-urlencoded\r\n");
result = WinHttpSendRequest(req->hRequest, additional_headers.data(), static_cast<DWORD>(additional_headers.size()),
req->post_data.data(), static_cast<DWORD>(req->post_data.size()),
static_cast<DWORD>(req->post_data.size()), reinterpret_cast<DWORD_PTR>(req));
}
else
{
result = WinHttpSendRequest(req->hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, WINHTTP_NO_REQUEST_DATA, 0, 0,
reinterpret_cast<DWORD_PTR>(req));
}
if (!result && GetLastError() != ERROR_IO_PENDING)
{
Log_ErrorPrintf("WinHttpSendRequest() failed: %u", GetLastError());
req->status_code = -1;
req->state.store(Request::State::Complete);
}
Log_DevPrintf("Started HTTP request for '%s'", req->url.c_str());
req->state = Request::State::Started;
req->start_time = Common::Timer::GetValue();
return true;
}
void HTTPDownloaderWinHttp::CloseRequest(HTTPDownloader::Request* request)
{
Request* req = static_cast<Request*>(request);
if (req->hRequest != NULL)
{
// req will be freed by the callback.
// the callback can fire immediately here if there's nothing running async, so don't touch req afterwards
WinHttpCloseHandle(req->hRequest);
return;
}
if (req->hConnection != NULL)
WinHttpCloseHandle(req->hConnection);
delete req;
}
} // namespace FrontendCommon

View File

@ -1,39 +0,0 @@
#pragma once
#include "http_downloader.h"
#include "common/windows_headers.h"
#include <winhttp.h>
namespace FrontendCommon {
class HTTPDownloaderWinHttp final : public HTTPDownloader
{
public:
HTTPDownloaderWinHttp();
~HTTPDownloaderWinHttp() override;
bool Initialize(const char* user_agent);
protected:
Request* InternalCreateRequest() override;
void InternalPollRequests() override;
bool StartRequest(HTTPDownloader::Request* request) override;
void CloseRequest(HTTPDownloader::Request* request) override;
private:
struct Request : HTTPDownloader::Request
{
std::wstring object_name;
HINTERNET hConnection = NULL;
HINTERNET hRequest = NULL;
u32 io_position = 0;
};
static void CALLBACK HTTPStatusCallback(HINTERNET hInternet, DWORD_PTR dwContext, DWORD dwInternetStatus,
LPVOID lpvStatusInformation, DWORD dwStatusInformationLength);
HINTERNET m_hSession = NULL;
};
} // namespace FrontendCommon