Skip to content

Commit

Permalink
Fix: Handle closed channels in backend request handlers
Browse files Browse the repository at this point in the history
Added non-blocking select statements to channel operations in LDAP and Lua backend handlers. This ensures that attempts to write to closed channels do not cause goroutines to hang, and returns appropriate errors where necessary.

Signed-off-by: Christian Roessner <c@roessner.co>
  • Loading branch information
Christian Roessner committed Sep 13, 2024
1 parent 4072186 commit 6a64665
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 14 deletions.
33 changes: 27 additions & 6 deletions server/backend/ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,10 @@ func (l *LDAPConnection) modifyAdd(ldapRequest *LDAPRequest) (err error) {
// Then it sets the state of the connection to global.LDAPStateFree.
// Finally, it unlocks the state of the connection using the ldapPool.conn[index].Mu.Unlock() method.
func sendLDAPReplyAndUnlockState[T PoolRequest[T]](ldapPool *LDAPPool, index int, request T, ldapReply *LDAPReply) {
request.GetLDAPReplyChan() <- ldapReply
select {
case request.GetLDAPReplyChan() <- ldapReply:
default:
}

ldapPool.conn[index].Mu.Lock()

Expand Down Expand Up @@ -1294,7 +1297,10 @@ func (l *LDAPPool) proccessLookupRequest(index int, ldapRequest *LDAPRequest, ld
ldapReply := &LDAPReply{}

if ldapReply.Err = l.checkConnection(ldapRequest.GUID, index); ldapReply.Err != nil {
ldapRequest.LDAPReplyChan <- ldapReply
select {
case ldapRequest.LDAPReplyChan <- ldapReply:
default:
}

return
}
Expand Down Expand Up @@ -1356,7 +1362,10 @@ func LDAPMainWorker(ctx context.Context) {
case ldapRequest := <-LDAPRequestChan:
// Check that we have enough idle connections.
if err := ldapPool.setIdleConnections(true); err != nil {
ldapRequest.LDAPReplyChan <- &LDAPReply{Err: err}
select {
case ldapRequest.LDAPReplyChan <- &LDAPReply{Err: err}:
default:
}
}

ldapPool.handleLookupRequest(ldapRequest, &ldapWaitGroup)
Expand Down Expand Up @@ -1405,7 +1414,10 @@ func (l *LDAPPool) processAuthRequest(index int, ldapAuthRequest *LDAPAuthReques
ldapReply := &LDAPReply{}

if ldapReply.Err = l.checkConnection(ldapAuthRequest.GUID, index); ldapReply.Err != nil {
ldapAuthRequest.LDAPReplyChan <- ldapReply
select {
case ldapAuthRequest.LDAPReplyChan <- ldapReply:
default:
}

return
}
Expand Down Expand Up @@ -1460,7 +1472,10 @@ func LDAPAuthWorker(ctx context.Context) {
case ldapAuthRequest := <-LDAPAuthRequestChan:
// Check that we have enough idle connections.
if err := ldapPool.setIdleConnections(false); err != nil {
ldapAuthRequest.LDAPReplyChan <- &LDAPReply{Err: err}
select {
case ldapAuthRequest.LDAPReplyChan <- &LDAPReply{Err: err}:
default:
}
}

ldapPool.handleAuthRequest(ldapAuthRequest, &ldapWaitGroup)
Expand Down Expand Up @@ -1512,7 +1527,13 @@ func LuaLDAPSearch(ctx context.Context) lua.LGFunction {

ldapRequest := createLDAPRequest(fieldValues, scope, ctx)

LDAPRequestChan <- ldapRequest
select {
case LDAPRequestChan <- ldapRequest:
default:
L.RaiseError(errors.ErrClosedChannel.Error())

return 1
}

return processReply(L, ldapRequest.GetLDAPReplyChan())
}
Expand Down
5 changes: 4 additions & 1 deletion server/backend/lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,11 @@ func processError(err error, luaRequest *LuaRequest, logs *lualib.CustomLogKeyVa
global.LogKeyError, err,
)

luaRequest.LuaReplyChan <- &lualib.LuaBackendResult{
select {
case luaRequest.LuaReplyChan <- &lualib.LuaBackendResult{
Err: err,
Logs: logs,
}:
default:
}
}
24 changes: 20 additions & 4 deletions server/core/ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ func ldapPassDB(auth *AuthState) (passDBResult *PassDBResult, err error) {
}

// Find user with account status enabled
backend.LDAPRequestChan <- ldapRequest
select {
case backend.LDAPRequestChan <- ldapRequest:
default:
return passDBResult, errors.ErrClosedChannel
}

ldapReply = <-ldapReplyChan

Expand Down Expand Up @@ -234,7 +238,11 @@ func ldapPassDB(auth *AuthState) (passDBResult *PassDBResult, err error) {
HTTPClientContext: auth.HTTPClientContext,
}

backend.LDAPAuthRequestChan <- ldapUserBindRequest
select {
case backend.LDAPAuthRequestChan <- ldapUserBindRequest:
default:
return passDBResult, errors.ErrClosedChannel
}

ldapReply = <-ldapReplyChan

Expand Down Expand Up @@ -333,7 +341,11 @@ func ldapAccountDB(auth *AuthState) (accounts AccountList, err error) {
}

// Find user with account status enabled
backend.LDAPRequestChan <- ldapRequest
select {
case backend.LDAPRequestChan <- ldapRequest:
default:
return accounts, errors.ErrClosedChannel
}

ldapReply = <-ldapReplyChan

Expand Down Expand Up @@ -424,7 +436,11 @@ func ldapAddTOTPSecret(auth *AuthState, totp *TOTPSecret) (err error) {
ldapRequest.ModifyAttributes = make(backend.LDAPModifyAttributes, 2)
ldapRequest.ModifyAttributes[configField] = []string{totp.getValue()}

backend.LDAPRequestChan <- ldapRequest
select {
case backend.LDAPRequestChan <- ldapRequest:
default:
return errors.ErrClosedChannel
}

ldapReply = <-ldapReplyChan

Expand Down
19 changes: 16 additions & 3 deletions server/core/lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package core
import (
"github.com/croessner/nauthilus/server/backend"
"github.com/croessner/nauthilus/server/config"
"github.com/croessner/nauthilus/server/errors"
"github.com/croessner/nauthilus/server/global"
"github.com/croessner/nauthilus/server/lualib"
"github.com/croessner/nauthilus/server/stats"
Expand Down Expand Up @@ -87,7 +88,11 @@ func luaPassDB(auth *AuthState) (passDBResult *PassDBResult, err error) {
},
}

backend.LuaRequestChan <- luaRequest
select {
case backend.LuaRequestChan <- luaRequest:
default:
return passDBResult, errors.ErrClosedChannel
}

luaBackendResult = <-luaReplyChan

Expand Down Expand Up @@ -175,7 +180,11 @@ func luaAccountDB(auth *AuthState) (accounts AccountList, err error) {
},
}

backend.LuaRequestChan <- luaRequest
select {
case backend.LuaRequestChan <- luaRequest:
default:
return accounts, errors.ErrClosedChannel
}

luaBackendResult = <-luaReplyChan

Expand Down Expand Up @@ -229,7 +238,11 @@ func luaAddTOTPSecret(auth *AuthState, totp *TOTPSecret) (err error) {
},
}

backend.LuaRequestChan <- luaRequest
select {
case backend.LuaRequestChan <- luaRequest:
default:
return errors.ErrClosedChannel
}

luaBackendResult = <-luaReplyChan

Expand Down
1 change: 1 addition & 0 deletions server/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ var (
// common.

var (
ErrClosedChannel = errors.New("channel closed")
ErrNoPassDBResult = errors.New("no pass Database result")
ErrUnknownCause = errors.New("something went wrong")
)
Expand Down

0 comments on commit 6a64665

Please sign in to comment.