File size: 2,326 Bytes
821c8d6
 
9355cba
 
821c8d6
 
3bdec0d
821c8d6
 
3bdec0d
821c8d6
3bdec0d
 
821c8d6
3bdec0d
821c8d6
3bdec0d
821c8d6
3bdec0d
 
 
 
 
 
821c8d6
 
3bdec0d
821c8d6
3bdec0d
 
821c8d6
3bdec0d
 
 
 
821c8d6
 
3bdec0d
821c8d6
3bdec0d
821c8d6
3bdec0d
 
 
821c8d6
3bdec0d
 
 
 
821c8d6
 
3bdec0d
 
 
821c8d6
3bdec0d
 
 
 
 
 
 
821c8d6
 
 
3bdec0d
9355cba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
---
library_name: transformers
datasets:
- kaist-ai/CoT-Collection
---

# Model Card for b1ade-1b


Instruction fine tuned 1B parameter model; pass in:

1. `context: <...>`
2. `question: <...>`

  and expect an `answer: <...>`

See implemetation example below (also see https://huggingface.co/spaces/w601sxs/b1ade-1b):

```
import torch
import transformers
import os, time
import tempfile
from transformers import AutoTokenizer, AutoModelForCausalLM


BASE_MODEL = "w601sxs/b1ade-1b-bf16"

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    

model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, 
                                             torch_dtype=torch.bfloat16,
                                             device_map="auto",
                                             offload_folder="offload")


model.eval()

from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] in self.keywords:
            return True
        return False


stop_words = ['>', ' >','> ']
stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
stop_criteria = StoppingCriteriaList([KeywordsStoppingCriteria(keywords_ids = stop_ids)])

def predict(text):
    inputs = tokenizer(text, return_tensors="pt").to('cuda')
    with torch.no_grad():
        outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=128, stopping_criteria=stop_criteria)
        out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1]
    
    return print(out_text.split(text)[-1])



predict("context: <The center contact of the bulb typically connects to the medium-power filament, and the ring connects to the low-power filament. Thus, if a 3-way bulb is screwed into a standard light socket that has only a center contact, only the medium-power filament operates. In the case of the 50 W / 100 W / 150 W bulb, putting this bulb in a regular lamp socket will result in it behaving like a normal 100W bulb.>\n question: <Question: Do 3 way light bulbs work in any lamp?>\n")
```