From ec66910b38c1fb4627ffa2b692939a99077d1290 Mon Sep 17 00:00:00 2001 From: Christian Roessner Date: Wed, 23 Oct 2024 17:03:25 +0200 Subject: [PATCH] Feat: Refactor Lua-Go value conversion and optimize Redis script handling Simplified Lua-Go value conversion by merging functions and reducing redundant code. Implemented a unified Redis script execution mechanism to improve efficiency and maintainability in `haveibeenpwnd.lua` and register Redis Lua functions in their own module. Signed-off-by: Christian Roessner --- server/backend/ldap.go | 2 +- server/backend/lua.go | 4 +- server/global/const.go | 3 + .../lua-plugins.d/actions/haveibeenpwnd.lua | 27 +-- server/lualib/backendresult.go | 5 +- server/lualib/context.go | 36 +--- server/lualib/convert/convert.go | 173 +++++------------- server/lualib/redislib/register.go | 1 + server/lualib/redislib/scripts.go | 63 +++++++ 9 files changed, 136 insertions(+), 178 deletions(-) create mode 100644 server/lualib/redislib/scripts.go diff --git a/server/backend/ldap.go b/server/backend/ldap.go index 5e28458f..2aab1afd 100644 --- a/server/backend/ldap.go +++ b/server/backend/ldap.go @@ -1667,7 +1667,7 @@ func processReply(L *lua.LState, ldapReplyChan chan *LDAPReply) int { convertedMap[key] = list } - resultTable := convert.MapToLuaTable(L, convertedMap) + resultTable := convert.GoToLuaValue(L, convertedMap) if resultTable == nil { L.Push(lua.LString("no result")) diff --git a/server/backend/lua.go b/server/backend/lua.go index f9ccc189..2f7e5471 100644 --- a/server/backend/lua.go +++ b/server/backend/lua.go @@ -351,7 +351,7 @@ func executeAndHandleError(compiledScript *lua.FunctionProto, luaCommand string, // with the custom logs to the LuaReplyChan channel. // // If the Lua request function is LuaCommandListAccounts, the function expects the return value -// to be a Lua table. The function converts the table to a map using the LuaTableToMap function, +// to be a Lua table. The function converts the table to a map using the LuaValueToGo function, // assigns it to the Attributes field of a new LuaBackendResult, and sends it to the LuaReplyChan channel. // // For all other Lua request functions, the function sends an empty LuaBackendResult with the custom logs @@ -390,7 +390,7 @@ func handleReturnTypes(L *lua.LState, nret int, luaRequest *LuaRequest, logs *lu case global.LuaCommandListAccounts: luaRequest.LuaReplyChan <- &lualib.LuaBackendResult{ - Attributes: convert.LuaTableToMap(L.ToTable(-1)), + Attributes: convert.LuaValueToGo(L.ToTable(-1)).(map[any]any), Logs: logs, } diff --git a/server/global/const.go b/server/global/const.go index 355a19e3..0a903f93 100644 --- a/server/global/const.go +++ b/server/global/const.go @@ -1199,6 +1199,9 @@ const ( // LuaFnRedisSCard represents a Lua function that returns the number of elements in a Redis set. LuaFnRedisSCard = "redis_scard" + // LuaFnRedisRunScript is the constant used to denote the operation for running a Lua script in Redis. + LuaFnRedisRunScript = "redis_run_script" + // LuaFnApplyBackendResult applies changes to the backend result from a former authentication process. LuaFnApplyBackendResult = "apply_backend_result" diff --git a/server/lua-plugins.d/actions/haveibeenpwnd.lua b/server/lua-plugins.d/actions/haveibeenpwnd.lua index e100e569..8817f918 100644 --- a/server/lua-plugins.d/actions/haveibeenpwnd.lua +++ b/server/lua-plugins.d/actions/haveibeenpwnd.lua @@ -45,9 +45,6 @@ function nauthilus_call_action(request) dynamic_loader("nauthilus_mail") local nauthilus_mail = require("nauthilus_mail") - dynamic_loader("nauthilus_misc") - local nauthilus_misc = require("nauthilus_misc") - dynamic_loader("nauthilus_context") local nauthilus_context = require("nauthilus_context") @@ -69,8 +66,6 @@ function nauthilus_call_action(request) dynamic_loader("nauthilus_gll_template") local template = require("template") - nauthilus_misc.wait_random(500, 3000) - local redis_key = "ntc:HAVEIBEENPWND:" .. crypto.md5(request.account) local hash = string.lower(crypto.sha1(request.password)) @@ -126,10 +121,23 @@ function nauthilus_call_action(request) nauthilus_context.context_set(N .. "_hash_info", hash:sub(1, 5) .. cmp_hash[2]) nauthilus_builtin.custom_log_add(N .. "_action", "leaked") - local already_sent_mail, err_redis_hget2 = nauthilus_redis.redis_hget(redis_key, "send_mail") - nauthilus_util.if_error_raise(err_redis_hget2) + local script = [[ + local redis_key = KEYS[1] + local send_mail = redis.call('HGET', redis_key, 'send_mail') - if already_sent_mail == "" then + if send_mail == false then + redis.call('HSET', redis_key, 'send_mail', '1') + + return {'send_email', redis_key} + else + return {'email_already_sent'} + end + ]] + + local script_result, err_run_script = nauthilus_redis.redis_run_script(script, { redis_key }) + nauthilus_util.if_error_raise(err_run_script) + + if script_result[1] == "send_mail" then local smtp_use_lmtp = os.environ("SMTP_USE_LMTP") local smtp_server = os.environ("SMTP_SERVER") local smtp_port = os.environ("SMTP_PORT") @@ -167,9 +175,6 @@ function nauthilus_call_action(request) }) nauthilus_util.if_error_raise(err_smtp) - _, err_redis_hset = nauthilus_redis.redis_hset(redis_key, "send_mail", 1) - nauthilus_util.if_error_raise(err_redis_hset) - _, err_redis_expire = nauthilus_redis.redis_expire(redis_key, 86400) nauthilus_util.if_error_raise(err_redis_expire) diff --git a/server/lualib/backendresult.go b/server/lualib/backendresult.go index 143e7301..789134d1 100644 --- a/server/lualib/backendresult.go +++ b/server/lualib/backendresult.go @@ -249,13 +249,12 @@ func backendResultGetSetAttributes(L *lua.LState) int { backendResult := checkBackendResult(L) if L.GetTop() == 2 { - // XXX: We expect keys to be strings! - backendResult.Attributes = convert.LuaTableToMap(L.CheckTable(2)) + backendResult.Attributes = convert.LuaValueToGo(L.CheckTable(2)).(map[any]any) return 0 } - L.Push(convert.MapToLuaTable(L, backendResult.Attributes)) + L.Push(convert.GoToLuaValue(L, backendResult.Attributes)) return 1 } diff --git a/server/lualib/context.go b/server/lualib/context.go index 912eb83d..b32ba908 100644 --- a/server/lualib/context.go +++ b/server/lualib/context.go @@ -16,14 +16,11 @@ package lualib import ( - "fmt" "sync" "time" "github.com/croessner/nauthilus/server/global" - "github.com/croessner/nauthilus/server/log" "github.com/croessner/nauthilus/server/lualib/convert" - "github.com/go-kit/log/level" lua "github.com/yuin/gopher-lua" ) @@ -125,20 +122,9 @@ func (c *Context) Value(_ any) lua.LValue { func ContextSet(ctx *Context) lua.LGFunction { return func(L *lua.LState) int { key := L.CheckString(1) + value := L.CheckAny(2) - switch value := L.Get(2).(type) { - case lua.LString: - ctx.Set(key, string(value)) - case lua.LBool: - ctx.Set(key, bool(value)) - case lua.LNumber: - ctx.Set(key, float64(value)) - case *lua.LTable: - ctx.Set(key, convert.LuaTableToMap(value)) - default: - level.Warn(log.Logger).Log( - global.LogKeyWarning, fmt.Sprintf("Lua key='%v' value='%v' unsupported", key, value)) - } + ctx.Set(key, convert.LuaValueToGo(value)) return 0 } @@ -149,23 +135,9 @@ func ContextSet(ctx *Context) lua.LGFunction { func ContextGet(ctx *Context) lua.LGFunction { return func(L *lua.LState) int { key := L.CheckString(1) + value := ctx.Get(key) - switch value := ctx.Get(key).(type) { - case string: - L.Push(lua.LString(value)) - case bool: - L.Push(lua.LBool(value)) - case float64: - L.Push(lua.LNumber(value)) - case map[any]any: - L.Push(convert.MapToLuaTable(L, value)) - case nil: - L.Push(lua.LNil) - default: - level.Warn(log.Logger).Log( - global.LogKeyWarning, fmt.Sprintf("Lua key='%v' value='%v' unsupported", key, value)) - L.Push(lua.LNil) - } + L.Push(convert.GoToLuaValue(L, value)) return 1 } diff --git a/server/lualib/convert/convert.go b/server/lualib/convert/convert.go index fef8c57b..b4c832b4 100644 --- a/server/lualib/convert/convert.go +++ b/server/lualib/convert/convert.go @@ -113,155 +113,70 @@ func StringCmd(value *redis.StringCmd, valType string, L *lua.LState) error { return nil } -// GoToLuaValue converts a Go value to a corresponding Lua value. -// It accepts an argument 'value' of type 'any' and returns a value of type 'lua.LValue'. -// If the input is a string, it returns a Lua string value (lua.LString). -// If the input is a float64 or an int, it returns a Lua number value (lua.LNumber). -// If the input is a boolean, it returns a Lua boolean value (lua.LBool). -// For any other types, it converts the value to a string and returns a Lua string value (lua.LString). -// The function uses the fmt.Sprintf method to convert values of any type to a string. -// This function is useful for converting Go values to their equivalent Lua values. -// The function is not safe for concurrent use. +// GoToLuaValue converts a Go value to a corresponding Lua value suitable for Lua state operations. func GoToLuaValue(L *lua.LState, value any) lua.LValue { switch v := value.(type) { case string: return lua.LString(v) case float64: return lua.LNumber(v) - case int: + case int64: return lua.LNumber(v) case bool: return lua.LBool(v) - case map[any]any: - return MapToLuaTable(L, v) - default: - return lua.LString(fmt.Sprintf("%v", value)) - } -} - -// LuaTableToMap takes a lua.LTable as input and converts it into a map[any]any. -// The function iterates over each key-value pair in the table and converts the keys and values -// into their corresponding Go types. The converted key-value pairs are then added to a new map, which is -// returned as the result. -// If the input table is nil, the function returns nil. -func LuaTableToMap(table *lua.LTable) map[any]any { - if table == nil { - return nil - } - - result := make(map[any]any) + case []any: + tbl := L.NewTable() - table.ForEach(func(key lua.LValue, value lua.LValue) { - var ( - mapKey any - mapValue any - ) - - switch k := key.(type) { - case lua.LBool: - mapKey = bool(k) - case lua.LNumber: - mapKey = float64(k) - case lua.LString: - mapKey = k.String() - default: - return - } - - switch v := value.(type) { - case lua.LBool: - mapValue = bool(v) - case lua.LNumber: - mapValue = float64(v) - case *lua.LTable: - mapValue = LuaTableToMap(v) - default: - mapValue = v.String() + for _, item := range v { + tbl.Append(GoToLuaValue(L, item)) } - result[mapKey] = mapValue - }) - - return result -} + return tbl + case map[string]any: + tbl := L.NewTable() -// MapToLuaTable takes an *lua.LState and a map[any]any as input and converts it into a *lua.LTable. -// The function iterates over each key-value pair in the map and converts the keys and values -// into their corresponding lua.LValue types. The converted key-value pairs are then added to a new *lua.LTable, -// which is returned as the result. -// If the input map is nil, the function returns nil. -func MapToLuaTable(L *lua.LState, table map[any]any) *lua.LTable { - var ( - key lua.LValue - value lua.LValue - ) - - lTable := L.NewTable() - - if table == nil { - return nil - } - - for k, v := range table { - switch mapKey := k.(type) { - case bool: - key = lua.LBool(mapKey) - case float64: - key = lua.LNumber(mapKey) - case string: - key = lua.LString(mapKey) - default: - return nil + for k, item := range v { + tbl.RawSetString(k, GoToLuaValue(L, item)) } - switch mapValue := v.(type) { - case bool: - value = lua.LBool(mapValue) - case float64: - value = lua.LNumber(mapValue) - case string: - value = lua.LString(mapValue) - case []any: - value = SliceToLuaTable(L, mapValue) // convert []any to *lua.LTable - case map[any]any: - value = MapToLuaTable(L, mapValue) - default: - return nil + return tbl + case map[any]any: + tbl := L.NewTable() + + for k, item := range v { + tbl.RawSet(GoToLuaValue(L, k), GoToLuaValue(L, item)) } - L.RawSet(lTable, key, value) + return tbl + case nil: + return lua.LNil + default: + return lua.LString(fmt.Sprintf("%v", value)) } - - return lTable } -// SliceToLuaTable converts a slice into a Lua table using the provided Lua state. -// It accepts two parameters: -// - L: a pointer to the Lua state -// - slice: a slice of type `any` -// -// For each value in the slice, the function checks the type of the value. -// If the value is a boolean, it sets the Lua table's element at index `i+1` to a Lua boolean with the same value. -// If the value is a float64, it sets the Lua table's element at index `i+1` to a Lua number with the same value. -// If the value is a string, it sets the Lua table's element at the index `i+1` to a Lua string with the same value. -// -// If the value is of any other type, the function returns nil. -// -// Finally, the function returns a pointer to a Lua table that contains all valid values from the slice. -func SliceToLuaTable(L *lua.LState, slice []any) *lua.LTable { - lTable := L.NewTable() - for i, v := range slice { - switch sliceValue := v.(type) { - case bool: - L.RawSetInt(lTable, i+1, lua.LBool(sliceValue)) - case float64: - L.RawSetInt(lTable, i+1, lua.LNumber(sliceValue)) - case string: - L.RawSetInt(lTable, i+1, lua.LString(sliceValue)) - default: - return nil - } +// LuaValueToGo converts a lua.LValue to a corresponding Go value (nil, bool, float64, string, or map). +func LuaValueToGo(value lua.LValue) any { + if value == lua.LNil { + return nil } - return lTable + switch v := value.(type) { + case lua.LBool: + return bool(v) + case lua.LNumber: + return float64(v) + case lua.LString: + return v.String() + case *lua.LTable: + table := make(map[any]any) + + v.ForEach(func(key lua.LValue, v2 lua.LValue) { + table[LuaValueToGo(key)] = LuaValueToGo(v2) + }) + + return table + default: + return v.String() + } } diff --git a/server/lualib/redislib/register.go b/server/lualib/redislib/register.go index dd54812c..bf48e01e 100644 --- a/server/lualib/redislib/register.go +++ b/server/lualib/redislib/register.go @@ -67,6 +67,7 @@ var exportsModRedis = map[string]lua.LGFunction{ global.LuaFnRedisSMembers: RedisSMembers, global.LuaFnRedisSRem: RedisSRem, global.LuaFnRedisSCard: RedisSCard, + global.LuaFnRedisRunScript: RedisRunScript, } // LoaderModRedis initializes a new module for Redis in Lua by setting the functions from the "exportsModRedis" map into diff --git a/server/lualib/redislib/scripts.go b/server/lualib/redislib/scripts.go new file mode 100644 index 00000000..74fa0c62 --- /dev/null +++ b/server/lualib/redislib/scripts.go @@ -0,0 +1,63 @@ +package redislib + +import ( + "github.com/croessner/nauthilus/server/lualib/convert" + "github.com/croessner/nauthilus/server/rediscli" + "github.com/yuin/gopher-lua" +) + +// executeRedisScript executes a given Lua script on the Redis server with specified keys and arguments. +func executeRedisScript(script string, keys []string, args ...any) (any, error) { + evalArgs := make([]any, len(keys)+len(args)) + + for i, key := range keys { + evalArgs[i] = key + } + + for i, arg := range args { + evalArgs[len(keys)+i] = arg + } + + result, err := rediscli.WriteHandle.Eval(ctx, script, keys, evalArgs...).Result() + if err != nil { + return nil, err + } + + return result, nil +} + +// RedisRunScript executes a Redis script with the provided keys and arguments, returning the result or an error as Lua values. +// It expects three arguments: the script string, a table of keys, and a table of arguments. It returns two values: an error message (or nil) and the script result (or nil). +func RedisRunScript(L *lua.LState) int { + var ( + keyList []string + argsList []any + ) + + script := L.CheckString(1) + keys := L.ToTable(2) + args := L.ToTable(3) + + keys.ForEach(func(k, v lua.LValue) { + keyList = append(keyList, v.String()) + }) + + args.ForEach(func(k, v lua.LValue) { + argsList = append(argsList, v.String()) + }) + + result, err := executeRedisScript(script, keyList, argsList...) + if err != nil { + L.Push(lua.LString(err.Error())) + L.Push(lua.LNil) + + return 2 + } + + lResult := convert.GoToLuaValue(L, result) + + L.Push(lua.LNil) + L.Push(lResult) + + return 2 +}