File size: 3,606 Bytes
72da74c
e690c95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79fe4a6
 
 
9615074
 
79fe4a6
9615074
53c910a
79fe4a6
9615074
79fe4a6
 
517ddd6
79fe4a6
53c910a
79fe4a6
 
 
 
 
 
 
517ddd6
9615074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79fe4a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
---
library_name: peft
---
## Training procedure


The following `bitsandbytes` quantization config was used during training:
- quant_method: bitsandbytes
- load_in_8bit: False
- load_in_4bit: True
- llm_int8_threshold: 6.0
- llm_int8_skip_modules: None
- llm_int8_enable_fp32_cpu_offload: False
- llm_int8_has_fp16_weight: False
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: False
- bnb_4bit_compute_dtype: float16
### Framework versions


- PEFT 0.5.0

## Get it started
```python
import torch
from datasets import Dataset
from huggingface_hub import login
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, AddedToken

# load model and tokenizer
login("[YOUR HF TOKEN HERE FOR USING LLAMA]")
config = PeftConfig.from_pretrained("ChangeIsKey/llama-7b-lexical-substitution")
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map='auto')

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_fast=False, trust_remote_code=True)
tokenizer.add_special_tokens({ "additional_special_tokens":[AddedToken("<|s|>"), AddedToken("<|answer|>"), AddedToken("<|end|>")]})
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.padding_side = 'left'
base_model.resize_token_embeddings(len(tokenizer))

model = PeftModel.from_pretrained(base_model, "ChangeIsKey/llama-7b-lexical-substitution")
model.eval()

# let's use this model
def formatting_func(records):
    text_batch = []
    
    for i in range(len(records['example'])):
        example = records[i]['example']
        start, end = records[i]['start'], records[i]['end']
        
        target = f'**{example[start:end]}**'
        input_text = f'{example[:start]} {target} {example[end:]}'
        text_batch.append(f"{input_text}<|answer|>")
    
    return text_batch

def tokenization(dataset):
    return tokenizer(formatting_func(dataset),
                       truncation=True,
                       max_length=512,
                       padding=True,
                       return_tensors="pt").to("cuda")


# a toy example
examples = [{'example': 'The traffic jam on the highway made everyone late for work.', 'start': 12, 'end': 15},
            {'example': 'I spread a generous layer of strawberry jam on my toast this morning', 'start': 40, 'end': 43}]
dataset = Dataset.from_list(examples)


batch_size = 32
output = list()

with torch.no_grad():
    for i in range(0, len(dataset), batch_size):
        model_input = tokenization(dataset.select(range(i, min(dataset.num_rows, i + batch_size))))

        output_ids = model.generate(**model_input,
                                    do_sample=True,
                                    num_return_sequences=1, 
                                    max_new_tokens=30,
                                    temperature=0.00001, 
                                    repetition_penalty=1/0.85,
                                    top_k=40,
                                    top_p=0.1)

        answers = tokenizer.batch_decode(output_ids, skip_special_tokens=False)

        for answer in answers:
            answer = " ".join(answer.split('<|answer|>')[1:])
            substitutes = [s.strip() for s in answer.split('<|end|>')[:-1] if s.strip() != ""]
            output.append(", ".join(substitutes))

# output
dataset = dataset.add_column('substitutes', output)
for row in dataset:
    target = row['example'][row['start']:row['end']]
    print(f"Target: {target}\nExample: {row['example']}\nSubstitutes: {row['substitutes']}\n")
```