Skip to content

Commit

Permalink
Fix: Refactor Redis cache handling and brute force counter logic
Browse files Browse the repository at this point in the history
Initialize PositivePasswordCache before LoadCacheFromRedis call and remove BruteForceBucketCache type. Replace generic RedisCache with PositivePasswordCache and refactor brute force counter functions to use a new load function.

Signed-off-by: Christian Roessner <c@roessner.co>
  • Loading branch information
Christian Roessner committed Nov 12, 2024
1 parent 7ffc802 commit 97798b3
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 31 deletions.
33 changes: 11 additions & 22 deletions server/backend/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ import (
"github.com/redis/go-redis/v9"
)

// BruteForceBucketCache is a Redis cache. It is a union member of RedisCache.
type BruteForceBucketCache uint

// PasswordHistory is a map of hashed passwords with their failure counter.
type PasswordHistory map[string]uint

Expand All @@ -43,19 +40,13 @@ type PasswordHistory map[string]uint
// refreshed upon continuous requests. If the Redis TTL has expired, the object is removed from the cache to force a refresh
// of the user data from underlying databases.
type PositivePasswordCache struct {
Backend global.Backend `redis:"passdb_backend"`
Password string `redis:"password"`
AccountField *string `redis:"account_field"`
TOTPSecretField *string `redis:"totp_secret_field"`
UniqueUserIDField *string `redis:"webauth_userid_field"`
DisplayNameField *string `redis:"display_name_field"`
Attributes DatabaseResult `redis:"attributes"`
}

// RedisCache is a union used for LoadCacheFromRedis and SaveUserDataToRedis Redis routines.
// These routines are generics.
type RedisCache interface {
PositivePasswordCache | BruteForceBucketCache
Backend global.Backend `json:"passdb_backend"`
Password string `json:"password"`
AccountField *string `json:"account_field"`
TOTPSecretField *string `json:"totp_secret_field"`
UniqueUserIDField *string `json:"webauth_userid_field"`
DisplayNameField *string `json:"display_name_field"`
Attributes DatabaseResult `json:"attributes"`
}

