eswardivi's picture
Update app.py
7115ad7 verified
raw
history blame
3.3 kB
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
BitsAndBytesConfig,
)
import os
from threading import Thread
import spaces
import time
token = os.environ["HF_TOKEN"]
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B-Instruct", quantization_config=quantization_config, token=token
)
tok = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
terminators = [
tok.eos_token_id,
tok.convert_tokens_to_ids("<|eot_id|>")
]
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
# model = model.to(device)
# Dispatch Errors
@spaces.GPU(duration=150)
def chat(message, history, temperature,do_sample, max_tokens):
start_time = time.time()
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
chat.append({"role": "user", "content": message})
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
model_inputs = tok([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=terminators,
)
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_text = ""
first_token_time = None
for new_text in streamer:
if not first_token_time:
first_token_time = time.time() - start_time
partial_text += new_text
yield partial_text
total_time = time.time() - start_time
tokens = len(tok.tokenize(partial_text))
tokens_per_second = tokens / total_time if total_time > 0 else 0
timing_info = f"\n\nTime taken to first token: {first_token_time:.2f} seconds\nTokens per second: {tokens_per_second:.2f}"
yield partial_text + timing_info
demo = gr.ChatInterface(
fn=chat,
examples=[["Write me a poem about Machine Learning."]],
# multimodal=False,
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
),
gr.Checkbox(label="Sampling",value=True),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False,
),
],
stop_btn="Stop Generation",
title="Chat With LLMs",
description="Now Running [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) in 4bit"
)
demo.launch()