b1ade-1b-bf16 / README.md
w601sxs's picture
Update README.md
9355cba verified
metadata
library_name: transformers
datasets:
  - kaist-ai/CoT-Collection

Model Card for b1ade-1b

Instruction fine tuned 1B parameter model; pass in:

  1. context: <...>
  2. question: <...>

and expect an answer: <...>

See implemetation example below (also see https://huggingface.co/spaces/w601sxs/b1ade-1b):

import torch
import transformers
import os, time
import tempfile
from transformers import AutoTokenizer, AutoModelForCausalLM


BASE_MODEL = "w601sxs/b1ade-1b-bf16"

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    

model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, 
                                             torch_dtype=torch.bfloat16,
                                             device_map="auto",
                                             offload_folder="offload")


model.eval()

from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] in self.keywords:
            return True
        return False


stop_words = ['>', ' >','> ']
stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
stop_criteria = StoppingCriteriaList([KeywordsStoppingCriteria(keywords_ids = stop_ids)])

def predict(text):
    inputs = tokenizer(text, return_tensors="pt").to('cuda')
    with torch.no_grad():
        outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=128, stopping_criteria=stop_criteria)
        out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1]
    
    return print(out_text.split(text)[-1])



predict("context: <The center contact of the bulb typically connects to the medium-power filament, and the ring connects to the low-power filament. Thus, if a 3-way bulb is screwed into a standard light socket that has only a center contact, only the medium-power filament operates. In the case of the 50 W / 100 W / 150 W bulb, putting this bulb in a regular lamp socket will result in it behaving like a normal 100W bulb.>\n question: <Question: Do 3 way light bulbs work in any lamp?>\n")