eswardivi's picture
Added Better Inferencing techq
7dc3087 verified
raw
history blame
No virus
2.51 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import os
from threading import Thread
import spaces
import time
token = os.environ["HF_TOKEN"]
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained("google/gemma-1.1-7b-it",
quantization_config=quantization_config,
token=token)
tok = AutoTokenizer.from_pretrained("google/gemma-1.1-7b-it", token=token)
if torch.cuda.is_available():
device = torch.device('cuda')
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device('cpu')
print("Using CPU")
model = model.to(device)
model = model.to_bettertransformer()
@spaces.GPU
def chat(message, history):
start_time = time.time()
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
chat.append({"role": "user", "content": message})
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
model_inputs = tok([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=0.75,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_text = ""
first_token_time = None
for new_text in streamer:
if not first_token_time:
first_token_time = time.time() - start_time
partial_text += new_text
yield partial_text
total_time = time.time() - start_time
tokens = len(tok.tokenize(partial_text))
tokens_per_second = tokens / total_time if total_time > 0 else 0
# Append the timing information to the final output
timing_info = f"\nTime taken to first token: {first_token_time:.2f} seconds\nTokens per second: {tokens_per_second:.2f}"
yield partial_text + timing_info
demo = gr.ChatInterface(fn=chat, examples=[["Write me a poem about Machine Learning."]], title="Chat With LLMS")
demo.launch()