Seers_Legacy / app.py
AdhamEhab's picture
Update app.py
8209b0c verified
raw
history blame
3.44 kB
import requests
from PIL import Image
import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, TextGenerationPipeline
import torch
import tempfile
import os
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.50")
# Initialize the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the image captioning model and processor
image_processor = BlipProcessor.from_pretrained("AdhamEhab/ImageCaptioning")
image_model = BlipForConditionalGeneration.from_pretrained("AdhamEhab/ImageCaptioning").to(device)
# Load the story generation model and tokenizer
story_generation_model = GPT2LMHeadModel.from_pretrained("AdhamEhab/StoryGen")
story_generation_tokenizer = AutoTokenizer.from_pretrained("AdhamEhab/StoryGen")
generator = TextGenerationPipeline(model=story_generation_model, tokenizer=story_generation_tokenizer)
# Define a function to generate caption from an image
def generate_caption(image):
try:
# If image is a file object, extract the file path
if isinstance(image, tempfile._TemporaryFileWrapper):
image_path = image.name
else:
image_path = image
# Load and preprocess the image
image = Image.open(image_path)
inputs = image_processor(image, return_tensors="pt", padding="max_length", truncation=True)
# Generate caption
with torch.no_grad():
caption_ids = image_model.generate(**inputs.to(device))
# Decode the caption
caption = image_processor.decode(caption_ids[0], skip_special_tokens=True)
return caption
except Exception as e:
return f"An error occurred: {str(e)}"
# Define a function to generate a story based on a prompt
def generate_story(prompt):
try:
input_prompt = prompt
story = generator(input_prompt, max_length=200, do_sample=True)[0]['generated_text']
return story
except Exception as e:
return f"An error occurred: {str(e)}"
# Create Gradio interfaces
image_caption_interface = gr.Interface(
fn=generate_caption,
inputs=gr.inputs.Image(type="filepath", label="Upload Image"),
outputs="text",
title="Image Captioning",
description="Generate a caption for the provided image."
)
story_generation_interface = gr.Interface(
fn=generate_story,
inputs="text",
outputs="text",
title="Story Generation",
description="Generate a story based on the provided prompt."
)
# Create Gradio interfaces
with gr.Blocks(css="footer{display:none !important}") as combined_interface:
gr.Markdown(
"""
# The Seer's Legacy
Crafting tales from the fabric of imagination.
""")
with gr.Row():
with gr.Column():
image_input = gr.inputs.Image(type="filepath", label="Upload Image")
image_output = gr.Text(label="Image Caption")
image_btn = gr.Button("Generate Image Caption")
with gr.Column():
text_input = gr.inputs.Textbox(label="Story Prompt")
story_output = gr.Text(label="Generated Story")
story_btn = gr.Button("Generate Story")
image_btn.click(generate_caption, inputs=[image_input], outputs=[image_output])
story_btn.click(generate_story, inputs=[text_input], outputs=[story_output])
# Launch the combined interface
combined_interface.launch()