Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions internal/cmd/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,24 @@ import (
)

var (
openSecretsStore = secrets.OpenDefault
authorizeGoogle = googleauth.Authorize
startManageServer = googleauth.StartManageServer
checkRefreshToken = googleauth.CheckRefreshToken
openSecretsStore = secrets.OpenDefault
authorizeGoogle = googleauth.Authorize
startManageServer = googleauth.StartManageServer
checkRefreshToken = googleauth.CheckRefreshToken
ensureKeychainAccess = defaultEnsureKeychainAccess
)

// defaultEnsureKeychainAccess verifies keychain is accessible before starting OAuth flow.
func defaultEnsureKeychainAccess() error {
store, err := secrets.OpenDefault()
if err != nil {
return fmt.Errorf("keychain access: %w", err)
}
// Trigger a read to verify keychain access
_, _ = store.Keys()
return nil
}

type AuthCmd struct {
Credentials AuthCredentialsCmd `cmd:"" name:"credentials" help:"Store OAuth client credentials"`
Add AuthAddCmd `cmd:"" name:"add" help:"Authorize and store a refresh token"`
Expand Down Expand Up @@ -300,6 +312,11 @@ type AuthAddCmd struct {
func (c *AuthAddCmd) Run(ctx context.Context) error {
u := ui.FromContext(ctx)

// Verify keychain access before starting the OAuth flow
if err := ensureKeychainAccess(); err != nil {
return err
}

var services []googleauth.Service
if strings.EqualFold(strings.TrimSpace(c.ServicesCSV), "") || strings.EqualFold(strings.TrimSpace(c.ServicesCSV), "all") {
services = googleauth.AllServices()
Expand Down
43 changes: 43 additions & 0 deletions internal/cmd/auth_add_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package cmd
import (
"context"
"encoding/json"
"errors"
"io"
"strings"
"testing"

"github.com/steipete/gogcli/internal/googleauth"
"github.com/steipete/gogcli/internal/secrets"
"github.com/steipete/gogcli/internal/ui"
)

func TestAuthAddCmd_JSON(t *testing.T) {
Expand Down Expand Up @@ -70,3 +73,43 @@ func TestAuthAddCmd_JSON(t *testing.T) {
t.Fatalf("unexpected token: %#v", tok)
}
}

func TestAuthAddCmd_KeychainError(t *testing.T) {
origAuth := authorizeGoogle
origOpen := openSecretsStore
origKeychain := ensureKeychainAccess
t.Cleanup(func() {
authorizeGoogle = origAuth
openSecretsStore = origOpen
ensureKeychainAccess = origKeychain
})

// Simulate keychain locked error
ensureKeychainAccess = func() error {
return errors.New("keychain is locked")
}

authCalled := false
authorizeGoogle = func(ctx context.Context, opts googleauth.AuthorizeOptions) (string, error) {
authCalled = true
return "rt", nil
}

cmd := &AuthAddCmd{Email: "test@example.com", ServicesCSV: "gmail"}
u, uiErr := ui.New(ui.Options{Stdout: io.Discard, Stderr: io.Discard, Color: "never"})
if uiErr != nil {
t.Fatalf("ui.New: %v", uiErr)
}
ctx := ui.WithUI(context.Background(), u)
err := cmd.Run(ctx)

if err == nil {
t.Fatal("expected error when keychain is locked")
}
if !strings.Contains(err.Error(), "keychain") {
t.Errorf("expected error to mention keychain, got: %v", err)
}
if authCalled {
t.Error("authorizeGoogle should not be called when keychain check fails")
}
}
15 changes: 9 additions & 6 deletions internal/secrets/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ type Token struct {
RefreshToken string `json:"-"`
}

const keyringPasswordEnv = "GOG_KEYRING_PASSWORD" //nolint:gosec // env var name, not a credential
const keyringBackendEnv = "GOG_KEYRING_BACKEND" //nolint:gosec // env var name, not a credential
const (
keyringPasswordEnv = "GOG_KEYRING_PASSWORD" //nolint:gosec // env var name, not a credential
keyringBackendEnv = "GOG_KEYRING_BACKEND" //nolint:gosec // env var name, not a credential
)

var (
errMissingEmail = errors.New("missing email")
errMissingRefreshToken = errors.New("missing refresh token")
errNoTTY = errors.New("no TTY available for keyring file backend password prompt")
errMissingEmail = errors.New("missing email")
errMissingRefreshToken = errors.New("missing refresh token")
errNoTTY = errors.New("no TTY available for keyring file backend password prompt")
errInvalidKeyringBackend = errors.New("invalid keyring backend")
)

func allowedBackendsFromEnv() ([]keyring.BackendType, error) {
Expand All @@ -55,7 +58,7 @@ func allowedBackendsFromEnv() ([]keyring.BackendType, error) {
case "file":
return []keyring.BackendType{keyring.FileBackend}, nil
default:
return nil, fmt.Errorf("invalid %s (expected auto, keychain, or file)", keyringBackendEnv)
return nil, fmt.Errorf("%w: %s (expected auto, keychain, or file)", errInvalidKeyringBackend, keyringBackendEnv)
}
}

Expand Down
45 changes: 45 additions & 0 deletions internal/secrets/store_more_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package secrets

import (
"errors"
"testing"

"github.com/99designs/keyring"
Expand Down Expand Up @@ -86,3 +87,47 @@ func TestFileKeyringPasswordFunc(t *testing.T) {
t.Fatalf("expected secret, got %q err=%v", res.got, res.err)
}
}

func TestAllowedBackendsFromEnv(t *testing.T) {
tests := []struct {
name string
envVal string
wantLen int
wantErr bool
}{
{"empty defaults to nil", "", 0, false},
{"auto defaults to nil", "auto", 0, false},
{"keychain returns one backend", "keychain", 1, false},
{"file returns one backend", "file", 1, false},
{"invalid returns error", "invalid", 0, true},
{"whitespace trimmed", " keychain ", 1, false},
{"case insensitive", "KEYCHAIN", 1, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv(keyringBackendEnv, tt.envVal)
backends, err := allowedBackendsFromEnv()

if tt.wantErr {
if err == nil {
t.Fatal("expected error")
}

if !errors.Is(err, errInvalidKeyringBackend) {
t.Errorf("expected errInvalidKeyringBackend, got %v", err)
}

return
}

if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(backends) != tt.wantLen {
t.Errorf("expected %d backends, got %d", tt.wantLen, len(backends))
}
})
}
}