Skip to content

Commit

Permalink
Merge pull request #140 from croessner/features
Browse files Browse the repository at this point in the history
Feat: Refactor Lua-Go value conversion and optimize Redis script hand…
  • Loading branch information
croessner authored Oct 23, 2024
2 parents 35ad02d + ec66910 commit c2ec615
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 178 deletions.
2 changes: 1 addition & 1 deletion server/backend/ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ func processReply(L *lua.LState, ldapReplyChan chan *LDAPReply) int {
convertedMap[key] = list
}

resultTable := convert.MapToLuaTable(L, convertedMap)
resultTable := convert.GoToLuaValue(L, convertedMap)

if resultTable == nil {
L.Push(lua.LString("no result"))
Expand Down
4 changes: 2 additions & 2 deletions server/backend/lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ func executeAndHandleError(compiledScript *lua.FunctionProto, luaCommand string,
// with the custom logs to the LuaReplyChan channel.
//
// If the Lua request function is LuaCommandListAccounts, the function expects the return value
// to be a Lua table. The function converts the table to a map using the LuaTableToMap function,
// to be a Lua table. The function converts the table to a map using the LuaValueToGo function,
// assigns it to the Attributes field of a new LuaBackendResult, and sends it to the LuaReplyChan channel.
//
// For all other Lua request functions, the function sends an empty LuaBackendResult with the custom logs
Expand Down Expand Up @@ -390,7 +390,7 @@ func handleReturnTypes(L *lua.LState, nret int, luaRequest *LuaRequest, logs *lu

case global.LuaCommandListAccounts:
luaRequest.LuaReplyChan <- &lualib.LuaBackendResult{
Attributes: convert.LuaTableToMap(L.ToTable(-1)),
Attributes: convert.LuaValueToGo(L.ToTable(-1)).(map[any]any),
Logs: logs,
}

Expand Down
3 changes: 3 additions & 0 deletions server/global/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,9 @@ const (
// LuaFnRedisSCard represents a Lua function that returns the number of elements in a Redis set.
LuaFnRedisSCard = "redis_scard"

// LuaFnRedisRunScript is the constant used to denote the operation for running a Lua script in Redis.
LuaFnRedisRunScript = "redis_run_script"

// LuaFnApplyBackendResult applies changes to the backend result from a former authentication process.
LuaFnApplyBackendResult = "apply_backend_result"

Expand Down
27 changes: 16 additions & 11 deletions server/lua-plugins.d/actions/haveibeenpwnd.lua
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ function nauthilus_call_action(request)
dynamic_loader("nauthilus_mail")
local nauthilus_mail = require("nauthilus_mail")

dynamic_loader("nauthilus_misc")
local nauthilus_misc = require("nauthilus_misc")

dynamic_loader("nauthilus_context")
local nauthilus_context = require("nauthilus_context")

Expand All @@ -69,8 +66,6 @@ function nauthilus_call_action(request)
dynamic_loader("nauthilus_gll_template")
local template = require("template")

nauthilus_misc.wait_random(500, 3000)

local redis_key = "ntc:HAVEIBEENPWND:" .. crypto.md5(request.account)
local hash = string.lower(crypto.sha1(request.password))

Expand Down Expand Up @@ -126,10 +121,23 @@ function nauthilus_call_action(request)
nauthilus_context.context_set(N .. "_hash_info", hash:sub(1, 5) .. cmp_hash[2])
nauthilus_builtin.custom_log_add(N .. "_action", "leaked")

local already_sent_mail, err_redis_hget2 = nauthilus_redis.redis_hget(redis_key, "send_mail")
nauthilus_util.if_error_raise(err_redis_hget2)
local script = [[
local redis_key = KEYS[1]
local send_mail = redis.call('HGET', redis_key, 'send_mail')
if already_sent_mail == "" then
if send_mail == false then
redis.call('HSET', redis_key, 'send_mail', '1')
return {'send_email', redis_key}
else
return {'email_already_sent'}
end
]]

local script_result, err_run_script = nauthilus_redis.redis_run_script(script, { redis_key })
nauthilus_util.if_error_raise(err_run_script)

if script_result[1] == "send_mail" then
local smtp_use_lmtp = os.environ("SMTP_USE_LMTP")
local smtp_server = os.environ("SMTP_SERVER")
local smtp_port = os.environ("SMTP_PORT")
Expand Down Expand Up @@ -167,9 +175,6 @@ function nauthilus_call_action(request)
})
nauthilus_util.if_error_raise(err_smtp)

_, err_redis_hset = nauthilus_redis.redis_hset(redis_key, "send_mail", 1)
nauthilus_util.if_error_raise(err_redis_hset)

_, err_redis_expire = nauthilus_redis.redis_expire(redis_key, 86400)
nauthilus_util.if_error_raise(err_redis_expire)

Expand Down
5 changes: 2 additions & 3 deletions server/lualib/backendresult.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,12 @@ func backendResultGetSetAttributes(L *lua.LState) int {
backendResult := checkBackendResult(L)

if L.GetTop() == 2 {
// XXX: We expect keys to be strings!
backendResult.Attributes = convert.LuaTableToMap(L.CheckTable(2))
backendResult.Attributes = convert.LuaValueToGo(L.CheckTable(2)).(map[any]any)

return 0
}

L.Push(convert.MapToLuaTable(L, backendResult.Attributes))
L.Push(convert.GoToLuaValue(L, backendResult.Attributes))

return 1
}
36 changes: 4 additions & 32 deletions server/lualib/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@
package lualib

import (
"fmt"
"sync"
"time"

"github.com/croessner/nauthilus/server/global"
"github.com/croessner/nauthilus/server/log"
"github.com/croessner/nauthilus/server/lualib/convert"
"github.com/go-kit/log/level"
lua "github.com/yuin/gopher-lua"
)

Expand Down Expand Up @@ -125,20 +122,9 @@ func (c *Context) Value(_ any) lua.LValue {
func ContextSet(ctx *Context) lua.LGFunction {
return func(L *lua.LState) int {
key := L.CheckString(1)
value := L.CheckAny(2)

switch value := L.Get(2).(type) {
case lua.LString:
ctx.Set(key, string(value))
case lua.LBool:
ctx.Set(key, bool(value))
case lua.LNumber:
ctx.Set(key, float64(value))
case *lua.LTable:
ctx.Set(key, convert.LuaTableToMap(value))
default:
level.Warn(log.Logger).Log(
global.LogKeyWarning, fmt.Sprintf("Lua key='%v' value='%v' unsupported", key, value))
}
ctx.Set(key, convert.LuaValueToGo(value))

return 0
}
Expand All @@ -149,23 +135,9 @@ func ContextSet(ctx *Context) lua.LGFunction {
func ContextGet(ctx *Context) lua.LGFunction {
return func(L *lua.LState) int {
key := L.CheckString(1)
value := ctx.Get(key)

switch value := ctx.Get(key).(type) {
case string:
L.Push(lua.LString(value))
case bool:
L.Push(lua.LBool(value))
case float64:
L.Push(lua.LNumber(value))
case map[any]any:
L.Push(convert.MapToLuaTable(L, value))
case nil:
L.Push(lua.LNil)
default:
level.Warn(log.Logger).Log(
global.LogKeyWarning, fmt.Sprintf("Lua key='%v' value='%v' unsupported", key, value))
L.Push(lua.LNil)
}
L.Push(convert.GoToLuaValue(L, value))

return 1
}
Expand Down
173 changes: 44 additions & 129 deletions server/lualib/convert/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,155 +113,70 @@ func StringCmd(value *redis.StringCmd, valType string, L *lua.LState) error {
return nil
}

// GoToLuaValue converts a Go value to a corresponding Lua value.
// It accepts an argument 'value' of type 'any' and returns a value of type 'lua.LValue'.
// If the input is a string, it returns a Lua string value (lua.LString).
// If the input is a float64 or an int, it returns a Lua number value (lua.LNumber).
// If the input is a boolean, it returns a Lua boolean value (lua.LBool).
// For any other types, it converts the value to a string and returns a Lua string value (lua.LString).
// The function uses the fmt.Sprintf method to convert values of any type to a string.
// This function is useful for converting Go values to their equivalent Lua values.
// The function is not safe for concurrent use.
// GoToLuaValue converts a Go value to a corresponding Lua value suitable for Lua state operations.
func GoToLuaValue(L *lua.LState, value any) lua.LValue {
switch v := value.(type) {
case string:
return lua.LString(v)
case float64:
return lua.LNumber(v)
case int:
case int64:
return lua.LNumber(v)
case bool:
return lua.LBool(v)
case map[any]any:
return MapToLuaTable(L, v)
default:
return lua.LString(fmt.Sprintf("%v", value))
}
}

// LuaTableToMap takes a lua.LTable as input and converts it into a map[any]any.
// The function iterates over each key-value pair in the table and converts the keys and values
// into their corresponding Go types. The converted key-value pairs are then added to a new map, which is
// returned as the result.
// If the input table is nil, the function returns nil.
func LuaTableToMap(table *lua.LTable) map[any]any {
if table == nil {
return nil
}

result := make(map[any]any)
case []any:
tbl := L.NewTable()

table.ForEach(func(key lua.LValue, value lua.LValue) {
var (
mapKey any
mapValue any
)

switch k := key.(type) {
case lua.LBool:
mapKey = bool(k)
case lua.LNumber:
mapKey = float64(k)
case lua.LString:
mapKey = k.String()
default:
return
}

switch v := value.(type) {
case lua.LBool:
mapValue = bool(v)
case lua.LNumber:
mapValue = float64(v)
case *lua.LTable:
mapValue = LuaTableToMap(v)
default:
mapValue = v.String()
for _, item := range v {
tbl.Append(GoToLuaValue(L, item))
}

result[mapKey] = mapValue
})

return result
}
return tbl
case map[string]any:
tbl := L.NewTable()

// MapToLuaTable takes an *lua.LState and a map[any]any as input and converts it into a *lua.LTable.
// The function iterates over each key-value pair in the map and converts the keys and values
// into their corresponding lua.LValue types. The converted key-value pairs are then added to a new *lua.LTable,
// which is returned as the result.
// If the input map is nil, the function returns nil.
func MapToLuaTable(L *lua.LState, table map[any]any) *lua.LTable {
var (
key lua.LValue
value lua.LValue
)

lTable := L.NewTable()

if table == nil {
return nil
}

for k, v := range table {
switch mapKey := k.(type) {
case bool:
key = lua.LBool(mapKey)
case float64:
key = lua.LNumber(mapKey)
case string:
key = lua.LString(mapKey)
default:
return nil
for k, item := range v {
tbl.RawSetString(k, GoToLuaValue(L, item))
}

switch mapValue := v.(type) {
case bool:
value = lua.LBool(mapValue)
case float64:
value = lua.LNumber(mapValue)
case string:
value = lua.LString(mapValue)
case []any:
value = SliceToLuaTable(L, mapValue) // convert []any to *lua.LTable
case map[any]any:
value = MapToLuaTable(L, mapValue)
default:
return nil
return tbl
case map[any]any:
tbl := L.NewTable()

for k, item := range v {
tbl.RawSet(GoToLuaValue(L, k), GoToLuaValue(L, item))
}

L.RawSet(lTable, key, value)
return tbl
case nil:
return lua.LNil
default:
return lua.LString(fmt.Sprintf("%v", value))
}

return lTable
}

// SliceToLuaTable converts a slice into a Lua table using the provided Lua state.
// It accepts two parameters:
// - L: a pointer to the Lua state
// - slice: a slice of type `any`
//
// For each value in the slice, the function checks the type of the value.
// If the value is a boolean, it sets the Lua table's element at index `i+1` to a Lua boolean with the same value.
// If the value is a float64, it sets the Lua table's element at index `i+1` to a Lua number with the same value.
// If the value is a string, it sets the Lua table's element at the index `i+1` to a Lua string with the same value.
//
// If the value is of any other type, the function returns nil.
//
// Finally, the function returns a pointer to a Lua table that contains all valid values from the slice.
func SliceToLuaTable(L *lua.LState, slice []any) *lua.LTable {
lTable := L.NewTable()
for i, v := range slice {
switch sliceValue := v.(type) {
case bool:
L.RawSetInt(lTable, i+1, lua.LBool(sliceValue))
case float64:
L.RawSetInt(lTable, i+1, lua.LNumber(sliceValue))
case string:
L.RawSetInt(lTable, i+1, lua.LString(sliceValue))
default:
return nil
}
// LuaValueToGo converts a lua.LValue to a corresponding Go value (nil, bool, float64, string, or map).
func LuaValueToGo(value lua.LValue) any {
if value == lua.LNil {
return nil
}

return lTable
switch v := value.(type) {
case lua.LBool:
return bool(v)
case lua.LNumber:
return float64(v)
case lua.LString:
return v.String()
case *lua.LTable:
table := make(map[any]any)

v.ForEach(func(key lua.LValue, v2 lua.LValue) {
table[LuaValueToGo(key)] = LuaValueToGo(v2)
})

return table
default:
return v.String()
}
}
1 change: 1 addition & 0 deletions server/lualib/redislib/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ var exportsModRedis = map[string]lua.LGFunction{
global.LuaFnRedisSMembers: RedisSMembers,
global.LuaFnRedisSRem: RedisSRem,
global.LuaFnRedisSCard: RedisSCard,
global.LuaFnRedisRunScript: RedisRunScript,
}

// LoaderModRedis initializes a new module for Redis in Lua by setting the functions from the "exportsModRedis" map into
Expand Down
Loading

0 comments on commit c2ec615

Please sign in to comment.