diff --git a/internal/repository/share.go b/internal/repository/share.go index c211f7e..652a225 100644 --- a/internal/repository/share.go +++ b/internal/repository/share.go @@ -12,6 +12,7 @@ type ShareRepository interface { SearchShare(ctx context.Context, email string, uniqueName string) ([]*model.Share, error) DeleteShare(ctx context.Context, id int64) error GetShareByUniqueName(ctx context.Context, uniqueName string) (*model.Share, error) + GetSharesByAccountId(ctx context.Context, accountId int) ([]*model.Share, error) } func NewShareRepository( @@ -26,6 +27,14 @@ type shareRepository struct { *Repository } +func (r *shareRepository) GetSharesByAccountId(ctx context.Context, accountId int) ([]*model.Share, error) { + var shares []*model.Share + if err := r.DB(ctx).Where("account_id = ?", accountId).Find(&shares).Error; err != nil { + return nil, err + } + return shares, nil +} + func (r *shareRepository) GetShareByUniqueName(ctx context.Context, uniqueName string) (*model.Share, error) { var share model.Share if err := r.DB(ctx).Where("unique_name = ?", uniqueName).First(&share).Error; err != nil { diff --git a/internal/service/account.go b/internal/service/account.go index f3c2456..7027137 100644 --- a/internal/service/account.go +++ b/internal/service/account.go @@ -76,7 +76,7 @@ func (s *accountService) RefreshAccount(ctx context.Context, id int64) error { return err } // 刷新此Account的所有ShareToken - shares, err := s.shareService.SearchShare(ctx, account.Email, "") + shares, err := s.shareService.GetSharesByAccountId(ctx, int(account.ID)) if err != nil { return err } @@ -90,7 +90,12 @@ func (s *accountService) RefreshAccount(ctx context.Context, id int64) error { } func (s *accountService) Update(ctx context.Context, account *model.Account) error { - err := s.accountRepository.Update(ctx, account) + // 刷新所有share + err := s.RefreshAccount(ctx, int64(account.ID)) + if err != nil { + return err + } + err = s.accountRepository.Update(ctx, account) if err != nil { return err } diff --git a/internal/service/share.go b/internal/service/share.go index 8759bc3..23a1f84 100644 --- a/internal/service/share.go +++ b/internal/service/share.go @@ -6,7 +6,6 @@ import ( "PandoraHelper/internal/repository" "context" "fmt" - "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" "github.com/spf13/viper" "go.uber.org/zap" @@ -24,8 +23,9 @@ type ShareService interface { SearchShare(ctx context.Context, email string, uniqueName string) ([]*model.Share, error) DeleteShare(ctx context.Context, id int64) error LoginShareByPassword(ctx context.Context, username string, password string) (string, error) - ShareStatistic(ctx *gin.Context, accountId int) (interface{}, interface{}) - ShareResetPassword(ctx *gin.Context, uniqueName string, password string, newPassword string, confirmNewPassword string) error + ShareStatistic(ctx context.Context, accountId int) (interface{}, interface{}) + ShareResetPassword(ctx context.Context, uniqueName string, password string, newPassword string, confirmNewPassword string) error + GetSharesByAccountId(ctx context.Context, accountId int) ([]*model.Share, error) } func NewShareService(service *Service, shareRepository repository.ShareRepository, viper *viper.Viper, coordinator *Coordinator) ShareService { @@ -44,7 +44,11 @@ type shareService struct { accountService AccountService } -func (s *shareService) ShareResetPassword(ctx *gin.Context, uniqueName string, password string, newPassword string, confirmNewPassword string) error { +func (s *shareService) GetSharesByAccountId(ctx context.Context, accountId int) ([]*model.Share, error) { + return s.shareRepository.GetSharesByAccountId(ctx, accountId) +} + +func (s *shareService) ShareResetPassword(ctx context.Context, uniqueName string, password string, newPassword string, confirmNewPassword string) error { share, err := s.shareRepository.GetShareByUniqueName(ctx, uniqueName) if err != nil { return err @@ -64,7 +68,7 @@ func (s *shareService) ShareResetPassword(ctx *gin.Context, uniqueName string, p } // ShareStatistic 转换为Go语言 -func (s *shareService) ShareStatistic(ctx *gin.Context, accountId int) (interface{}, interface{}) { +func (s *shareService) ShareStatistic(ctx context.Context, accountId int) (interface{}, interface{}) { account, err := s.accountService.GetAccount(ctx, int64(accountId)) if err != nil { return nil, err