From e45df6fe783f3e613f1d381b8215e2873f3fb63d Mon Sep 17 00:00:00 2001 From: lmangani Date: Wed, 16 Oct 2024 08:28:58 +0000 Subject: [PATCH] try secrets --- src/httpserver_extension.cpp | 72 ++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/src/httpserver_extension.cpp b/src/httpserver_extension.cpp index 770bb0b..8e70b8d 100644 --- a/src/httpserver_extension.cpp +++ b/src/httpserver_extension.cpp @@ -8,6 +8,7 @@ #include "duckdb/common/atomic.hpp" #include "duckdb/common/exception/http_exception.hpp" #include "duckdb/common/allocator.hpp" +#include "duckdb/main/secret/secret_manager.hpp" #include #define CPPHTTPLIB_OPENSSL_SUPPORT @@ -25,6 +26,47 @@ using namespace duckdb_yyjson; // NOLINT namespace duckdb { +static unique_ptr CreateHTTPSecretFunction(ClientContext &, CreateSecretInput &input) { + // apply any overridden settings + vector prefix_paths; + auto result = make_uniq(prefix_paths, "http", "config", input.name); + for (const auto &named_param : input.options) { + auto lower_name = StringUtil::Lower(named_param.first); + + if (lower_name == "token") { + result->secret_map["token"] = named_param.second.ToString(); + } else { + throw InternalException("Unknown named parameter passed to CreateHTTPSecretFunction: " + lower_name); + } + } + + //! Set redact keys + result->redact_keys = {"token"}; + + return std::move(result); +} + +static void SetHTTPSecretParameters(CreateSecretFunction &function) { + function.named_parameters["token"] = LogicalType::VARCHAR; +} + +unique_ptr GetSecret(ClientContext &context, const string &secret_name) { + auto &secret_manager = SecretManager::Get(context); + auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context); + // FIXME: this should be adjusted once the `GetSecretByName` API supports this + // use case + auto secret_entry = secret_manager.GetSecretByName(transaction, secret_name, "memory"); + if (secret_entry) { + return secret_entry; + } + secret_entry = secret_manager.GetSecretByName(transaction, secret_name, "local_file"); + if (secret_entry) { + return secret_entry; + } + return nullptr; +} + + struct HttpServerState { std::unique_ptr server; std::unique_ptr server_thread; @@ -152,6 +194,22 @@ std::string base64_decode(const std::string &in) { // Auth Check bool IsAuthenticated(const duckdb_httplib_openssl::Request& req) { + + /* TODO: No context + string secret_name = "__default_http"; + auto secret_entry = GetSecret(context, secret_name); + if (secret_entry) { + // secret found - read data + const auto &kv_secret = dynamic_cast(*secret_entry->secret); + string new_connection_info; + Value input_val = kv_secret.TryGetValue("token"); + if (!input_val.IsNull() ) { + return input_val.ToString(); + } + } + */ + + if (global_state.auth_token.empty()) { return true; // No authentication required if no token is set } @@ -412,6 +470,20 @@ static void HttpServerCleanup() { } static void LoadInternal(DatabaseInstance &instance) { + + SecretType secret_type; + secret_type.name = "http"; + secret_type.deserializer = KeyValueSecret::Deserialize; + secret_type.default_provider = "config"; + + ExtensionUtil::RegisterSecretType(instance, secret_type); + + CreateSecretFunction http_secret_function = {"http", "config", CreateHTTPSecretFunction}; + SetHTTPSecretParameters(http_secret_function); + ExtensionUtil::RegisterFunction(instance, http_secret_function); + + + auto httpserve_start = ScalarFunction("httpserve_start", {LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR,