Skip to content

Commit

Permalink
Merge pull request #28 from martinohansen/main
Browse files Browse the repository at this point in the history
fix: 401 invalid token
  • Loading branch information
frieser authored Dec 23, 2024
2 parents 34afc3b + 6dddb28 commit 4bd54a9
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 67 deletions.
54 changes: 13 additions & 41 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,70 +30,42 @@ type Transport struct {
// StartTokenHandler handles token refreshes in the background
func (c *Client) StartTokenHandler(ctx context.Context) error {
// Initialize the first token
token, err := c.newToken(ctx)
err := c.newToken(ctx)
if err != nil {
return errors.New("getting initial token: " + err.Error())
}
c.m.Lock()
c.token = token
c.m.Unlock()

go c.tokenHandler(ctx)
return nil
}

// tokenHandler gets a new token using the refresh token and a new pair when the
// refresh token expires.
// refresh token expires
func (c *Client) tokenHandler(ctx context.Context) {
newTokenTimer := time.NewTimer(0) // Start immediately
refreshTokenTimer := time.NewTimer(0) // Start immediately
defer func() {
newTokenTimer.Stop()
refreshTokenTimer.Stop()
}()

resetTimer := func(timer *time.Timer, expiryTime time.Time) {
if !timer.Stop() {
<-timer.C
}
timer.Reset(time.Until(expiryTime))
}
refresh := time.NewTicker(time.Hour * 12) // 12 hours
new := time.NewTicker(time.Hour * 24 * 14) // 14 days
defer refresh.Stop()
defer new.Stop()

for {
c.m.RLock()
newTokenExpiry := c.token.accessExpires(2)
refreshTokenExpiry := c.token.refreshExpires(2)
c.m.RUnlock()

resetTimer(newTokenTimer, newTokenExpiry)
resetTimer(refreshTokenTimer, refreshTokenExpiry)

select {
case <-ctx.Done():
return
case <-newTokenTimer.C:
if token, err := c.newToken(ctx); err != nil {

case <-new.C:
if err := c.newToken(ctx); err != nil {
// TODO(Martin): Improve error handling
panic(fmt.Sprintf("getting new token: %s", err))
} else {
c.updateToken(token)
}
case <-refreshTokenTimer.C:
if token, err := c.refreshToken(ctx); err != nil {

case <-refresh.C:
if err := c.refreshToken(ctx); err != nil {
panic(fmt.Sprintf("refreshing token: %s", err))
} else {
c.updateToken(token)
}
}
}
}

// updateToken updates the client's token
func (c *Client) updateToken(t *Token) {
c.m.Lock()
defer c.m.Unlock()
c.token = t
}

func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "https"
req.URL.Host = baseUrl
Expand Down
3 changes: 1 addition & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ func initTestClient(t *testing.T) *Client {
c.c.Transport = Transport{rt: http.DefaultTransport, cli: c}

// Initialize the first token
token, err := c.newToken(context.Background())
err := c.newToken(context.Background())
if err != nil {
t.Fatalf("newToken: %s", err)
}

c.token = token
sharedClient = c
})

Expand Down
50 changes: 26 additions & 24 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"net/http"
"net/url"
"strings"
"time"
)

type Token struct {
Expand All @@ -31,23 +30,17 @@ const tokenPath = "token"
const tokenNewPath = "new/"
const tokenRefreshPath = "refresh/"

// accessExpires returns the time when access token expires divided by divisor
func (t *Token) accessExpires(divisor int) time.Time {
return time.Now().Add(time.Second * time.Duration(t.AccessExpires/divisor))
}

// refreshExpires returns the time when refresh token expires divided by divisor
func (t *Token) refreshExpires(divisor int) time.Time {
return time.Now().Add(time.Second * time.Duration(t.RefreshExpires/divisor))
}
// newToken gets a new access token
func (c *Client) newToken(ctx context.Context) error {
c.m.Lock()
defer c.m.Unlock()

func (c *Client) newToken(ctx context.Context) (*Token, error) {
data, err := json.Marshal(Secret{
SecretId: c.secretId,
AccessId: c.secretKey,
})
if err != nil {
return nil, err
return err
}

req := &http.Request{
Expand All @@ -61,29 +54,35 @@ func (c *Client) newToken(ctx context.Context) (*Token, error) {

resp, err := c.c.Do(req)
if err != nil {
return nil, err
return err
}
defer resp.Body.Close()

body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, readErr
return readErr
}
if resp.StatusCode != http.StatusOK {
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)}
return &APIError{StatusCode: resp.StatusCode, Body: string(body)}
}

t := &Token{}
if err := json.Unmarshal(body, t); err != nil {
return nil, err
return err
}
return t, nil

c.token = t
return nil
}

func (c *Client) refreshToken(ctx context.Context) (*Token, error) {
// refreshToken gets a new access token using the refresh token
func (c *Client) refreshToken(ctx context.Context) error {
c.m.Lock()
defer c.m.Unlock()

data, err := json.Marshal(TokenRefresh{Refresh: c.token.Refresh})
if err != nil {
return nil, err
return err
}

req := &http.Request{
Expand All @@ -97,21 +96,24 @@ func (c *Client) refreshToken(ctx context.Context) (*Token, error) {

resp, err := c.c.Do(req)
if err != nil {
return nil, err
return err
}
defer resp.Body.Close()

body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, readErr
return readErr
}
if resp.StatusCode != http.StatusOK {
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)}
return &APIError{StatusCode: resp.StatusCode, Body: string(body)}
}

t := &Token{}
if err := json.Unmarshal(body, t); err != nil {
return nil, err
return err
}
return t, nil

c.token.Access = t.Access
c.token.AccessExpires = t.AccessExpires
return nil
}

0 comments on commit 4bd54a9

Please sign in to comment.