Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
import re | |
# Load the fine-tuned model and tokenizer | |
try: | |
model = GPT2LMHeadModel.from_pretrained("Manasa1/finetuned_distillGPT2") # Path to your fine-tuned GPT-2 model | |
tokenizer = GPT2Tokenizer.from_pretrained("Manasa1/finetuned_distillGPT2") # Path to tokenizer | |
tokenizer.pad_token = tokenizer.eos_token # Ensure pad_token is set correctly | |
except Exception as e: | |
print(f"Error loading model or tokenizer: {e}") | |
exit() | |
# Function to generate an answer to a question | |
def generate_answer(question): | |
if not question.strip(): | |
return "Error: Question cannot be empty." | |
try: | |
prompt = f"Q: {question} A:" | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024) | |
prompt_length = len(inputs["input_ids"][0]) | |
max_new_tokens = 1024 - prompt_length | |
output = model.generate( | |
inputs["input_ids"], | |
max_new_tokens=max_new_tokens, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
top_p=0.9, | |
top_k=50, | |
temperature=0.6, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
answer = tokenizer.decode(output[0], skip_special_tokens=True) | |
return answer[len(prompt):].strip() if answer else "Error: Could not generate a meaningful response." | |
except Exception as e: | |
return f"Error during generation: {e}" | |
# Function to add relevant hashtags and emojis | |
def add_hashtags_and_emojis(tweet): | |
hashtags_and_emojis = { | |
"AI": ["#AI", "๐ค"], | |
"machine learning": ["#MachineLearning", "๐"], | |
"data": ["#DataScience", "๐"], | |
"technology": ["#Tech", "๐ป"], | |
"innovation": ["#Innovation", "โจ"], | |
"coding": ["#Coding", "๐จโ๐ป"], | |
"future": ["#Future", "๐ฎ"], | |
"startup": ["#Startup", "๐"], | |
"sustainability": ["#Sustainability", "๐ฑ"], | |
} | |
tweet_lower = tweet.lower() | |
added_items = [] | |
for keyword, items in hashtags_and_emojis.items(): | |
if keyword in tweet_lower: | |
added_items.extend(items) | |
added_items = list(dict.fromkeys(added_items)) | |
return tweet.strip() + " " + " ".join(added_items) | |
# Function to handle Gradio input and output | |
def generate_tweet_with_hashtags(question): | |
generated_tweet = generate_answer(question) | |
final_tweet = add_hashtags_and_emojis(generated_tweet) | |
return final_tweet | |
# Gradio app | |
with gr.Blocks() as app: | |
gr.Markdown("# AI Tweet Generator with Hashtags and Emojis") | |
gr.Markdown("Enter a question or topic, and the app will generate a tweet and enhance it with relevant hashtags and emojis!") | |
question_input = gr.Textbox(label="Enter your question or topic:") | |
output_tweet = gr.Textbox(label="Generated Tweet with Hashtags and Emojis:", interactive=False) | |
generate_button = gr.Button("Generate Tweet") | |
generate_button.click(generate_tweet_with_hashtags, inputs=[question_input], outputs=[output_tweet]) | |
# Run the app | |
if __name__ == "__main__": | |
app.launch() | |