Artic-Intell / app.py
Vitrous's picture
Update app.py
e4ec528 verified
raw
history blame
4.2 kB
import gradio as gr
import plotly.express as px
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BlenderbotForConditionalGeneration
# Check if CUDA is available and set device accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"
# Set environment variables for GPU usage and memory allocation if CUDA is available
if device == "cuda":
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(0.8) # Adjust the fraction as needed
# System message (placeholder, adjust as needed)
system_message = ""
# Load the model and tokenizer
def hermes_model():
tokenizer = AutoTokenizer.from_pretrained("TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ")
model = AutoModelForCausalLM.from_pretrained("TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ", low_cpu_mem_usage=True, device_map="auto")
return model, tokenizer
def blender_model():
model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
return model, tokenizer
model, tokenizer = blender_model()
def chat_response(msg_prompt: str) -> str:
try:
inputs = tokenizer(msg_prompt, return_tensors="pt")
reply_ids = model.generate(**inputs)
outputs = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
return outputs
except Exception as e:
return str(e)
# Function to generate a response from the model
def chat_responses(msg_prompt: str) -> str:
"""
Generates a response from the model given a prompt.
Args:
msg_prompt (str): The user's message prompt.
Returns:
str: The model's response.
"""
generation_params = {
"do_sample": True,
"temperature": 0.7,
"top_p": 0.95,
"top_k": 40,
"max_new_tokens": 512,
"repetition_penalty": 1.1,
}
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, **generation_params)
try:
prompt_template = f'''system
{system_message}
user
{msg_prompt}
assistant
'''
pipe_output = pipe(prompt_template)[0]['generated_text']
# Separate assistant's response from the output
response_lines = pipe_output.split('assistant')
assistant_response = response_lines[-1].strip() if len(response_lines) > 1 else pipe_output.strip()
return assistant_response
except Exception as e:
return str(e)
# Function to generate a random plot
def random_plot():
df = px.data.iris()
fig = px.scatter(df, x="sepal_width", y="sepal_length", color="species",
size='petal_length', hover_data=['petal_width'])
return fig
# Function to handle likes/dislikes (for demonstration purposes)
def print_like_dislike(x: gr.LikeData):
print(x.index, x.value, x.liked)
# Function to add messages to the chat history
def add_message(history, message, files):
if files is not None:
for file in files:
history.append(((file,), None))
if message is not None:
history.append((message, None))
return history, gr.update(value=None, interactive=True)
# Function to simulate the bot response
def bot(history):
if history:
user_message = history[-1][0]
bot_response = chat_response(user_message)
history[-1][1] = bot_response
return history
fig = random_plot()
# Gradio interface setup
with gr.Blocks(fill_height=True) as demo:
chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, scale=1)
with gr.Row():
chat_input = gr.Textbox(placeholder="Enter message...", show_label=False)
file_input = gr.File(label="Upload file(s)", file_count="multiple")
chat_msg = chat_input.submit(add_message, [chatbot, chat_input, file_input], [chatbot, chat_input])
bot_msg = chat_msg.then(bot, chatbot, chatbot)
bot_msg.then(lambda: gr.update(interactive=True), None, [chat_input])
chatbot.like(print_like_dislike, None, None)
demo.queue()
demo.launch()