Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import pipeline as transformers_pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification | |
import pandas as pd | |
import torch | |
import requests | |
from PIL import Image | |
import io | |
# Load tokenizer and models for similarity and story generation | |
similarity_tokenizer = AutoTokenizer.from_pretrained("AidenYan/MiniLM_L6_v2_finetuned_ISOM5240_Group27") | |
similarity_model = AutoModelForSequenceClassification.from_pretrained("AidenYan/MiniLM_L6_v2_finetuned_ISOM5240_Group27") | |
story_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/genre-story-generator-v2") | |
story_model = AutoModelForCausalLM.from_pretrained("pranavpsv/genre-story-generator-v2") | |
# Load the CSV file into a dataframe | |
labels_df = pd.read_csv("labels_to_image_urls.csv") # Make sure to update this path | |
def image_to_text_from_url(image_url): | |
""" | |
Generates a caption from the image at the given URL using an image-to-text pipeline. | |
""" | |
image_to_text_pipeline = transformers_pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
return image_to_text_pipeline(image_url)[0]['generated_text'] | |
def generate_mask_from_result(input_text): | |
""" | |
Placeholder for generating a mask from the result. This should be replaced with your actual logic. | |
""" | |
# Placeholder logic, replace with actual text processing if needed | |
return "Processed input: " + input_text | |
def generate_story_from_text(input_text): | |
""" | |
Generates a story based on the input text using a causal language model. | |
""" | |
input_ids = story_tokenizer.encode(input_text, return_tensors='pt') | |
story_outputs = story_model.generate(input_ids, max_length=200, num_return_sequences=1) | |
return story_tokenizer.decode(story_outputs[0], skip_special_tokens=True) | |
def select_closest_sentence(generated_text): | |
""" | |
Predicts the similarity label for the generated text using the similarity model. | |
""" | |
inputs = similarity_tokenizer(generated_text, return_tensors="pt") | |
outputs = similarity_model(**inputs) | |
predictions = torch.softmax(outputs.logits, dim=1) | |
label_id = predictions.argmax().item() | |
return f"Label_{label_id}" | |
def get_image_url_for_label(label): | |
""" | |
Returns the image URL for a given label from the labels dataframe. | |
""" | |
row = labels_df[labels_df['Label'] == label] | |
if not row.empty: | |
return row['ImageURL'].values[0] | |
else: | |
return None | |
def display_image_from_url(image_url): | |
""" | |
Displays an image in the Streamlit app given its URL. | |
""" | |
try: | |
response = requests.get(image_url) | |
image = Image.open(io.BytesIO(response.content)) | |
st.image(image, use_column_width=True) | |
except Exception as e: | |
st.error(f"Failed to load image from URL: {e}") | |
def main(): | |
st.title("SmartCart (Product Recommender)") | |
# User input for text or URL | |
input_option = st.radio("Select input option:", ("Text", "URL")) | |
# Handling input via text | |
if input_option == "Text": | |
text_input = st.text_input("Enter the text:") | |
if st.button("Generate Story and Image") and text_input: | |
processed_text = generate_mask_from_result(text_input) | |
story_text = generate_story_from_text(processed_text) | |
st.text_area('Generated Story:', story_text, height=300) | |
closest_label = select_closest_sentence(processed_text) | |
image_url = get_image_url_for_label(closest_label) | |
if image_url: | |
display_image_from_url(image_url) | |
# Handling input via image URL | |
elif input_option == "URL": | |
image_url = st.text_input("Enter the image URL:") | |
if st.button("Generate Story and Image") and image_url: | |
image_text = image_to_text_from_url(image_url) | |
processed_text = generate_mask_from_result(image_text) | |
story_text = generate_story_from_text(processed_text) | |
st.text_area('Generated Story:', story_text, height=300) | |
closest_label = select_closest_sentence(processed_text) | |
mapped_image_url = get_image_url_for_label(closest_label) | |
if mapped_image_url: | |
display_image_from_url(mapped_image_url) | |
if __name__ == "__main__": | |
main() | |