Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 150 additions & 13 deletions pkg/retrieve/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ import (
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

log "github.com/sirupsen/logrus"
Expand All @@ -25,6 +29,10 @@ import (
"github.com/redhat-openshift-ecosystem/opct/pkg/status"
)

// retrieveFunc is the function used by retrieveResultsRetry for each attempt.
// It is a package-level variable to allow injection in tests.
var retrieveFunc = retrieveResults

func NewCmdRetrieve() *cobra.Command {
var skipRedact bool

Expand All @@ -34,6 +42,9 @@ func NewCmdRetrieve() *cobra.Command {
Short: "Collect results from validation environment",
Long: `Downloads the results archive from the validation environment`,
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer cancel()
Comment on lines 44 to +46

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Locate the file
ls -la pkg/retrieve/retrieve.go

# Show the RunE block around the claimed lines
nl -ba pkg/retrieve/retrieve.go | sed -n '1,120p'

# Find and print retrieveResultsRetry / retrieveResults / downloadFromPod sections
rg -n "func (retrieveResultsRetry|retrieveResults|downloadFromPod)\\b" pkg/retrieve/retrieve.go
nl -ba pkg/retrieve/retrieve.go | sed -n '120,260p'

# Check for any other context.Background() usage in this file
rg -n "context\\.Background\\(" pkg/retrieve/retrieve.go

# Check where ctx is used (call sites) in retrieveResults and surrounding helpers
rg -n "\\bctx\\b" pkg/retrieve/retrieve.go

Repository: redhat-openshift-ecosystem/opct

Length of output: 211


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="pkg/retrieve/retrieve.go"
echo "== File: $FILE =="
wc -l "$FILE"

echo "== RunE context block (approx lines 35-70) =="
sed -n '35,80p' "$FILE"

echo "== Function headers (retrieveResultsRetry / retrieveResults / downloadFromPod) =="
rg -n "func (retrieveResultsRetry|retrieveResults|downloadFromPod)\\b" "$FILE"

echo "== retrieveResultsRetry body (approx lines 80-190) =="
sed -n '80,200p' "$FILE"

echo "== retrieveResults body (if different) (approx lines 200-320) =="
sed -n '200,340p' "$FILE"

echo "== downloadFromPod body (approx lines 340-520) =="
sed -n '340,520p' "$FILE"

echo "== Any context.Background usage in this file =="
rg -n "context\\.Background\\(" "$FILE" || true

echo "== Any ctx usage in this file (to spot gaps) =="
rg -n "\\bctx\\b" "$FILE" | head -n 120

Repository: redhat-openshift-ecosystem/opct

Length of output: 11877


Derive signal context from cmd.Context() and propagate ctx through local scan/redact + extract phases.

  • NewCmdRetrieve’s RunE roots the signal-aware context from context.Background(), discarding any cancellation/deadline supplied via Cobra/ExecuteContext; base it on cmd.Context() instead.
  • retrieveResults uses ctx only for downloadFromPod; the scan/redact (cleaner.ScanPatchTarGzipReaderFor), file writing (io.Copy), and extraction (sonobuoyclient.UntarAll + rename) run without any ctx checks, so termination signals won’t reliably interrupt long local phases.
Suggested fix
-			ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
+			baseCtx := cmd.Context()
+			if baseCtx == nil {
+				baseCtx = context.Background()
+			}
+			ctx, cancel := signal.NotifyContext(baseCtx, syscall.SIGTERM, syscall.SIGINT)
 			defer cancel()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer cancel()
RunE: func(cmd *cobra.Command, args []string) error {
baseCtx := cmd.Context()
if baseCtx == nil {
baseCtx = context.Background()
}
ctx, cancel := signal.NotifyContext(baseCtx, syscall.SIGTERM, syscall.SIGINT)
defer cancel()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@pkg/retrieve/retrieve.go` around lines 44 - 46, Change RunE in NewCmdRetrieve
to derive the signal-aware context from cmd.Context() (use
signal.NotifyContext(cmd.Context(), ...)) instead of context.Background(), and
pass that ctx into retrieveResults; inside retrieveResults ensure the same ctx
is used not only for downloadFromPod but also for the local phases by (a)
calling cleaner.ScanPatchTarGzipReaderFor with a context-aware variant or
wrapping checks for ctx.Done while reading/writing, (b) making the file
write/io.Copy abort when ctx is cancelled (check ctx.Done and return), and (c)
invoking sonobuoyclient.UntarAll and the subsequent rename with cancellation
support (either by using context-aware versions or checking ctx between steps
and returning early). Update function signatures (retrieveResults and any helper
calls) to accept the ctx where needed so cancellation/deadline from
cmd.Context() propagates through downloadFromPod,
cleaner.ScanPatchTarGzipReaderFor, io.Copy, sonobuoyclient.UntarAll and rename
operations.

Source: Coding guidelines


if skipRedact {
log.Warn("═════════════════════════════════════════════════════════════")
log.Warn("WARNING: --debug-only-skip-redact is enabled")
Expand Down Expand Up @@ -65,7 +76,7 @@ func NewCmdRetrieve() *cobra.Command {
}

log.Info("Collecting results...")
if err := retrieveResultsRetry(destinationDirectory); err != nil {
if err := retrieveResultsRetry(ctx, destinationDirectory); err != nil {
return fmt.Errorf("retrieve finished with errors: %v", err)
}

Expand All @@ -81,19 +92,29 @@ func NewCmdRetrieve() *cobra.Command {
return cmd
}

func retrieveResultsRetry(destinationDirectory string) error {
func retrieveResultsRetry(ctx context.Context, destinationDirectory string) error {
var err error
limit := 10
pause := time.Second * 2
retries := 1
for retries <= limit {
err = retrieveResults(destinationDirectory)
select {
case <-ctx.Done():
return fmt.Errorf("retrieval cancelled: %w", ctx.Err())
default:
}

err = retrieveFunc(ctx, destinationDirectory)
if err != nil {
log.Error(err)
if retries+1 < limit {
log.Warnf("Retrying retrieval %d more times after %d sec", limit-retries, pause/time.Second)
}
time.Sleep(pause)
select {
case <-ctx.Done():
return fmt.Errorf("retrieval cancelled during retry wait: %w", ctx.Err())
case <-time.After(pause):
}
retries++
continue
}
Expand All @@ -103,9 +124,9 @@ func retrieveResultsRetry(destinationDirectory string) error {
return fmt.Errorf("retrieval retry limit reached")
}

func retrieveResults(destinationDirectory string) error {
func retrieveResults(ctx context.Context, destinationDirectory string) error {
// Phase 1: Download archive to temp file
tmpFile, err := downloadFromPod()
tmpFile, err := downloadFromPod(ctx)
if err != nil {
return fmt.Errorf("error retrieving results from aggregator server: %w", err)
}
Expand Down Expand Up @@ -166,7 +187,7 @@ func retrieveResults(destinationDirectory string) error {

// downloadFromPod downloads the results archive from the sonobuoy aggregator pod
// to a temp file using WebSocket (with SPDY fallback).
func downloadFromPod() (string, error) {
func downloadFromPod(ctx context.Context) (string, error) {
cli, err := opclient.NewClient()
if err != nil {
return "", fmt.Errorf("error creating kubernetes client: %w", err)
Expand Down Expand Up @@ -215,10 +236,15 @@ func downloadFromPod() (string, error) {
log.Debugf("Discovered aggregator server running on pod %s/%s...", pkg.CertificationNamespace, podName)
startTime := time.Now()

err = exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{
Stdout: tmpFile,
// Wrap temp file with progress tracking to log download progress every 30s
pw := newProgressWriter(tmpFile, 30*time.Second)

err = exec.StreamWithContext(ctx, remotecommand.StreamOptions{
Stdout: pw,
Tty: false,
})
pw.Close()

if err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpFile.Name())
Expand All @@ -228,11 +254,122 @@ func downloadFromPod() (string, error) {
return "", fmt.Errorf("error closing temp file: %w", err)
}

fi, err := os.Stat(tmpFile.Name())
log.Infof("Downloaded %s in %s", formatBytes(pw.BytesWritten()), time.Since(startTime).Round(time.Second))

return tmpFile.Name(), nil
}

// progressReader wraps an io.Reader and periodically logs the number of bytes read.
type progressReader struct {
r io.Reader
bytes atomic.Int64
interval time.Duration
done chan struct{}
closeOnce sync.Once
}

// newProgressReader creates a progressReader that logs bytes read every interval.
func newProgressReader(r io.Reader, interval time.Duration) *progressReader {
pr := &progressReader{
r: r,
interval: interval,
done: make(chan struct{}),
}
go pr.logProgress()
return pr
}

func (pr *progressReader) Read(p []byte) (int, error) {
n, err := pr.r.Read(p)
if n > 0 {
pr.bytes.Add(int64(n))
}
if err != nil {
return "", fmt.Errorf("error stat temp file: %w", err)
pr.closeOnce.Do(func() { close(pr.done) })
}
log.Infof("Downloaded %.1f MB in %s", float64(fi.Size())/(1024*1024), time.Since(startTime).Round(time.Second))
return n, err
}

return tmpFile.Name(), nil
// BytesRead returns the total number of bytes read so far.
func (pr *progressReader) BytesRead() int64 {
return pr.bytes.Load()
}

func (pr *progressReader) logProgress() {
ticker := time.NewTicker(pr.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
log.Infof("Retrieve in progress: %s received so far...", formatBytes(pr.bytes.Load()))
case <-pr.done:
return
}
}
}

// progressWriter wraps an io.Writer and periodically logs the number of bytes written.
// It mirrors progressReader but for write operations (e.g., streaming download to disk).
type progressWriter struct {
w io.Writer
bytes atomic.Int64
interval time.Duration
done chan struct{}
closeOnce sync.Once
}

// newProgressWriter creates a progressWriter that logs bytes written every interval.
func newProgressWriter(w io.Writer, interval time.Duration) *progressWriter {
pw := &progressWriter{
w: w,
interval: interval,
done: make(chan struct{}),
}
go pw.logProgress()
return pw
}

func (pw *progressWriter) Write(p []byte) (int, error) {
n, err := pw.w.Write(p)
if n > 0 {
pw.bytes.Add(int64(n))
}
return n, err
}

// BytesWritten returns the total number of bytes written so far.
func (pw *progressWriter) BytesWritten() int64 {
return pw.bytes.Load()
}

// Close stops the progress logging goroutine.
func (pw *progressWriter) Close() {
pw.closeOnce.Do(func() { close(pw.done) })
}

func (pw *progressWriter) logProgress() {
ticker := time.NewTicker(pw.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
log.Infof("Retrieve in progress: %s received so far...", formatBytes(pw.bytes.Load()))
case <-pw.done:
return
}
}
}

// formatBytes converts bytes to human-readable format (KiB, MiB, GiB).
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
Loading
Loading