From e083fe896b993e73f385c7a028a8b389cc6e9f9d Mon Sep 17 00:00:00 2001 From: Christian Roessner Date: Thu, 26 Sep 2024 18:20:36 +0200 Subject: [PATCH] Fix: Return http.Client from registerDynamicLoader functions Modified registerDynamicLoader functions to return an http.Client and ensured idle connections are properly closed. This enhances resource management and prevents potential issues with hanging HTTP connections. Signed-off-by: Christian Roessner --- server/backend/lua.go | 12 +++++++++--- server/core/hydra.go | 12 ++++++++++++ server/lualib/action/action.go | 10 +++++++--- server/lualib/callback/callback.go | 13 ++++++++++--- server/lualib/feature/feature.go | 12 +++++++++--- server/lualib/filter/filter.go | 12 +++++++++--- server/lualib/loader.go | 9 ++++++--- server/util/util.go | 11 +++++++++++ 8 files changed, 73 insertions(+), 18 deletions(-) diff --git a/server/backend/lua.go b/server/backend/lua.go index 220df3d8..47045c7f 100644 --- a/server/backend/lua.go +++ b/server/backend/lua.go @@ -18,6 +18,7 @@ package backend import ( "context" "fmt" + "net/http" "time" "github.com/croessner/nauthilus/server/config" @@ -128,7 +129,7 @@ func LuaMainWorker(ctx context.Context) { // - luaRequest: The *LuaRequest object containing the request parameters. // // Returns: None. -func registerDynamicLoader(L *lua.LState, ctx context.Context, luaRequest *LuaRequest) { +func registerDynamicLoader(L *lua.LState, ctx context.Context, luaRequest *LuaRequest) (httpClient *http.Client) { dynamicLoader := L.NewFunction(func(L *lua.LState) int { modName := L.CheckString(1) @@ -137,13 +138,15 @@ func registerDynamicLoader(L *lua.LState, ctx context.Context, luaRequest *LuaRe return 0 } - lualib.RegisterCommonLuaLibraries(L, modName, registry) + httpClient = lualib.RegisterCommonLuaLibraries(L, modName, registry) registerModule(L, ctx, luaRequest, modName, registry) return 0 }) L.SetGlobal("dynamic_loader", dynamicLoader) + + return httpClient } // registerModule registers a module in the Lua state based on the given modName. @@ -220,7 +223,10 @@ func handleLuaRequest(ctx context.Context, luaRequest *LuaRequest, compiledScrip global.LuaBackendResultAttributes, ) - registerDynamicLoader(L, ctx, luaRequest) + httpClient := registerDynamicLoader(L, ctx, luaRequest) + + defer util.CloseIdleHTTPConnections(httpClient) + setupGlobals(luaRequest, L, logs) request := L.NewTable() diff --git a/server/core/hydra.go b/server/core/hydra.go index 9eca51c5..4dbc5bc7 100644 --- a/server/core/hydra.go +++ b/server/core/hydra.go @@ -1131,6 +1131,8 @@ func loginGETHandler(ctx *gin.Context) { apiConfig.initialize() + defer util.CloseIdleHTTPConnections(apiConfig.httpClient) + apiConfig.challenge = loginChallenge apiConfig.csrfToken = ctx.GetString(global.CtxCSRFTokenKey) @@ -1663,6 +1665,8 @@ func loginPOSTHandler(ctx *gin.Context) { apiConfig.initialize() + defer util.CloseIdleHTTPConnections(apiConfig.httpClient) + apiConfig.challenge = loginChallenge auth, err := initializeAuthLogin(ctx) @@ -2214,6 +2218,8 @@ func consentGETHandler(ctx *gin.Context) { apiConfig.initialize() + defer util.CloseIdleHTTPConnections(apiConfig.httpClient) + apiConfig.challenge = consentChallenge apiConfig.csrfToken = ctx.GetString(global.CtxCSRFTokenKey) @@ -2456,6 +2462,8 @@ func consentPOSTHandler(ctx *gin.Context) { apiConfig.initialize() + defer util.CloseIdleHTTPConnections(apiConfig.httpClient) + apiConfig.challenge = consentChallenge apiConfig.consentRequest, httpResponse, err = apiConfig.apiClient.OAuth2API.GetOAuth2ConsentRequest(ctx).ConsentChallenge( @@ -2564,6 +2572,8 @@ func logoutGETHandler(ctx *gin.Context) { apiConfig.initialize() + defer util.CloseIdleHTTPConnections(apiConfig.httpClient) + apiConfig.challenge = logoutChallenge apiConfig.csrfToken = ctx.GetString(global.CtxCSRFTokenKey) @@ -2712,6 +2722,8 @@ func logoutPOSTHandler(ctx *gin.Context) { apiConfig.initialize() + defer util.CloseIdleHTTPConnections(apiConfig.httpClient) + apiConfig.challenge = logoutChallenge apiConfig.logoutRequest, httpResponse, err = apiConfig.apiClient.OAuth2API.GetOAuth2LogoutRequest(ctx).LogoutChallenge( diff --git a/server/lualib/action/action.go b/server/lualib/action/action.go index 9a2d37ac..dccd168d 100644 --- a/server/lualib/action/action.go +++ b/server/lualib/action/action.go @@ -220,7 +220,7 @@ func (aw *Worker) loadScript(luaAction *LuaScriptAction, scriptName string, scri // dynamic loader function. // // Note that this documentation assumes familiarity with the Lua programming language and its module system. -func (aw *Worker) registerDynamicLoader(L *lua.LState, httpRequest *http.Request) { +func (aw *Worker) registerDynamicLoader(L *lua.LState, httpRequest *http.Request) (httpClient *http.Client) { dynamicLoader := L.NewFunction(func(L *lua.LState) int { modName := L.CheckString(1) @@ -229,13 +229,15 @@ func (aw *Worker) registerDynamicLoader(L *lua.LState, httpRequest *http.Request return 0 } - lualib.RegisterCommonLuaLibraries(L, modName, registry) + httpClient = lualib.RegisterCommonLuaLibraries(L, modName, registry) aw.registerModule(L, httpRequest, modName, registry) return 0 }) L.SetGlobal("dynamic_loader", dynamicLoader) + + return httpClient } // registerModule registers a Lua module in the given Lua state. @@ -326,7 +328,9 @@ func (aw *Worker) handleRequest(httpRequest *http.Request) { defer L.Close() - aw.registerDynamicLoader(L, httpRequest) + httpClient := aw.registerDynamicLoader(L, httpRequest) + + defer util.CloseIdleHTTPConnections(httpClient) logs := new(lualib.CustomLogKeyValue) diff --git a/server/lualib/callback/callback.go b/server/lualib/callback/callback.go index f669e7d0..9c9c0d14 100644 --- a/server/lualib/callback/callback.go +++ b/server/lualib/callback/callback.go @@ -17,6 +17,7 @@ package callback import ( "context" + "net/http" "sync" "time" @@ -25,6 +26,7 @@ import ( "github.com/croessner/nauthilus/server/global" "github.com/croessner/nauthilus/server/log" "github.com/croessner/nauthilus/server/lualib" + "github.com/croessner/nauthilus/server/util" "github.com/gin-gonic/gin" "github.com/go-kit/log/level" "github.com/spf13/viper" @@ -157,7 +159,7 @@ func setupLogging(L *lua.LState) *lua.LTable { // Note: The implementation of the dynamic loader function is not shown in this // documentation. Please refer to the source code for more details on the // implementation of the dynamic loader function. -func registerDynamicLoader(L *lua.LState, ctx *gin.Context) { +func registerDynamicLoader(L *lua.LState, ctx *gin.Context) (httpClient *http.Client) { dynamicLoader := L.NewFunction(func(L *lua.LState) int { modName := L.CheckString(1) @@ -166,13 +168,15 @@ func registerDynamicLoader(L *lua.LState, ctx *gin.Context) { return 0 } - lualib.RegisterCommonLuaLibraries(L, modName, registry) + httpClient = lualib.RegisterCommonLuaLibraries(L, modName, registry) registerModule(L, ctx, modName, registry) return 0 }) L.SetGlobal("dynamic_loader", dynamicLoader) + + return httpClient } // registerModule registers a Lua module in the provided Lua state (L). @@ -222,7 +226,10 @@ func RunCallbackLuaRequest(ctx *gin.Context) (err error) { defer L.Close() - registerDynamicLoader(L, ctx) + httpClient := registerDynamicLoader(L, ctx) + + defer util.CloseIdleHTTPConnections(httpClient) + L.SetContext(luaCtx) logTable := setupLogging(L) diff --git a/server/lualib/feature/feature.go b/server/lualib/feature/feature.go index 0447b0e3..26f14210 100644 --- a/server/lualib/feature/feature.go +++ b/server/lualib/feature/feature.go @@ -19,6 +19,7 @@ import ( "context" stderrors "errors" "fmt" + "net/http" "sync" "time" @@ -166,7 +167,7 @@ type Request struct { // - ctx *gin.Context: the gin Context containing the request data // // Returns: none -func (r *Request) registerDynamicLoader(L *lua.LState, ctx *gin.Context) { +func (r *Request) registerDynamicLoader(L *lua.LState, ctx *gin.Context) (httpClient *http.Client) { dynamicLoader := L.NewFunction(func(L *lua.LState) int { modName := L.CheckString(1) @@ -175,13 +176,15 @@ func (r *Request) registerDynamicLoader(L *lua.LState, ctx *gin.Context) { return 0 } - lualib.RegisterCommonLuaLibraries(L, modName, registry) + httpClient = lualib.RegisterCommonLuaLibraries(L, modName, registry) r.registerModule(L, ctx, modName, registry) return 0 }) L.SetGlobal("dynamic_loader", dynamicLoader) + + return httpClient } // registerModule registers a module in the LuaState based on the provided module name. @@ -233,7 +236,10 @@ func (r *Request) CallFeatureLua(ctx *gin.Context) (triggered bool, abortFeature defer L.Close() - r.registerDynamicLoader(L, ctx) + httpClient := r.registerDynamicLoader(L, ctx) + + defer util.CloseIdleHTTPConnections(httpClient) + r.setGlobals(L) request := r.setRequest(L) diff --git a/server/lualib/filter/filter.go b/server/lualib/filter/filter.go index f11924a0..edcd998b 100644 --- a/server/lualib/filter/filter.go +++ b/server/lualib/filter/filter.go @@ -18,6 +18,7 @@ package filter import ( "context" stderrors "errors" + "net/http" "sync" "time" @@ -47,7 +48,7 @@ import ( // Then, it calls registerModule to register module-specific libraries based on the module name. // After registering the libraries, it sets the global variable "dynamic_loader" in the Lua state to the created function. // The function does not return any value. -func registerDynamicLoader(L *lua.LState, ctx *gin.Context, r *Request, backendResult **lualib.LuaBackendResult, removeAttributes *[]string) { +func registerDynamicLoader(L *lua.LState, ctx *gin.Context, r *Request, backendResult **lualib.LuaBackendResult, removeAttributes *[]string) (httpClient *http.Client) { dynamicLoader := L.NewFunction(func(L *lua.LState) int { modName := L.CheckString(1) @@ -56,13 +57,15 @@ func registerDynamicLoader(L *lua.LState, ctx *gin.Context, r *Request, backendR return 0 } - lualib.RegisterCommonLuaLibraries(L, modName, registry) + httpClient = lualib.RegisterCommonLuaLibraries(L, modName, registry) registerModule(L, ctx, r, modName, registry, backendResult, removeAttributes) return 0 }) L.SetGlobal("dynamic_loader", dynamicLoader) + + return httpClient } // registerModule registers a Lua module based on the given modName. It loads and preloads the respective Lua functions @@ -614,7 +617,10 @@ func (r *Request) CallFilterLua(ctx *gin.Context) (action bool, backendResult *l defer L.Close() - registerDynamicLoader(L, ctx, r, &backendResult, &removeAttributes) + httpClient := registerDynamicLoader(L, ctx, r, &backendResult, &removeAttributes) + + defer util.CloseIdleHTTPConnections(httpClient) + lualib.RegisterBackendResultType(L, global.LuaBackendResultAttributes) setGlobals(r, L) diff --git a/server/lualib/loader.go b/server/lualib/loader.go index 79de7d3a..1b83617a 100644 --- a/server/lualib/loader.go +++ b/server/lualib/loader.go @@ -74,7 +74,7 @@ import ( // Please refer to the individual module documentations for more details on each Preload function. // Please also note that the declaration codes for the constants used in the switch cases are not shown here. // Refer to the module documentations for the declaration codes of the constants. -func RegisterCommonLuaLibraries(L *lua.LState, modName string, registry map[string]bool) { +func RegisterCommonLuaLibraries(L *lua.LState, modName string, registry map[string]bool) (httpClient *stdhttp.Client) { switch modName { case global.LuaModGLLPlugin: plugin.Preload(L) @@ -147,7 +147,7 @@ func RegisterCommonLuaLibraries(L *lua.LState, modName string, registry map[stri case global.LuaModGLuaCrypto: gluacrypto.Preload(L) case global.LuaModGLuaHTTP: - httpClient := &stdhttp.Client{ + httpClient = &stdhttp.Client{ Timeout: 60 * stdtime.Second, Transport: &stdhttp.Transport{ TLSClientConfig: &tls.Config{ @@ -163,7 +163,8 @@ func RegisterCommonLuaLibraries(L *lua.LState, modName string, registry map[stri case global.LuaModRedis: L.PreloadModule(modName, redislib.LoaderModRedis) case global.LuaModMail: - mailModule := NewMailModule(&smtp.EmailClient{}) + smtpClient := &smtp.EmailClient{} + mailModule := NewMailModule(smtpClient) L.PreloadModule(modName, mailModule.Loader) case global.LuaModMisc: @@ -173,4 +174,6 @@ func RegisterCommonLuaLibraries(L *lua.LState, modName string, registry map[stri } registry[modName] = true + + return } diff --git a/server/util/util.go b/server/util/util.go index 0de6caab..f5e32dc2 100644 --- a/server/util/util.go +++ b/server/util/util.go @@ -25,6 +25,7 @@ import ( "fmt" "hash" "net" + "net/http" "regexp" "runtime" "strings" @@ -516,3 +517,13 @@ func NewDNSResolver() (resolver *net.Resolver) { return } + +// CloseIdleHTTPConnections closes any idle connections used by the provided HTTP client. +// If the client is nil or the transport is not of type *http.Transport, it does nothing. +func CloseIdleHTTPConnections(httpClient *http.Client) { + if httpClient != nil { + if transport, ok := httpClient.Transport.(*http.Transport); ok { + transport.CloseIdleConnections() + } + } +}