QLoRA-Flan-T5-Small / README.md
emonty777's picture
Update README.md
c079f78
|
raw
history blame
No virus
2.26 kB
---
license: apache-2.0
tags:
- generated_from_trainer
datasets:
- cnn_dailymail
model-index:
- name: QLoRA-Flan-T5-Small
results: []
---
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
# QLoRA-Flan-T5-Small
This model is a fine-tuned version of [google/flan-t5-small](https://huggingface.co/google/flan-t5-small) on the cnn_dailymail dataset.
## Model description
More information needed
## Intended uses & limitations
More information needed
## Training and evaluation data
More information needed
## How to use model
1. Loading the model
'''python
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load peft config for pre-trained checkpoint etc.
peft_model_id = "emonty777/QLoRA-Flan-T5-Small"
config = PeftConfig.from_pretrained(peft_model_id)
# load base LLM model and tokenizer / runs on CPU
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# load base LLM model and tokenizer for GPU
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id, device_map={"":0})
model.eval()
'''
## Training procedure
### Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 3e-05
- train_batch_size: 8
- eval_batch_size: 8
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- num_epochs: 4
### Training results
Evaluated on full CNN Dailymail test set
'rouge-1': {'r': 0.3484396421841008, 'p': 0.37845620239152916, 'f': 0.3484265780526604},
'rouge-2': {'r': 0.1472418310455188, 'p': 0.15418276080118026, 'f': 0.14343059577230782},
'rouge-l': {'r': 0.3280567401095563, 'p': 0.3565504002457199, 'f': 0.32809541498574013}
### Framework versions
- Transformers 4.27.1
- Pytorch 2.0.1+cu118
- Datasets 2.9.0
- Tokenizers 0.13.3