camel-5b-hf / README.md
kiranr's picture
Update README.md
c59532d
|
raw
history blame
No virus
1.6 kB
metadata
license: apache-2.0

usage :

import os 
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

# set HF_TOKEN in terminal as export HF_TOKEN=hf_***
auth_token = os.environ.get("HF_TOKEN", True)

model_name = "Writer/camel-5b"

tokenizer = AutoTokenizer.from_pretrained(
   model_name, use_auth_token=auth_token
)
model = AutoModelForCausalLM.from_pretrained(
   model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    use_auth_token=auth_token,
)


instruction = "Describe a futuristic device that revolutionizes space travel."


PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}

text = (
    PROMPT_DICT["prompt_no_input"].format(instruction=instruction)
    if not input
    else PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input)
)

model_inputs = tokenizer(text, return_tensors="pt").to("cuda")
output_ids = model.generate(
    **model_inputs,
    max_length=100,
)
output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
clean_output = output_text.split("### Response:")[1].strip()

print(clean_output)