diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e7d756e7..033aecf86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ ### Fixed +- Auth: bind browser, manual, remote, and account-manager OAuth exchanges with S256 PKCE; unfinished pre-PKCE manual flows must restart at step 1. (#693, #725) — thanks @TurboTheTurtle. - Docs: reset inherited text styles before applying Markdown find-replace formatting so leading bold spans and later inline styles stay paired correctly. (#735) — thanks @sebsnyk. - Docs: accept leading-dash Markdown list values in `docs cell-update --content` and reject nonempty Markdown that produces no editable cell text. (#733) — thanks @sebsnyk. - Docs: keep inline Markdown find-replace fragments inside their existing paragraph unless the replacement explicitly ends with a newline. (#736) — thanks @sebsnyk. diff --git a/docs/auth-clients.md b/docs/auth-clients.md index 369f6154d..f0ac87739 100644 --- a/docs/auth-clients.md +++ b/docs/auth-clients.md @@ -64,6 +64,11 @@ Shows stored credential files plus any configured domain mappings. - Legacy `token:` entries are copied to `token:default:` the first time they are read. - Legacy `default_account` is still respected for the default client. +- Browser, manual, remote, and account-manager authorization use S256 PKCE. + Manual state includes a short-lived verifier under the active `gog` config + directory. Keep the same `GOG_HOME` and `--client` between remote steps. +- Manual or remote authorization started before v0.24.0 cannot be completed + after upgrading. Run step 1 again to generate a PKCE-bound URL. ## Workspace service accounts diff --git a/docs/quickstart.md b/docs/quickstart.md index 872e65da5..5a2a18de9 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -64,6 +64,10 @@ refresh token in your OS keyring (Keychain on macOS, Secret Service on Linux, Credential Manager on Windows). Headless? Add `--manual` for a paste-the-URL flow, or `--remote --step 1`/`--step 2` for fully split server runs. +Installed-app authorization uses S256 PKCE. Complete a manual or remote flow +with the same `gog` home and client that generated its URL. After upgrading +from a pre-PKCE release, restart any unfinished flow at step 1. + Verify: ```bash diff --git a/docs/spec.md b/docs/spec.md index 53acd101e..0cd42cf6c 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -131,6 +131,10 @@ Implementation: `internal/secrets/store.go`. - Supports a remote/server-friendly 2-step manual flow: - Step 1 prints an auth URL (`gog auth add ... --remote --step 1`) - Step 2 exchanges the pasted redirect URL and requires `state` validation (`--remote --step 2 --auth-url ...`) + - Browser, manual, remote, and account-manager flows bind authorization + requests and token exchanges with S256 PKCE. + - Remote steps must share the same config home and OAuth client. Unfinished + pre-v0.24.0 flows must restart at step 1. - Refresh token issuance: - requests `access_type=offline` - supports `--force-consent` to force the consent prompt when Google doesn't return a refresh token @@ -151,7 +155,7 @@ Scope selection note: - `credentials-.json` (OAuth client id/secret; named clients) - State: - `state/gmail-watch/.json` (Gmail watch state) - - `oauth-manual-state-.json` (temporary manual OAuth state cache; expires quickly; no tokens) + - `oauth-manual-state-.json` (temporary manual OAuth state and PKCE verifier cache; expires quickly; no tokens) - Secrets: - refresh tokens in keyring diff --git a/internal/googleauth/accounts_server.go b/internal/googleauth/accounts_server.go index 309211933..8dd954a6a 100644 --- a/internal/googleauth/accounts_server.go +++ b/internal/googleauth/accounts_server.go @@ -58,7 +58,8 @@ type ManageServer struct { fetchIdentity func(ctx context.Context, tok *oauth2.Token) (Identity, error) oauthMu sync.Mutex oauthState string - oauthStates map[string]struct{} + oauthVerifier string + oauthStates map[string]string resultCh chan error } @@ -286,7 +287,8 @@ func (ms *ManageServer) handleAuthStart(w http.ResponseWriter, r *http.Request) return } - ms.addOAuthState(state) + codeVerifier := generateVerifierFn() + ms.addOAuthState(state, codeVerifier) services := manageServices(ms.opts.Services) @@ -306,7 +308,7 @@ func (ms *ManageServer) handleAuthStart(w http.ResponseWriter, r *http.Request) Scopes: scopes, } - authURL := cfg.AuthCodeURL(state, authURLParams(ms.opts.ForceConsent, true)...) + authURL := cfg.AuthCodeURL(state, pkceAuthURLParams(ms.opts.ForceConsent, true, codeVerifier)...) http.Redirect(w, r, authURL, http.StatusFound) } @@ -340,7 +342,8 @@ func (ms *ManageServer) handleAuthUpgrade(w http.ResponseWriter, r *http.Request return } - ms.addOAuthState(state) + codeVerifier := generateVerifierFn() + ms.addOAuthState(state, codeVerifier) // Use requested manage services (exclude Keep) services := manageServices(ms.opts.Services) @@ -364,7 +367,7 @@ func (ms *ManageServer) handleAuthUpgrade(w http.ResponseWriter, r *http.Request // Always force consent for upgrades to ensure user sees all scopes // Add login_hint to pre-select the account authURL := cfg.AuthCodeURL(state, - append(authURLParams(true, true), + append(pkceAuthURLParams(true, true, codeVerifier), oauth2.SetAuthURLParam("login_hint", email))...) http.Redirect(w, r, authURL, http.StatusFound) @@ -382,13 +385,21 @@ func (ms *ManageServer) handleOAuthCallback(w http.ResponseWriter, r *http.Reque return } - if !ms.consumeOAuthState(q.Get("state")) { + codeVerifier, ok := ms.consumeOAuthState(q.Get("state")) + if !ok { w.WriteHeader(http.StatusBadRequest) renderErrorPage(w, "State mismatch - possible CSRF attack. Please try again.") return } + if codeVerifier == "" { + w.WriteHeader(http.StatusBadRequest) + renderErrorPage(w, "Missing PKCE verifier. Please try again.") + + return + } + code := q.Get("code") if code == "" { w.WriteHeader(http.StatusBadRequest) @@ -428,7 +439,7 @@ func (ms *ManageServer) handleOAuthCallback(w http.ResponseWriter, r *http.Reque ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) defer cancel() - tok, err := cfg.Exchange(ctx, code) + tok, err := cfg.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier)) if err != nil { w.WriteHeader(http.StatusInternalServerError) renderErrorPage(w, "Failed to exchange code for token: "+err.Error()) @@ -597,38 +608,48 @@ type defaultAccountDeleter interface { DeleteDefaultAccount(client string) error } -func (ms *ManageServer) addOAuthState(state string) { +func (ms *ManageServer) addOAuthState(state string, codeVerifier string) { ms.oauthMu.Lock() defer ms.oauthMu.Unlock() ms.oauthState = state + ms.oauthVerifier = codeVerifier + if ms.oauthStates == nil { - ms.oauthStates = make(map[string]struct{}) + ms.oauthStates = make(map[string]string) } - ms.oauthStates[state] = struct{}{} + ms.oauthStates[state] = codeVerifier } -func (ms *ManageServer) consumeOAuthState(state string) bool { +func (ms *ManageServer) consumeOAuthState(state string) (string, bool) { ms.oauthMu.Lock() defer ms.oauthMu.Unlock() if ms.oauthStates != nil { - if _, ok := ms.oauthStates[state]; ok { + if codeVerifier, ok := ms.oauthStates[state]; ok { delete(ms.oauthStates, state) - return true + + if state == ms.oauthState { + ms.oauthState = "" + ms.oauthVerifier = "" + } + + return codeVerifier, true } - return false + return "", false } if state == "" || state != ms.oauthState { - return false + return "", false } ms.oauthState = "" + codeVerifier := ms.oauthVerifier + ms.oauthVerifier = "" - return true + return codeVerifier, true } func (ms *ManageServer) accountExists(email string) bool { diff --git a/internal/googleauth/accounts_server_more_test.go b/internal/googleauth/accounts_server_more_test.go index 49adef8fe..c6686f9e9 100644 --- a/internal/googleauth/accounts_server_more_test.go +++ b/internal/googleauth/accounts_server_more_test.go @@ -113,9 +113,10 @@ func TestManageServerHandleOAuthCallback_ReadCredsError(t *testing.T) { t.Cleanup(func() { _ = ln.Close() }) ms := &ManageServer{ - oauthState: "state1", - listener: ln, - store: &fakeStore{}, + oauthState: "state1", + oauthVerifier: testCodeVerifier, + listener: ln, + store: &fakeStore{}, } rr := httptest.NewRecorder() @@ -144,10 +145,11 @@ func TestManageServerHandleOAuthCallback_ScopesError(t *testing.T) { t.Cleanup(func() { _ = ln.Close() }) ms := &ManageServer{ - oauthState: "state1", - listener: ln, - store: &fakeStore{}, - opts: ManageServerOptions{Services: []Service{Service("nope")}}, + oauthState: "state1", + oauthVerifier: testCodeVerifier, + listener: ln, + store: &fakeStore{}, + opts: ManageServerOptions{Services: []Service{Service("nope")}}, } rr := httptest.NewRecorder() @@ -188,10 +190,11 @@ func TestManageServerHandleOAuthCallback_ExchangeError(t *testing.T) { t.Cleanup(func() { _ = ln.Close() }) ms := &ManageServer{ - oauthState: "state1", - listener: ln, - store: &fakeStore{}, - opts: ManageServerOptions{Services: []Service{ServiceGmail}}, + oauthState: "state1", + oauthVerifier: testCodeVerifier, + listener: ln, + store: &fakeStore{}, + opts: ManageServerOptions{Services: []Service{ServiceGmail}}, } rr := httptest.NewRecorder() @@ -237,10 +240,11 @@ func TestManageServerHandleOAuthCallback_MissingRefreshToken(t *testing.T) { t.Cleanup(func() { _ = ln.Close() }) ms := &ManageServer{ - oauthState: "state1", - listener: ln, - store: &fakeStore{}, - opts: ManageServerOptions{Services: []Service{ServiceGmail}}, + oauthState: "state1", + oauthVerifier: testCodeVerifier, + listener: ln, + store: &fakeStore{}, + opts: ManageServerOptions{Services: []Service{ServiceGmail}}, } rr := httptest.NewRecorder() @@ -287,9 +291,10 @@ func TestManageServerHandleOAuthCallback_FetchEmailError(t *testing.T) { t.Cleanup(func() { _ = ln.Close() }) ms := &ManageServer{ - oauthState: "state1", - listener: ln, - store: &fakeStore{}, + oauthState: "state1", + oauthVerifier: testCodeVerifier, + listener: ln, + store: &fakeStore{}, fetchIdentity: func(context.Context, *oauth2.Token) (Identity, error) { return Identity{}, errTestStoreBoom }, diff --git a/internal/googleauth/accounts_server_test.go b/internal/googleauth/accounts_server_test.go index 1d86f8ddf..f70a39c2a 100644 --- a/internal/googleauth/accounts_server_test.go +++ b/internal/googleauth/accounts_server_test.go @@ -246,26 +246,31 @@ func TestManageServer_HandleListAccounts_StaleDefaultFallsBackToFirst(t *testing func TestManageServer_OAuthStatesAreIndependent(t *testing.T) { ms := &ManageServer{} - ms.addOAuthState("state1") - ms.addOAuthState("state2") + ms.addOAuthState("state1", "verifier1") + ms.addOAuthState("state2", "verifier2") - if !ms.consumeOAuthState("state1") { + if verifier, ok := ms.consumeOAuthState("state1"); !ok || verifier != "verifier1" { t.Fatalf("expected first state accepted") } - if ms.consumeOAuthState("state1") { + if _, ok := ms.consumeOAuthState("state1"); ok { t.Fatalf("expected consumed state rejected") } - if !ms.consumeOAuthState("state2") { + if verifier, ok := ms.consumeOAuthState("state2"); !ok || verifier != "verifier2" { t.Fatalf("expected second state accepted") } + + if ms.oauthState != "" || ms.oauthVerifier != "" { + t.Fatalf("expected consumed current verifier to be cleared, got state=%q verifier=%q", ms.oauthState, ms.oauthVerifier) + } } func TestManageServer_HandleOAuthCallback_ErrorAndValidation(t *testing.T) { ms := &ManageServer{ - csrfToken: "csrf", - oauthState: "state1", + csrfToken: "csrf", + oauthState: "state1", + oauthVerifier: testCodeVerifier, } // Need a listener for redirectURI generation even though we don't reach exchange. ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") @@ -534,17 +539,20 @@ func TestManageServer_HandleAuthStart(t *testing.T) { origRead := readClientCredentials origState := randomStateFn origEndpoint := oauthEndpoint + origVerifier := generateVerifierFn t.Cleanup(func() { readClientCredentials = origRead randomStateFn = origState oauthEndpoint = origEndpoint + generateVerifierFn = origVerifier }) readClientCredentials = func(string) (config.ClientCredentials, error) { return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil } randomStateFn = func() (string, error) { return "state123", nil } + generateVerifierFn = func() string { return testCodeVerifier } oauthEndpoint = oauth2.Endpoint{AuthURL: "http://example.com/auth", TokenURL: "http://example.com/token"} ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") @@ -581,6 +589,22 @@ func TestManageServer_HandleAuthStart(t *testing.T) { t.Fatalf("expected oauthState set") } + if ms.oauthVerifier != testCodeVerifier { + t.Fatalf("expected oauthVerifier set") + } + + if got := parsed.Query().Get("code_challenge_method"); got != "S256" { + t.Fatalf("expected S256 challenge method, got %q", got) + } + + if got, want := parsed.Query().Get("code_challenge"), pkceChallengeForTest(); got != want { + t.Fatalf("unexpected code_challenge: got %q want %q", got, want) + } + + if got := parsed.Query().Get("code_verifier"); got != "" { + t.Fatalf("code_verifier must not be exposed in auth URL, got %q", got) + } + if redirectURI := parsed.Query().Get("redirect_uri"); !strings.Contains(redirectURI, "127.0.0.1:") { t.Fatalf("expected redirect uri, got %q", redirectURI) } @@ -743,6 +767,10 @@ func TestManageServer_HandleOAuthCallback_Success(t *testing.T) { t.Fatalf("expected code=abc, got %q", r.Form.Get("code")) } + if got := r.Form.Get("code_verifier"); got != testCodeVerifier { + t.Fatalf("expected code_verifier, got %q", got) + } + w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{ "access_token": "token", @@ -764,9 +792,10 @@ func TestManageServer_HandleOAuthCallback_Success(t *testing.T) { store := &fakeStore{} ms := &ManageServer{ - oauthState: "state1", - listener: ln, - store: store, + oauthState: "state1", + oauthVerifier: testCodeVerifier, + listener: ln, + store: store, fetchIdentity: func(ctx context.Context, tok *oauth2.Token) (Identity, error) { return fetchUserIdentityWithURL(ctx, tok.AccessToken, userinfoSrv.URL+"/oauth2/v2/userinfo") }, @@ -849,9 +878,10 @@ func TestManageServer_HandleOAuthCallback_MigratesAndDeletesAliasAfterSetToken(t }}, } ms := &ManageServer{ - oauthState: "state1", - listener: ln, - store: store, + oauthState: "state1", + oauthVerifier: testCodeVerifier, + listener: ln, + store: store, fetchIdentity: func(context.Context, *oauth2.Token) (Identity, error) { return Identity{Subject: "sub-123", Email: "new@example.com"}, nil }, @@ -933,9 +963,10 @@ func TestManageServer_HandleOAuthCallback_FileBackendSkipsKeychain(t *testing.T) store := &fakeStore{} ms := &ManageServer{ - oauthState: "state1", - listener: ln, - store: store, + oauthState: "state1", + oauthVerifier: testCodeVerifier, + listener: ln, + store: store, fetchIdentity: func(ctx context.Context, tok *oauth2.Token) (Identity, error) { return Identity{Email: "me@example.com"}, nil }, @@ -997,10 +1028,11 @@ func TestManageServer_HandleOAuthCallback_Success_IDTokenEmail(t *testing.T) { store := &fakeStore{} ms := &ManageServer{ - oauthState: "state1", - listener: ln, - store: store, - opts: ManageServerOptions{Services: []Service{ServiceGmail}}, + oauthState: "state1", + oauthVerifier: testCodeVerifier, + listener: ln, + store: store, + opts: ManageServerOptions{Services: []Service{ServiceGmail}}, } rr := httptest.NewRecorder() @@ -1171,17 +1203,20 @@ func TestManageServer_HandleAuthUpgrade(t *testing.T) { origRead := readClientCredentials origState := randomStateFn origEndpoint := oauthEndpoint + origVerifier := generateVerifierFn t.Cleanup(func() { readClientCredentials = origRead randomStateFn = origState oauthEndpoint = origEndpoint + generateVerifierFn = origVerifier }) readClientCredentials = func(string) (config.ClientCredentials, error) { return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil } randomStateFn = func() (string, error) { return "state456", nil } + generateVerifierFn = func() string { return testCodeVerifier } oauthEndpoint = oauth2.Endpoint{AuthURL: "http://example.com/auth", TokenURL: "http://example.com/token"} ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") @@ -1223,6 +1258,18 @@ func TestManageServer_HandleAuthUpgrade(t *testing.T) { t.Fatalf("expected oauthState set") } + if ms.oauthVerifier != testCodeVerifier { + t.Fatalf("expected oauthVerifier set") + } + + if got := parsed.Query().Get("code_challenge_method"); got != "S256" { + t.Fatalf("expected S256 challenge method, got %q", got) + } + + if got, want := parsed.Query().Get("code_challenge"), pkceChallengeForTest(); got != want { + t.Fatalf("unexpected code_challenge: got %q want %q", got, want) + } + scope := parsed.Query().Get("scope") expectedScopes, err := ScopesForManage([]Service{ServiceGmail}) diff --git a/internal/googleauth/manual_state.go b/internal/googleauth/manual_state.go index b5d47f77d..af20f76fb 100644 --- a/internal/googleauth/manual_state.go +++ b/internal/googleauth/manual_state.go @@ -30,6 +30,7 @@ type manualState struct { Scopes []string `json:"scopes"` ForceConsent bool `json:"force_consent,omitempty"` RedirectURI string `json:"redirect_uri,omitempty"` + CodeVerifier string `json:"code_verifier,omitempty"` CreatedAt time.Time `json:"created_at"` } @@ -121,6 +122,12 @@ func loadManualState(client string, scopes []string, forceConsent bool) (manualS continue } + // CodeVerifier is required for PKCE-bound step 2 exchanges. + // Older cache entries (pre-PKCE) should not be reused. + if strings.TrimSpace(st.CodeVerifier) == "" { + continue + } + if bestState.State == "" || st.CreatedAt.After(bestCreated) { bestState = st bestCreated = st.CreatedAt @@ -163,7 +170,7 @@ func loadManualStateByPath(path string) (manualState, bool, error) { return st, true, nil } -func saveManualState(client string, scopes []string, forceConsent bool, state string, redirectURI string) error { +func saveManualState(client string, scopes []string, forceConsent bool, state string, redirectURI string, codeVerifier string) error { path, err := manualStatePathFor(state) if err != nil { return err @@ -175,6 +182,7 @@ func saveManualState(client string, scopes []string, forceConsent bool, state st Scopes: normalizeScopes(scopes), ForceConsent: forceConsent, RedirectURI: strings.TrimSpace(redirectURI), + CodeVerifier: strings.TrimSpace(codeVerifier), CreatedAt: manualStateNowFn().UTC(), } diff --git a/internal/googleauth/manual_state_test.go b/internal/googleauth/manual_state_test.go index 9580e474f..2f5133101 100644 --- a/internal/googleauth/manual_state_test.go +++ b/internal/googleauth/manual_state_test.go @@ -67,6 +67,67 @@ func TestManualAuthURL_ReusesState(t *testing.T) { } } +func TestManualAuthURL_UsesPKCEAndPersistsVerifier(t *testing.T) { + origRead := readClientCredentials + origEndpoint := oauthEndpoint + origState := randomStateFn + origVerifier := generateVerifierFn + + t.Cleanup(func() { + readClientCredentials = origRead + oauthEndpoint = origEndpoint + randomStateFn = origState + generateVerifierFn = origVerifier + }) + + useTempManualStatePath(t) + + readClientCredentials = func(string) (config.ClientCredentials, error) { + return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil + } + oauthEndpoint = oauth2EndpointForTest("http://example.com") + randomStateFn = func() (string, error) { return "state1", nil } + generateVerifierFn = func() string { return testCodeVerifier } + + res, err := ManualAuthURL(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + }) + if err != nil { + t.Fatalf("ManualAuthURL: %v", err) + } + + parsed, err := url.Parse(res.URL) + if err != nil { + t.Fatalf("parse auth URL: %v", err) + } + + if got := parsed.Query().Get("code_challenge_method"); got != "S256" { + t.Fatalf("expected S256 challenge method, got %q", got) + } + + if got, want := parsed.Query().Get("code_challenge"), pkceChallengeForTest(); got != want { + t.Fatalf("unexpected code challenge: got %q want %q", got, want) + } + + if got := parsed.Query().Get("code_verifier"); got != "" { + t.Fatalf("code_verifier must not be exposed in auth URL, got %q", got) + } + + st, ok, err := loadManualState("", []string{"s1"}, false) + if err != nil { + t.Fatalf("load manual state: %v", err) + } + + if !ok { + t.Fatalf("expected manual state") + } + + if st.CodeVerifier != testCodeVerifier { + t.Fatalf("expected persisted verifier") + } +} + func TestManualAuthURL_UsesRedirectURIOverride(t *testing.T) { origRead := readClientCredentials origEndpoint := oauthEndpoint diff --git a/internal/googleauth/oauth_flow.go b/internal/googleauth/oauth_flow.go index d6aabac9d..ea7defe50 100644 --- a/internal/googleauth/oauth_flow.go +++ b/internal/googleauth/oauth_flow.go @@ -58,6 +58,7 @@ var ( openBrowserFn = openBrowser oauthEndpoint = google.Endpoint randomStateFn = randomState + generateVerifierFn = oauth2.GenerateVerifier manualRedirectURIFn = randomManualRedirectURI ) @@ -70,8 +71,8 @@ var ( errMissingScopes = errors.New("missing scopes") errNoCodeInURL = errors.New("no code found in URL") errNoRefreshToken = errors.New("no refresh token received; try again with --force-consent") - errManualStateMissing = errors.New("manual auth state missing; run remote step 1 again") - errManualStateMismatch = errors.New("manual auth state mismatch; run remote step 1 again") + errManualStateMissing = errors.New("manual auth state missing; start a new manual flow or run remote step 1 again") + errManualStateMismatch = errors.New("manual auth state mismatch; start a new manual flow or run remote step 1 again") errStateMismatch = errors.New("state mismatch") errInvalidAuthorizeOptionsAuthURLAndCode = errors.New("cannot combine auth-url with auth-code") @@ -222,7 +223,8 @@ func authorizeServer(ctx context.Context, opts AuthorizeOptions, creds config.Cl } }() - authURL := cfg.AuthCodeURL(state, authURLParams(opts.ForceConsent, !opts.DisableIncludeGrantedScopes)...) + codeVerifier := generateVerifierFn() + authURL := cfg.AuthCodeURL(state, pkceAuthURLParams(opts.ForceConsent, !opts.DisableIncludeGrantedScopes, codeVerifier)...) fmt.Fprintln(os.Stderr, "Opening browser for authorization…") fmt.Fprintln(os.Stderr, "If the browser doesn't open, visit this URL:") @@ -238,7 +240,7 @@ func authorizeServer(ctx context.Context, opts AuthorizeOptions, creds config.Cl fmt.Fprintln(os.Stderr, "Authorization received. Finishing…") var tok *oauth2.Token - if t, exchangeErr := cfg.Exchange(ctx, code); exchangeErr != nil { + if t, exchangeErr := cfg.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier)); exchangeErr != nil { _ = srv.Close() return "", fmt.Errorf("exchange code: %w", exchangeErr) @@ -280,6 +282,10 @@ func authURLParams(forceConsent bool, includeGrantedScopes bool) []oauth2.AuthCo return opts } +func pkceAuthURLParams(forceConsent bool, includeGrantedScopes bool, codeVerifier string) []oauth2.AuthCodeOption { + return append(authURLParams(forceConsent, includeGrantedScopes), oauth2.S256ChallengeOption(codeVerifier)) +} + func randomState() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { diff --git a/internal/googleauth/oauth_flow_authorize_test.go b/internal/googleauth/oauth_flow_authorize_test.go index 4bbec8f28..b5ee135b9 100644 --- a/internal/googleauth/oauth_flow_authorize_test.go +++ b/internal/googleauth/oauth_flow_authorize_test.go @@ -2,6 +2,8 @@ package googleauth import ( "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -18,9 +20,22 @@ import ( "github.com/steipete/gogcli/internal/config" ) -var errMissingRedirectState = errors.New("missing redirect/state") +var ( + errMissingRedirectState = errors.New("missing redirect/state") + errUnexpectedCodeChallengeMethod = errors.New("unexpected code_challenge_method") + errUnexpectedCodeChallenge = errors.New("unexpected code_challenge") + errExposedCodeVerifier = errors.New("auth URL exposed code_verifier") +) + +const ( + testRedirectURI = "http://127.0.0.1:55555/oauth2/callback" + testCodeVerifier = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~" +) -const testRedirectURI = "http://127.0.0.1:55555/oauth2/callback" +func pkceChallengeForTest() string { + sum := sha256.Sum256([]byte(testCodeVerifier)) + return base64.RawURLEncoding.EncodeToString(sum[:]) +} func useManualRedirectURI(t *testing.T) { t.Helper() @@ -280,7 +295,7 @@ func TestAuthorize_Manual_AuthCode(t *testing.T) { return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil } - if err := saveManualState("", []string{"s1"}, false, "state123", testRedirectURI); err != nil { + if err := saveManualState("", []string{"s1"}, false, "state123", testRedirectURI, testCodeVerifier); err != nil { t.Fatalf("save manual state: %v", err) } @@ -328,6 +343,9 @@ func TestAuthorize_Manual_AuthCode_WithRedirectURI(t *testing.T) { } wantRedirectURI := "https://host.example/oauth2/callback" + if err := saveManualState("", []string{"s1"}, false, "state123", wantRedirectURI, testCodeVerifier); err != nil { + t.Fatalf("save manual state: %v", err) + } tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/token" { @@ -387,6 +405,9 @@ func TestAuthorize_Manual_AuthURL_PrefersAuthURLRedirectOverOverride(t *testing. } redirectFromAuthURL := "https://from-auth-url.example/oauth2/callback" + if err := saveManualState("", []string{"s1"}, false, "state123", redirectFromAuthURL, testCodeVerifier); err != nil { + t.Fatalf("save manual state: %v", err) + } tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/token" { @@ -510,7 +531,7 @@ func TestAuthorize_Manual_AuthURL_RequireStateMissingForDifferentState(t *testin } oauthEndpoint = oauth2EndpointForTest("http://example.com") - if err := saveManualState("default", []string{"s1"}, false, "state123", "http://127.0.0.1:55555/oauth2/callback"); err != nil { + if err := saveManualState("default", []string{"s1"}, false, "state123", "http://127.0.0.1:55555/oauth2/callback", testCodeVerifier); err != nil { t.Fatalf("save manual state: %v", err) } @@ -535,18 +556,45 @@ func TestAuthorize_ServerFlow_Success(t *testing.T) { origRead := readClientCredentials origEndpoint := oauthEndpoint origOpen := openBrowserFn + origVerifier := generateVerifierFn t.Cleanup(func() { readClientCredentials = origRead oauthEndpoint = origEndpoint openBrowserFn = origOpen + generateVerifierFn = origVerifier }) readClientCredentials = func(string) (config.ClientCredentials, error) { return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil } - tokenSrv := newTokenServer(t) + generateVerifierFn = func() string { return testCodeVerifier } + + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/token" { + http.NotFound(w, r) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "bad form", http.StatusBadRequest) + return + } + + if got := r.Form.Get("code_verifier"); got != testCodeVerifier { + http.Error(w, "bad code_verifier", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "at", + "refresh_token": "rt", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) defer tokenSrv.Close() oauthEndpoint = oauth2EndpointForTest(tokenSrv.URL) @@ -560,6 +608,18 @@ func TestAuthorize_ServerFlow_Success(t *testing.T) { var state string + if got := q.Get("code_challenge_method"); got != "S256" { + return fmt.Errorf("%w: got %q", errUnexpectedCodeChallengeMethod, got) + } + + if got, want := q.Get("code_challenge"), pkceChallengeForTest(); got != want { + return fmt.Errorf("%w: got %q want %q", errUnexpectedCodeChallenge, got, want) + } + + if got := q.Get("code_verifier"); got != "" { + return fmt.Errorf("%w: got %q", errExposedCodeVerifier, got) + } + if s := q.Get("state"); redirect == "" || s == "" { return errMissingRedirectState } else { @@ -599,6 +659,94 @@ func TestAuthorize_ServerFlow_Success(t *testing.T) { } } +func TestAuthorize_Manual_AuthURL_UsesStoredPKCEVerifier(t *testing.T) { + origRead := readClientCredentials + origEndpoint := oauthEndpoint + origState := randomStateFn + origVerifier := generateVerifierFn + + t.Cleanup(func() { + readClientCredentials = origRead + oauthEndpoint = origEndpoint + randomStateFn = origState + generateVerifierFn = origVerifier + }) + + useTempManualStatePath(t) + + readClientCredentials = func(string) (config.ClientCredentials, error) { + return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil + } + randomStateFn = func() (string, error) { return "state123", nil } + generateVerifierFn = func() string { return testCodeVerifier } + + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/token" { + http.NotFound(w, r) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "bad form", http.StatusBadRequest) + return + } + + if got := r.Form.Get("code_verifier"); got != testCodeVerifier { + http.Error(w, "bad code_verifier", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "at", + "refresh_token": "rt", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer tokenSrv.Close() + oauthEndpoint = oauth2EndpointForTest(tokenSrv.URL) + + res, err := ManualAuthURL(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + Timeout: 2 * time.Second, + }) + if err != nil { + t.Fatalf("ManualAuthURL: %v", err) + } + + parsed, err := url.Parse(res.URL) + if err != nil { + t.Fatalf("parse auth URL: %v", err) + } + + if got := parsed.Query().Get("code_challenge_method"); got != "S256" { + t.Fatalf("expected S256 challenge method, got %q", got) + } + + if got, want := parsed.Query().Get("code_challenge"), pkceChallengeForTest(); got != want { + t.Fatalf("unexpected code challenge: got %q want %q", got, want) + } + + redirectURI := parsed.Query().Get("redirect_uri") + state := parsed.Query().Get("state") + + rt, err := Authorize(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + AuthURL: redirectURI + "?code=abc&state=" + url.QueryEscape(state), + Timeout: 2 * time.Second, + }) + if err != nil { + t.Fatalf("Authorize: %v", err) + } + + if rt != "rt" { + t.Fatalf("unexpected refresh token: %q", rt) + } +} + func TestAuthorize_ServerFlow_CallbackErrors(t *testing.T) { tests := []struct { name string diff --git a/internal/googleauth/oauth_flow_manual.go b/internal/googleauth/oauth_flow_manual.go index 60d6c685c..de656c024 100644 --- a/internal/googleauth/oauth_flow_manual.go +++ b/internal/googleauth/oauth_flow_manual.go @@ -98,7 +98,19 @@ func authorizeManualWithCode( return "", errMissingRedirectURI } - tok, exchangeErr := cfg.Exchange(ctx, code) + if st.CodeVerifier == "" && gotState == "" { + if cached, ok, err := loadManualState(opts.Client, opts.Scopes, opts.ForceConsent); err != nil { + return "", err + } else if ok && cfg.RedirectURL == cached.RedirectURI { + st = cached + } + } + + if strings.TrimSpace(st.CodeVerifier) == "" { + return "", errManualStateMissing + } + + tok, exchangeErr := cfg.Exchange(ctx, code, oauth2.VerifierOption(st.CodeVerifier)) if exchangeErr != nil { return "", fmt.Errorf("exchange code: %w", exchangeErr) } @@ -107,8 +119,8 @@ func authorizeManualWithCode( return "", errNoRefreshToken } - if gotState != "" { - _ = clearManualState(gotState) + if st.State != "" { + _ = clearManualState(st.State) } return tok.RefreshToken, nil @@ -121,7 +133,7 @@ func authorizeManualInteractive(ctx context.Context, opts AuthorizeOptions, cfg } cfg.RedirectURL = setup.redirectURI - authURL := cfg.AuthCodeURL(setup.state, authURLParams(opts.ForceConsent, !opts.DisableIncludeGrantedScopes)...) + authURL := cfg.AuthCodeURL(setup.state, pkceAuthURLParams(opts.ForceConsent, !opts.DisableIncludeGrantedScopes, setup.codeVerifier)...) fmt.Fprintln(os.Stderr, "Visit this URL to authorize:") fmt.Fprintln(os.Stderr, authURL) @@ -161,7 +173,7 @@ func authorizeManualInteractive(ctx context.Context, opts AuthorizeOptions, cfg } } - tok, exchangeErr := cfg.Exchange(ctx, code) + tok, exchangeErr := cfg.Exchange(ctx, code, oauth2.VerifierOption(setup.codeVerifier)) if exchangeErr != nil { return "", fmt.Errorf("exchange code: %w", exchangeErr) } @@ -251,15 +263,16 @@ func ManualAuthURL(ctx context.Context, opts AuthorizeOptions) (ManualAuthURLRes } return ManualAuthURLResult{ - URL: cfg.AuthCodeURL(setup.state, authURLParams(opts.ForceConsent, !opts.DisableIncludeGrantedScopes)...), + URL: cfg.AuthCodeURL(setup.state, pkceAuthURLParams(opts.ForceConsent, !opts.DisableIncludeGrantedScopes, setup.codeVerifier)...), StateReused: setup.reused, }, nil } type manualAuthSetupResult struct { - state string - redirectURI string - reused bool + state string + redirectURI string + codeVerifier string + reused bool } func manualAuthSetup(ctx context.Context, opts AuthorizeOptions) (manualAuthSetupResult, error) { @@ -280,11 +293,13 @@ func manualAuthSetup(ctx context.Context, opts AuthorizeOptions) (manualAuthSetu state := st.State redirectURI := st.RedirectURI + codeVerifier := st.CodeVerifier if redirectURIOverride != "" { if !reused || st.RedirectURI != redirectURIOverride { reused = false redirectURI = redirectURIOverride + codeVerifier = "" } } @@ -301,14 +316,17 @@ func manualAuthSetup(ctx context.Context, opts AuthorizeOptions) (manualAuthSetu return manualAuthSetupResult{}, err } - if err := saveManualState(opts.Client, opts.Scopes, opts.ForceConsent, state, redirectURI); err != nil { + codeVerifier = generateVerifierFn() + + if err := saveManualState(opts.Client, opts.Scopes, opts.ForceConsent, state, redirectURI, codeVerifier); err != nil { return manualAuthSetupResult{}, err } } return manualAuthSetupResult{ - state: state, - redirectURI: redirectURI, - reused: reused, + state: state, + redirectURI: redirectURI, + codeVerifier: codeVerifier, + reused: reused, }, nil }