diff --git a/src/http_client_extension.cpp b/src/http_client_extension.cpp index 131dba5..dee7518 100644 --- a/src/http_client_extension.cpp +++ b/src/http_client_extension.cpp @@ -125,6 +125,25 @@ static std::string GetHttpErrorMessage(const duckdb_httplib_openssl::Result &res return err_message; } +// Helper function to convert list of entries to a map of parameters. +template +static int ConvertListEntryToMap(const list_entry_t& list_entry, const duckdb::Vector& input, T& result) { + for (idx_t i = list_entry.offset; i < list_entry.offset + list_entry.length; i++) { + const auto &child_value = input.GetValue(i); + + Vector tmp(child_value); + auto &children = StructVector::GetEntries(tmp); + + if (children.size() == 2) { + auto name = FlatVector::GetData(*children[0]); + auto data = FlatVector::GetData(*children[1]); + std::string key = name->GetString(); + std::string val = data->GetString(); + result.emplace(key, val); + } + } + return result.size(); +} static void HTTPGetRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.data.size() == 1); @@ -149,6 +168,50 @@ static void HTTPGetRequestFunction(DataChunk &args, ExpressionState &state, Vect }); } +static void HTTPGetExRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.data.size() == 3); + + using STRING_TYPE = PrimitiveType; + using LENTRY_TYPE = PrimitiveType; + + auto &url_vector = args.data[0]; + auto &headers_vector = args.data[1]; + auto &headers_entry = ListVector::GetEntry(headers_vector); + auto ¶ms_vector = args.data[2]; + auto ¶ms_entry = ListVector::GetEntry(params_vector); + + GenericExecutor::ExecuteTernary( + url_vector, headers_vector, params_vector, result, args.size(), + [&](STRING_TYPE url, LENTRY_TYPE headers, LENTRY_TYPE params) { + std::string url_str = url.val.GetString(); + + // Use helper to setup client and parse URL + auto client_and_path = SetupHttpClient(url_str); + auto &client = client_and_path.first; + auto &path = client_and_path.second; + + // Prepare headers + duckdb_httplib_openssl::Headers header_map; + auto header_list = headers.val; + ConvertListEntryToMap(header_list, headers_entry, header_map); + + // Prepare params + duckdb_httplib_openssl::Params param_map; + auto params_list = params.val; + ConvertListEntryToMap(params_list, params_entry, param_map); + + // Make the POST request with headers and params + auto res = client.Get(path.c_str(), param_map, header_map); + if (res) { + std::string response = GetJsonResponse(res->status, res->reason, res->body); + return StringVector::AddString(result, response); + } else { + std::string response = GetJsonResponse(-1, GetHttpErrorMessage(res, "GET"), ""); + return StringVector::AddString(result, response); + } + }); +} + static void HTTPPostRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.data.size() == 3); @@ -173,20 +236,7 @@ static void HTTPPostRequestFunction(DataChunk &args, ExpressionState &state, Vec // Prepare headers duckdb_httplib_openssl::Headers header_map; auto header_list = headers.val; - for (idx_t i = header_list.offset; i < header_list.offset + header_list.length; i++) { - const auto &child_value = headers_entry.GetValue(i); - - Vector tmp(child_value); - auto &children = StructVector::GetEntries(tmp); - - if (children.size() == 2) { - auto name = FlatVector::GetData(*children[0]); - auto data = FlatVector::GetData(*children[1]); - std::string key = name->GetString(); - std::string val = data->GetString(); - header_map.emplace(key, val); - } - } + ConvertListEntryToMap(header_list, headers_entry, header_map); // Make the POST request with headers and body auto res = client.Post(path.c_str(), header_map, body.val.GetString(), "application/json"); @@ -204,6 +254,10 @@ static void HTTPPostRequestFunction(DataChunk &args, ExpressionState &state, Vec static void LoadInternal(DatabaseInstance &instance) { ScalarFunctionSet http_get("http_get"); http_get.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::JSON(), HTTPGetRequestFunction)); + http_get.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR), + LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)}, + LogicalType::JSON(), HTTPGetExRequestFunction)); ExtensionUtil::RegisterFunction(instance, http_get); ScalarFunctionSet http_post("http_post"); diff --git a/test/sql/httpclient.test b/test/sql/httpclient.test index c658552..7744239 100644 --- a/test/sql/httpclient.test +++ b/test/sql/httpclient.test @@ -39,6 +39,38 @@ FROM ---- 200 OK httpbin.org +# Confirm the GET extension works with headers and params +query III +WITH __input AS ( + SELECT + http_get( + 'https://httpbin.org/delay/0', + headers => MAP { + 'accept': 'application/json', + }, + params => MAP { + 'limit': 10 + } + ) AS res +), +__response AS ( + SELECT + (res->>'status')::INT AS status, + (res->>'reason') AS reason, + unnest( from_json(((res->>'body')::JSON)->'headers', '{"Host": "VARCHAR"}') ) AS features + FROM + __input +) +SELECT + __response.status, + __response.reason, + __response.Host AS host +FROM + __response +; +---- +200 OK httpbin.org + # Confirm the POST extension works query III WITH __input AS (