Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import re | |
from gtts import gTTS | |
import os | |
import logging | |
# Set up logging | |
logging.basicConfig(filename='app.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Function to set up the model and tokenizer | |
def setup_model(model_name): | |
logging.info('Setting up model and tokenizer.') | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
trust_remote_code=False, | |
revision="main" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
model.eval() | |
logging.info('Model and tokenizer setup completed.') | |
return model, tokenizer | |
# Function to generate a response from the model | |
def generate_response(model, tokenizer, prompt, max_new_tokens=140): | |
logging.info('Generating response for the prompt.') | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=max_new_tokens) | |
response = tokenizer.batch_decode(outputs)[0] | |
# Extract only the response part (assuming everything after the last newline belongs to the response) | |
response_parts = response.split("\n") | |
logging.info('Response generated.') | |
return response_parts[-1] # Return the last element (response) | |
# Function to remove various tags using regular expressions | |
def remove_tags(text): | |
logging.info('Removing tags from the text.') | |
# Combine multiple tag removal patterns for broader coverage | |
tag_regex = r"<[^>]*>" # Standard HTML tags | |
custom_tag_regex = r"<.*?>|\[.*?\]|{\s*?\(.*?\)\s*}" # Custom, non-standard tags (may need adjustments) | |
all_tags_regex = f"{tag_regex}|{custom_tag_regex}" # Combine patterns | |
cleaned_text = re.sub(all_tags_regex, "", text) | |
logging.info('Tags removed.') | |
return cleaned_text | |
# Function to generate the audio file | |
def text_to_speech(text, filename="response.mp3"): | |
logging.info('Generating speech audio file.') | |
tts = gTTS(text) | |
tts.save(filename) | |
logging.info('Speech audio file saved.') | |
return filename | |
# Main function for the Gradio app | |
def main(comment): | |
logging.info('Main function triggered.') | |
instructions_string = ( | |
"virtual marketer assistant, communicates in business, focused on services, " | |
"escalating to technical depth upon request. It reacts to feedback aptly and ends responses " | |
"with its signature Mr.jon will tailor the length of its responses to match the individual's comment, " | |
"providing concise acknowledgments to brief expressions of gratitude or feedback, thus keeping the interaction natural and supportive.\n" | |
) | |
model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ" | |
try: | |
model, tokenizer = setup_model(model_name) | |
if comment: | |
prompt_template = lambda comment: f"[INST] {instructions_string} \n{comment} \n[/INST]" | |
prompt = prompt_template(comment) | |
response = generate_response(model, tokenizer, prompt) | |
# Apply tag removal before displaying the response | |
response_without_tags = remove_tags(response) | |
# Remove the "[/INST]" string at the end (assuming it's always present) | |
response_without_inst = response_without_tags.rstrip("[/INST]") | |
# Generate and return the response and the audio file | |
audio_file = text_to_speech(response_without_inst) | |
logging.info('Response and audio file generated.') | |
return response_without_inst, audio_file | |
else: | |
logging.warning('No comment entered.') | |
return "Please enter a comment to generate a response.", None | |
except Exception as e: | |
logging.error(f'Error occurred: {str(e)}') | |
return "An error occurred. Please try again later.", None | |
iface = gr.Interface( | |
fn=main, | |
inputs=gr.Textbox(lines=2, placeholder="Enter a comment..."), | |
outputs=["text", "file"], | |
title="Virtual Marketer Assistant", | |
description="Enter a comment and get a response from the virtual marketer assistant. Download the response as an MP3 file." | |
) | |
if __name__ == "__main__": | |
iface.launch(share=True) | |