Spaces:
Sleeping
Sleeping
import requests | |
from PIL import Image | |
import gradio as gr | |
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, TextGenerationPipeline | |
import torch | |
import tempfile | |
import os | |
os.system("pip uninstall -y gradio") | |
os.system("pip install gradio==3.50") | |
# Initialize the device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load the image captioning model and processor | |
image_processor = BlipProcessor.from_pretrained("AdhamEhab/ImageCaptioning") | |
image_model = BlipForConditionalGeneration.from_pretrained("AdhamEhab/ImageCaptioning").to(device) | |
# Load the story generation model and tokenizer | |
story_generation_model = GPT2LMHeadModel.from_pretrained("AdhamEhab/StoryGen") | |
story_generation_tokenizer = AutoTokenizer.from_pretrained("AdhamEhab/StoryGen") | |
generator = TextGenerationPipeline(model=story_generation_model, tokenizer=story_generation_tokenizer) | |
# Define a function to generate caption from an image | |
def generate_caption(image): | |
try: | |
# If image is a file object, extract the file path | |
if isinstance(image, tempfile._TemporaryFileWrapper): | |
image_path = image.name | |
else: | |
image_path = image | |
# Load and preprocess the image | |
image = Image.open(image_path) | |
inputs = image_processor(image, return_tensors="pt", padding="max_length", truncation=True) | |
# Generate caption | |
with torch.no_grad(): | |
caption_ids = image_model.generate(**inputs.to(device)) | |
# Decode the caption | |
caption = image_processor.decode(caption_ids[0], skip_special_tokens=True) | |
return caption | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Define a function to generate a story based on a prompt | |
def generate_story(prompt): | |
try: | |
input_prompt = prompt | |
story = generator(input_prompt, max_length=200, do_sample=True)[0]['generated_text'] | |
return story | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Create Gradio interfaces | |
image_caption_interface = gr.Interface( | |
fn=generate_caption, | |
inputs=gr.inputs.Image(type="filepath", label="Upload Image"), | |
outputs="text", | |
title="Image Captioning", | |
description="Generate a caption for the provided image." | |
) | |
story_generation_interface = gr.Interface( | |
fn=generate_story, | |
inputs="text", | |
outputs="text", | |
title="Story Generation", | |
description="Generate a story based on the provided prompt." | |
) | |
# Create Gradio interfaces | |
with gr.Blocks(css="footer{display:none !important}") as combined_interface: | |
gr.Markdown( | |
""" | |
# The Seer's Legacy | |
Crafting tales from the fabric of imagination. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.inputs.Image(type="filepath", label="Upload Image") | |
image_output = gr.Text(label="Image Caption") | |
image_btn = gr.Button("Generate Image Caption") | |
with gr.Column(): | |
text_input = gr.inputs.Textbox(label="Story Prompt") | |
story_output = gr.Text(label="Generated Story") | |
story_btn = gr.Button("Generate Story") | |
image_btn.click(generate_caption, inputs=[image_input], outputs=[image_output]) | |
story_btn.click(generate_story, inputs=[text_input], outputs=[story_output]) | |
# Launch the combined interface | |
combined_interface.launch() | |