gradio-server / app.py
caffeinatedcherrychic's picture
Upload folder using huggingface_hub
db328d1 verified
raw
history blame
No virus
2.09 kB
import gradio as gr
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain.memory import ConversationBufferWindowMemory
from peft import PeftModel
import torch
import re
print("Initializing model")
# Initialize the tokenizer and model
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
base_model = AutoModelForCausalLM.from_pretrained(base_model)
ft_model = PeftModel.from_pretrained(base_model, "nuratamton/story_sculptor_mistral")
# ft_model = ft_model.merge_and_unload()
ft_model.eval()
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ft_model.to(device)
memory = ConversationBufferWindowMemory(k=10)
def slow_echo(message, history):
message = chat_interface(message)
for i in range(len(message)):
time.sleep(0.05)
yield message[: i+1]
def chat_interface(user_in):
if user_in.lower() == "quit":
return "Goodbye!"
#memory.save_context({"input": user_in}, {"output": ""})
memory_context = memory.load_memory_variables({})["history"]
user_input = f"[INST] Continue the game and maintain context and keep the story consistent throughout: {memory_context}{user_in}[/INST]"
encodings = tokenizer(user_input, return_tensors="pt", padding=True).to(device)
input_ids = encodings["input_ids"]
attention_mask = encodings["attention_mask"]
output_ids = ft_model.generate(input_ids, attention_mask = attention_mask, max_new_tokens=1000, num_return_sequences=1, do_sample=True, temperature=1.1, top_p=0.9, repetition_penalty=1.2)
generated_ids = output_ids[0, input_ids.shape[-1]:]
# Decode the output
response = tokenizer.decode(generated_ids, skip_special_tokens=True)
memory.save_context({"input": user_in}, {"output": response})
print(f"Game Agent: {response}")
# Your chatbot logic here
# response = "You said: " + user_in
return response
iface = gr.ChatInterface(slow_echo).queue()
iface.launch(share=True)