File size: 1,875 Bytes
114578c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d9e246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import streamlit as st
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration, GPT2LMHeadModel, GPT2Tokenizer

# Load the BLIP model for image captioning
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")

# Load GPT-2 model and tokenizer for story generation
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained("gpt2")
model_gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")

# Streamlit app
st.title("Image to Story Generator")

# Uploading the image
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image is not None:
    image = Image.open(uploaded_image).convert('RGB')
    st.image(image, caption='Uploaded Image', use_column_width=True)

    # Generate image caption
    st.write("Generating caption...")
    inputs = processor(image, return_tensors="pt")
    out = model.generate(**inputs)
    caption = processor.decode(out[0], skip_special_tokens=True)
    st.write(f"Caption: {caption}")

    # Generate story from caption
    if st.button('Generate Story from Caption'):
      st.write("Generating story...")
      story_prompt = f"Based on the image, here's a story: {caption}"

    # Encode prompt text to input ids
    input_ids = tokenizer_gpt2.encode(story_prompt, return_tensors='pt')

    # Generate text using GPT-2 model
    story = model_gpt2.generate(
        input_ids,
        max_length=200,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        temperature=0.9,
        top_k=50,
        top_p=0.95,
        pad_token_id=tokenizer_gpt2.eos_token_id
    )

    # Decode and display the story
    story_text = tokenizer_gpt2.decode(story[0], skip_special_tokens=True)
    st.text_area("Generated Story", story_text, height=250)