Skip to content

Commit

Permalink
Feat: Refactor authentication handling and naming conventions
Browse files Browse the repository at this point in the history
Renamed several functions and constants to improve code clarity around authentication. Updated comments and function logic to use more descriptive names, focusing on authentication workflows and handling across different service types.

Signed-off-by: Christian Roessner <c@roessner.co>
  • Loading branch information
Christian Roessner committed Nov 21, 2024
1 parent 8ade1a6 commit 1f08b09
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 94 deletions.
52 changes: 26 additions & 26 deletions server/core/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ type AuthState struct {
// StatusMessage is the HTTP response payload that is sent to the remote server that asked for authentication.
StatusMessage string

// Service is set by Nauthilus depending on the router endpoint. Look at httpQueryHandler for the structure of available
// Service is set by Nauthilus depending on the router endpoint. Look at requestHandler for the structure of available
// endpoints.
Service string

Expand Down Expand Up @@ -580,9 +580,9 @@ func (a *AuthState) authOK(ctx *gin.Context) {
switch a.Service {
case global.ServNginx:
setNginxHeaders(ctx, a)
case global.ServDovecot:
setDovecotHeaders(ctx, a)
case global.ServUserInfo, global.ServJSON:
case global.ServHeader:
setHeaderHeaders(ctx, a)
case global.ServJSON:
sendAuthResponse(ctx, a)
}

Expand All @@ -597,13 +597,13 @@ func (a *AuthState) authOK(ctx *gin.Context) {

// setCommonHeaders sets common headers for the given gin.Context and AuthState.
// It sets the "Auth-Status" header to "OK" and the "X-Nauthilus-Session" header to the GUID of the AuthState.
// If the AuthState's Service is not global.ServBasicAuth, and the HaveAccountField flag is true,
// If the AuthState's Service is not global.ServBasic, and the HaveAccountField flag is true,
// it retrieves the account from the AuthState and sets the "Auth-User" header
func setCommonHeaders(ctx *gin.Context, a *AuthState) {
ctx.Header("Auth-Status", "OK")
ctx.Header("X-Nauthilus-Session", *a.GUID)

if a.Service != global.ServBasicAuth {
if a.Service != global.ServBasic {
if account, found := a.getAccountOk(); found {
ctx.Header("Auth-User", account)
}
Expand Down Expand Up @@ -650,7 +650,7 @@ func setNginxHeaders(ctx *gin.Context, a *AuthState) {
}
}

// setDovecotHeaders sets the specified headers in the given gin.Context based on the attributes in the AuthState object.
// setHeaderHeaders sets the specified headers in the given gin.Context based on the attributes in the AuthState object.
// It iterates through the attributes and calls the handleAttributeValue function for each attribute.
//
// Parameters:
Expand All @@ -665,12 +665,12 @@ func setNginxHeaders(ctx *gin.Context, a *AuthState) {
// "Attribute2": []any{"Value2_1", "Value2_2"},
// },
// }
// setDovecotHeaders(ctx, a)
// setHeaderHeaders(ctx, a)
//
// Resulting headers in ctx:
// - X-Nauthilus-Attribute1: "Value1"
// - X-Nauthilus-Attribute2: "Value2_1,Value2_2"
func setDovecotHeaders(ctx *gin.Context, a *AuthState) {
func setHeaderHeaders(ctx *gin.Context, a *AuthState) {
if a.Attributes != nil && len(a.Attributes) > 0 {
for name, value := range a.Attributes {
handleAttributeValue(ctx, name, value)
Expand Down Expand Up @@ -805,7 +805,7 @@ func (a *AuthState) setFailureHeaders(ctx *gin.Context) {

ctx.Header("Auth-Wait", fmt.Sprintf("%v", waitDelay))
}
case global.ServUserInfo, global.ServJSON:
case global.ServJSON:
ctx.Header("Content-Type", "application/json; charset=UTF-8")

if a.PasswordHistory != nil {
Expand Down Expand Up @@ -859,7 +859,7 @@ func (a *AuthState) setSMPTHeaders(ctx *gin.Context) {
//
// func (a *AuthState) authTempFail(ctx *gin.Context, reason string) {
// ...
// if a.Service == global.ServUserInfo {
// if a.Service == global.ServJSON {
// a.sendAuthResponse(ctx, reason)
// return
// }
Expand All @@ -875,7 +875,6 @@ func (a *AuthState) setUserInfoHeaders(ctx *gin.Context, reason string) {
}

ctx.Header("Content-Type", "application/json; charset=UTF-8")
ctx.Header("X-User-Found", fmt.Sprintf("%v", a.UserFound))

ctx.JSON(a.StatusCodeInternalError, &errType{Error: reason})
}
Expand All @@ -892,12 +891,12 @@ func (a *AuthState) setUserInfoHeaders(ctx *gin.Context, reason string) {
//
// Usage example:
//
// func (a *AuthState) generic(ctx *gin.Context) {
// func (a *AuthState) handleAuthentication(ctx *gin.Context) {
// ...
// a.authTempFail(ctx, global.TempFailDefault)
// ...
// }
// func (a *AuthState) saslAuthd(ctx *gin.Context) {
// func (a *AuthState) handleSASLAuthdAuthentication(ctx *gin.Context) {
// ...
// a.authTempFail(ctx, global.TempFailDefault)
// ...
Expand All @@ -915,8 +914,9 @@ func (a *AuthState) authTempFail(ctx *gin.Context, reason string) {

a.StatusMessage = reason

if a.Service == global.ServUserInfo {
if a.Service == global.ServJSON {
a.setUserInfoHeaders(ctx, reason)

return
}

Expand Down Expand Up @@ -1117,11 +1117,11 @@ func updateAuthentication(a *AuthState, passDBResult *PassDBResult, passDB *Pass
// setStatusCodes sets different status codes for various services.
func (a *AuthState) setStatusCodes(service string) error {
switch service {
case global.ServNginx, global.ServDovecot:
case global.ServNginx:
a.StatusCodeOK = http.StatusOK
a.StatusCodeInternalError = http.StatusOK
a.StatusCodeFail = http.StatusOK
case global.ServSaslauthd, global.ServBasicAuth, global.ServOryHydra, global.ServUserInfo, global.ServJSON, global.ServCallback:
case global.ServSaslauthd, global.ServBasic, global.ServOryHydra, global.ServHeader, global.ServJSON, global.ServCallback:
a.StatusCodeOK = http.StatusOK
a.StatusCodeInternalError = http.StatusInternalServerError
a.StatusCodeFail = http.StatusForbidden
Expand Down Expand Up @@ -2163,10 +2163,10 @@ func setupBodyBasedAuth(ctx *gin.Context, auth *AuthState) {
}
}

// setupHTTPBasiAuth sets up basic authentication for HTTP requests.
// setupHTTPBasicAuth sets up basic authentication for HTTP requests.
// It takes in a gin.Context object and a pointer to an AuthState object.
// It calls the withClientInfo, withLocalInfo, withUserAgent, and withXSSL methods of the AuthState object to set client, local, user-agent, and X-SSL information, respectively
func setupHTTPBasiAuth(ctx *gin.Context, auth *AuthState) {
func setupHTTPBasicAuth(ctx *gin.Context, auth *AuthState) {
// NOTE: We must get username and password later!
auth.withClientInfo(ctx)
auth.withLocalInfo(ctx)
Expand All @@ -2192,9 +2192,9 @@ func (a *AuthState) initMethodAndUserAgent() *AuthState {
// setupAuth sets up the authentication based on the service parameter in the gin context.
// It takes the gin context and an AuthState struct as input.
//
// If the service parameter is "nginx", "dovecot", or "user", it calls the setupHeaderBasedAuth function.
// If the service parameter is "nginx" or "header", it calls the setupHeaderBasedAuth function.
// If the service parameter is "saslauthd", it calls the setupBodyBasedAuth function.
// If the service parameter is "basicauth", it calls the setupHTTPBasiAuth function.
// If the service parameter is "basicauth", it calls the setupHTTPBasicAuth function.
//
// After setting up the authentication, it calls the withDefaults method on the AuthState struct.
//
Expand All @@ -2208,19 +2208,19 @@ func setupAuth(ctx *gin.Context, auth *AuthState) {
auth.Protocol = &config.Protocol{}

switch ctx.Param("service") {
case global.ServNginx, global.ServDovecot, global.ServUserInfo:
case global.ServNginx, global.ServHeader:
setupHeaderBasedAuth(ctx, auth)
case global.ServSaslauthd, global.ServJSON:
setupBodyBasedAuth(ctx, auth)
case global.ServBasicAuth:
setupHTTPBasiAuth(ctx, auth)
case global.ServBasic:
setupHTTPBasicAuth(ctx, auth)
case global.ServCallback:
auth.withDefaults(ctx)

return
}

if ctx.Query("mode") != "list-accounts" && ctx.Param("service") != global.ServBasicAuth {
if ctx.Query("mode") != "list-accounts" && ctx.Param("service") != global.ServBasic {
if !util.ValidateUsername(auth.Username) {
auth.Username = ""

Expand Down Expand Up @@ -2293,7 +2293,7 @@ func (a *AuthState) withDefaults(ctx *gin.Context) *AuthState {
a.Service = ctx.Param("service")
a.Context = ctx.MustGet(global.CtxDataExchangeKey).(*lualib.Context)

if a.Service == global.ServBasicAuth {
if a.Service == global.ServBasic {
a.Protocol.Set(global.ProtoHTTP)
}

Expand Down
2 changes: 1 addition & 1 deletion server/core/features.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func isLocalOrEmptyIP(ip string) bool {
return ip == global.Localhost4 || ip == global.Localhost6 || ip == ""
}

// logAddMessage logs a message with the specified parameters using the global logger. It is intended to be a generic logging function.
// logAddMessage logs a message with the specified parameters using the global logger. It is intended to be a handleAuthentication logging function.
//
// Parameters:
// - auth: Pointer to AuthState
Expand Down
57 changes: 18 additions & 39 deletions server/core/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ var (
LangBundle *i18n.Bundle
)

// RESTResult is a generic JSON result object for the Nauthilus REST API.
// RESTResult is a handleAuthentication JSON result object for the Nauthilus REST API.
type RESTResult struct {
// GUID represents a unique identifier for a session. It is a string field used in the RESTResult struct
// and is also annotated with the json tag "session".
Expand Down Expand Up @@ -165,12 +165,12 @@ func (w *customWriter) Write(data []byte) (numBytes int, err error) {
}

//nolint:gocognit // Main logic
func httpQueryHandler(ctx *gin.Context) {
func requestHandler(ctx *gin.Context) {
if ctx.FullPath() == "/ping" {
healthCheck(ctx)
} else {
switch ctx.Param("category") {
case global.CatMail, global.CatGeneric:
case global.CatAuth:
auth := NewAuthState(ctx)
if auth == nil {
ctx.AbortWithStatus(http.StatusBadRequest)
Expand All @@ -185,42 +185,21 @@ func httpQueryHandler(ctx *gin.Context) {
}

switch ctx.Param("service") {
case global.ServNginx, global.ServDovecot, global.ServUserInfo, global.ServJSON:
auth.generic(ctx)
case global.ServBasic, global.ServNginx, global.ServHeader, global.ServJSON:
auth.handleAuthentication(ctx)
case global.ServSaslauthd:
auth.saslAuthd(ctx)
auth.handleSASLAuthdAuthentication(ctx)
case global.ServCallback:
auth.callback(ctx)
auth.handleCallback(ctx)
ctx.Status(auth.StatusCodeOK)
default:
ctx.AbortWithStatus(http.StatusNotFound)
}

case global.CatHTTP:
auth := NewAuthState(ctx)
if auth == nil {
ctx.AbortWithStatus(http.StatusBadRequest)

return
}

if found, reject := auth.preproccessAuthRequest(ctx); reject {
return
} else if found {
auth.withClientInfo(ctx).withLocalInfo(ctx).withUserAgent(ctx).withXSSL(ctx)
}

switch ctx.Param("service") {
case global.ServBasicAuth:
auth.generic(ctx)
default:
ctx.AbortWithStatus(http.StatusNotFound)
}

case global.CatBruteForce:
switch ctx.Param("service") {
case global.ServList:
listBruteforce(ctx)
hanldeBruteForceList(ctx)
default:
ctx.AbortWithStatus(http.StatusNotFound)
}
Expand All @@ -239,23 +218,23 @@ func httpQueryHandler(ctx *gin.Context) {
// 2. It uses a switch statement to handle different category values.
// 3. For the "cache" category, it retrieves the "service" parameter and uses a switch statement
// to handle different service values.
// 4. For the "flush" service, it calls the flushCache function.
// 4. For the "flush" service, it calls the handleUserFlush function.
// 5. For the "bruteforce" category, it retrieves the "service" parameter and uses a switch statement
// to handle different service values.
// 6. For the "flush" service, it calls the flushBruteForceRule function.
// 6. For the "flush" service, it calls the handleBruteForceRuleFlush function.
func httpCacheHandler(ctx *gin.Context) {
//nolint:gocritic // Prepared for future commands
switch ctx.Param("category") {
case global.CatCache:
switch ctx.Param("service") {
case global.ServFlush:
flushCache(ctx)
handleUserFlush(ctx)
}

case global.CatBruteForce:
switch ctx.Param("service") {
case global.ServFlush:
flushBruteForceRule(ctx)
handleBruteForceRuleFlush(ctx)
}
}
}
Expand Down Expand Up @@ -359,7 +338,7 @@ func basicAuthMiddleware() gin.HandlerFunc {
guid := ctx.GetString(global.CtxGUIDKey)

// Note: Chicken-egg problem.
if ctx.Param("category") == global.CatHTTP && ctx.Param("service") == global.ServBasicAuth {
if ctx.Param("category") == global.CatAuth && ctx.Param("service") == global.ServBasic {
level.Warn(log.Logger).Log(
global.LogKeyGUID, guid,
global.LogKeyMsg, "Disabling HTTP basic Auth",
Expand Down Expand Up @@ -732,8 +711,8 @@ func setupNotifyEndpoint(router *gin.Engine, sessionStore sessions.Store) {
// it adds a middleware to the group that implements basic authentication.
//
// It then adds three endpoints to the group:
// - A GET endpoint with the path "/:category/:service" that is handled by the luaContextMiddleware and httpQueryHandler functions.
// - A POST endpoint with the path "/:category/:service" that is also handled by the luaContextMiddleware and httpQueryHandler functions.
// - A GET endpoint with the path "/:category/:service" that is handled by the luaContextMiddleware and requestHandler functions.
// - A POST endpoint with the path "/:category/:service" that is also handled by the luaContextMiddleware and requestHandler functions.
// - A DELETE endpoint with the path "/:category/:service" that is handled by the httpCacheHandler function.
func setupBackChannelEndpoints(router *gin.Engine) {
group := router.Group("/api/v1")
Expand All @@ -742,8 +721,8 @@ func setupBackChannelEndpoints(router *gin.Engine) {
group.Use(basicAuthMiddleware())
}

group.GET("/:category/:service", luaContextMiddleware(), httpQueryHandler)
group.POST("/:category/:service", luaContextMiddleware(), httpQueryHandler)
group.GET("/:category/:service", luaContextMiddleware(), requestHandler)
group.POST("/:category/:service", luaContextMiddleware(), requestHandler)
group.DELETE("/:category/:service", httpCacheHandler)
}

Expand Down Expand Up @@ -990,7 +969,7 @@ func setupRouter(router *gin.Engine) {
router.GET("/metrics", gin.WrapF(promhttp.Handler().ServeHTTP))

// Healthcheck
router.GET("/ping", httpQueryHandler)
router.GET("/ping", requestHandler)

// Parse static folder for template files
router.LoadHTMLGlob(viper.GetString("html_static_content_path") + "/*.html")
Expand Down
Loading

0 comments on commit 1f08b09

Please sign in to comment.