Jim_Aiden / app.py
AidenYan's picture
Update app.py
a648d47 verified
import streamlit as st
from transformers import pipeline as transformers_pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
import pandas as pd
import torch
import requests
from PIL import Image
import io
# Load tokenizer and models for similarity and story generation
similarity_tokenizer = AutoTokenizer.from_pretrained("AidenYan/MiniLM_L6_v2_finetuned_ISOM5240_Group27")
similarity_model = AutoModelForSequenceClassification.from_pretrained("AidenYan/MiniLM_L6_v2_finetuned_ISOM5240_Group27")
story_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/genre-story-generator-v2")
story_model = AutoModelForCausalLM.from_pretrained("pranavpsv/genre-story-generator-v2")
# Load the CSV file into a dataframe
labels_df = pd.read_csv("labels_to_image_urls.csv") # Make sure to update this path
def image_to_text_from_url(image_url):
"""
Generates a caption from the image at the given URL using an image-to-text pipeline.
"""
image_to_text_pipeline = transformers_pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
return image_to_text_pipeline(image_url)[0]['generated_text']
def generate_mask_from_result(input_text):
"""
Placeholder for generating a mask from the result. This should be replaced with your actual logic.
"""
# Placeholder logic, replace with actual text processing if needed
return "Processed input: " + input_text
def generate_story_from_text(input_text):
"""
Generates a story based on the input text using a causal language model.
"""
input_ids = story_tokenizer.encode(input_text, return_tensors='pt')
story_outputs = story_model.generate(input_ids, max_length=200, num_return_sequences=1)
return story_tokenizer.decode(story_outputs[0], skip_special_tokens=True)
def select_closest_sentence(generated_text):
"""
Predicts the similarity label for the generated text using the similarity model.
"""
inputs = similarity_tokenizer(generated_text, return_tensors="pt")
outputs = similarity_model(**inputs)
predictions = torch.softmax(outputs.logits, dim=1)
label_id = predictions.argmax().item()
return f"Label_{label_id}"
def get_image_url_for_label(label):
"""
Returns the image URL for a given label from the labels dataframe.
"""
row = labels_df[labels_df['Label'] == label]
if not row.empty:
return row['ImageURL'].values[0]
else:
return None
def display_image_from_url(image_url):
"""
Displays an image in the Streamlit app given its URL.
"""
try:
response = requests.get(image_url)
image = Image.open(io.BytesIO(response.content))
st.image(image, use_column_width=True)
except Exception as e:
st.error(f"Failed to load image from URL: {e}")
def main():
st.title("SmartCart (Product Recommender)")
# User input for text or URL
input_option = st.radio("Select input option:", ("Text", "URL"))
# Handling input via text
if input_option == "Text":
text_input = st.text_input("Enter the text:")
if st.button("Generate Story and Image") and text_input:
processed_text = generate_mask_from_result(text_input)
story_text = generate_story_from_text(processed_text)
st.text_area('Generated Story:', story_text, height=300)
closest_label = select_closest_sentence(processed_text)
image_url = get_image_url_for_label(closest_label)
if image_url:
display_image_from_url(image_url)
# Handling input via image URL
elif input_option == "URL":
image_url = st.text_input("Enter the image URL:")
if st.button("Generate Story and Image") and image_url:
image_text = image_to_text_from_url(image_url)
processed_text = generate_mask_from_result(image_text)
story_text = generate_story_from_text(processed_text)
st.text_area('Generated Story:', story_text, height=300)
closest_label = select_closest_sentence(processed_text)
mapped_image_url = get_image_url_for_label(closest_label)
if mapped_image_url:
display_image_from_url(mapped_image_url)
if __name__ == "__main__":
main()