File size: 4,484 Bytes
1be8c69 7752b69 1be8c69 dc85892 1be8c69 dc85892 1be8c69 dc85892 1be8c69 dc85892 0da2af5 dc85892 0da2af5 dc85892 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
---
library_name: peft
language:
- ja
pipeline_tag: text-generation
---
# Fine-tuned OpenCALM-7B Adapters for Meeting Summarization
## Description
These are weights for LoRA adapters fine-tuned on the OpenCALM-7B ([Andonian et al., 2021](https://huggingface.co/cyberagent/open-calm-7b)) model for Japanese meeting summarization.
## Usage
### Load model and tokenizer
Loading the model in the 4-bit quantized format is recommended to get reliable results since these LoRA adapters were trained by using QLoRA ([Dettmers et al., 2023](https://arxiv.org/abs/2305.14314)).
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-7b")
model = AutoModelForCausalLM.from_pretrained(
"cyberagent/open-calm-7b",
quantization_config=bnb_config,
device_map="auto"
)
model = PeftModel.from_pretrained(model, "haih2/open-calm-7b-summarizer-lora")
```
### Generate summary
In the prompt provided to the model:
* The first part is the length of the summary to be generated,
* and The second part is the source meeting to be summarized.
```python
prompt = "この段落の要約50字以内生成:次に、私立高校の生徒に対する留学支援についてでございますが、都内の私立高校は、それぞれの学校における教育方針に基づきまして、生徒の留学先として海外の学校と提携するなど、既にさまざまな独自の取り組みを進めております。\\nこうした状況等を踏まえ、私立高校を対象とした留学支援のあり方について、今後検討してまいります。\\n\n"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_k=32,
top_p=0.9,
repetition_penalty=1.0,
no_repeat_ngram_size=0,
pad_token_id=tokenizer.pad_token_id,
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(output)
```
## Prompt Format
Any prompt is fine, but it is suggested to have `length` and `source` parts as follows:
```
"この段落を{length}に要約しなさい:{source}\n要約:"
```
or
```
"この段落の要約{length}生成:{source}\n"
```
## Fine-tuning Details
### Dataset
* [Congressional meeting's minutes](https://github.com/kmr-y/NTCIR14-QALab-PoliInfo-FormalRunDataset/tree/master) provided by QA Lab PoliInfo.
### Fine-tuning procedure
The OpenCALM-7B model was fine-tuned on the above dataset using the QLoRA method with prompt `この段落の要約{length}生成:{source}\n`. We outline the following hyperparameters:
|||
|----------------|----------------:|
| **Optimizer** <br>   beta_1 <br>   beta_2 <br>   weight decay | AdamW <br> 0.9 <br> 0.999 <br> 0.01 |
| **Learning rate** <br>   scheduler type | 2e-5 <br> linear |
| **LoRA** <br>   target modules <br>   r <br>   alpha <br>   dropout | <br> query_key_value, dense <br> 4 <br> 64 <br> 0.05 |
| **QLoRA** <br>   compute dtype <br>   storage dtype <br>   quantization strategy | <br> float16 <br> nf4 <br> double quantization |
| **Sequence length** | 1536 |
| **Batch size** | 4 |
| **Gradient accumulation steps** | 2 |
| **Epochs** | 10 |
| **Warmup steps** | 200 |
## Evaluation
### Testing data & Metric
We evaluated the model on two sets: one for *multi-topic* summarization and the other for *single-topic* summarization. ROUGE-L (F1-score-based) with the [Japanese Mecab tokenizer](https://pypi.org/project/mecab-python3/) was used as the evaluation metric.
### Results
| Solution/Model | ROUGE-L <br> (multi-topic) | ROUGE-L <br> (single-topic) |
|----------------|:--------------------------:|:---------------------------:|
|1st place solution* |34.12 |**34.44**|
|2nd place solution* |32.79 |33.65 |
|*OpenCALM-7B (QLoRA)*|***36.75***|*33.31* |
*\* These scores are extracted from this [leaderboard](https://github.com/PoliInfo/PoliInfo.github.io/blob/master/FormalRunResult.md) for the summarization task.*
|