kingabzpro's picture
Update app.py
7d6a5e7
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import time
import numpy as np
from torch.nn import functional as F
import os
from threading import Thread
title = "🦅Falcon 🗨️ChatBot"
description = "Falcon-RW-1B is a 1B parameters causal decoder-only model built by TII and trained on 350B tokens of RefinedWeb."
examples = [["How are you?"]]
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b")
model = AutoModelForCausalLM.from_pretrained(
"tiiuae/falcon-rw-1b",
trust_remote_code=True,
torch_dtype=torch.float16,
)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def chat(curr_system_message, history):
# Initialize a StopOnTokens object
stop = StopOnTokens()
# Construct the input message string for the model by concatenating the current system message and conversation history
messages = curr_system_message + \
"".join(["".join(["<user>: "+item[0], "<chatbot>: "+item[1]])
for item in history])
# Tokenize the messages string
tokens = tokenizer([messages], return_tensors="pt")
streamer = TextIteratorStreamer(
tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
token_ids = tokens.input_ids
attention_mask=tokens.attention_mask
generate_kwargs = dict(
input_ids=token_ids,
attention_mask = attention_mask,
streamer = streamer,
max_length=2048,
do_sample=True,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
temperature = 0.7,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
#Initialize an empty string to store the generated text
partial_text = ""
for new_text in streamer:
# print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
yield history
return partial_text
gr.ChatInterface(chat,
title=title,
description=description,
examples=examples,
cache_examples=True,
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
chatbot=gr.Chatbot(height=300),
textbox=gr.Textbox(placeholder="Chat with me")).queue().launch()