Skip to content

Commit

Permalink
Fix: Refactor session handling and error handling.
Browse files Browse the repository at this point in the history
Revised session validation logic, removed redundant code, and replaced hardcoded protocol checks with a lookup table for better maintainability. Introduced a new error message "ErrFilterFailed" to provide more specific feedback on filter execution failures. These updates aim to improve code readability, error traceability, and system robustness.

Signed-off-by: Christian Roessner <c@roessner.co>
  • Loading branch information
Christian Roessner committed Dec 9, 2024
1 parent 143b813 commit a67369c
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 94 deletions.
1 change: 1 addition & 0 deletions server/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ var (
ErrNoFiltersDefined = errors.New("no filters defined")
ErrFilterLuaNameMissing = errors.New("filter 'name' sttribute missing")
ErrFilterLuaScriptPathEmpty = errors.New("filter 'script_path' attribute missing")
ErrFilterFailed = errors.New("filter failed")
)

// misc.
Expand Down
152 changes: 70 additions & 82 deletions server/lua-plugins.d/filters/monitoring.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,24 @@
dynamic_loader("nauthilus_backend")
local nauthilus_backend = require("nauthilus_backend")

local N = "monitoring"

local wanted_protocols = {
"imap", "imapa", "pop3", "pop3s", "lmtp", "lmtps",
"sieve", -- Not sure about this
local N = "director"

local WANTED_PROTOCOLS = {
imap = true,
imapa = true,
pop3 = true,
pop3s = true,
lmtp = true,
lmtps = true,
sieve = true,
}

function nauthilus_call_filter(request)
local skip_and_accept_filter = false

-- Dovecot userdb request
if request.authenticated and request.no_auth then
skip_and_accept_filter = true
end

-- Dovecot passdb request
if request.authenticated and not request.no_auth then
skip_and_accept_filter = true

for _, proto in ipairs(wanted_protocols) do
if proto == request.protocol then
skip_and_accept_filter = false

break
end
end
if not request.authenticated then
return nauthilus_builtin.FILTER_REJECT, nauthilus_builtin.FILTER_RESULT_OK
end

if skip_and_accept_filter then
nauthilus_backend.remove_from_backend_result({ "Proxy-Host" })

if not WANTED_PROTOCOLS[request.protocol] then
return nauthilus_builtin.FILTER_ACCEPT, nauthilus_builtin.FILTER_RESULT_OK
end

Expand All @@ -66,6 +53,7 @@ function nauthilus_call_filter(request)

dynamic_loader("nauthilus_redis")
local nauthilus_redis = require("nauthilus_redis")
local redis_key = "ntc:DS:" .. request.account

local custom_pool = "default"
local custom_pool_name = os.getenv("CUSTOM_REDIS_POOL_NAME")
Expand All @@ -76,7 +64,7 @@ function nauthilus_call_filter(request)
nauthilus_util.if_error_raise(err_redis_client)
end

local function set_initial_expiry(redis_key)
local function set_initial_expiry()
local length, err_redis_hlen = nauthilus_redis.redis_hlen(custom_pool, redis_key)
if err_redis_hlen then
if err_redis_hlen ~= "redis: nil" then
Expand All @@ -90,30 +78,25 @@ function nauthilus_call_filter(request)
end
end

dynamic_loader("nauthilus_gluacrypto")
local crypto = require("crypto")

local function add_session(session, server)
if session == nil then
return
end
local function invalidate_stale_sessions()
local _, err_redis_hdel = nauthilus_redis.redis_del(custom_pool, redis_key)

local redis_key = "ntc:DS:" .. crypto.md5(request.account)
nauthilus_util.if_error_raise(err_redis_hdel)
end

local function add_session(session, server)
local _, err_redis_hset = nauthilus_redis.redis_hset(custom_pool, redis_key, session, server)
if err_redis_hset then
nauthilus_builtin.custom_log_add(N .. "_redis_hset_error", err_redis_hset)

return
end

set_initial_expiry(redis_key)
set_initial_expiry()
nauthilus_builtin.custom_log_add(N .. "_dovecot_session", session)
end

local function get_server_from_sessions(session)
local redis_key = "ntc:DS:" .. crypto.md5(request.account)

local server_from_session, err_redis_hget = nauthilus_redis.redis_hget(custom_pool, redis_key, session)
if err_redis_hget then
if err_redis_hget ~= "redis: nil" then
Expand Down Expand Up @@ -143,70 +126,75 @@ function nauthilus_call_filter(request)
return nil
end

-- Only look for backend servers, if a user was authenticated (passdb requests)
if request.authenticated and not request.no_auth then
local num_of_bs = 0
local function preprocess_backend_servers(backend_servers)
local valid_servers = {}

local backend_servers = nauthilus_backend.get_backend_servers()
if nauthilus_util.is_table(backend_servers) then
num_of_bs = nauthilus_util.table_length(backend_servers)

local server_host = ""
local new_server_host = ""

local session = get_dovecot_session()
if session then
local maybe_server = get_server_from_sessions(session)
if maybe_server then
server_host = maybe_server
end
for _, server in ipairs(backend_servers) do
if server.protocol == request.protocol then
table.insert(valid_servers, server)
end
end

if num_of_bs > 0 then
local attributes = {}

local b = nauthilus_backend_result.new()
return valid_servers
end

for _, server in ipairs(backend_servers) do
new_server_host = server.host
local server_host
local session = get_dovecot_session()

if server_host == new_server_host then
attributes["Proxy-Host"] = server_host
if session then
local valid_servers = preprocess_backend_servers(nauthilus_backend.get_backend_servers())
local num_of_bs = nauthilus_util.table_length(valid_servers)

add_session(session, server_host)
nauthilus_builtin.custom_log_add(N .. "_backend_server_current", server_host)
if num_of_bs > 0 then
local maybe_server = get_server_from_sessions(session)

b:attributes(attributes)
nauthilus_backend.apply_backend_result(b)
if maybe_server then
for _, server in ipairs(valid_servers) do
if server.host == maybe_server then
server_host = maybe_server

break
end
end

if server_host ~= new_server_host then
-- Put your own logic here to select a proper server for the user. In this demo, the last server
-- available is always used.
attributes["Proxy-Host"] = new_server_host
if not server_host then
invalidate_stale_sessions()

add_session(session, new_server_host)
nauthilus_builtin.custom_log_add(N .. "_backend_server_new", new_server_host)

b:attributes(attributes)
nauthilus_backend.apply_backend_result(b)
server_host = valid_servers[math.random(1, num_of_bs)].host
end
else
server_host = valid_servers[math.random(1, num_of_bs)].host
end
end

if num_of_bs == 0 then
nauthilus_builtin.custom_log_add(N .. "_backend_server", "failed")
nauthilus_builtin.status_message_set("No backend servers are available")
if server_host then
local backend_result = nauthilus_backend_result.new()
local attributes = {}

add_session(session, server_host)

return nauthilus_builtin.FILTER_ACCEPT, nauthilus_builtin.FILTER_RESULT_FAIL
local expected_server = get_server_from_sessions(session)

-- Another client might have been faster at the same point in time...
if expected_server and server_host ~= expected_server then
server_host = expected_server
end

attributes["Proxy-Host"] = server_host

nauthilus_builtin.custom_log_add(N .. "_backend_server", server_host)

backend_result:attributes(attributes)
nauthilus_backend.apply_backend_result(backend_result)
end
end

return nauthilus_builtin.FILTER_ACCEPT, nauthilus_builtin.FILTER_RESULT_OK
if server_host == nil then
nauthilus_builtin.custom_log_add(N .. "_backend_server", "failed")
nauthilus_builtin.status_message_set("No backend servers are available")

return nauthilus_builtin.FILTER_ACCEPT, nauthilus_builtin.FILTER_RESULT_FAIL
end

-- Anything else must be a rejected request
return nauthilus_builtin.FILTER_REJECT, nauthilus_builtin.FILTER_RESULT_OK
return nauthilus_builtin.FILTER_ACCEPT, nauthilus_builtin.FILTER_RESULT_OK
end
5 changes: 1 addition & 4 deletions server/lua-plugins.d/hooks/dovecot-session-cleaner.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ local nauthilus_redis = require("nauthilus_redis")
dynamic_loader("nauthilus_http_request")
local nauthilus_http_request = require("nauthilus_http_request")

dynamic_loader("nauthilus_gluacrypto")
local crypto = require("crypto")

dynamic_loader("nauthilus_gll_json")
local json = require("json")

Expand Down Expand Up @@ -108,7 +105,7 @@ function nauthilus_run_hook(logging, session)

if result.category == "service:imap" or result.category == "service:pop3" or result.category == "service:lmtp" or result.category == "service:sieve" then
if result.dovecot_session ~= "unknown" then
local redis_key = "ntc:DS:" .. crypto.md5(result.user)
local redis_key = "ntc:DS:" .. result.user

if is_cmd_noop then
result.cmd = "NOOP"
Expand Down
20 changes: 12 additions & 8 deletions server/lualib/filter/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package filter
import (
"context"
stderrors "errors"
"fmt"
"net/http"
"sync"
"time"
Expand Down Expand Up @@ -445,6 +446,8 @@ func setRequest(r *Request, L *lua.LState) *lua.LTable {
// It also calls the Lua function with the given parameters and logs the result.
// The function will return a boolean indicating whether the Lua function was called successfully, and an error if any occurred.
func executeScriptWithinContext(request *lua.LTable, script *LuaFilter, r *Request, ctx *gin.Context, L *lua.LState) (bool, error) {
var err error

stopTimer := stats.PrometheusTimer(definitions.PromFilter, script.Name)

if stopTimer != nil {
Expand Down Expand Up @@ -486,11 +489,15 @@ func executeScriptWithinContext(request *lua.LTable, script *LuaFilter, r *Reque

logResult(r, script, action, result)

if result != 0 {
err = fmt.Errorf("%v: %s", errors.ErrFilterFailed, script.Name)
}

if action {
return true, nil
return true, err
}

return false, nil
return false, err
}

// logError is a function that logs error information when a LuaFilter script fails during a Request session.
Expand All @@ -516,12 +523,9 @@ func logResult(r *Request, script *LuaFilter, action bool, ret int) {
"result", resultMap[ret],
}

if ret != 0 {

if r.Logs != nil {
for index := range *r.Logs {
logs = append(logs, (*r.Logs)[index])
}
if r.Logs != nil {
for index := range *r.Logs {
logs = append(logs, (*r.Logs)[index])
}
}

Expand Down

0 comments on commit a67369c

Please sign in to comment.