konstantinG commited on
Commit
45e49d7
1 Parent(s): 77fd092

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +50 -0
  2. requirements.txt +94 -0
main.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import random
3
+ import os
4
+ import clip
5
+ from funcs.get_similarity import get_similarity_score, create_filelist, load_embeddings, find_matches
6
+ from funcs.fiass_similaruty import load_embeddings, encode_text, find_matches_fiass
7
+ import torch.nn.functional as F
8
+ import pandas as pd
9
+
10
+ device = 'cpu'
11
+ model_path = "weights/ViT-B-32.pt"
12
+
13
+ model, preprocess = clip.load('ViT-B/32', device)
14
+
15
+ file_name = create_filelist('img')
16
+ features = load_embeddings('embeddings/emb_images_5000.npy')
17
+ df = pd.read_csv('data/results.csv')
18
+
19
+ random_queries = ['friends playing cards', 'rock band playing on guitars', 'policeman cross the road',
20
+ 'sleeping kids', 'football team playing on the grass' , 'learning programming'
21
+ ]
22
+
23
+ st.header('Find my pic!')
24
+
25
+ request = st.text_input('Write a description of the picture', ' Two people at the photo')
26
+
27
+ img_count = st.slider('How much pic you need?', 4, 8, 6, 2)
28
+
29
+
30
+ matches = find_matches_fiass(features, request, file_name, img_count)
31
+ row1, row2 = st.columns(2)
32
+
33
+ if st.button('Find!'):
34
+
35
+ selected_filenames = matches
36
+
37
+ for i in range(int(img_count/2)):
38
+ filename = selected_filenames[i]
39
+ img_path = filename
40
+ img_discription = df[df['image_name'] == filename.split('/')[1]]['comment'].iloc[0]
41
+ with row1:
42
+ st.image(img_path, width=300, caption=img_discription)
43
+
44
+ # display next 3 images in the second row
45
+ for i in range(int(img_count/2), img_count):
46
+ filename = selected_filenames[i]
47
+ img_path = filename
48
+ img_discription = df[df['image_name'] == filename.split('/')[1]]['comment'].iloc[0]
49
+ with row2:
50
+ st.image(img_path, width=300, caption=img_discription)
requirements.txt ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip @ git+https://github.com/openai/CLIP.git@a9b1bf5920416aaeaec965c25dd9e8f98c864f16
2
+ clip-by-openai==1.1
3
+ altair==4.2.2
4
+ asttokens==2.2.1
5
+ attrs==22.2.0
6
+ backcall==0.2.0
7
+ beautifulsoup4==4.12.0
8
+ blinker==1.5
9
+ cachetools==5.3.0
10
+ certifi==2022.12.7
11
+ charset-normalizer==3.1.0
12
+ click==8.1.3
13
+ clip==1.0
14
+ colorama==0.4.6
15
+ comm==0.1.3
16
+ contourpy==1.0.7
17
+ cycler==0.11.0
18
+ debugpy==1.6.6
19
+ decorator==5.1.1
20
+ entrypoints==0.4
21
+ executing==1.2.0
22
+ faiss-cpu==1.7.3
23
+ filelock==3.10.0
24
+ fonttools==4.39.2
25
+ ftfy==6.1.1
26
+ gdown==4.6.4
27
+ gitdb==4.0.10
28
+ GitPython==3.1.31
29
+ idna==3.4
30
+ importlib-metadata==6.1.0
31
+ ipykernel==6.22.0
32
+ ipython==8.11.0
33
+ jedi==0.18.2
34
+ Jinja2==3.1.2
35
+ jsonschema==4.17.3
36
+ jupyter_client==8.1.0
37
+ jupyter_core==5.3.0
38
+ kiwisolver==1.4.4
39
+ markdown-it-py==2.2.0
40
+ MarkupSafe==2.1.2
41
+ matplotlib==3.7.1
42
+ matplotlib-inline==0.1.6
43
+ mdurl==0.1.2
44
+ mpmath==1.3.0
45
+ nest-asyncio==1.5.6
46
+ networkx==3.0
47
+ numpy==1.24.2
48
+ opencv-python==4.7.0.72
49
+ packaging==23.0
50
+ pandas==1.5.3
51
+ parso==0.8.3
52
+ pickleshare==0.7.5
53
+ Pillow==9.4.0
54
+ platformdirs==3.1.1
55
+ prompt-toolkit==3.0.38
56
+ protobuf==3.20.3
57
+ psutil==5.9.4
58
+ pure-eval==0.2.2
59
+ pyarrow==11.0.0
60
+ pydeck==0.8.0
61
+ Pygments==2.14.0
62
+ Pympler==1.0.1
63
+ pyparsing==3.0.9
64
+ pyrsistent==0.19.3
65
+ PySocks==1.7.1
66
+ python-dateutil==2.8.2
67
+ pytz==2022.7.1
68
+ pytz-deprecation-shim==0.1.0.post0
69
+ pyzmq==25.0.2
70
+ regex==2023.3.23
71
+ requests==2.28.2
72
+ rich==13.3.2
73
+ semver==2.13.0
74
+ six==1.16.0
75
+ smmap==5.0.0
76
+ soupsieve==2.4
77
+ stack-data==0.6.2
78
+ streamlit==1.20.0
79
+ sympy==1.11.1
80
+ toml==0.10.2
81
+ toolz==0.12.0
82
+ torch==1.7.1
83
+ torchvision==0.8.2
84
+ tornado==6.2
85
+ tqdm==4.65.0
86
+ traitlets==5.9.0
87
+ typing_extensions==4.5.0
88
+ tzdata==2022.7
89
+ tzlocal==4.3
90
+ urllib3==1.26.15
91
+ validators==0.20.0
92
+ watchdog==3.0.0
93
+ wcwidth==0.2.6
94
+ zipp==3.15.0