diff --git a/server/backend/ldap.go b/server/backend/ldap.go index d28c5537..904daa2f 100644 --- a/server/backend/ldap.go +++ b/server/backend/ldap.go @@ -1196,7 +1196,10 @@ func (l *LDAPConnection) modifyAdd(ldapRequest *LDAPRequest) (err error) { // Then it sets the state of the connection to global.LDAPStateFree. // Finally, it unlocks the state of the connection using the ldapPool.conn[index].Mu.Unlock() method. func sendLDAPReplyAndUnlockState[T PoolRequest[T]](ldapPool *LDAPPool, index int, request T, ldapReply *LDAPReply) { - request.GetLDAPReplyChan() <- ldapReply + select { + case request.GetLDAPReplyChan() <- ldapReply: + default: + } ldapPool.conn[index].Mu.Lock() @@ -1294,7 +1297,10 @@ func (l *LDAPPool) proccessLookupRequest(index int, ldapRequest *LDAPRequest, ld ldapReply := &LDAPReply{} if ldapReply.Err = l.checkConnection(ldapRequest.GUID, index); ldapReply.Err != nil { - ldapRequest.LDAPReplyChan <- ldapReply + select { + case ldapRequest.LDAPReplyChan <- ldapReply: + default: + } return } @@ -1356,7 +1362,10 @@ func LDAPMainWorker(ctx context.Context) { case ldapRequest := <-LDAPRequestChan: // Check that we have enough idle connections. if err := ldapPool.setIdleConnections(true); err != nil { - ldapRequest.LDAPReplyChan <- &LDAPReply{Err: err} + select { + case ldapRequest.LDAPReplyChan <- &LDAPReply{Err: err}: + default: + } } ldapPool.handleLookupRequest(ldapRequest, &ldapWaitGroup) @@ -1405,7 +1414,10 @@ func (l *LDAPPool) processAuthRequest(index int, ldapAuthRequest *LDAPAuthReques ldapReply := &LDAPReply{} if ldapReply.Err = l.checkConnection(ldapAuthRequest.GUID, index); ldapReply.Err != nil { - ldapAuthRequest.LDAPReplyChan <- ldapReply + select { + case ldapAuthRequest.LDAPReplyChan <- ldapReply: + default: + } return } @@ -1460,7 +1472,10 @@ func LDAPAuthWorker(ctx context.Context) { case ldapAuthRequest := <-LDAPAuthRequestChan: // Check that we have enough idle connections. if err := ldapPool.setIdleConnections(false); err != nil { - ldapAuthRequest.LDAPReplyChan <- &LDAPReply{Err: err} + select { + case ldapAuthRequest.LDAPReplyChan <- &LDAPReply{Err: err}: + default: + } } ldapPool.handleAuthRequest(ldapAuthRequest, &ldapWaitGroup) @@ -1512,7 +1527,13 @@ func LuaLDAPSearch(ctx context.Context) lua.LGFunction { ldapRequest := createLDAPRequest(fieldValues, scope, ctx) - LDAPRequestChan <- ldapRequest + select { + case LDAPRequestChan <- ldapRequest: + default: + L.RaiseError(errors.ErrClosedChannel.Error()) + + return 1 + } return processReply(L, ldapRequest.GetLDAPReplyChan()) } diff --git a/server/backend/lua.go b/server/backend/lua.go index 220df3d8..1a4a548b 100644 --- a/server/backend/lua.go +++ b/server/backend/lua.go @@ -403,8 +403,11 @@ func processError(err error, luaRequest *LuaRequest, logs *lualib.CustomLogKeyVa global.LogKeyError, err, ) - luaRequest.LuaReplyChan <- &lualib.LuaBackendResult{ + select { + case luaRequest.LuaReplyChan <- &lualib.LuaBackendResult{ Err: err, Logs: logs, + }: + default: } } diff --git a/server/core/ldap.go b/server/core/ldap.go index 36d7f62f..a24d7354 100644 --- a/server/core/ldap.go +++ b/server/core/ldap.go @@ -175,7 +175,11 @@ func ldapPassDB(auth *AuthState) (passDBResult *PassDBResult, err error) { } // Find user with account status enabled - backend.LDAPRequestChan <- ldapRequest + select { + case backend.LDAPRequestChan <- ldapRequest: + default: + return passDBResult, errors.ErrClosedChannel + } ldapReply = <-ldapReplyChan @@ -234,7 +238,11 @@ func ldapPassDB(auth *AuthState) (passDBResult *PassDBResult, err error) { HTTPClientContext: auth.HTTPClientContext, } - backend.LDAPAuthRequestChan <- ldapUserBindRequest + select { + case backend.LDAPAuthRequestChan <- ldapUserBindRequest: + default: + return passDBResult, errors.ErrClosedChannel + } ldapReply = <-ldapReplyChan @@ -333,7 +341,11 @@ func ldapAccountDB(auth *AuthState) (accounts AccountList, err error) { } // Find user with account status enabled - backend.LDAPRequestChan <- ldapRequest + select { + case backend.LDAPRequestChan <- ldapRequest: + default: + return accounts, errors.ErrClosedChannel + } ldapReply = <-ldapReplyChan @@ -424,7 +436,11 @@ func ldapAddTOTPSecret(auth *AuthState, totp *TOTPSecret) (err error) { ldapRequest.ModifyAttributes = make(backend.LDAPModifyAttributes, 2) ldapRequest.ModifyAttributes[configField] = []string{totp.getValue()} - backend.LDAPRequestChan <- ldapRequest + select { + case backend.LDAPRequestChan <- ldapRequest: + default: + return errors.ErrClosedChannel + } ldapReply = <-ldapReplyChan diff --git a/server/core/lua.go b/server/core/lua.go index c882d778..6452c2ba 100644 --- a/server/core/lua.go +++ b/server/core/lua.go @@ -18,6 +18,7 @@ package core import ( "github.com/croessner/nauthilus/server/backend" "github.com/croessner/nauthilus/server/config" + "github.com/croessner/nauthilus/server/errors" "github.com/croessner/nauthilus/server/global" "github.com/croessner/nauthilus/server/lualib" "github.com/croessner/nauthilus/server/stats" @@ -87,7 +88,11 @@ func luaPassDB(auth *AuthState) (passDBResult *PassDBResult, err error) { }, } - backend.LuaRequestChan <- luaRequest + select { + case backend.LuaRequestChan <- luaRequest: + default: + return passDBResult, errors.ErrClosedChannel + } luaBackendResult = <-luaReplyChan @@ -175,7 +180,11 @@ func luaAccountDB(auth *AuthState) (accounts AccountList, err error) { }, } - backend.LuaRequestChan <- luaRequest + select { + case backend.LuaRequestChan <- luaRequest: + default: + return accounts, errors.ErrClosedChannel + } luaBackendResult = <-luaReplyChan @@ -229,7 +238,11 @@ func luaAddTOTPSecret(auth *AuthState, totp *TOTPSecret) (err error) { }, } - backend.LuaRequestChan <- luaRequest + select { + case backend.LuaRequestChan <- luaRequest: + default: + return errors.ErrClosedChannel + } luaBackendResult = <-luaReplyChan diff --git a/server/errors/errors.go b/server/errors/errors.go index 725b5a34..36b1eba8 100644 --- a/server/errors/errors.go +++ b/server/errors/errors.go @@ -144,6 +144,7 @@ var ( // common. var ( + ErrClosedChannel = errors.New("channel closed") ErrNoPassDBResult = errors.New("no pass Database result") ErrUnknownCause = errors.New("something went wrong") )