eswardivi's picture
Update app.py
5e407f5 verified
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
BitsAndBytesConfig,
)
import os
from threading import Thread
import spaces
import time
token = os.environ["HF_TOKEN"]
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
"KissanAI/llama3-8b-dhenu-0.1-sft-16bit", quantization_config=quantization_config, token=token
)
tok = AutoTokenizer.from_pretrained("KissanAI/llama3-8b-dhenu-0.1-sft-16bit", token=token)
terminators = [
tok.eos_token_id,
tok.convert_tokens_to_ids("<|eot_id|>")
]
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
# model = model.to(device)
# Dispatch Errors
@spaces.GPU()
def chat(message, history, temperature,do_sample, max_tokens):
prompt_template = """
You are a helpful Agricultural assistant for farmers. You are given the following input. Please complete the response briefly.
## Question:
{}
## Response:
{}"""
start_time = time.time()
chat = []
# for item in history:
# chat.append({"role": "user", "content": item[0]})
# if item[1] is not None:
# chat.append({"role": "assistant", "content": item[1]})
# chat.append({"role": "user", "content": message})
# messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
model_inputs = tok(prompt_template.format(
message, #input
"" # response
), return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
repetition_penalty=1.2,
use_cache=False,
eos_token_id=terminators,
)
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_text = ""
first_token_time = None
for new_text in streamer:
if not first_token_time:
first_token_time = time.time() - start_time
partial_text += new_text
yield partial_text
total_time = time.time() - start_time
tokens = len(tok.tokenize(partial_text))
tokens_per_second = tokens / total_time if total_time > 0 else 0
timing_info = f"\n\nTime taken to first token: {first_token_time:.2f} seconds\nTokens per second: {tokens_per_second:.2f}"
yield partial_text + timing_info
demo = gr.ChatInterface(
fn=chat,
examples=[["I'm a farmer from Odisha, how do I take care of whitefly in my cotton crop?"]],
# multimodal=False,
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.5, label="Temperature", render=False
),
gr.Checkbox(label="Sampling",value=False),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False,
),
],
stop_btn="Stop Generation",
title="Chat With LLMs",
description="Now Running [KissanAI/llama3-8b-dhenu-0.1-sft-16bit](https://huggingface.co/KissanAI/llama3-8b-dhenu-0.1-sft-16bit) in 4bit")
demo.launch()