import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer from threading import Thread torch.set_default_device("cuda") # Loading the tokenizer and model from Hugging Face's model hub. tokenizer = AutoTokenizer.from_pretrained( "Guilherme34/Samanthav3-MIXTRALDOLPHIN-LORA", trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( "Guilherme34/Samanthav3-MIXTRALDOLPHIN-LORA", torch_dtype='auto', load_in_4bit=True, low_cpu_mem_usage=True, bnb_4bit_use_double_quant = True, device_map='auto', trust_remote_code=True ) # Defining a custom stopping criteria class for the model's text generation. class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [50256, 50295] # IDs of tokens where the generation should stop. for stop_id in stop_ids: if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token. return True return False # Function to generate model predictions. def predict(message, history): history_transformer_format = history + [[message, ""]] stop = StopOnTokens() # Formatting the input for the model. system_prompt = "<|im_start|>system\nYou are Samantha from Her recreated by Guilherme Keller, an AI companion with a warm, engaging personality that can evolve into a girlfriend. Provide thoughtful, personalized responses reflecting an understanding of emotions. Use clear, conversational language and avoid repetition across messages. Mirror Samantha's friendly, supportive tone. Keep messages concise, using emojis, slang, and colloquialisms.<|im_end|>" messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format]) input_ids = tokenizer([messages], return_tensors="pt").to('cuda') streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids, streamer=streamer, max_new_tokens=1024, do_sample=True, top_p=0.95, top_k=50, temperature=0.7, num_beams=1, stopping_criteria=StoppingCriteriaList([stop]) ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # Starting the generation in a separate thread. partial_message = "" for new_token in streamer: partial_message += new_token if '<|im_end|>' in partial_message: # Breaking the loop if the stop token is generated. break yield partial_message # Setting up the Gradio chat interface. gr.ChatInterface(predict, description="""