diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index ba0d46a..f5e5dd6 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -245,6 +245,16 @@ def train(self, args, local_rank, global_rank, loader, model, optimizer, schedul self.current_step += 1 + if self.total_steps % self.save_checkpoint_steps != 0: + if args.deepspeed: + if args.use_lora: + if global_rank == 0: + save_model(model, self.output_model_path + "-" + str(self.total_steps), args.use_lora) + else: + model.save_checkpoint(self.output_model_path, str(self.total_steps)) + else: + save_model(model, self.output_model_path + "-" + str(self.total_steps), args.use_lora) + class MlmTrainer(Trainer): def __init__(self, args):