|
import streamlit as st |
|
import numpy as np |
|
|
|
from st_btn_select import st_btn_select |
|
from streamlit_option_menu import option_menu |
|
|
|
from cgi import test |
|
import streamlit as st |
|
import pandas as pd |
|
from PIL import Image |
|
import os |
|
import glob |
|
|
|
from transformers import CLIPVisionModel, AutoTokenizer, AutoModel |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
|
|
import torch |
|
from tqdm import tqdm |
|
from PIL import Image |
|
import numpy as np |
|
from torch.utils.data import DataLoader |
|
from transformers import default_data_collator |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
import torchvision.transforms as transforms |
|
|
|
from bokeh.models.widgets import Button |
|
from bokeh.models import CustomJS |
|
from streamlit_bokeh_events import streamlit_bokeh_events |
|
|
|
from webcam import webcam |
|
|
|
|
|
MP3_ROOT_PATH = "sample_mp3/" |
|
SPECTROGRAMS_PATH = "sample_spectrograms/" |
|
|
|
IMAGE_SIZE = 224 |
|
MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]) |
|
STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]) |
|
|
|
TEXT_MODEL = 'bert-base-uncased' |
|
|
|
CLIP_TEXT_MODEL_PATH = "text_model/" |
|
CLIP_VISION_MODEL_PATH = "vision_model/" |
|
|
|
|
|
def streamlit_menu(example=1): |
|
if example == 1: |
|
|
|
with st.sidebar: |
|
selected = option_menu( |
|
menu_title="Main Menu", |
|
options=["Text", "Audio", "Camera"], |
|
icons=["chat-text", "mic", "camera"], |
|
menu_icon="cast", |
|
default_index=0, |
|
) |
|
return selected |
|
|
|
if example == 2: |
|
|
|
selected = option_menu( |
|
menu_title=None, |
|
options=["Text", "Audio", "Camera"], |
|
icons=["chat-text", "mic", "camera"], |
|
menu_icon="cast", |
|
default_index=0, |
|
orientation="horizontal", |
|
) |
|
return selected |
|
|
|
if example == 3: |
|
|
|
selected = option_menu( |
|
menu_title=None, |
|
options=["Text", "Audio", "Camera"], |
|
icons=["chat-text", "mic", "camera"], |
|
menu_icon="cast", |
|
default_index=0, |
|
orientation="horizontal", |
|
styles={ |
|
"container": {"padding": "0!important", "background-color": "#fafafa"}, |
|
"icon": {"color": "#ffde59", "font-size": "25px"}, |
|
"nav-link": { |
|
"font-size": "25px", |
|
"text-align": "left", |
|
"margin": "0px", |
|
"--hover-color": "#eee", |
|
}, |
|
"nav-link-selected": {"background-color": "#5271ff"}, |
|
}, |
|
) |
|
return selected |
|
|
|
|
|
|
|
def draw_sidebar( |
|
key, |
|
plot=False, |
|
): |
|
|
|
st.write( |
|
""" |
|
# Sidebar |
|
|
|
```python |
|
Think. |
|
Search. |
|
Feel. |
|
``` |
|
""" |
|
) |
|
|
|
st.slider("From 1 to 10, how cool is this app?", min_value=1, max_value=10, key=key) |
|
|
|
option = st_btn_select(('option1', 'option2', 'option3'), index=2) |
|
st.write(f'Selected option: {option}') |
|
|
|
|
|
|
|
|
|
|
|
class VisionDataset(Dataset): |
|
preprocess = transforms.Compose([ |
|
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=MEAN, std=STD) |
|
]) |
|
|
|
def __init__(self, image_paths: list): |
|
self.image_paths = image_paths |
|
|
|
def __getitem__(self, idx): |
|
return self.preprocess(Image.open(self.image_paths[idx]).convert('RGB')) |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
|
|
class TextDataset(Dataset): |
|
def __init__(self, text: list, tokenizer, max_len): |
|
self.len = len(text) |
|
self.tokens = tokenizer(text, padding='max_length', |
|
max_length=max_len, truncation=True) |
|
|
|
def __getitem__(self, idx): |
|
token = self.tokens[idx] |
|
return {'input_ids': token.ids, 'attention_mask': token.attention_mask} |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
|
|
class CLIPDemo: |
|
def __init__(self, vision_encoder, text_encoder, tokenizer, |
|
batch_size: int = 64, max_len: int = 64, device='cuda'): |
|
""" Initializes CLIPDemo |
|
it has the following functionalities: |
|
image_search: Search images based on text query |
|
zero_shot: Zero shot image classification |
|
analogy: Analogies with embedding space arithmetic. |
|
|
|
Args: |
|
vision_encoder: Fine-tuned vision encoder |
|
text_encoder: Fine-tuned text encoder |
|
tokenizer: Transformers tokenizer |
|
device (torch.device): Running device |
|
batch_size (int): Size of mini-batches used to embeddings |
|
max_length (int): Tokenizer max length |
|
|
|
Example: |
|
>>> demo = CLIPDemo(vision_encoder, text_encoder, tokenizer) |
|
>>> demo.compute_image_embeddings(test_df.image.to_list()) |
|
>>> demo.image_search('یک مرد و یک زن') |
|
>>> demo.zero_shot('./workers.jpg') |
|
>>> demo.anology('./sunset.jpg', additional_text='دریا') |
|
""" |
|
self.vision_encoder = vision_encoder.eval().to(device) |
|
self.text_encoder = text_encoder.eval().to(device) |
|
self.batch_size = batch_size |
|
self.device = device |
|
self.tokenizer = tokenizer |
|
self.max_len = max_len |
|
self.text_embeddings_ = None |
|
self.image_embeddings_ = None |
|
|
|
|
|
def compute_image_embeddings(self, image_paths: list): |
|
self.image_paths = image_paths |
|
dataloader = DataLoader(VisionDataset( |
|
image_paths=image_paths), batch_size=self.batch_size) |
|
embeddings = [] |
|
with torch.no_grad(): |
|
|
|
bar = st.progress(0) |
|
for i, images in tqdm(enumerate(dataloader), desc='computing image embeddings'): |
|
bar.progress(int(i/len(dataloader)*100)) |
|
image_embedding = self.vision_encoder( |
|
pixel_values=images.to(self.device)).pooler_output |
|
embeddings.append(image_embedding) |
|
bar.empty() |
|
self.image_embeddings_ = torch.cat(embeddings) |
|
|
|
def compute_text_embeddings(self, text: list): |
|
self.text = text |
|
dataloader = DataLoader(TextDataset(text=text, tokenizer=self.tokenizer, max_len=self.max_len), |
|
batch_size=self.batch_size, collate_fn=default_data_collator) |
|
embeddings = [] |
|
with torch.no_grad(): |
|
for tokens in tqdm(dataloader, desc='computing text embeddings'): |
|
image_embedding = self.text_encoder(input_ids=tokens["input_ids"].to(self.device), |
|
attention_mask=tokens["attention_mask"].to(self.device)).pooler_output |
|
embeddings.append(image_embedding) |
|
self.text_embeddings_ = torch.cat(embeddings) |
|
|
|
def text_query_embedding(self, query: str = 'A happy song'): |
|
tokens = self.tokenizer(query, return_tensors='pt') |
|
with torch.no_grad(): |
|
text_embedding = self.text_encoder(input_ids=tokens["input_ids"].to(self.device), |
|
attention_mask=tokens["attention_mask"].to(self.device)).pooler_output |
|
return text_embedding |
|
|
|
def most_similars(self, embeddings_1, embeddings_2): |
|
values, indices = torch.cosine_similarity( |
|
embeddings_1, embeddings_2).sort(descending=True) |
|
return values.cpu(), indices.cpu() |
|
|
|
|
|
def image_search(self, query: str, top_k=10): |
|
""" Search images based on text query |
|
Args: |
|
query (str): text query |
|
image_paths (list[str]): a bunch of image paths |
|
top_k (int): number of relevant images |
|
""" |
|
query_embedding = self.text_query_embedding(query=query) |
|
_, indices = self.most_similars(self.image_embeddings_, query_embedding) |
|
|
|
matches = np.array(self.image_paths)[indices][:top_k] |
|
songs_path = [] |
|
for match in matches: |
|
filename = os.path.split(match)[1] |
|
filename = int(filename.replace(".jpeg", "")) |
|
audio_path = MP3_ROOT_PATH + "/" + f"{filename:06d}" |
|
songs_path.append(audio_path) |
|
return songs_path |
|
|
|
|
|
def draw_text( |
|
key, |
|
plot=False, |
|
device=None, |
|
): |
|
|
|
|
|
image = Image.open("data/logo.png") |
|
st.image(image, use_column_width="always") |
|
|
|
if 'model' not in st.session_state: |
|
|
|
text_encoder = AutoModel.from_pretrained(CLIP_TEXT_MODEL_PATH, local_files_only=True) |
|
vision_encoder = CLIPVisionModel.from_pretrained(CLIP_VISION_MODEL_PATH, local_files_only=True).to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL) |
|
model = CLIPDemo(vision_encoder=vision_encoder, text_encoder=text_encoder, tokenizer=tokenizer, device=device) |
|
model.compute_image_embeddings(glob.glob(SPECTROGRAMS_PATH + "/*.jpeg")[:1000]) |
|
st.session_state["model"] = model |
|
|
|
|
|
"" |
|
"" |
|
|
|
moods = ['-', 'angry', 'calm', 'happy', 'sad'] |
|
genres = ['-', 'house', 'pop', 'rock', 'techno'] |
|
artists = ['-', 'bad dad', 'lazy magnet', 'the astronauts', 'yan yalego'] |
|
years = ['-', '80s', '90s', '2000s', '2010s'] |
|
|
|
col1, col2 = st.columns(2) |
|
mood = col1.selectbox('Which mood do you feel right now?', moods, help="Select a mood here") |
|
genre = col2.selectbox('Which genre do you want to listen?', genres, help="Select a genre here") |
|
artist = col1.selectbox('Which artist do you like best?', artists, help="Select an artist here") |
|
year = col2.selectbox('Which period do you want to relive?', years, help="Select a period here") |
|
button_form = st.button('Search', key="button_form") |
|
|
|
st.text_input("Otherwise, describe the song you are looking for!", value="", key="sentence") |
|
button_sentence = st.button('Search', key="button_sentence") |
|
|
|
if (button_sentence and st.session_state.sentence != "") or (button_form and not (mood == "-" and artist == "-" and genre == "-" and year == "-")): |
|
if button_sentence: |
|
sentence = st.session_state.sentence |
|
elif button_form: |
|
sentence = mood if mood != "-" else "" |
|
sentence = sentence + " " + genre if genre != "-" else sentence |
|
sentence = sentence + " " + artist if artist != "-" else sentence |
|
sentence = sentence + " " + year if year != "-" else sentence |
|
|
|
song_paths = st.session_state.model.image_search(sentence) |
|
for song in song_paths: |
|
song_name = df.loc[df['track_id'] == int(song[-6:])]['track_title'].to_list()[0] |
|
artist_name = df.loc[df['track_id'] == int(song[-6:])]['artist_name'].to_list()[0] |
|
st.write('**"'+song_name+'"**' + ' by ' + artist_name) |
|
st.audio(song + ".ogg", format="audio/ogg", start_time=0) |
|
|
|
|
|
def draw_audio( |
|
key, |
|
plot=False, |
|
device=None, |
|
): |
|
|
|
image = Image.open("data/logo.png") |
|
st.image(image, use_column_width="always") |
|
|
|
if 'model' not in st.session_state: |
|
|
|
text_encoder = AutoModel.from_pretrained(CLIP_TEXT_MODEL_PATH, local_files_only=True) |
|
vision_encoder = CLIPVisionModel.from_pretrained(CLIP_VISION_MODEL_PATH, local_files_only=True).to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL) |
|
model = CLIPDemo(vision_encoder=vision_encoder, text_encoder=text_encoder, tokenizer=tokenizer, device=device) |
|
model.compute_image_embeddings(glob.glob(SPECTROGRAMS_PATH + "/*.jpeg")[:1000]) |
|
st.session_state["model"] = model |
|
|
|
|
|
|
|
|
|
"" |
|
"" |
|
|
|
st.write("Please, describe the kind of song you are looking for!") |
|
stt_button = Button(label="Start Recording", margin=[5,5,5,200], width=200, default_size=10, width_policy='auto', button_type='primary') |
|
|
|
stt_button.js_on_event("button_click", CustomJS(code=""" |
|
var recognition = new webkitSpeechRecognition(); |
|
recognition.continuous = false; |
|
recognition.interimResults = true; |
|
|
|
recognition.onresult = function (e) { |
|
var value = ""; |
|
for (var i = e.resultIndex; i < e.results.length; ++i) { |
|
if (e.results[i].isFinal) { |
|
value += e.results[i][0].transcript; |
|
} |
|
} |
|
if ( value != "") { |
|
document.dispatchEvent(new CustomEvent("GET_TEXT", {detail: value})); |
|
} |
|
} |
|
recognition.start(); |
|
""")) |
|
|
|
|
|
result = streamlit_bokeh_events( |
|
stt_button, |
|
events="GET_TEXT", |
|
key="listen", |
|
refresh_on_update=False, |
|
override_height=75, |
|
debounce_time=0) |
|
|
|
if result: |
|
if "GET_TEXT" in result: |
|
sentence = result.get("GET_TEXT") |
|
st.write('You asked for: "' + sentence + '"') |
|
|
|
song_paths = st.session_state.model.image_search(sentence) |
|
for song in song_paths: |
|
song_name = df.loc[df['track_id'] == int(song[-6:])]['track_title'].to_list()[0] |
|
artist_name = df.loc[df['track_id'] == int(song[-6:])]['artist_name'].to_list()[0] |
|
st.write('**"'+song_name+'"**' + ' by ' + artist_name) |
|
st.audio(song + ".ogg", format="audio/ogg", start_time=0) |
|
|
|
|
|
def draw_camera( |
|
key, |
|
plot=False, |
|
device=None, |
|
): |
|
|
|
image = Image.open("data/logo.png") |
|
st.image(image, use_column_width="always") |
|
|
|
if 'model' not in st.session_state: |
|
|
|
text_encoder = AutoModel.from_pretrained(CLIP_TEXT_MODEL_PATH, local_files_only=True) |
|
vision_encoder = CLIPVisionModel.from_pretrained(CLIP_VISION_MODEL_PATH, local_files_only=True).to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL) |
|
model = CLIPDemo(vision_encoder=vision_encoder, text_encoder=text_encoder, tokenizer=tokenizer, device=device) |
|
model.compute_image_embeddings(glob.glob(SPECTROGRAMS_PATH + "/*.jpeg")[:1000]) |
|
st.session_state["model"] = model |
|
|
|
|
|
|
|
|
|
"" |
|
"" |
|
|
|
st.write("Please, show us how you are feeling today!") |
|
captured_image = webcam() |
|
if captured_image is None: |
|
st.write("Waiting for capture...") |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captured_image = captured_image.convert("RGB") |
|
|
|
vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") |
|
vit_model = ViTForImageClassification.from_pretrained("ViT_ER/best_checkpoint", local_files_only=True) |
|
inputs = vit_feature_extractor(images=[captured_image], return_tensors="pt") |
|
outputs = vit_model(**inputs, output_hidden_states=True) |
|
|
|
emotions = ['Anger', 'Disgust', 'Fear', 'Happiness', 'Sadness', 'Surprise', 'Neutral'] |
|
mood = emotions[np.argmax(outputs.logits.detach().cpu().numpy())] |
|
|
|
|
|
st.write(f"Your mood seems to be **{mood.lower()}** today! Here's a song for you that matches with how you feel!") |
|
|
|
song_paths = st.session_state.model.image_search(mood) |
|
for song in song_paths: |
|
song_name = df.loc[df['track_id'] == int(song[-6:])]['track_title'].to_list()[0] |
|
artist_name = df.loc[df['track_id'] == int(song[-6:])]['artist_name'].to_list()[0] |
|
st.write('**"'+song_name+'"**' + ' by ' + artist_name) |
|
st.audio(song + ".ogg", format="audio/ogg", start_time=0) |
|
|
|
|
|
|
|
selected = streamlit_menu(example=3) |
|
df = pd.read_csv('full_metadata.csv', index_col=False) |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
if selected == "Text": |
|
|
|
draw_text("text", plot=True, device=device) |
|
if selected == "Audio": |
|
|
|
draw_audio("audio", plot=True, device=device) |
|
if selected == "Camera": |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|