File size: 2,326 Bytes
821c8d6 9355cba 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 821c8d6 3bdec0d 9355cba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
---
library_name: transformers
datasets:
- kaist-ai/CoT-Collection
---
# Model Card for b1ade-1b
Instruction fine tuned 1B parameter model; pass in:
1. `context: <...>`
2. `question: <...>`
and expect an `answer: <...>`
See implemetation example below (also see https://huggingface.co/spaces/w601sxs/b1ade-1b):
```
import torch
import transformers
import os, time
import tempfile
from transformers import AutoTokenizer, AutoModelForCausalLM
BASE_MODEL = "w601sxs/b1ade-1b-bf16"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
offload_folder="offload")
model.eval()
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords_ids:list):
self.keywords = keywords_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if input_ids[0][-1] in self.keywords:
return True
return False
stop_words = ['>', ' >','> ']
stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
stop_criteria = StoppingCriteriaList([KeywordsStoppingCriteria(keywords_ids = stop_ids)])
def predict(text):
inputs = tokenizer(text, return_tensors="pt").to('cuda')
with torch.no_grad():
outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=128, stopping_criteria=stop_criteria)
out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1]
return print(out_text.split(text)[-1])
predict("context: <The center contact of the bulb typically connects to the medium-power filament, and the ring connects to the low-power filament. Thus, if a 3-way bulb is screwed into a standard light socket that has only a center contact, only the medium-power filament operates. In the case of the 50 W / 100 W / 150 W bulb, putting this bulb in a regular lamp socket will result in it behaving like a normal 100W bulb.>\n question: <Question: Do 3 way light bulbs work in any lamp?>\n")
``` |