import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer import time import numpy as np from torch.nn import functional as F import os from threading import Thread title = "🦅Falcon 🗨️ChatBot" description = "Falcon-RW-1B is a 1B parameters causal decoder-only model built by TII and trained on 350B tokens of RefinedWeb." examples = [["How are you?"]] tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b") model = AutoModelForCausalLM.from_pretrained( "tiiuae/falcon-rw-1b", trust_remote_code=True, torch_dtype=torch.float16, ) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [0] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] def chat(curr_system_message, history): # Initialize a StopOnTokens object stop = StopOnTokens() # Construct the input message string for the model by concatenating the current system message and conversation history messages = curr_system_message + \ "".join(["".join([": "+item[0], ": "+item[1]]) for item in history]) # Tokenize the messages string tokens = tokenizer([messages], return_tensors="pt") streamer = TextIteratorStreamer( tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) token_ids = tokens.input_ids attention_mask=tokens.attention_mask generate_kwargs = dict( input_ids=token_ids, attention_mask = attention_mask, streamer = streamer, max_length=2048, do_sample=True, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, temperature = 0.7, stopping_criteria=StoppingCriteriaList([stop]) ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() #Initialize an empty string to store the generated text partial_text = "" for new_text in streamer: # print(new_text) partial_text += new_text history[-1][1] = partial_text # Yield an empty string to cleanup the message textbox and the updated conversation history yield history return partial_text gr.ChatInterface(chat, title=title, description=description, examples=examples, cache_examples=True, retry_btn=None, undo_btn="Delete Previous", clear_btn="Clear", chatbot=gr.Chatbot(height=300), textbox=gr.Textbox(placeholder="Chat with me")).queue().launch()