Update trainer.py

fix lora trainer bug
This commit is contained in:
Andy
2025-05-15 23:08:57 +08:00
committed by GitHub
parent 6c14becf98
commit db5a9fea4e

View File

@@ -76,6 +76,7 @@ class Pipeline(LightningModule):
except ImportError:
raise ImportError("Please install peft library to use LoRA training")
with open(lora_config_path, encoding="utf-8") as f:
import json
lora_config = json.load(f)
lora_config = LoraConfig(**lora_config)
transformers.add_adapter(adapter_config=lora_config, adapter_name=adapter_name)
@@ -825,6 +826,7 @@ def main(args):
dataset_path=args.dataset_path,
checkpoint_dir=args.checkpoint_dir,
adapter_name=args.exp_name,
lora_config_path=args.lora_config_path
)
checkpoint_callback = ModelCheckpoint(
monitor=None,