Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
from itertools import cycle | |
import transformers | |
import pandas as pd | |
import numpy as np | |
import os | |
import torch | |
import skimage | |
import requests | |
import numpy as np | |
import pandas as pd | |
from PIL import Image | |
from io import BytesIO | |
from datasets import load_dataset | |
from collections import OrderedDict | |
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import random | |
# from hugchat import hugchat | |
# from hugchat.login import Login | |
# App title | |
def get_model_info(model_ID, device): | |
model = CLIPModel.from_pretrained(model_ID).to(device) | |
processor = CLIPProcessor.from_pretrained(model_ID) | |
tokenizer = CLIPTokenizer.from_pretrained(model_ID) | |
return model, processor, tokenizer | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_ID = "./fashion-clip" | |
model, processor, tokenizer = get_model_info(model_ID, device) | |
data = pd.read_pickle("./code/image_info.pkl") | |
st.set_page_config(page_title="ππ¬ MuseChat") | |
def random_choice_gen(): | |
return np.random.choice(np.arange(50.25,150.75,4.5), size=1)[0] | |
# Hugging Face Credentials | |
with st.sidebar: | |
st.title('ππ¬ MuseChat') | |
hf_email = st.text_input('Enter E-mail:', type='password') | |
hf_pass = st.text_input('Enter password:', type='password') | |
if not (hf_email and hf_pass): | |
st.warning('Please enter your credentials!', icon='β οΈ') | |
else: | |
st.success('Proceed to entering your prompt message!', icon='π') | |
st.markdown('Interact with MuseChat!') | |
# Store LLM generated responses | |
if "messages" not in st.session_state.keys(): | |
st.session_state.messages = [{"role": "assistant", "content": ("How may I help you?","How may I help you?")}] | |
def generate_output(filteredImages,caption): | |
cols = cycle(st.columns(3)) | |
for idx, filteredImage in enumerate(filteredImages): | |
next(cols).image(filteredImage, width=150,caption=caption[idx]) | |
old_context=[] | |
# Display chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
if message["role"]=="user": | |
st.write(message["content"]) | |
old_context.append(message["content"]) | |
else: | |
if message["content"][1]=="How may I help you?": | |
st.write(message["content"][0]) | |
else: | |
generate_output(message["content"][0],message["content"][1]) | |
# st.image(message["content"][0], caption=message["content"][1]) | |
# print("old_context" ,old_context) | |
# print("message[content]" , message["content"]) | |
# Function for generating LLM response | |
def get_single_text_embedding(text): | |
inputs = tokenizer(text, return_tensors = "pt") | |
text_embeddings = model.get_text_features(**inputs) | |
# convert the embeddings to numpy array | |
embedding_as_np = text_embeddings.cpu().detach().numpy() | |
return embedding_as_np | |
def correct_paths(a): | |
p=a.split("/") | |
k=p[-1] | |
strt = "./fashion-dataset/images/"+k | |
#print(k) | |
if len(k) ==9 : | |
strt = "./fashion-dataset/images"+k[0]+"/"+k | |
return strt | |
def generate_response(prompt_input, email, passwd): | |
top_K = 6 | |
text_embeddings = get_single_text_embedding(prompt_input) | |
data["cos_sim"] = data["img_embeddings"].apply(lambda x: cosine_similarity(text_embeddings, x)) | |
data["cos_sim"] = data["cos_sim"].apply(lambda x: x[0][0]) | |
most_similar_articles = data.sort_values(by='cos_sim', ascending=False)[0:top_K] | |
most_similar_articles['image_path2']=most_similar_articles.apply(lambda x:correct_paths(x.image_path),axis=1) | |
relevant_columns = ['image_path2','cos_sim'] | |
most_similar_articles=most_similar_articles[relevant_columns] | |
image_list = list(most_similar_articles['image_path2']) | |
resp = [] | |
for i in range(len(image_list)): | |
ret_val=random_choice_gen() | |
strt_price="Price : $ "+str(ret_val) | |
resp.append(strt_price) | |
return image_list,resp | |
flag_refresh=0 | |
#filteredImages = [] # your images here | |
#caption = [] # your caption here | |
#cols = cycle(st.columns(4)) # st.columns here since it is out of beta at the time I'm writing this | |
#for idx, filteredImage in enumerate(filteredImages): | |
# next(cols).image(filteredImage, width=150, caption=caption[idx]) | |
#User-provided prompt | |
if prompt := st.chat_input(disabled=not (hf_email and hf_pass)): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.write(prompt) | |
print(prompt) | |
print("old_context", old_context) | |
if prompt in ("hi","HI", "How are you?","Hi","refresh"): | |
tt=len(old_context) | |
flag_refresh=tt | |
# print("refreshing") | |
if flag_refresh>0:old_context=old_context[flag_refresh:] | |
# print("old_context", old_context) | |
# Generate a new response if last message is not from assistant | |
if st.session_state.messages[-1]["role"] != "assistant": | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
if len(old_context)>0: | |
older_prompt=" ".join(old_context[-3:]) | |
prompt=older_prompt+ " " + prompt | |
print(prompt) | |
response,caption = generate_response(prompt, hf_email, hf_pass) | |
generate_output(response,caption) | |
message = {"role": "assistant", "content": (response,caption)} | |
st.session_state.messages.append(message) | |