import streamlit as st from transformers import pipeline as transformers_pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification import pandas as pd import torch import requests from PIL import Image import io # Load tokenizer and models for similarity and story generation similarity_tokenizer = AutoTokenizer.from_pretrained("AidenYan/MiniLM_L6_v2_finetuned_ISOM5240_Group27") similarity_model = AutoModelForSequenceClassification.from_pretrained("AidenYan/MiniLM_L6_v2_finetuned_ISOM5240_Group27") story_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/genre-story-generator-v2") story_model = AutoModelForCausalLM.from_pretrained("pranavpsv/genre-story-generator-v2") # Load the CSV file into a dataframe labels_df = pd.read_csv("labels_to_image_urls.csv") # Make sure to update this path def image_to_text_from_url(image_url): """ Generates a caption from the image at the given URL using an image-to-text pipeline. """ image_to_text_pipeline = transformers_pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") return image_to_text_pipeline(image_url)[0]['generated_text'] def generate_mask_from_result(input_text): """ Placeholder for generating a mask from the result. This should be replaced with your actual logic. """ # Placeholder logic, replace with actual text processing if needed return "Processed input: " + input_text def generate_story_from_text(input_text): """ Generates a story based on the input text using a causal language model. """ input_ids = story_tokenizer.encode(input_text, return_tensors='pt') story_outputs = story_model.generate(input_ids, max_length=200, num_return_sequences=1) return story_tokenizer.decode(story_outputs[0], skip_special_tokens=True) def select_closest_sentence(generated_text): """ Predicts the similarity label for the generated text using the similarity model. """ inputs = similarity_tokenizer(generated_text, return_tensors="pt") outputs = similarity_model(**inputs) predictions = torch.softmax(outputs.logits, dim=1) label_id = predictions.argmax().item() return f"Label_{label_id}" def get_image_url_for_label(label): """ Returns the image URL for a given label from the labels dataframe. """ row = labels_df[labels_df['Label'] == label] if not row.empty: return row['ImageURL'].values[0] else: return None def display_image_from_url(image_url): """ Displays an image in the Streamlit app given its URL. """ try: response = requests.get(image_url) image = Image.open(io.BytesIO(response.content)) st.image(image, use_column_width=True) except Exception as e: st.error(f"Failed to load image from URL: {e}") def main(): st.title("SmartCart (Product Recommender)") # User input for text or URL input_option = st.radio("Select input option:", ("Text", "URL")) # Handling input via text if input_option == "Text": text_input = st.text_input("Enter the text:") if st.button("Generate Story and Image") and text_input: processed_text = generate_mask_from_result(text_input) story_text = generate_story_from_text(processed_text) st.text_area('Generated Story:', story_text, height=300) closest_label = select_closest_sentence(processed_text) image_url = get_image_url_for_label(closest_label) if image_url: display_image_from_url(image_url) # Handling input via image URL elif input_option == "URL": image_url = st.text_input("Enter the image URL:") if st.button("Generate Story and Image") and image_url: image_text = image_to_text_from_url(image_url) processed_text = generate_mask_from_result(image_text) story_text = generate_story_from_text(processed_text) st.text_area('Generated Story:', story_text, height=300) closest_label = select_closest_sentence(processed_text) mapped_image_url = get_image_url_for_label(closest_label) if mapped_image_url: display_image_from_url(mapped_image_url) if __name__ == "__main__": main()