Spaces:
Runtime error
Runtime error
import gradio as gr | |
import time | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from langchain.memory import ConversationBufferWindowMemory | |
from peft import PeftModel | |
import torch | |
import re | |
print("Initializing model") | |
# Initialize the tokenizer and model | |
base_model = "mistralai/Mistral-7B-Instruct-v0.2" | |
tokenizer = AutoTokenizer.from_pretrained(base_model) | |
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
base_model = AutoModelForCausalLM.from_pretrained(base_model) | |
ft_model = PeftModel.from_pretrained(base_model, "nuratamton/story_sculptor_mistral") | |
# ft_model = ft_model.merge_and_unload() | |
ft_model.eval() | |
# Set the device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
ft_model.to(device) | |
memory = ConversationBufferWindowMemory(k=10) | |
def slow_echo(message, history): | |
message = chat_interface(message) | |
for i in range(len(message)): | |
time.sleep(0.05) | |
yield message[: i+1] | |
def chat_interface(user_in): | |
if user_in.lower() == "quit": | |
return "Goodbye!" | |
#memory.save_context({"input": user_in}, {"output": ""}) | |
memory_context = memory.load_memory_variables({})["history"] | |
user_input = f"[INST] Continue the game and maintain context and keep the story consistent throughout: {memory_context}{user_in}[/INST]" | |
encodings = tokenizer(user_input, return_tensors="pt", padding=True).to(device) | |
input_ids = encodings["input_ids"] | |
attention_mask = encodings["attention_mask"] | |
output_ids = ft_model.generate(input_ids, attention_mask = attention_mask, max_new_tokens=1000, num_return_sequences=1, do_sample=True, temperature=1.1, top_p=0.9, repetition_penalty=1.2) | |
generated_ids = output_ids[0, input_ids.shape[-1]:] | |
# Decode the output | |
response = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
memory.save_context({"input": user_in}, {"output": response}) | |
print(f"Game Agent: {response}") | |
# Your chatbot logic here | |
# response = "You said: " + user_in | |
return response | |
iface = gr.ChatInterface(slow_echo).queue() | |
iface.launch(share=True) | |