Skip to content

Commit

Permalink
Fix: Return http.Client from registerDynamicLoader functions
Browse files Browse the repository at this point in the history
Modified registerDynamicLoader functions to return an http.Client and ensured idle connections are properly closed. This enhances resource management and prevents potential issues with hanging HTTP connections.

Signed-off-by: Christian Roessner <c@roessner.co>
  • Loading branch information
Christian Roessner committed Sep 26, 2024
1 parent f490fda commit e083fe8
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 18 deletions.
12 changes: 9 additions & 3 deletions server/backend/lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package backend
import (
"context"
"fmt"
"net/http"
"time"

"github.com/croessner/nauthilus/server/config"
Expand Down Expand Up @@ -128,7 +129,7 @@ func LuaMainWorker(ctx context.Context) {
// - luaRequest: The *LuaRequest object containing the request parameters.
//
// Returns: None.
func registerDynamicLoader(L *lua.LState, ctx context.Context, luaRequest *LuaRequest) {
func registerDynamicLoader(L *lua.LState, ctx context.Context, luaRequest *LuaRequest) (httpClient *http.Client) {
dynamicLoader := L.NewFunction(func(L *lua.LState) int {
modName := L.CheckString(1)

Expand All @@ -137,13 +138,15 @@ func registerDynamicLoader(L *lua.LState, ctx context.Context, luaRequest *LuaRe
return 0
}

lualib.RegisterCommonLuaLibraries(L, modName, registry)
httpClient = lualib.RegisterCommonLuaLibraries(L, modName, registry)
registerModule(L, ctx, luaRequest, modName, registry)

return 0
})

L.SetGlobal("dynamic_loader", dynamicLoader)

return httpClient
}

// registerModule registers a module in the Lua state based on the given modName.
Expand Down Expand Up @@ -220,7 +223,10 @@ func handleLuaRequest(ctx context.Context, luaRequest *LuaRequest, compiledScrip
global.LuaBackendResultAttributes,
)

registerDynamicLoader(L, ctx, luaRequest)
httpClient := registerDynamicLoader(L, ctx, luaRequest)

defer util.CloseIdleHTTPConnections(httpClient)

setupGlobals(luaRequest, L, logs)

request := L.NewTable()
Expand Down
12 changes: 12 additions & 0 deletions server/core/hydra.go
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,8 @@ func loginGETHandler(ctx *gin.Context) {

apiConfig.initialize()

defer util.CloseIdleHTTPConnections(apiConfig.httpClient)

apiConfig.challenge = loginChallenge
apiConfig.csrfToken = ctx.GetString(global.CtxCSRFTokenKey)

Expand Down Expand Up @@ -1663,6 +1665,8 @@ func loginPOSTHandler(ctx *gin.Context) {

apiConfig.initialize()

defer util.CloseIdleHTTPConnections(apiConfig.httpClient)

apiConfig.challenge = loginChallenge

auth, err := initializeAuthLogin(ctx)
Expand Down Expand Up @@ -2214,6 +2218,8 @@ func consentGETHandler(ctx *gin.Context) {

apiConfig.initialize()

defer util.CloseIdleHTTPConnections(apiConfig.httpClient)

apiConfig.challenge = consentChallenge
apiConfig.csrfToken = ctx.GetString(global.CtxCSRFTokenKey)

Expand Down Expand Up @@ -2456,6 +2462,8 @@ func consentPOSTHandler(ctx *gin.Context) {

apiConfig.initialize()

defer util.CloseIdleHTTPConnections(apiConfig.httpClient)

apiConfig.challenge = consentChallenge

apiConfig.consentRequest, httpResponse, err = apiConfig.apiClient.OAuth2API.GetOAuth2ConsentRequest(ctx).ConsentChallenge(
Expand Down Expand Up @@ -2564,6 +2572,8 @@ func logoutGETHandler(ctx *gin.Context) {

apiConfig.initialize()

defer util.CloseIdleHTTPConnections(apiConfig.httpClient)

apiConfig.challenge = logoutChallenge
apiConfig.csrfToken = ctx.GetString(global.CtxCSRFTokenKey)

Expand Down Expand Up @@ -2712,6 +2722,8 @@ func logoutPOSTHandler(ctx *gin.Context) {

apiConfig.initialize()

defer util.CloseIdleHTTPConnections(apiConfig.httpClient)

apiConfig.challenge = logoutChallenge

apiConfig.logoutRequest, httpResponse, err = apiConfig.apiClient.OAuth2API.GetOAuth2LogoutRequest(ctx).LogoutChallenge(
Expand Down
10 changes: 7 additions & 3 deletions server/lualib/action/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func (aw *Worker) loadScript(luaAction *LuaScriptAction, scriptName string, scri
// dynamic loader function.
//
// Note that this documentation assumes familiarity with the Lua programming language and its module system.
func (aw *Worker) registerDynamicLoader(L *lua.LState, httpRequest *http.Request) {
func (aw *Worker) registerDynamicLoader(L *lua.LState, httpRequest *http.Request) (httpClient *http.Client) {
dynamicLoader := L.NewFunction(func(L *lua.LState) int {
modName := L.CheckString(1)

Expand All @@ -229,13 +229,15 @@ func (aw *Worker) registerDynamicLoader(L *lua.LState, httpRequest *http.Request
return 0
}

lualib.RegisterCommonLuaLibraries(L, modName, registry)
httpClient = lualib.RegisterCommonLuaLibraries(L, modName, registry)
aw.registerModule(L, httpRequest, modName, registry)

return 0
})

L.SetGlobal("dynamic_loader", dynamicLoader)

return httpClient
}

// registerModule registers a Lua module in the given Lua state.
Expand Down Expand Up @@ -326,7 +328,9 @@ func (aw *Worker) handleRequest(httpRequest *http.Request) {

defer L.Close()

aw.registerDynamicLoader(L, httpRequest)
httpClient := aw.registerDynamicLoader(L, httpRequest)

defer util.CloseIdleHTTPConnections(httpClient)

logs := new(lualib.CustomLogKeyValue)

Expand Down
13 changes: 10 additions & 3 deletions server/lualib/callback/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package callback

import (
"context"
"net/http"
"sync"
"time"

Expand All @@ -25,6 +26,7 @@ import (
"github.com/croessner/nauthilus/server/global"
"github.com/croessner/nauthilus/server/log"
"github.com/croessner/nauthilus/server/lualib"
"github.com/croessner/nauthilus/server/util"
"github.com/gin-gonic/gin"
"github.com/go-kit/log/level"
"github.com/spf13/viper"
Expand Down Expand Up @@ -157,7 +159,7 @@ func setupLogging(L *lua.LState) *lua.LTable {
// Note: The implementation of the dynamic loader function is not shown in this
// documentation. Please refer to the source code for more details on the
// implementation of the dynamic loader function.
func registerDynamicLoader(L *lua.LState, ctx *gin.Context) {
func registerDynamicLoader(L *lua.LState, ctx *gin.Context) (httpClient *http.Client) {
dynamicLoader := L.NewFunction(func(L *lua.LState) int {
modName := L.CheckString(1)

Expand All @@ -166,13 +168,15 @@ func registerDynamicLoader(L *lua.LState, ctx *gin.Context) {
return 0
}

lualib.RegisterCommonLuaLibraries(L, modName, registry)
httpClient = lualib.RegisterCommonLuaLibraries(L, modName, registry)
registerModule(L, ctx, modName, registry)

return 0
})

L.SetGlobal("dynamic_loader", dynamicLoader)

return httpClient
}

// registerModule registers a Lua module in the provided Lua state (L).
Expand Down Expand Up @@ -222,7 +226,10 @@ func RunCallbackLuaRequest(ctx *gin.Context) (err error) {

defer L.Close()

registerDynamicLoader(L, ctx)
httpClient := registerDynamicLoader(L, ctx)

defer util.CloseIdleHTTPConnections(httpClient)

L.SetContext(luaCtx)

logTable := setupLogging(L)
Expand Down
12 changes: 9 additions & 3 deletions server/lualib/feature/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"context"
stderrors "errors"
"fmt"
"net/http"
"sync"
"time"

Expand Down Expand Up @@ -166,7 +167,7 @@ type Request struct {
// - ctx *gin.Context: the gin Context containing the request data
//
// Returns: none
func (r *Request) registerDynamicLoader(L *lua.LState, ctx *gin.Context) {
func (r *Request) registerDynamicLoader(L *lua.LState, ctx *gin.Context) (httpClient *http.Client) {
dynamicLoader := L.NewFunction(func(L *lua.LState) int {
modName := L.CheckString(1)

Expand All @@ -175,13 +176,15 @@ func (r *Request) registerDynamicLoader(L *lua.LState, ctx *gin.Context) {
return 0
}

lualib.RegisterCommonLuaLibraries(L, modName, registry)
httpClient = lualib.RegisterCommonLuaLibraries(L, modName, registry)
r.registerModule(L, ctx, modName, registry)

return 0
})

L.SetGlobal("dynamic_loader", dynamicLoader)

return httpClient
}

// registerModule registers a module in the LuaState based on the provided module name.
Expand Down Expand Up @@ -233,7 +236,10 @@ func (r *Request) CallFeatureLua(ctx *gin.Context) (triggered bool, abortFeature

defer L.Close()

r.registerDynamicLoader(L, ctx)
httpClient := r.registerDynamicLoader(L, ctx)

defer util.CloseIdleHTTPConnections(httpClient)

r.setGlobals(L)

request := r.setRequest(L)
Expand Down
12 changes: 9 additions & 3 deletions server/lualib/filter/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package filter
import (
"context"
stderrors "errors"
"net/http"
"sync"
"time"

Expand Down Expand Up @@ -47,7 +48,7 @@ import (
// Then, it calls registerModule to register module-specific libraries based on the module name.
// After registering the libraries, it sets the global variable "dynamic_loader" in the Lua state to the created function.
// The function does not return any value.
func registerDynamicLoader(L *lua.LState, ctx *gin.Context, r *Request, backendResult **lualib.LuaBackendResult, removeAttributes *[]string) {
func registerDynamicLoader(L *lua.LState, ctx *gin.Context, r *Request, backendResult **lualib.LuaBackendResult, removeAttributes *[]string) (httpClient *http.Client) {
dynamicLoader := L.NewFunction(func(L *lua.LState) int {
modName := L.CheckString(1)

Expand All @@ -56,13 +57,15 @@ func registerDynamicLoader(L *lua.LState, ctx *gin.Context, r *Request, backendR
return 0
}

lualib.RegisterCommonLuaLibraries(L, modName, registry)
httpClient = lualib.RegisterCommonLuaLibraries(L, modName, registry)
registerModule(L, ctx, r, modName, registry, backendResult, removeAttributes)

return 0
})

L.SetGlobal("dynamic_loader", dynamicLoader)

return httpClient
}

// registerModule registers a Lua module based on the given modName. It loads and preloads the respective Lua functions
Expand Down Expand Up @@ -614,7 +617,10 @@ func (r *Request) CallFilterLua(ctx *gin.Context) (action bool, backendResult *l

defer L.Close()

registerDynamicLoader(L, ctx, r, &backendResult, &removeAttributes)
httpClient := registerDynamicLoader(L, ctx, r, &backendResult, &removeAttributes)

defer util.CloseIdleHTTPConnections(httpClient)

lualib.RegisterBackendResultType(L, global.LuaBackendResultAttributes)
setGlobals(r, L)

Expand Down
9 changes: 6 additions & 3 deletions server/lualib/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ import (
// Please refer to the individual module documentations for more details on each Preload function.
// Please also note that the declaration codes for the constants used in the switch cases are not shown here.
// Refer to the module documentations for the declaration codes of the constants.
func RegisterCommonLuaLibraries(L *lua.LState, modName string, registry map[string]bool) {
func RegisterCommonLuaLibraries(L *lua.LState, modName string, registry map[string]bool) (httpClient *stdhttp.Client) {
switch modName {
case global.LuaModGLLPlugin:
plugin.Preload(L)
Expand Down Expand Up @@ -147,7 +147,7 @@ func RegisterCommonLuaLibraries(L *lua.LState, modName string, registry map[stri
case global.LuaModGLuaCrypto:
gluacrypto.Preload(L)
case global.LuaModGLuaHTTP:
httpClient := &stdhttp.Client{
httpClient = &stdhttp.Client{
Timeout: 60 * stdtime.Second,
Transport: &stdhttp.Transport{
TLSClientConfig: &tls.Config{
Expand All @@ -163,7 +163,8 @@ func RegisterCommonLuaLibraries(L *lua.LState, modName string, registry map[stri
case global.LuaModRedis:
L.PreloadModule(modName, redislib.LoaderModRedis)
case global.LuaModMail:
mailModule := NewMailModule(&smtp.EmailClient{})
smtpClient := &smtp.EmailClient{}
mailModule := NewMailModule(smtpClient)

L.PreloadModule(modName, mailModule.Loader)
case global.LuaModMisc:
Expand All @@ -173,4 +174,6 @@ func RegisterCommonLuaLibraries(L *lua.LState, modName string, registry map[stri
}

registry[modName] = true

return
}
11 changes: 11 additions & 0 deletions server/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"hash"
"net"
"net/http"
"regexp"
"runtime"
"strings"
Expand Down Expand Up @@ -516,3 +517,13 @@ func NewDNSResolver() (resolver *net.Resolver) {

return
}

// CloseIdleHTTPConnections closes any idle connections used by the provided HTTP client.
// If the client is nil or the transport is not of type *http.Transport, it does nothing.
func CloseIdleHTTPConnections(httpClient *http.Client) {
if httpClient != nil {
if transport, ok := httpClient.Transport.(*http.Transport); ok {
transport.CloseIdleConnections()
}
}
}

0 comments on commit e083fe8

Please sign in to comment.