// LookupUserAccountFromRedis returns the user account value from the user Redis hash.
Expand All @@ -81,7 +72,7 @@ func LookupUserAccountFromRedis(ctx context.Context, username string) (accountNa
// If there is an error retrieving the value from Redis, it returns isRedisErr=true and err.
// Otherwise, it unmarshals the value into the cache pointer and returns isRedisErr=false and err=nil.
// It also logs any error messages using the Logger.
func LoadCacheFromRedis[T RedisCache](ctx context.Context, key string, cache **T) (isRedisErr bool, err error) {
func LoadCacheFromRedis(ctx context.Context, key string, ucp *PositivePasswordCache) (isRedisErr bool, err error) {
var redisValue []byte

defer stats.RedisReadCounter.Inc()
Expand All @@ -96,24 +87,22 @@ func LoadCacheFromRedis[T RedisCache](ctx context.Context, key string, cache **T
return true, err
}

*cache = new(T)

if err = json.Unmarshal(redisValue, *cache); err != nil {
if err = json.Unmarshal(redisValue, ucp); err != nil {
level.Error(log.Logger).Log(global.LogKeyMsg, err)

return
}

util.DebugModule(
global.DbgCache,
global.LogKeyMsg, "Load password history from redis", "type", fmt.Sprintf("%T", **cache))
global.LogKeyMsg, "Load password history from redis", "type", fmt.Sprintf("%T", *ucp))

return false, nil
}

// SaveUserDataToRedis is a generic routine to store a cache object on Redis. The type is a RedisCache, which is a
// union.
func SaveUserDataToRedis[T RedisCache](ctx context.Context, guid string, key string, ttl uint, cache *T) {
func SaveUserDataToRedis(ctx context.Context, guid string, key string, ttl uint, cache *PositivePasswordCache) {
var result string

util.DebugModule(
Expand Down
6 changes: 4 additions & 2 deletions server/core/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1616,10 +1616,12 @@ func (a *AuthState) postVerificationProcesses(ctx *gin.Context, useCache bool, b
DisplayNameField: a.DisplayNameField,
Password: func() string {
if a.Password != "" {
return util.GetHash(util.PreparePassword(a.Password))
passwordShort := util.GetHash(util.PreparePassword(a.Password))

return passwordShort
}

return a.Password
return ""
}(),
Backend: a.SourcePassDBBackend,
Attributes: a.Attributes,
Expand Down
46 changes: 40 additions & 6 deletions server/core/bruteforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package core

import (
"context"
"encoding/json"
stderrors "errors"
"fmt"
"net"
Expand All @@ -38,6 +40,9 @@ import (
"github.com/redis/go-redis/v9"
)

// BruteForceBucketCounter represents a cache mechanism to handle brute force attack mitigation using brute force buckets.
type BruteForceBucketCounter uint

// isRepeatingWrongPassword is a method associated with the AuthState struct used to check for repeated wrong password usage.
// It retrieves and loads a password history from Redis using a certain key.
// The function then checks if the current password has previously been within the loaded history and if its attempt count exceeds one.
Expand Down Expand Up @@ -489,23 +494,52 @@ func (a *AuthState) saveFailedPasswordCounterInRedis() {
}
}

// loadBruteForceBucketCounterFromRedis is a method on the AuthState struct that loads the brute force
// loadBruteForceBucketCounterFromRedis loads a bucket counter from Redis.
// Increments Redis read operations counter.
// On success, unmarshals the Redis value into bucketCounter.
// If the key doesn't exist, returns nil.
// Logs errors if Redis operations or JSON unmarshalling fails.
func loadBruteForceBucketCounterFromRedis(ctx context.Context, key string, bucketCounter *BruteForceBucketCounter) (err error) {
var redisValue []byte

defer stats.RedisReadCounter.Inc()

if redisValue, err = rediscli.ReadHandle.Get(ctx, key).Bytes(); err != nil {
if stderrors.Is(err, redis.Nil) {
return nil
}

level.Error(log.Logger).Log(global.LogKeyMsg, err)

return err
}

if err = json.Unmarshal(redisValue, bucketCounter); err != nil {
level.Error(log.Logger).Log(global.LogKeyMsg, err)

return
}

return nil
}

// loadBruteForceBucketCounter is a method on the AuthState struct that loads the brute force
// bucket counter from Redis and updates the BruteForceCounter map. The given BruteForceRule is used to generate the Redis key.
// If the key is not empty, it retrieves the counter-value from Redis using the backend.LoadCacheFromRedis function.
// If an error occurs while loading the cache, the function returns.
// If the BruteForceCounter is not initialized, it creates a new map.
// Finally, it updates the BruteForceCounter map with the counter-value retrieved from Redis using the rule name as the key.
func (a *AuthState) loadBruteForceBucketCounterFromRedis(rule *config.BruteForceRule) {
func (a *AuthState) loadBruteForceBucketCounter(rule *config.BruteForceRule) {
if !config.LoadableConfig.HasFeature(global.FeatureBruteForce) {
return
}

cache := new(backend.BruteForceBucketCache)
bucketCounter := new(BruteForceBucketCounter)

if key := a.getBruteForceBucketRedisKey(rule); key != "" {
util.DebugModule(global.DbgBf, global.LogKeyGUID, a.GUID, "load_key", key)

if _, err := backend.LoadCacheFromRedis(a.HTTPClientContext, key, &cache); err != nil {
if err := loadBruteForceBucketCounterFromRedis(a.HTTPClientContext, key, bucketCounter); err != nil {
return
}
}
Expand All @@ -514,7 +548,7 @@ func (a *AuthState) loadBruteForceBucketCounterFromRedis(rule *config.BruteForce
a.BruteForceCounter = make(map[string]uint)
}

a.BruteForceCounter[rule.Name] = uint(*cache)
a.BruteForceCounter[rule.Name] = uint(*bucketCounter)
}

// saveBruteForceBucketCounterToRedis is a method on the AuthState struct that saves brute force
Expand Down Expand Up @@ -763,7 +797,7 @@ func (a *AuthState) checkBucketOverLimit(rules []config.BruteForceRule, network
continue
}

a.loadBruteForceBucketCounterFromRedis(&rules[ruleNumber])
a.loadBruteForceBucketCounter(&rules[ruleNumber])

// The counter goes from 0...N-1, but the 'failed_requests' setting from 1...N
if a.BruteForceCounter[rules[ruleNumber].Name]+1 > rules[ruleNumber].FailedRequests {
Expand Down
4 changes: 3 additions & 1 deletion server/core/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ func cachePassDB(auth *AuthState) (passDBResult *PassDBResult, err error) {
if accountName != "" {
redisPosUserKey := config.LoadableConfig.Server.Redis.Prefix + "ucp:" + cacheName + ":" + accountName

if _, err = backend.LoadCacheFromRedis(auth.HTTPClientContext, redisPosUserKey, &ppc); err != nil {
ppc = &backend.PositivePasswordCache{}

if _, err = backend.LoadCacheFromRedis(auth.HTTPClientContext, redisPosUserKey, ppc); err != nil {
return
}
}
Expand Down

0 comments on commit 97798b3

Please sign in to comment.