Skip to content

Commit

Permalink
Add function to remove attributes from backend result
Browse files Browse the repository at this point in the history
A new function 'removeFromBackendResult' is added to the 'filter.go' file which allows for the removal of attributes from the backend result set. It is implemented in the Lua library and integrated into the 'CallFilterLua' function. Modifications were also done to accommodate this change in 'server/core/auth.go', and the relevant constant was added in 'global/const.go'. Minor adjustments in the authentication operation mode and local cache TTL configuration were also done.

Signed-off-by: Christian Roessner <c@roessner.co>
  • Loading branch information
Christian Roessner committed Jul 12, 2024
1 parent 05f7e5f commit 698feeb
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 15 deletions.
6 changes: 3 additions & 3 deletions server/config/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,9 @@ func (c *Config) setConfigMaxActionWorkers() {
// Otherwise, the field will be assigned the value of 5 seconds.
// Please note that this method does not return errors.
func (c *Config) setLocalCacheTTL() {
if val := viper.GetDuration("local_cache_auth_ttl"); val > 5*time.Second {
if val < time.Hour {
c.LocalCacheAuthTTL = val
if val := viper.GetDuration("local_cache_auth_ttl"); val*time.Second > 5*time.Second {
if val*time.Second < time.Hour {
c.LocalCacheAuthTTL = val * time.Second
} else {
c.LocalCacheAuthTTL = time.Hour
}
Expand Down
11 changes: 10 additions & 1 deletion server/core/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,7 @@ func (a *Authentication) filterLua(passDBResult *PassDBResult, ctx *gin.Context)
},
}

filterResult, luaBackendResult, err := filterRequest.CallFilterLua(ctx)
filterResult, luaBackendResult, removeAttributes, err := filterRequest.CallFilterLua(ctx)
if err != nil {
if !errors.Is(err, errors2.ErrNoFiltersDefined) {
level.Error(logging.DefaultErrLogger).Log(global.LogKeyGUID, a.GUID, global.LogKeyError, err.Error())
Expand All @@ -1703,6 +1703,10 @@ func (a *Authentication) filterLua(passDBResult *PassDBResult, ctx *gin.Context)
return global.AuthResultFail
}

for _, attributeName := range removeAttributes {
delete(a.Attributes, attributeName)
}

if luaBackendResult != nil {
// XXX: We currently only support changing attributes from the Authentication object.
if (*luaBackendResult).Attributes != nil {
Expand Down Expand Up @@ -1849,6 +1853,11 @@ func (a *Authentication) getUserAccountFromRedis() (accountName string, err erro
func (a *Authentication) setOperationMode(ctx *gin.Context) {
guid := ctx.GetString(global.CtxGUIDKey)

// We reset flags, because they might have been cached in the in-memory cahce.
a.NoAuth = false
a.ListAccounts = false
a.MonitoringFlags = []global.Monitoring{}

switch ctx.Query("mode") {
case "no-auth":
util.DebugModule(global.DbgAuth, global.LogKeyGUID, guid, global.LogKeyMsg, "mode=no-auth")
Expand Down
33 changes: 33 additions & 0 deletions server/global/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,36 @@ const (
LuaCommandAddMFAValue
)

// LuaFnCtxSet represents the function name for "context_set" in Lua
// LuaFnCtxGet represents the function name for "context_get" in Lua
// LuaFnCtxDelete represents the function name for "context_delete" in Lua
// LuaFnAddCustomLog represents the function name for "custom_log_add" in Lua
// LuaFnBackendVerifyPassword represents the function name for "nauthilus_backend_verify_password" in Lua
// LuaFnBackendListAccounts represents the function name for "nauthilus_backend_list_accounts" in Lua
// LuaFnBackendAddTOTPSecret represents the function name for "nauthilus_backend_add_totp" in Lua
// LuaModUtil represents the module name for "nauthilus_util" in Lua
// LuaFnCallFeature represents the function name for "nauthilus_call_feature" in Lua
// LuaFnCallAction represents the function name for "nauthilus_call_action" in Lua
// LuaFnCallFilter represents the function name for "nauthilus_call_filter" in Lua
// LuaFnRunCallback represents the constant string "nauthilus_run_callback".
// LuaFnGetBackendServers represents the Lua function name "get_backend_servers" that retrieves the backend servers.
// LuaFnSelectBackendServer represents the constant used as the key for the Lua function "select_backend_server".
// LuaFnSetStatusMessage represents the Lua function name for setting the status message of a Lua request.
// LuaFnGetAllHTTPRequestHeaders represents the function name for "get_all_http_request_headers" in Lua
// LuaFnGetHTTPRequestHeader represents the function name for "get_http_request_header" in Lua
// LuaFnGetHTTPRequestBody represents the function name for "get_http_request_body" in Lua
// LuaFnRedisGet represents the function name for "redis_get_str" in Lua
// LuaFnRedisSet represents the function name for "redis_set_str" in Lua
// LuaFnRedisIncr represents a constant string identifier for the Lua function redis_incr.
// LuaFnRedisDel represents the function name for "redis_det" in Lua
// LuaFnRedisExpire represents the function name for "redis_expire" in Lua
// LuaFnRedisHGet represents the function name for "redis_hget" in Lua.
// LuaFnRedisHSet represents the function name for "redis_hset" in Lua
// LuaFnRedisHDel represents the function name for "redis_hdel" in Lua
// LuaFnRedisHLen represents the function name for "redis_hlen" in Lua.
// LuaFnRedisHGetAll represents the function name for "redis_hgetall" in Lua
// LuaFNRedisHIncrBy represents the function name for "redis_hincrby" in Lua.
// LuaFN
const (
// LuaFnCtxSet represents the function name for "context_set" in Lua
LuaFnCtxSet = "context_set"
Expand Down Expand Up @@ -1009,6 +1039,9 @@ const (
// LuaFnApplyBackendResult applies changes to the backend result from a former authentication process.
LuaFnApplyBackendResult = "apply_backend_result"

// LuaFnRemoveFromBackendResult represents the function to remove an attribute from the backend result set.
LuaFnRemoveFromBackendResult = "remove_from_backend_result"

// LuaFnCheckBackendConnection represents the Lua function name for checking the backend connection.
LuaFnCheckBackendConnection = "check_backend_connection"

Expand Down
11 changes: 4 additions & 7 deletions server/lua-plugins.d/filters/monitoring.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@ local N = "monitoring"

---@type table wanted_protocols
local wanted_protocols = {
[1] = "imap",
[2] = "imapa",
[3] = "pop3",
[4] = "pop3s",
[5] = "lmtp",
[6] = "lmtps",
[7] = "sieve", -- Not sure about this
"imap", "imapa", "pop3", "pop3s", "lmtp", "lmtps",
"sieve", -- Not sure about this
}

---@param request table
Expand Down Expand Up @@ -39,6 +34,8 @@ function nauthilus_call_filter(request)
end

if skip_and_accept_filter then
nauthilus.remove_from_backend_result({ "Proxy-Host" })

return nauthilus.FILTER_ACCEPT, nauthilus.FILTER_RESULT_OK
end

Expand Down
41 changes: 37 additions & 4 deletions server/lualib/filter/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,38 @@ func applyBackendResult(backendResult **lualib.LuaBackendResult) lua.LGFunction
}
}

// removeFromBackendResult is a function that creates and returns a Lua LGFunction.
// The LGFunction takes a Lua state as argument and modifies a slice (attributes)
// by appending values from a Lua table passed as argument to the LGFunction.
// The function returns 0, indicating no values are returned to Lua.
// If the attributes slice is nil, the function returns 0 immediately.
// The function extracts a Lua table from the Lua stack and iterates over its
// values. For each value, it appends its string representation to the attributes slice.
// Finally, the function returns 0 to Lua.
//
// Params:
//
// attributes *[]string : Pointer to a slice of strings to store the extracted attributes
//
// Returns:
//
// the LGFunction that takes a Lua state as argument and modifies the attributes slice
func removeFromBackendResult(attributes *[]string) lua.LGFunction {
return func(L *lua.LState) int {
if attributes == nil {
return 0
}

attributeTable := L.ToTable(1)

attributeTable.ForEach(func(_, value lua.LValue) {
*attributes = append(*attributes, value.String())
})

return 0
}
}

// setGlobals is a function that initializes a set of global variables in the provided lua.LState.
// The globals are set using the provided context (r) and lua table (globals).
// The following lua variables are set:
Expand All @@ -323,7 +355,7 @@ func applyBackendResult(backendResult **lualib.LuaBackendResult) lua.LGFunction
// Returns:
//
// A new request table
func setGlobals(ctx *gin.Context, r *Request, L *lua.LState, backendResult **lualib.LuaBackendResult) *lua.LTable {
func setGlobals(ctx *gin.Context, r *Request, L *lua.LState, backendResult **lualib.LuaBackendResult, removeAttributes *[]string) *lua.LTable {
r.Logs = new(lualib.CustomLogKeyValue)

globals := L.NewTable()
Expand All @@ -336,6 +368,7 @@ func setGlobals(ctx *gin.Context, r *Request, L *lua.LState, backendResult **lua
globals.RawSetString(global.LuaFnAddCustomLog, L.NewFunction(lualib.AddCustomLog(r.Logs)))
globals.RawSetString(global.LuaFnSetStatusMessage, L.NewFunction(lualib.SetStatusMessage(&r.StatusMessage)))
globals.RawSetString(global.LuaFnApplyBackendResult, L.NewFunction(applyBackendResult(backendResult)))
globals.RawSetString(global.LuaFnRemoveFromBackendResult, L.NewFunction(removeFromBackendResult(removeAttributes)))
globals.RawSetString(global.LuaFnGetAllHTTPRequestHeaders, L.NewFunction(lualib.GetAllHTTPRequestHeaders(ctx.Request)))
globals.RawSetString(global.LuaFnGetHTTPRequestHeader, L.NewFunction(lualib.GetHTTPRequestHeader(ctx.Request)))

Expand Down Expand Up @@ -470,9 +503,9 @@ func logResult(r *Request, script *LuaFilter, action bool, ret int) {
// executes successfully or all scripts have been attempted.
// If the context has been cancelled, the function returns without executing any more scripts.
// If a script returns an error, it is skipped and the next script is tried.
func (r *Request) CallFilterLua(ctx *gin.Context) (action bool, backendResult *lualib.LuaBackendResult, err error) {
func (r *Request) CallFilterLua(ctx *gin.Context) (action bool, backendResult *lualib.LuaBackendResult, removeAttributes []string, err error) {
if LuaFilters == nil || len(LuaFilters.LuaScripts) == 0 {
return false, nil, errors2.ErrNoFiltersDefined
return false, nil, nil, errors2.ErrNoFiltersDefined
}

LuaFilters.Mu.RLock()
Expand All @@ -484,7 +517,7 @@ func (r *Request) CallFilterLua(ctx *gin.Context) (action bool, backendResult *l
defer LuaPool.Put(L)
defer L.SetGlobal(global.LuaDefaultTable, lua.LNil)

globals := setGlobals(ctx, r, L, &backendResult)
globals := setGlobals(ctx, r, L, &backendResult, &removeAttributes)
request := setRequest(r, L)

for _, script := range LuaFilters.LuaScripts {
Expand Down

0 comments on commit 698feeb

Please sign in to comment.