diff --git a/internal/scheduler/jobdb/jobdb.go b/internal/scheduler/jobdb/jobdb.go index 220a0c56bd3..c95c1425eb8 100644 --- a/internal/scheduler/jobdb/jobdb.go +++ b/internal/scheduler/jobdb/jobdb.go @@ -72,7 +72,6 @@ type JobDb struct { jobsByQueue map[string]immutable.SortedSet[*Job] jobsByPoolAndQueue map[string]map[string]immutable.SortedSet[*Job] leasedJobs *immutable.Set[*Job] - terminalJobs *immutable.Set[*Job] unvalidatedJobs *immutable.Set[*Job] // Configured priority classes. priorityClasses map[string]types.PriorityClass @@ -137,7 +136,6 @@ func NewJobDbWithSchedulingKeyGenerator( } unvalidatedJobs := immutable.NewSet[*Job](JobHasher{}) leasedJobs := immutable.NewSet[*Job](JobHasher{}) - terminalJobs := immutable.NewSet[*Job](JobHasher{}) return &JobDb{ jobsById: immutable.NewMap[string, *Job](nil), jobsByRunId: immutable.NewMap[string, string](nil), @@ -145,7 +143,6 @@ func NewJobDbWithSchedulingKeyGenerator( jobsByQueue: map[string]immutable.SortedSet[*Job]{}, jobsByPoolAndQueue: map[string]map[string]immutable.SortedSet[*Job]{}, leasedJobs: &leasedJobs, - terminalJobs: &terminalJobs, unvalidatedJobs: &unvalidatedJobs, priorityClasses: priorityClasses, defaultPriorityClass: defaultPriorityClass, @@ -178,7 +175,6 @@ func (jobDb *JobDb) Clone() *JobDb { jobsByQueue: maps.Clone(jobDb.jobsByQueue), jobsByPoolAndQueue: deepClone(jobDb.jobsByPoolAndQueue), leasedJobs: jobDb.leasedJobs, - terminalJobs: jobDb.terminalJobs, unvalidatedJobs: jobDb.unvalidatedJobs, priorityClasses: jobDb.priorityClasses, defaultPriorityClass: jobDb.defaultPriorityClass, @@ -353,7 +349,6 @@ func (jobDb *JobDb) ReadTxn() *Txn { jobsByQueue: jobDb.jobsByQueue, jobsByPoolAndQueue: jobDb.jobsByPoolAndQueue, leasedJobs: jobDb.leasedJobs, - terminalJobs: jobDb.terminalJobs, unvalidatedJobs: jobDb.unvalidatedJobs, bidPriceSnapshot: jobDb.bidPriceSnapshot, active: true, @@ -376,7 +371,6 @@ func (jobDb *JobDb) WriteTxn() *Txn { jobsByQueue: maps.Clone(jobDb.jobsByQueue), jobsByPoolAndQueue: deepClone(jobDb.jobsByPoolAndQueue), leasedJobs: jobDb.leasedJobs, - terminalJobs: jobDb.terminalJobs, unvalidatedJobs: jobDb.unvalidatedJobs, bidPriceSnapshot: jobDb.bidPriceSnapshot, active: true, @@ -399,7 +393,6 @@ func (jobDb *JobDb) DryRunTxn() *Txn { jobsByQueue: maps.Clone(jobDb.jobsByQueue), jobsByPoolAndQueue: deepClone(jobDb.jobsByPoolAndQueue), leasedJobs: jobDb.leasedJobs, - terminalJobs: jobDb.terminalJobs, unvalidatedJobs: jobDb.unvalidatedJobs, bidPriceSnapshot: jobDb.bidPriceSnapshot, active: true, @@ -444,8 +437,6 @@ type Txn struct { jobsByPoolAndQueue map[string]map[string]immutable.SortedSet[*Job] // Jobs that are currently leased leasedJobs *immutable.Set[*Job] - // Jobs that are currently in a terminal state - terminalJobs *immutable.Set[*Job] // Jobs that require submit checking unvalidatedJobs *immutable.Set[*Job] // The current snapshot of bid prices - allowing look up of bidding prices on job creation @@ -473,7 +464,6 @@ func (txn *Txn) Commit() { txn.jobDb.jobsByQueue = txn.jobsByQueue txn.jobDb.jobsByPoolAndQueue = txn.jobsByPoolAndQueue txn.jobDb.leasedJobs = txn.leasedJobs - txn.jobDb.terminalJobs = txn.terminalJobs txn.jobDb.unvalidatedJobs = txn.unvalidatedJobs txn.jobDb.bidPriceSnapshot = txn.bidPriceSnapshot @@ -614,11 +604,6 @@ func (txn *Txn) Upsert(jobs []*Job) error { txn.leasedJobs = &newLeasedJobs } - if existingJob.InTerminalState() { - newTerminalJobs := txn.terminalJobs.Delete(existingJob) - txn.terminalJobs = &newTerminalJobs - } - if !existingJob.Validated() { newUnvalidatedJobs := txn.unvalidatedJobs.Delete(existingJob) txn.unvalidatedJobs = &newUnvalidatedJobs @@ -629,7 +614,7 @@ func (txn *Txn) Upsert(jobs []*Job) error { // Now need to insert jobs, runs and queuedJobs. This can be done in parallel. wg := sync.WaitGroup{} - wg.Add(7) + wg.Add(6) // jobs go func() { @@ -792,30 +777,6 @@ func (txn *Txn) Upsert(jobs []*Job) error { } }() - // Terminal jobs - go func() { - defer wg.Done() - if hasJobs { - for _, job := range jobs { - if job.InTerminalState() { - terminalJobs := txn.terminalJobs.Add(job) - txn.terminalJobs = &terminalJobs - } - } - } else { - terminalJobs := map[*Job]bool{} - - for _, job := range jobs { - if job.InTerminalState() { - terminalJobs[job] = true - } - } - - terminalJobsImmutable := immutable.NewSet[*Job](JobHasher{}, maps.Keys(terminalJobs)...) - txn.terminalJobs = &terminalJobsImmutable - } - }() - // Unvalidated jobs go func() { defer wg.Done() @@ -957,11 +918,6 @@ func (txn *Txn) GetAllLeasedJobs() []*Job { return txn.leasedJobs.Items() } -// GetAllTerminalJobs returns all terminal jobs in the database -func (txn *Txn) GetAllTerminalJobs() []*Job { - return txn.terminalJobs.Items() -} - // GetAll returns all jobs in the database. func (txn *Txn) GetAll() []*Job { allJobs := make([]*Job, 0, txn.jobsById.Len()) @@ -1034,9 +990,6 @@ func (txn *Txn) delete(jobId string) { newLeasedJobs := txn.leasedJobs.Delete(job) txn.leasedJobs = &newLeasedJobs - newTerminalJobs := txn.terminalJobs.Delete(job) - txn.terminalJobs = &newTerminalJobs - newUnvalidatedJobs := txn.unvalidatedJobs.Delete(job) txn.unvalidatedJobs = &newUnvalidatedJobs } diff --git a/internal/scheduler/jobdb/jobdb_test.go b/internal/scheduler/jobdb/jobdb_test.go index d763e559ce5..defd59081f9 100644 --- a/internal/scheduler/jobdb/jobdb_test.go +++ b/internal/scheduler/jobdb/jobdb_test.go @@ -147,68 +147,6 @@ func TestJobDb_LeasedJobs_Deleted(t *testing.T) { assert.Empty(t, txn.GetAllLeasedJobs()) } -func TestJobDb_TestGetTerminalJobs(t *testing.T) { - jobDb := NewTestJobDb() - job1 := newJob().WithQueued(false).WithNewRun("executor", "nodeId", "nodeName", "pool", 5) - job2 := newJob().WithQueued(true) - job3 := newJob().WithQueued(false).WithSucceeded(true) - job4 := newJob().WithQueued(false).WithCancelled(true) - job5 := newJob().WithQueued(false).WithFailed(true) - job6 := newJob().WithQueued(true).WithFailed(true) - txn := jobDb.WriteTxn() - - err := txn.Upsert([]*Job{job1, job2, job3, job4, job5, job6}) - require.NoError(t, err) - - expected := []*Job{job3, job4, job5, job6} - actual := txn.GetAllTerminalJobs() - sort.SliceStable(actual, func(i, j int) bool { return actual[i].id < actual[j].id }) - sort.SliceStable(expected, func(i, j int) bool { return expected[i].id < expected[j].id }) - assert.Equal(t, expected, actual) -} - -func TestJobDb_TerminalJobs_Lifecycle(t *testing.T) { - jobDb := NewTestJobDb() - - upsert := func(jobDb *JobDb, job *Job) { - txn := jobDb.WriteTxn() - err := txn.Upsert([]*Job{job}) - require.NoError(t, err) - txn.Commit() - } - - job1 := newJob().WithQueued(true) - upsert(jobDb, job1) - assert.Empty(t, jobDb.ReadTxn().GetAllTerminalJobs()) - - // leased - job1 = job1.WithQueued(false).WithNewRun("executor", "nodeId", "nodeName", "pool", 5) - upsert(jobDb, job1) - assert.Empty(t, jobDb.ReadTxn().GetAllTerminalJobs()) - - // finished - job1 = job1.WithSucceeded(true) - upsert(jobDb, job1) - assert.NotEmpty(t, jobDb.ReadTxn().GetAllTerminalJobs()) -} - -func TestJobDb_TerminalJobs_Deleted(t *testing.T) { - jobDb := NewTestJobDb() - job1 := newJob().WithFailed(true) - txn := jobDb.WriteTxn() - - err := txn.Upsert([]*Job{job1}) - require.NoError(t, err) - - expected := []*Job{job1} - actual := txn.GetAllTerminalJobs() - assert.Equal(t, expected, actual) - - err = txn.BatchDelete([]string{job1.Id()}) - require.NoError(t, err) - assert.Empty(t, txn.GetAllTerminalJobs()) -} - func TestJobDb_TestGetUnvalidated(t *testing.T) { jobDb := NewTestJobDb() job1 := newJob().WithValidated(false) diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 70dd98c804c..1d00dcf6e19 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -285,7 +285,7 @@ func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToke }(ctx) // Update job state. ctx.Info("Syncing internal state with database") - updatedJobs, jsts, newJobsSerial, newRunsSerial, err := s.syncState(ctx, false, cycleNumber%10 == 0) + updatedJobs, jsts, newJobsSerial, newRunsSerial, err := s.syncState(ctx, false) if err != nil { return false, err } @@ -438,7 +438,7 @@ func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToke // syncState updates jobs in jobDb to match state in postgres and returns all updated jobs along with // the new jobsSerial and runsSerial cursor values that should be applied once the resulting events // have been published successfully. -func (s *Scheduler) syncState(ctx *armadacontext.Context, initial, fullJobGc bool) ([]*jobdb.Job, []jobdb.JobStateTransitions, int64, int64, error) { +func (s *Scheduler) syncState(ctx *armadacontext.Context, initial bool) ([]*jobdb.Job, []jobdb.JobStateTransitions, int64, int64, error) { txn := s.jobDb.WriteTxn() defer txn.Abort() @@ -501,25 +501,12 @@ func (s *Scheduler) syncState(ctx *armadacontext.Context, initial, fullJobGc boo // Delete jobs in a terminal state. idsOfJobsToDelete := make([]string, 0) - deletionCandidates := jobDbJobs - if fullJobGc { - // Occasional full gc so jobs that were not deleted - // earlier as ShortJobPenalty was being applied - // eventually get deleted. - deletionCandidates = txn.GetAll() - } - shortJobCount := 0 - for _, j := range deletionCandidates { - if !j.InTerminalState() { - continue - } - if s.shortJobPenalty.ShouldApplyPenalty(j) { - shortJobCount++ - continue + for _, j := range jobDbJobs { + if j.InTerminalState() { + idsOfJobsToDelete = append(idsOfJobsToDelete, j.Id()) } - idsOfJobsToDelete = append(idsOfJobsToDelete, j.Id()) } - ctx.Logger().Infof("Deleting %d jobs out of %d considered for deletion (%d short jobs, full job gc=%t)", len(idsOfJobsToDelete), len(deletionCandidates), shortJobCount, fullJobGc) + ctx.Logger().Infof("Deleting %d terminal jobs out of %d updated jobs", len(idsOfJobsToDelete), len(jobDbJobs)) if err := txn.BatchDelete(idsOfJobsToDelete); err != nil { return nil, nil, 0, 0, err } @@ -1070,6 +1057,9 @@ func (s *Scheduler) generateUpdateMessagesFromJob(ctx *armadacontext.Context, jo } if !origJob.Equal(job) { + if job.InTerminalState() { + s.shortJobPenalty.ReportFinishedJob(job) + } if err := txn.Upsert([]*jobdb.Job{job}); err != nil { return nil, err } @@ -1128,7 +1118,9 @@ func (s *Scheduler) expireJobsIfNecessary(ctx *armadacontext.Context, txn *jobdb run := job.LatestRun() if run != nil && !job.Queued() && staleExecutors[run.Executor()] { ctx.Warnf("Cancelling job %s as it is running on lost executor %s", job.Id(), run.Executor()) - jobsToUpdate = append(jobsToUpdate, job.WithQueued(false).WithFailed(true).WithUpdatedRun(run.WithFailed(true))) + expiredJob := job.WithQueued(false).WithFailed(true).WithUpdatedRun(run.WithFailed(true)) + s.shortJobPenalty.ReportFinishedJob(expiredJob) + jobsToUpdate = append(jobsToUpdate, expiredJob) leaseExpiredError := &armadaevents.Error{ Terminal: true, @@ -1234,6 +1226,7 @@ func (s *Scheduler) submitCheck(ctx *armadacontext.Context, txn *jobdb.Txn) ([]* } } else { job = job.WithFailed(true).WithQueued(false) + s.shortJobPenalty.ReportFinishedJob(job) jobsToUpdate = append(jobsToUpdate, job) es.Events[0].Event = &armadaevents.EventSequence_Event_JobErrors{ @@ -1275,7 +1268,7 @@ func (s *Scheduler) initialise(ctx *armadacontext.Context) error { case <-ctx.Done(): return nil default: - if _, _, newJobsSerial, newRunsSerial, err := s.syncState(ctx, true, false); err != nil { + if _, _, newJobsSerial, newRunsSerial, err := s.syncState(ctx, true); err != nil { ctx.Logger(). WithStacktrace(err). Error("failed to initialise; trying again in 1 second") diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index 5a3368a70a2..ab487ad6bae 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -2009,7 +2009,7 @@ func TestScheduler_TestSyncInitialState(t *testing.T) { // which must be consistent within tests. sched.jobDb = testfixtures.NewJobDb(testfixtures.TestResourceListFactory) - initialJobs, _, newJobsSerial, newRunsSerial, err := sched.syncState(ctx, true, false) + initialJobs, _, newJobsSerial, newRunsSerial, err := sched.syncState(ctx, true) require.NoError(t, err) sched.jobsSerial = newJobsSerial sched.runsSerial = newRunsSerial @@ -2232,7 +2232,7 @@ func TestScheduler_TestSyncState(t *testing.T) { require.NoError(t, err) txn.Commit() - updatedJobs, _, _, _, err := sched.syncState(ctx, false, false) + updatedJobs, _, _, _, err := sched.syncState(ctx, false) require.NoError(t, err) expectedJobDb := testfixtures.NewJobDbWithJobs(tc.expectedUpdatedJobs) @@ -4048,3 +4048,103 @@ func TestAppendEventSequencesFromPreemptedJobs_NilPreemptingJob(t *testing.T) { assert.Equal(t, preemptedRun.Id(), preemptedEvent.PreemptedRunId) assert.Equal(t, "", preemptedEvent.PreemptingJobId) } + +func TestScheduler_ReportsShortJobPenaltyAtTerminalisationSites(t *testing.T) { + const cutoff = time.Minute + now := time.Now() + runningTime := now.Add(-30 * time.Second) // within cutoff, so the job qualifies for a penalty + + newSchedulerWithPenalty := func(t *testing.T, submitCheckSuccess bool, staleExecutor bool) (*Scheduler, *scheduling.ShortJobPenalty) { + penalty := scheduling.NewShortJobPenalty(map[string]time.Duration{testfixtures.TestPool: cutoff}) + penalty.SetNow(now) + + clusterTimeout := 1 * time.Hour + testClock := clock.NewFakeClock(now) + heartbeat := testClock.Now() + if staleExecutor { + heartbeat = heartbeat.Add(-2 * clusterTimeout) + } + + sched, err := NewScheduler( + testfixtures.NewJobDb(testfixtures.TestResourceListFactory), + &testJobRepository{}, + &testExecutorRepository{updateTimes: map[string]time.Time{"testExecutor": heartbeat}}, + runner.NewSyncSchedulingRunner(&testSchedulingAlgo{}), + leaderelection.NewStandaloneLeaderController(), + &testPublisher{}, + &testSubmitChecker{checkSuccess: submitCheckSuccess}, + &testGangValidator{validateSuccess: true}, + 1*time.Second, + 5*time.Second, + clusterTimeout, + penalty, + maxNumberOfAttempts, + nodeIdLabel, + schedulerMetrics, + pricing.NoopBidPriceProvider{}, + []string{}, + &testQueueCache{}, + ) + require.NoError(t, err) + sched.clock = testClock + return sched, penalty + } + + seed := func(t *testing.T, sched *Scheduler, job *jobdb.Job) { + txn := sched.jobDb.WriteTxn() + require.NoError(t, txn.Upsert([]*jobdb.Job{job})) + txn.Commit() + } + + runCycle := func(t *testing.T, sched *Scheduler, updateAll bool) { + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) + defer cancel() + _, err := sched.cycle(ctx, updateAll, sched.leaderController.GetToken(), false, 1) + require.NoError(t, err) + } + + shortJob := func(validated bool) *jobdb.Job { + job := testfixtures.Test32Cpu256GiJob(testfixtures.TestQueue, testfixtures.PriorityClass2). + WithQueued(false). + WithValidated(validated). + WithNewRun("testExecutor", "test-node", "node", testfixtures.TestPool, 5) + return job.WithUpdatedRun(job.LatestRun().WithRunningTime(&runningTime)) + } + + assertPenalised := func(t *testing.T, penalty *scheduling.ShortJobPenalty, job *jobdb.Job) { + penalties := penalty.Snapshot().GetPenaltiesForPool(testfixtures.TestPool) + require.True(t, penalties[testfixtures.TestQueue].Equal(job.AllResourceRequirements()), + "expected penalty for queue %s to equal the job's resources", testfixtures.TestQueue) + } + + t.Run("succeeded job via generateUpdateMessages", func(t *testing.T) { + sched, penalty := newSchedulerWithPenalty(t, true, false) + job := shortJob(true) + job = job.WithUpdatedRun(job.LatestRun().WithSucceeded(true)) + seed(t, sched, job) + + runCycle(t, sched, true) + + assertPenalised(t, penalty, job) + }) + + t.Run("expired job via expireJobsIfNecessary", func(t *testing.T) { + sched, penalty := newSchedulerWithPenalty(t, true, true) + job := shortJob(true) + seed(t, sched, job) + + runCycle(t, sched, false) + + assertPenalised(t, penalty, job) + }) + + t.Run("rejected job via submitCheck", func(t *testing.T) { + sched, penalty := newSchedulerWithPenalty(t, false, false) + job := shortJob(false) + seed(t, sched, job) + + runCycle(t, sched, false) + + assertPenalised(t, penalty, job) + }) +} diff --git a/internal/scheduler/scheduling/scheduling_algo.go b/internal/scheduler/scheduling/scheduling_algo.go index 130e476e3f1..2392bffff5f 100644 --- a/internal/scheduler/scheduling/scheduling_algo.go +++ b/internal/scheduler/scheduling/scheduling_algo.go @@ -137,6 +137,8 @@ func (l *FairSchedulingAlgo) Schedule( return nil, err } + shortJobPenalty := l.shortJobPenalty.Snapshot() + reconciliationByPool, err := l.reconcilePools(ctx, txn, executors) if err != nil { return nil, err @@ -153,7 +155,7 @@ func (l *FairSchedulingAlgo) Schedule( if reconciliation.Err() != nil { outcome = reconciliation.Outcome() } else { - outcome, schedulingResult, err = l.runPoolSchedulingRound(ctx, pool, txn, executors) + outcome, schedulingResult, err = l.runPoolSchedulingRound(ctx, pool, txn, executors, shortJobPenalty) if err != nil { return nil, err } @@ -207,6 +209,7 @@ func (l *FairSchedulingAlgo) runPoolSchedulingRound( pool configuration.PoolConfig, txn *jobdb.Txn, executors []*schedulerobjects.Executor, + shortJobPenalty *ShortJobPenaltySnapshot, ) (*PoolSchedulingOutcome, *SchedulingResult, error) { select { case <-ctx.Done(): @@ -217,7 +220,7 @@ func (l *FairSchedulingAlgo) runPoolSchedulingRound( // It is important to pass the validated executors here // This is because the validation ensures those nodes are inline with the jobs // If we use a different copy of nodes (possibly more to date copy) it may no longer align with the jobs/runs - fsctx, err := l.newFairSchedulingAlgoContext(ctx, txn, executors, pool) + fsctx, err := l.newFairSchedulingAlgoContext(ctx, txn, executors, pool, shortJobPenalty) if err != nil { return NewPoolSchedulingOutcome(PoolSchedulingTerminationReasonSchedulingDisabled, errors.WithMessagef(err, "failed to create scheduling algo context")), nil, nil } @@ -410,7 +413,7 @@ func markAsFailedReconciliation(clock clock.Clock, job *jobdb.Job) *jobdb.Job { return job } -func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Context, txn *jobdb.Txn, executors []*schedulerobjects.Executor, currentPool configuration.PoolConfig) (*FairSchedulingAlgoContext, error) { +func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Context, txn *jobdb.Txn, executors []*schedulerobjects.Executor, currentPool configuration.PoolConfig, shortJobPenalty *ShortJobPenaltySnapshot) (*FairSchedulingAlgoContext, error) { queues, err := l.queueCache.GetAll(ctx) if err != nil { return nil, err @@ -435,16 +438,12 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Con // - Jobs active on the nodes of this pool // - These are used to populate the jobdb, calculate demand/fairshare // - This may include nodes from other pools, especially if the nodes pool has changed - // - Terminal jobs of this pool - // - For calculating short job penalty // - Jobs queued against home/away pools relevant to the pool being computed // - This is to calculate demand on both home and away pools leasedJobs := txn.GetAllLeasedJobs() - terminalJobs := txn.GetAllTerminalJobs() queuedJobs := getQueuedJobs(txn, allPools) - allJobs := make([]*jobdb.Job, 0, len(leasedJobs)+len(terminalJobs)+len(queuedJobs)) + allJobs := make([]*jobdb.Job, 0, len(leasedJobs)+len(queuedJobs)) allJobs = append(allJobs, leasedJobs...) - allJobs = append(allJobs, terminalJobs...) allJobs = append(allJobs, queuedJobs...) jobSchedulingInfo, err := l.calculateJobSchedulingInfo(ctx, @@ -455,7 +454,8 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Con allJobs, currentPool.Name, awayAllocationPools, - allPools) + allPools, + shortJobPenalty) if err != nil { return nil, err } @@ -580,13 +580,13 @@ type jobSchedulingInfo struct { func (l *FairSchedulingAlgo) calculateJobSchedulingInfo(ctx *armadacontext.Context, activeExecutorsSet map[string]bool, queues map[string]*api.Queue, jobs []*jobdb.Job, currentPool string, awayAllocationPools []string, allPools []string, + shortJobPenalty *ShortJobPenaltySnapshot, ) (*jobSchedulingInfo, error) { jobsByExecutorId := make(map[string][]*jobdb.Job) jobsByPool := make(map[string][]*jobdb.Job) demandByQueueAndPriorityClass := make(map[string]map[string]internaltypes.ResourceList) allocatedByQueueAndPriorityClass := make(map[string]map[string]internaltypes.ResourceList) awayAllocatedByQueueAndPriorityClass := make(map[string]map[string]internaltypes.ResourceList) - shortJobPenaltyByQueue := make(map[string]internaltypes.ResourceList) for _, job := range jobs { queue, present := queues[job.Queue()] @@ -595,14 +595,6 @@ func (l *FairSchedulingAlgo) calculateJobSchedulingInfo(ctx *armadacontext.Conte continue } - if l.shortJobPenalty.ShouldApplyPenalty(job) { - jobPool := job.LatestRun().Pool() - jobRequirements := job.AllResourceRequirements() - if jobPool == currentPool { - shortJobPenaltyByQueue[queue.Name] = shortJobPenaltyByQueue[queue.Name].Add(jobRequirements) - } - } - if job.InTerminalState() { continue } @@ -680,6 +672,7 @@ func (l *FairSchedulingAlgo) calculateJobSchedulingInfo(ctx *armadacontext.Conte jobsByExecutorId[executorId] = append(jobsByExecutorId[executorId], job) } + shortJobPenaltyByQueue := shortJobPenalty.GetPenaltiesForPool(currentPool) return &jobSchedulingInfo{ jobsByExecutorId: jobsByExecutorId, jobsByPool: jobsByPool, diff --git a/internal/scheduler/scheduling/short_job_penalty.go b/internal/scheduler/scheduling/short_job_penalty.go index bfd9563e2f7..f9c45d8fa6d 100644 --- a/internal/scheduler/scheduling/short_job_penalty.go +++ b/internal/scheduler/scheduling/short_job_penalty.go @@ -1,21 +1,20 @@ package scheduling import ( + "container/heap" + "maps" "time" + "github.com/armadaproject/armada/internal/scheduler/internaltypes" "github.com/armadaproject/armada/internal/scheduler/jobdb" ) -// Used to penalize short-running jobs by pretending they -// ran for some minimum length when calculating costs. -type ShortJobPenalty struct { - cutoffDurationByPool map[string]time.Duration - now time.Time -} - func NewShortJobPenalty(cutoffDurationByPool map[string]time.Duration) *ShortJobPenalty { return &ShortJobPenalty{ cutoffDurationByPool: cutoffDurationByPool, + penaltyByJobID: map[string]*penaltyEntry{}, + expiry: &entryHeap{}, + sums: map[string]map[string]internaltypes.ResourceList{}, } } @@ -23,10 +22,12 @@ func (sjp *ShortJobPenalty) SetNow(now time.Time) { if sjp == nil { return } + sjp.mu.Lock() + defer sjp.mu.Unlock() sjp.now = now } -func (sjp *ShortJobPenalty) ShouldApplyPenalty(job *jobdb.Job) bool { +func (sjp *ShortJobPenalty) shouldApplyPenalty(job *jobdb.Job) bool { if sjp == nil || sjp.now.IsZero() { return false } @@ -51,3 +52,89 @@ func (sjp *ShortJobPenalty) ShouldApplyPenalty(job *jobdb.Job) bool { return sjp.now.Sub(*jobStart) < sjp.cutoffDurationByPool[jobRun.Pool()] } + +// ReportFinishedJob applies a terminal short job's resources to its (pool, queue) sums once. +// Non-terminal and duplicate jobs are ignored. +func (sjp *ShortJobPenalty) ReportFinishedJob(job *jobdb.Job) { + if sjp == nil { + return + } + sjp.mu.Lock() + defer sjp.mu.Unlock() + sjp.expireUpTo(sjp.now) + + if _, alreadyCounted := sjp.penaltyByJobID[job.Id()]; alreadyCounted { + return + } + if !sjp.shouldApplyPenalty(job) { + return + } + + run := job.LatestRun() + pool := run.Pool() + queue := job.Queue() + resources := job.AllResourceRequirements() + deadline := run.RunningTime().Add(sjp.cutoffDurationByPool[pool]) + + e := &penaltyEntry{ + jobID: job.Id(), + pool: pool, + queue: queue, + resources: resources, + deadline: deadline, + } + sjp.penaltyByJobID[job.Id()] = e + heap.Push(sjp.expiry, e) + sjp.addToSums(pool, queue, resources) +} + +// Snapshot expires entries up to the current now and returns an immutable, +// deep-copied view of the per-(pool,queue) penalty sums. +func (sjp *ShortJobPenalty) Snapshot() *ShortJobPenaltySnapshot { + if sjp == nil { + return &ShortJobPenaltySnapshot{} + } + sjp.mu.Lock() + defer sjp.mu.Unlock() + sjp.expireUpTo(sjp.now) + + sums := make(map[string]map[string]internaltypes.ResourceList, len(sjp.sums)) + for pool, queueSums := range sjp.sums { + inner := make(map[string]internaltypes.ResourceList, len(queueSums)) + maps.Copy(inner, queueSums) + sums[pool] = inner + } + return &ShortJobPenaltySnapshot{sums: sums} +} + +// expireUpTo pops every entry whose deadline is at or before now, +// subtracting its penalty contribution +func (sjp *ShortJobPenalty) expireUpTo(now time.Time) { + for sjp.expiry.Len() > 0 && !sjp.expiry.peek().deadline.After(now) { + e := heap.Pop(sjp.expiry).(*penaltyEntry) + sjp.subtractFromSums(e.pool, e.queue, e.resources) + delete(sjp.penaltyByJobID, e.jobID) + } +} + +func (sjp *ShortJobPenalty) addToSums(pool, queue string, resources internaltypes.ResourceList) { + queueSums, ok := sjp.sums[pool] + if !ok { + queueSums = map[string]internaltypes.ResourceList{} + sjp.sums[pool] = queueSums + } + queueSums[queue] = queueSums[queue].Add(resources) +} + +func (sjp *ShortJobPenalty) subtractFromSums(pool, queue string, resources internaltypes.ResourceList) { + queueSums := sjp.sums[pool] + remaining := queueSums[queue].Subtract(resources) + if remaining.AllZero() { + delete(queueSums, queue) + if len(queueSums) == 0 { + delete(sjp.sums, pool) + } + return + } + queueSums[queue] = remaining +} diff --git a/internal/scheduler/scheduling/short_job_penalty_test.go b/internal/scheduler/scheduling/short_job_penalty_test.go index f995927c016..c37a4dfe7d1 100644 --- a/internal/scheduler/scheduling/short_job_penalty_test.go +++ b/internal/scheduler/scheduling/short_job_penalty_test.go @@ -1,11 +1,13 @@ package scheduling import ( + "sync" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/scheduler/internaltypes" "github.com/armadaproject/armada/internal/scheduler/jobdb" "github.com/armadaproject/armada/internal/scheduler/testfixtures" ) @@ -13,12 +15,43 @@ import ( func TestNilSjpReturnsFalse(t *testing.T) { var nilSjp *ShortJobPenalty = nil job := shortTestJob(time.Now()).WithSucceeded(true) - assert.False(t, nilSjp.ShouldApplyPenalty(job)) + assert.False(t, nilSjp.shouldApplyPenalty(job)) +} + +func TestNilSjpIsSafeToCall(t *testing.T) { + var nilSjp *ShortJobPenalty = nil + job := shortTestJob(time.Now()).WithSucceeded(true) + assert.NotPanics(t, func() { + nilSjp.SetNow(time.Now()) + nilSjp.ReportFinishedJob(job) + }) + assert.Nil(t, nilSjp.Snapshot().GetPenaltiesForPool(testfixtures.TestPool)) } func TestTimeNotSetReturnsFalse(t *testing.T) { job := shortTestJob(time.Now()).WithSucceeded(true) - assert.False(t, makeSut().ShouldApplyPenalty(job)) + assert.False(t, makeSut().shouldApplyPenalty(job)) +} + +func TestJobWithNoRunReturnsFalse(t *testing.T) { + now := time.Now() + sut := makeSut() + sut.SetNow(now) + + job := testfixtures.Test32Cpu256GiJob("q", testfixtures.PriorityClass2).WithSucceeded(true) + assert.False(t, sut.shouldApplyPenalty(job)) +} + +func TestJobWithNoRunningTimeReturnsFalse(t *testing.T) { + now := time.Now() + sut := makeSut() + sut.SetNow(now) + + job := testfixtures.Test32Cpu256GiJob("q", testfixtures.PriorityClass2). + WithNewRun("testExecutor", "test-node", "node", testfixtures.TestPool, 5). + WithSucceeded(true) + assert.Nil(t, job.LatestRun().RunningTime()) + assert.False(t, sut.shouldApplyPenalty(job)) } func TestLongSucceededJobReturnsFalse(t *testing.T) { @@ -27,7 +60,7 @@ func TestLongSucceededJobReturnsFalse(t *testing.T) { sut.SetNow(now) job := longTestJob(now).WithSucceeded(true) - assert.False(t, sut.ShouldApplyPenalty(job)) + assert.False(t, sut.shouldApplyPenalty(job)) } func TestShortRunningJobReturnsFalse(t *testing.T) { @@ -36,7 +69,7 @@ func TestShortRunningJobReturnsFalse(t *testing.T) { sut.SetNow(now) job := shortTestJob(now) - assert.False(t, sut.ShouldApplyPenalty(job)) + assert.False(t, sut.shouldApplyPenalty(job)) } func TestShortSucceededJobReturnsTrue(t *testing.T) { @@ -45,7 +78,25 @@ func TestShortSucceededJobReturnsTrue(t *testing.T) { sut.SetNow(now) job := shortTestJob(now).WithSucceeded(true) - assert.True(t, sut.ShouldApplyPenalty(job)) + assert.True(t, sut.shouldApplyPenalty(job)) +} + +func TestShortCancelledJobReturnsTrue(t *testing.T) { + now := time.Now() + sut := makeSut() + sut.SetNow(now) + + job := shortTestJob(now).WithCancelled(true) + assert.True(t, sut.shouldApplyPenalty(job)) +} + +func TestShortFailedJobReturnsTrue(t *testing.T) { + now := time.Now() + sut := makeSut() + sut.SetNow(now) + + job := shortTestJob(now).WithFailed(true) + assert.True(t, sut.shouldApplyPenalty(job)) } func TestShortPreemptedJobReturnsFalse(t *testing.T) { @@ -55,7 +106,7 @@ func TestShortPreemptedJobReturnsFalse(t *testing.T) { job := shortTestJob(now).WithSucceeded(true) job = job.WithUpdatedRun(job.LatestRun().WithPreempted(true)) - assert.False(t, sut.ShouldApplyPenalty(job)) + assert.False(t, sut.shouldApplyPenalty(job)) } func TestShortJobWithPreemptRequestedReturnsFalse(t *testing.T) { @@ -65,7 +116,7 @@ func TestShortJobWithPreemptRequestedReturnsFalse(t *testing.T) { job := shortTestJob(now).WithSucceeded(true) job = job.WithUpdatedRun(job.LatestRun().WithPreemptRequested(true)) - assert.False(t, sut.ShouldApplyPenalty(job)) + assert.False(t, sut.shouldApplyPenalty(job)) } func TestShortJobWithPreemptedTimeSetReturnsFalse(t *testing.T) { @@ -75,7 +126,7 @@ func TestShortJobWithPreemptedTimeSetReturnsFalse(t *testing.T) { job := shortTestJob(now).WithSucceeded(true) job = job.WithUpdatedRun(job.LatestRun().WithPreemptedTime(&now)) - assert.False(t, sut.ShouldApplyPenalty(job)) + assert.False(t, sut.shouldApplyPenalty(job)) } func makeSut() *ShortJobPenalty { @@ -95,3 +146,182 @@ func testJob(runningTime time.Time) *jobdb.Job { run := job.LatestRun() return job.WithUpdatedRun(run.WithRunningTime(&runningTime)) } + +func TestAccumulatesPerQueue(t *testing.T) { + now := time.Now() + sut := makeSut() + sut.SetNow(now) + + jobA := testJobForQueue("q1", now.Add(-30*time.Second)).WithSucceeded(true) + jobB := testJobForQueue("q1", now.Add(-20*time.Second)).WithSucceeded(true) + jobC := testJobForQueue("q2", now.Add(-20*time.Second)).WithSucceeded(true) + + sut.ReportFinishedJob(jobA) + sut.ReportFinishedJob(jobB) + sut.ReportFinishedJob(jobC) + + penalties := sut.Snapshot().GetPenaltiesForPool(testfixtures.TestPool) + expectedTwoJobs := jobA.AllResourceRequirements().Add(jobB.AllResourceRequirements()) + assert.True(t, penalties["q1"].Equal(expectedTwoJobs)) + assert.True(t, penalties["q2"].Equal(jobC.AllResourceRequirements())) +} + +func TestNonQualifyingJobIsNotCharged(t *testing.T) { + now := time.Now() + sut := makeSut() + sut.SetNow(now) + + longJob := longTestJob(now).WithSucceeded(true) + sut.ReportFinishedJob(longJob) + + penalties := sut.Snapshot().GetPenaltiesForPool(testfixtures.TestPool) + assert.Empty(t, penalties) +} + +func TestDedupSameJobReportedTwice(t *testing.T) { + now := time.Now() + sut := makeSut() + sut.SetNow(now) + + job := testJobForQueue("q1", now.Add(-30*time.Second)).WithSucceeded(true) + sut.ReportFinishedJob(job) + sut.ReportFinishedJob(job) + + penalties := sut.Snapshot().GetPenaltiesForPool(testfixtures.TestPool) + assert.True(t, penalties["q1"].Equal(job.AllResourceRequirements())) +} + +func TestEntryExpiresExactlyAtDeadline(t *testing.T) { + start := time.Now() + sut := makeSut() + job := testJobForQueue("q1", start).WithSucceeded(true) + + reportNow := start.Add(30 * time.Second) + sut.SetNow(reportNow) + sut.ReportFinishedJob(job) + assert.True(t, sut.Snapshot().GetPenaltiesForPool(testfixtures.TestPool)["q1"].Equal(job.AllResourceRequirements())) + + atDeadline := start.Add(time.Minute) + sut.SetNow(atDeadline) + assert.Empty(t, sut.Snapshot().GetPenaltiesForPool(testfixtures.TestPool)) +} + +func TestPartialExpiryLeavesRemainder(t *testing.T) { + start := time.Now() + sut := makeSut() + sut.SetNow(start) + + early := testJobForQueue("q1", start.Add(-50*time.Second)).WithSucceeded(true) + late := testJobForQueue("q1", start.Add(-20*time.Second)).WithSucceeded(true) + sut.ReportFinishedJob(early) + sut.ReportFinishedJob(late) + + sut.SetNow(start.Add(20 * time.Second)) + penalties := sut.Snapshot().GetPenaltiesForPool(testfixtures.TestPool) + assert.True(t, penalties["q1"].Equal(late.AllResourceRequirements())) +} + +func TestPostExpiryReReportNeverReQualifies(t *testing.T) { + start := time.Now() + sut := makeSut() + job := testJobForQueue("q1", start).WithSucceeded(true) + + sut.SetNow(start.Add(10 * time.Second)) + sut.ReportFinishedJob(job) + sut.SetNow(start.Add(2 * time.Minute)) + assert.Empty(t, sut.Snapshot().GetPenaltiesForPool(testfixtures.TestPool)) + sut.ReportFinishedJob(job) + assert.Empty(t, sut.Snapshot().GetPenaltiesForPool(testfixtures.TestPool)) +} + +func TestPerPoolCutoffAndPoolIsolation(t *testing.T) { + now := time.Now() + sut := NewShortJobPenalty(map[string]time.Duration{ + "poolA": time.Minute, + "poolB": time.Hour, + }) + sut.SetNow(now) + + jobA := testJobForPool("q1", "poolA", now.Add(-30*time.Minute)).WithSucceeded(true) + jobB := testJobForPool("q1", "poolB", now.Add(-30*time.Minute)).WithSucceeded(true) + jobC := testJobForPool("q1", "poolC", now.Add(-1*time.Second)).WithSucceeded(true) + + sut.ReportFinishedJob(jobA) + sut.ReportFinishedJob(jobB) + sut.ReportFinishedJob(jobC) + + assert.Empty(t, sut.Snapshot().GetPenaltiesForPool("poolA")) + assert.True(t, sut.Snapshot().GetPenaltiesForPool("poolB")["q1"].Equal(jobB.AllResourceRequirements())) + assert.Empty(t, sut.Snapshot().GetPenaltiesForPool("poolC")) +} + +func TestGetPenaltiesForUnknownPoolIsEmpty(t *testing.T) { + now := time.Now() + sut := makeSut() + sut.SetNow(now) + assert.Empty(t, sut.Snapshot().GetPenaltiesForPool("does-not-exist")) +} + +func TestShortJobPenalty_ConcurrentReportsAreCorrect(t *testing.T) { + now := time.Now() + sut := makeSut() + sut.SetNow(now) + + const numWorkers = 8 + const jobsPerWorker = 500 + runningTime := now.Add(-10 * time.Second) + + batches := make([][]*jobdb.Job, numWorkers) + expected := internaltypes.ResourceList{} + for w := range batches { + batch := make([]*jobdb.Job, jobsPerWorker) + for i := range batch { + job := testJobForQueue("q1", runningTime).WithSucceeded(true) + batch[i] = job + expected = expected.Add(job.AllResourceRequirements()) + } + batches[w] = batch + } + + done := make(chan struct{}) + var readers sync.WaitGroup + readers.Add(1) + go func() { + defer readers.Done() + for { + select { + case <-done: + return + default: + sut.Snapshot().GetPenaltiesForPool(testfixtures.TestPool) + } + } + }() + + var writers sync.WaitGroup + for _, batch := range batches { + writers.Add(1) + go func(batch []*jobdb.Job) { + defer writers.Done() + for _, job := range batch { + sut.ReportFinishedJob(job) + } + }(batch) + } + writers.Wait() + close(done) + readers.Wait() + + penalties := sut.Snapshot().GetPenaltiesForPool(testfixtures.TestPool) + assert.True(t, penalties["q1"].Equal(expected)) +} + +func testJobForQueue(queue string, runningTime time.Time) *jobdb.Job { + return testJobForPool(queue, testfixtures.TestPool, runningTime) +} + +func testJobForPool(queue string, pool string, runningTime time.Time) *jobdb.Job { + job := testfixtures.Test32Cpu256GiJob(queue, testfixtures.PriorityClass2).WithNewRun("testExecutor", "test-node", "node", pool, 5) + run := job.LatestRun() + return job.WithUpdatedRun(run.WithRunningTime(&runningTime)) +} diff --git a/internal/scheduler/scheduling/short_job_penalty_types.go b/internal/scheduler/scheduling/short_job_penalty_types.go new file mode 100644 index 00000000000..025c2413dd3 --- /dev/null +++ b/internal/scheduler/scheduling/short_job_penalty_types.go @@ -0,0 +1,69 @@ +package scheduling + +import ( + "container/heap" + "sync" + "time" + + "github.com/armadaproject/armada/internal/scheduler/internaltypes" +) + +type penaltyEntry struct { + jobID string + pool string + queue string + resources internaltypes.ResourceList + // deadline is runStart + cutoffDurationByPool[pool], fixed at insert time + deadline time.Time +} + +// ShortJobPenalty owns job penalty state keyed by (pool, queue). +type ShortJobPenalty struct { + mu sync.Mutex + cutoffDurationByPool map[string]time.Duration + now time.Time + + penaltyByJobID map[string]*penaltyEntry + expiry *entryHeap + // Derived cache of the per-(pool,queue) running total + sums map[string]map[string]internaltypes.ResourceList +} + +type ShortJobPenaltySnapshot struct { + sums map[string]map[string]internaltypes.ResourceList +} + +func (s *ShortJobPenaltySnapshot) GetPenaltiesForPool(pool string) map[string]internaltypes.ResourceList { + if s == nil { + return nil + } + return s.sums[pool] +} + +// entryHeap is a min-heap of penaltyEntry ordered by deadline. +type entryHeap []*penaltyEntry + +func (h entryHeap) Len() int { return len(h) } +func (h entryHeap) Less(i, j int) bool { return h[i].deadline.Before(h[j].deadline) } +func (h entryHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *entryHeap) Push(x any) { + *h = append(*h, x.(*penaltyEntry)) +} + +func (h *entryHeap) Pop() any { + old := *h + n := len(old) + e := old[n-1] + old[n-1] = nil + *h = old[:n-1] + return e +} + +func (h entryHeap) peek() *penaltyEntry { + return h[0] +} + +var _ heap.Interface = (*entryHeap)(nil)