From 6a6466575fc9e75dc6ed233babfb80872a719dc8 Mon Sep 17 00:00:00 2001 From: Christian Roessner Date: Fri, 13 Sep 2024 17:02:21 +0200 Subject: [PATCH] Fix: Handle closed channels in backend request handlers Added non-blocking select statements to channel operations in LDAP and Lua backend handlers. This ensures that attempts to write to closed channels do not cause goroutines to hang, and returns appropriate errors where necessary. Signed-off-by: Christian Roessner --- server/backend/ldap.go | 33 +++++++++++++++++++++++++++------ server/backend/lua.go | 5 ++++- server/core/ldap.go | 24 ++++++++++++++++++++---- server/core/lua.go | 19 ++++++++++++++++--- server/errors/errors.go | 1 + 5 files changed, 68 insertions(+), 14 deletions(-) diff --git a/server/backend/ldap.go b/server/backend/ldap.go index d28c5537..904daa2f 100644 --- a/server/backend/ldap.go +++ b/server/backend/ldap.go @@ -1196,7 +1196,10 @@ func (l *LDAPConnection) modifyAdd(ldapRequest *LDAPRequest) (err error) { // Then it sets the state of the connection to global.LDAPStateFree. // Finally, it unlocks the state of the connection using the ldapPool.conn[index].Mu.Unlock() method. func sendLDAPReplyAndUnlockState[T PoolRequest[T]](ldapPool *LDAPPool, index int, request T, ldapReply *LDAPReply) { - request.GetLDAPReplyChan() <- ldapReply + select { + case request.GetLDAPReplyChan() <- ldapReply: + default: + } ldapPool.conn[index].Mu.Lock() @@ -1294,7 +1297,10 @@ func (l *LDAPPool) proccessLookupRequest(index int, ldapRequest *LDAPRequest, ld ldapReply := &LDAPReply{} if ldapReply.Err = l.checkConnection(ldapRequest.GUID, index); ldapReply.Err != nil { - ldapRequest.LDAPReplyChan <- ldapReply + select { + case ldapRequest.LDAPReplyChan <- ldapReply: + default: + } return } @@ -1356,7 +1362,10 @@ func LDAPMainWorker(ctx context.Context) { case ldapRequest := <-LDAPRequestChan: // Check that we have enough idle connections. if err := ldapPool.setIdleConnections(true); err != nil { - ldapRequest.LDAPReplyChan <- &LDAPReply{Err: err} + select { + case ldapRequest.LDAPReplyChan <- &LDAPReply{Err: err}: + default: + } } ldapPool.handleLookupRequest(ldapRequest, &ldapWaitGroup) @@ -1405,7 +1414,10 @@ func (l *LDAPPool) processAuthRequest(index int, ldapAuthRequest *LDAPAuthReques ldapReply := &LDAPReply{} if ldapReply.Err = l.checkConnection(ldapAuthRequest.GUID, index); ldapReply.Err != nil { - ldapAuthRequest.LDAPReplyChan <- ldapReply + select { + case ldapAuthRequest.LDAPReplyChan <- ldapReply: + default: + } return } @@ -1460,7 +1472,10 @@ func LDAPAuthWorker(ctx context.Context) { case ldapAuthRequest := <-LDAPAuthRequestChan: // Check that we have enough idle connections. if err := ldapPool.setIdleConnections(false); err != nil { - ldapAuthRequest.LDAPReplyChan <- &LDAPReply{Err: err} + select { + case ldapAuthRequest.LDAPReplyChan <- &LDAPReply{Err: err}: + default: + } } ldapPool.handleAuthRequest(ldapAuthRequest, &ldapWaitGroup) @@ -1512,7 +1527,13 @@ func LuaLDAPSearch(ctx context.Context) lua.LGFunction { ldapRequest := createLDAPRequest(fieldValues, scope, ctx) - LDAPRequestChan <- ldapRequest + select { + case LDAPRequestChan <- ldapRequest: + default: + L.RaiseError(errors.ErrClosedChannel.Error()) + + return 1 + } return processReply(L, ldapRequest.GetLDAPReplyChan()) } diff --git a/server/backend/lua.go b/server/backend/lua.go index 220df3d8..1a4a548b 100644 --- a/server/backend/lua.go +++ b/server/backend/lua.go @@ -403,8 +403,11 @@ func processError(err error, luaRequest *LuaRequest, logs *lualib.CustomLogKeyVa global.LogKeyError, err, ) - luaRequest.LuaReplyChan <- &lualib.LuaBackendResult{ + select { + case luaRequest.LuaReplyChan <- &lualib.LuaBackendResult{ Err: err, Logs: logs, + }: + default: } } diff --git a/server/core/ldap.go b/server/core/ldap.go index 36d7f62f..a24d7354 100644 --- a/server/core/ldap.go +++ b/server/core/ldap.go @@ -175,7 +175,11 @@ func ldapPassDB(auth *AuthState) (passDBResult *PassDBResult, err error) { } // Find user with account status enabled - backend.LDAPRequestChan <- ldapRequest + select { + case backend.LDAPRequestChan <- ldapRequest: + default: + return passDBResult, errors.ErrClosedChannel + } ldapReply = <-ldapReplyChan @@ -234,7 +238,11 @@ func ldapPassDB(auth *AuthState) (passDBResult *PassDBResult, err error) { HTTPClientContext: auth.HTTPClientContext, } - backend.LDAPAuthRequestChan <- ldapUserBindRequest + select { + case backend.LDAPAuthRequestChan <- ldapUserBindRequest: + default: + return passDBResult, errors.ErrClosedChannel + } ldapReply = <-ldapReplyChan @@ -333,7 +341,11 @@ func ldapAccountDB(auth *AuthState) (accounts AccountList, err error) { } // Find user with account status enabled - backend.LDAPRequestChan <- ldapRequest + select { + case backend.LDAPRequestChan <- ldapRequest: + default: + return accounts, errors.ErrClosedChannel + } ldapReply = <-ldapReplyChan @@ -424,7 +436,11 @@ func ldapAddTOTPSecret(auth *AuthState, totp *TOTPSecret) (err error) { ldapRequest.ModifyAttributes = make(backend.LDAPModifyAttributes, 2) ldapRequest.ModifyAttributes[configField] = []string{totp.getValue()} - backend.LDAPRequestChan <- ldapRequest + select { + case backend.LDAPRequestChan <- ldapRequest: + default: + return errors.ErrClosedChannel + } ldapReply = <-ldapReplyChan diff --git a/server/core/lua.go b/server/core/lua.go index c882d778..6452c2ba 100644 --- a/server/core/lua.go +++ b/server/core/lua.go @@ -18,6 +18,7 @@ package core import ( "github.com/croessner/nauthilus/server/backend" "github.com/croessner/nauthilus/server/config" + "github.com/croessner/nauthilus/server/errors" "github.com/croessner/nauthilus/server/global" "github.com/croessner/nauthilus/server/lualib" "github.com/croessner/nauthilus/server/stats" @@ -87,7 +88,11 @@ func luaPassDB(auth *AuthState) (passDBResult *PassDBResult, err error) { }, } - backend.LuaRequestChan <- luaRequest + select { + case backend.LuaRequestChan <- luaRequest: + default: + return passDBResult, errors.ErrClosedChannel + } luaBackendResult = <-luaReplyChan @@ -175,7 +180,11 @@ func luaAccountDB(auth *AuthState) (accounts AccountList, err error) { }, } - backend.LuaRequestChan <- luaRequest + select { + case backend.LuaRequestChan <- luaRequest: + default: + return accounts, errors.ErrClosedChannel + } luaBackendResult = <-luaReplyChan @@ -229,7 +238,11 @@ func luaAddTOTPSecret(auth *AuthState, totp *TOTPSecret) (err error) { }, } - backend.LuaRequestChan <- luaRequest + select { + case backend.LuaRequestChan <- luaRequest: + default: + return errors.ErrClosedChannel + } luaBackendResult = <-luaReplyChan diff --git a/server/errors/errors.go b/server/errors/errors.go index 725b5a34..36b1eba8 100644 --- a/server/errors/errors.go +++ b/server/errors/errors.go @@ -144,6 +144,7 @@ var ( // common. var ( + ErrClosedChannel = errors.New("channel closed") ErrNoPassDBResult = errors.New("no pass Database result") ErrUnknownCause = errors.New("something went wrong") )