GenAI_project / models /lora_model.py
jaothan's picture
Upload 24 files
fa64206 verified
raw
history blame
387 Bytes
from transformers import BertForSequenceClassification, LoRAConfig
def get_lora_model(config):
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
lora_config = LoRAConfig(r=config['model']['lora']['r'], alpha=config['model']['lora']['alpha'])
model.add_lora('imdb_lora', config=lora_config)
model.train_lora('imdb_lora')
return model