diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 3138fd034..e0f564552 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -4,10 +4,13 @@ import ( "bufio" "errors" "fmt" + "net" + "net/url" "os" "os/exec" "runtime" "strings" + "time" "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" @@ -17,6 +20,12 @@ import ( "github.com/pkg/browser" ) +// refreshBeforeExpiry is how far in advance of access-token expiration the +// CLI refreshes. Using a window larger than typical request RTTs avoids 401 +// round-trips at the tail of a token's life, at the cost of refreshing a +// small number of still-valid tokens. +const refreshBeforeExpiry = 5 * time.Minute + type LoginAuth struct { Auth } @@ -146,6 +155,13 @@ func (t Auth) GetFreshAccessTokenOrNil() (string, error) { return "", nil } + // Older CLI versions stored the literal string "auto-login" when + // `brev login --token` had no real refresh token to save. Treat it as + // absent so we do not attempt to exchange it with the IdP and fail. + if tokens.RefreshToken == autoLoginSentinel { + tokens.RefreshToken = "" + } + // should always at least have access token? if tokens.AccessToken == "" { breverrors.GetDefaultErrorReporter().ReportMessage("access token is an empty string but shouldn't be") @@ -154,20 +170,123 @@ func (t Auth) GetFreshAccessTokenOrNil() (string, error) { if err != nil { return "", breverrors.WrapAndTrace(err) } - if !isAccessTokenValid && tokens.RefreshToken != "" { - tokens, err = t.getNewTokensWithRefreshOrNil(tokens.RefreshToken) - if err != nil { - return "", breverrors.WrapAndTrace(err) + + // Trigger a refresh when the token is invalid OR when it is still valid + // but close enough to expiry that the next API call is likely to race + // the exp boundary. The proactive branch is tolerant of refresh failure: + // if the IdP is briefly unreachable, fall back to the (still-valid) + // current access token rather than logging the user out. + expiringSoon := isAccessTokenValid && tokens.RefreshToken != "" && accessTokenExpiresSoon(tokens) + if !isAccessTokenValid || expiringSoon { + if tokens.RefreshToken == "" { + // Access token is expired and we have no refresh token. Returning + // the expired token here would just cause a 401 on the next API + // call; return empty so callers can prompt for re-login instead. + return "", nil } - if tokens == nil { + newTokens, refreshErr := t.getNewTokensWithRefreshOrNil(tokens.RefreshToken) + if refreshErr != nil { + if expiringSoon { + // Current token still validates; swallow the transient + // failure and try again on the next call. + return tokens.AccessToken, nil + } + return "", breverrors.WrapAndTrace(refreshErr) + } + if newTokens == nil { return "", nil } - } else if tokens.RefreshToken == "" && tokens.AccessToken == "" { - return "", nil + tokens = newTokens } return tokens.AccessToken, nil } +// accessTokenExpiresSoon reports whether the stored access token's +// expiration is within refreshBeforeExpiry of now. It prefers the persisted +// AccessTokenExp field (written by populateTokenTimestamps on save) and +// falls back to decoding the access JWT for files written by older CLI +// versions that never persisted the claim. +func accessTokenExpiresSoon(tokens *entity.AuthTokens) bool { + var exp time.Time + if tokens.AccessTokenExp != nil { + exp = *tokens.AccessTokenExp + } else { + exp, _ = accessTokenClaims(tokens.AccessToken) + } + if exp.IsZero() { + return false + } + return time.Until(exp) < refreshBeforeExpiry +} + +// accessTokenClaims parses the access JWT without signature verification +// and returns its exp and iat claims. Missing or malformed claims are +// returned as the zero time.Time; the caller is responsible for guarding +// with IsZero(). +func accessTokenClaims(token string) (exp, iat time.Time) { + if token == "" { + return time.Time{}, time.Time{} + } + parser := jwt.Parser{} + ptoken, _, err := parser.ParseUnverified(token, jwt.MapClaims{}) + if err != nil { + return time.Time{}, time.Time{} + } + claims, ok := ptoken.Claims.(jwt.MapClaims) + if !ok { + return time.Time{}, time.Time{} + } + if v, ok := claims["exp"].(float64); ok { + exp = time.Unix(int64(v), 0) + } + if v, ok := claims["iat"].(float64); ok { + iat = time.Unix(int64(v), 0) + } + return exp, iat +} + +// populateTokenTimestamps fills in AccessTokenExp and IssuedAt from the +// access JWT when they are not already set. Safe to call on any AuthTokens +// value; missing or non-JWT access tokens leave the fields nil. +func populateTokenTimestamps(tokens *entity.AuthTokens) { + if tokens == nil || tokens.AccessToken == "" { + return + } + exp, iat := accessTokenClaims(tokens.AccessToken) + if tokens.AccessTokenExp == nil && !exp.IsZero() { + tokens.AccessTokenExp = &exp + } + if tokens.IssuedAt == nil && !iat.IsZero() { + tokens.IssuedAt = &iat + } +} + +// isTransientRefreshError reports whether an error from the OAuth refresh +// call is a transient network condition (timeout, connection refused, +// DNS failure, etc.) as opposed to an authoritative rejection of the +// refresh token by the IdP. Transient errors should not force the user to +// re-login. +func isTransientRefreshError(err error) bool { + if err == nil { + return false + } + var urlErr *url.Error + if errors.As(err, &urlErr) { + if urlErr.Timeout() { + return true + } + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + // DNS / connection-refused / TLS handshake errors surface as url.Error + // wrapping an *net.OpError. Treat connection-level failures as + // transient: the refresh token is probably fine, the network isn't. + var opErr *net.OpError + return errors.As(err, &opErr) +} + // Prompts for login and returns tokens, and saves to store func (t Auth) PromptForLogin() (*LoginTokens, error) { shouldLogin, err := t.shouldLogin() @@ -197,25 +316,42 @@ func shouldLogin() (bool, error) { return trimmed == "y" || trimmed == "", nil } +// autoLoginSentinel is a legacy value older CLI versions stored in place of +// a real token when `brev login --token` was used. It is not a valid token of +// any kind; treat it as "absent" on read. +const autoLoginSentinel = "auto-login" + func (t Auth) LoginWithToken(token string) error { valid, err := isAccessTokenValid(token) if err != nil { return breverrors.WrapAndTrace(err) } if valid { - err := t.authStore.SaveAuthTokens(entity.AuthTokens{ + // The token is a self-contained JWT access token with no accompanying + // refresh token. Previously we stored the string "auto-login" in the + // RefreshToken slot; when the access token expired the refresh path + // then attempted to exchange that sentinel with the IdP, which always + // failed, logging the user out every time the short-lived access + // token aged out. Store an empty RefreshToken instead so the refresh + // path correctly recognizes there is nothing to refresh and prompts + // for a fresh login exactly once. + fmt.Fprintln(os.Stderr, "Note: tokens from --token cannot be refreshed; re-run `brev login` when the session expires.") + tokens := entity.AuthTokens{ AccessToken: token, - RefreshToken: "auto-login", - }) - if err != nil { + RefreshToken: "", + } + populateTokenTimestamps(&tokens) + if err := t.authStore.SaveAuthTokens(tokens); err != nil { return breverrors.WrapAndTrace(err) } } else { - err := t.authStore.SaveAuthTokens(entity.AuthTokens{ - AccessToken: "auto-login", + // The token is not a JWT, assume it is a refresh token. The access + // token slot is filled with the sentinel so the first API call + // triggers a refresh to populate a real access token. + if err := t.authStore.SaveAuthTokens(entity.AuthTokens{ + AccessToken: autoLoginSentinel, RefreshToken: token, - }) - if err != nil { + }); err != nil { return breverrors.WrapAndTrace(err) } } @@ -322,13 +458,20 @@ func (t Auth) getSavedTokensOrNil() (*entity.AuthTokens, error) { // gets new access and refresh token or returns nil if refresh token expired, and updates store func (t Auth) getNewTokensWithRefreshOrNil(refreshToken string) (*entity.AuthTokens, error) { tokens, err := t.oauth.GetNewAuthTokensWithRefresh(refreshToken) - // TODO 2 handle if 403 invalid grant - // https://stackoverflow.com/questions/57383523/how-to-detect-when-an-oauth2-refresh-token-expired if err != nil { if strings.Contains(err.Error(), "not implemented") { return nil, nil } - return nil, breverrors.WrapAndTrace(err) + if isTransientRefreshError(err) { + // Network hiccup; do not clear the user's session. Surface the + // error so the caller can decide whether to swallow it (when + // the current access token is still valid) or propagate it. + return nil, breverrors.WrapAndTrace(fmt.Errorf("could not reach auth provider to refresh session: %w", err)) + } + // Definitive rejection from the IdP. Tell the user in plain + // language rather than burying it in a stack trace. + fmt.Fprintln(os.Stderr, "Your brev session could not be refreshed; re-run `brev login`.") + return nil, nil } if tokens == nil { return nil, nil @@ -336,6 +479,7 @@ func (t Auth) getNewTokensWithRefreshOrNil(refreshToken string) (*entity.AuthTok if tokens.RefreshToken == "" { tokens.RefreshToken = refreshToken } + populateTokenTimestamps(tokens) err = t.authStore.SaveAuthTokens(*tokens) if err != nil { diff --git a/pkg/entity/entity.go b/pkg/entity/entity.go index 868f1fee9..d359fb7a0 100644 --- a/pkg/entity/entity.go +++ b/pkg/entity/entity.go @@ -27,6 +27,14 @@ var LegacyWorkspaceGroups = map[string]bool{ type AuthTokens struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` + // AccessTokenExp and IssuedAt are populated from the access JWT's `exp` + // and `iat` claims when available. They let the CLI refresh proactively + // before the access token expires, and let UX surfaces like `brev + // status` display session lifetime without re-parsing the JWT. Both are + // optional: files written by older CLI versions lack these fields, and + // tokens whose JWTs do not carry the claims will leave them nil. + AccessTokenExp *time.Time `json:"access_token_exp,omitempty"` + IssuedAt *time.Time `json:"issued_at,omitempty"` } type IDEConfig struct { diff --git a/pkg/store/http.go b/pkg/store/http.go index 60884f810..1bb264971 100644 --- a/pkg/store/http.go +++ b/pkg/store/http.go @@ -104,7 +104,7 @@ func (s *AuthHTTPStore) SetForbiddenStatusRetryHandler(handler func() error) err } attemptsThresh := 1 s.authHTTPClient.restyClient.OnAfterResponse(func(c *resty.Client, r *resty.Response) error { - if r.StatusCode() == http.StatusForbidden && r.Request.Attempt < attemptsThresh+1 { + if isAuthFailure(r.StatusCode()) && r.Request.Attempt < attemptsThresh+1 { err := handler() if err != nil { return breverrors.WrapAndTrace(err) @@ -117,7 +117,7 @@ func (s *AuthHTTPStore) SetForbiddenStatusRetryHandler(handler func() error) err if e != nil { return false } - return r.StatusCode() == http.StatusForbidden + return isAuthFailure(r.StatusCode()) }) s.authHTTPClient.restyClient.SetRetryCount(attemptsThresh) @@ -125,6 +125,14 @@ func (s *AuthHTTPStore) SetForbiddenStatusRetryHandler(handler func() error) err return nil } +// isAuthFailure reports whether an HTTP status code indicates the caller's +// credentials are missing, invalid, or expired. Both 401 Unauthorized and +// 403 Forbidden can signal an expired access token from Brev's APIs, so we +// treat both as triggers for the refresh-and-retry path. +func isAuthFailure(code int) bool { + return code == http.StatusUnauthorized || code == http.StatusForbidden +} + type AuthHTTPClient struct { restyClient *resty.Client auth Auth