diff --git a/trainer/mamba_trainer.py b/trainer/mamba_trainer.py index 77708ef..7f5e933 100644 --- a/trainer/mamba_trainer.py +++ b/trainer/mamba_trainer.py @@ -1,7 +1,7 @@ from transformers import Trainer import torch import os - +import json class MambaTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): @@ -20,6 +20,8 @@ def compute_loss(self, model, inputs, return_outputs=False): def save_model(self, output_dir, _internal_call): if not os.path.exists(output_dir): os.makedirs(output_dir) - + with open(output_dir, 'w') as f: + json.dump(self.model.config.__dict__, f) torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin") - self.tokenizer.save_pretrained(output_dir) \ No newline at end of file + self.tokenizer.save_pretrained(output_dir) + \ No newline at end of file