Silvia Terragni commited on
Commit
ce649db
·
1 Parent(s): f1abd41

updated template

Browse files
Files changed (4) hide show
  1. app.py +8 -111
  2. home.py +11 -0
  3. image2text.py +0 -0
  4. text2image.py +106 -0
app.py CHANGED
@@ -1,112 +1,9 @@
1
- import io
2
- import os
3
- import requests
4
- import zipfile
5
- import natsort
6
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
7
- from pathlib import Path
8
- from stqdm import stqdm
9
  import streamlit as st
10
- from jax import numpy as jnp
11
- import transformers
12
- from transformers import AutoTokenizer
13
- from torchvision.transforms import Compose, CenterCrop, Normalize, Resize, ToTensor
14
- from torchvision.transforms.functional import InterpolationMode
15
- from modeling_hybrid_clip import FlaxHybridCLIP
16
-
17
- import utils
18
-
19
-
20
- @st.cache(hash_funcs={FlaxHybridCLIP: lambda _: None})
21
- def get_model():
22
- return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
23
-
24
-
25
- @st.cache(hash_funcs={transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None})
26
- def get_tokenizer():
27
- return AutoTokenizer.from_pretrained("dbmdz/bert-base-italian-xxl-uncased", cache_dir="./", use_fast=True)
28
-
29
-
30
- @st.cache(suppress_st_warning=True)
31
- def download_images():
32
- # from sentence_transformers import SentenceTransformer, util
33
- img_folder = "photos/"
34
- if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
35
- os.makedirs(img_folder, exist_ok=True)
36
-
37
- photo_filename = "unsplash-25k-photos.zip"
38
- if not os.path.exists(photo_filename): # Download dataset if does not exist
39
- print(f"Downloading {photo_filename}...")
40
- response = requests.get(f"http://sbert.net/datasets/{photo_filename}", stream=True)
41
- total_size_in_bytes= int(response.headers.get('content-length', 0))
42
- block_size = 1024 #1 Kb
43
- progress_bar = stqdm(total=total_size_in_bytes) # , unit='iB', unit_scale=True
44
- content = io.BytesIO()
45
- for data in response.iter_content(block_size):
46
- progress_bar.update(len(data))
47
- content.write(data)
48
- progress_bar.close()
49
- z = zipfile.ZipFile(content)
50
- # content.close()
51
- print("Extracting the dataset...")
52
- z.extractall(path=img_folder)
53
- print("Done.")
54
-
55
-
56
- @st.cache()
57
- def get_image_features():
58
- return jnp.load("static/features/features.npy")
59
-
60
-
61
- def read_markdown_file(markdown_file):
62
- return Path(markdown_file).read_text()
63
-
64
-
65
- """
66
-
67
- # 👋 Ciao!
68
-
69
- # CLIP Italian Demo
70
- ## HF-Flax Community Week
71
-
72
- In this demo you can search for images in the Unsplash 25k Photos dataset.
73
-
74
- 🤌 Italian mode on! 🤌
75
-
76
- """
77
-
78
- query = st.text_input("Insert an italian query text here...")
79
- if query:
80
- with st.spinner("Computing in progress..."):
81
- model = get_model()
82
- download_images()
83
-
84
- image_features = get_image_features()
85
-
86
- model = get_model()
87
- tokenizer = get_tokenizer()
88
-
89
- image_size = model.config.vision_config.image_size
90
-
91
- val_preprocess = Compose(
92
- [
93
- Resize([image_size], interpolation=InterpolationMode.BICUBIC),
94
- CenterCrop(image_size),
95
- ToTensor(),
96
- Normalize(
97
- (0.48145466, 0.4578275, 0.40821073),
98
- (0.26862954, 0.26130258, 0.27577711),
99
- ),
100
- ]
101
- )
102
-
103
- dataset = utils.CustomDataSet("photos/", transform=val_preprocess)
104
-
105
- image_paths = utils.find_image(
106
- query, model, dataset, tokenizer, image_features, n=2
107
- )
108
-
109
- st.image(image_paths)
110
-
111
- intro_markdown = read_markdown_file("introduction.md")
112
- st.markdown(intro_markdown, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import image2text
3
+ import text2image
4
+ import home
5
+
6
+ PAGES = {"Home": home, "Text to Image": text2image, "Image to Text": image2text}
7
+ st.sidebar.title("Navigation")
8
+ page = st.sidebar.selectbox("Choose a task", list(PAGES.keys()))
9
+ PAGES[page].app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
home.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import streamlit as st
3
+
4
+
5
+ def read_markdown_file(markdown_file):
6
+ return Path(markdown_file).read_text()
7
+
8
+
9
+ def app():
10
+ intro_markdown = read_markdown_file("introduction.md")
11
+ st.markdown(intro_markdown, unsafe_allow_html=True)
image2text.py ADDED
File without changes
text2image.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import requests
4
+ import zipfile
5
+ import natsort
6
+
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+ from stqdm import stqdm
9
+ import streamlit as st
10
+ from jax import numpy as jnp
11
+ import transformers
12
+ from transformers import AutoTokenizer
13
+ from torchvision.transforms import Compose, CenterCrop, Normalize, Resize, ToTensor
14
+ from torchvision.transforms.functional import InterpolationMode
15
+ from modeling_hybrid_clip import FlaxHybridCLIP
16
+
17
+ import utils
18
+
19
+
20
+ @st.cache(hash_funcs={FlaxHybridCLIP: lambda _: None})
21
+ def get_model():
22
+ return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
23
+
24
+
25
+ @st.cache(hash_funcs={transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None})
26
+ def get_tokenizer():
27
+ return AutoTokenizer.from_pretrained("dbmdz/bert-base-italian-xxl-uncased", cache_dir="./", use_fast=True)
28
+
29
+
30
+ @st.cache(suppress_st_warning=True)
31
+ def download_images():
32
+ # from sentence_transformers import SentenceTransformer, util
33
+ img_folder = "photos/"
34
+ if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
35
+ os.makedirs(img_folder, exist_ok=True)
36
+
37
+ photo_filename = "unsplash-25k-photos.zip"
38
+ if not os.path.exists(photo_filename): # Download dataset if does not exist
39
+ print(f"Downloading {photo_filename}...")
40
+ response = requests.get(f"http://sbert.net/datasets/{photo_filename}", stream=True)
41
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
42
+ block_size = 1024 # 1 Kb
43
+ progress_bar = stqdm(total=total_size_in_bytes) # , unit='iB', unit_scale=True
44
+ content = io.BytesIO()
45
+ for data in response.iter_content(block_size):
46
+ progress_bar.update(len(data))
47
+ content.write(data)
48
+ progress_bar.close()
49
+ z = zipfile.ZipFile(content)
50
+ # content.close()
51
+ print("Extracting the dataset...")
52
+ z.extractall(path=img_folder)
53
+ print("Done.")
54
+
55
+
56
+ @st.cache()
57
+ def get_image_features():
58
+ return jnp.load("static/features/features.npy")
59
+
60
+ def app():
61
+
62
+ """
63
+
64
+ # 👋 Ciao!
65
+
66
+ # CLIP Italian Demo
67
+ ## HF-Flax Community Week
68
+
69
+ In this demo you can search for images in the Unsplash 25k Photos dataset.
70
+
71
+ 🤌 Italian mode on! 🤌
72
+
73
+ """
74
+
75
+ query = st.text_input("Insert an italian query text here...")
76
+ if query:
77
+ with st.spinner("Computing in progress..."):
78
+ model = get_model()
79
+ download_images()
80
+
81
+ image_features = get_image_features()
82
+
83
+ model = get_model()
84
+ tokenizer = get_tokenizer()
85
+
86
+ image_size = model.config.vision_config.image_size
87
+
88
+ val_preprocess = Compose(
89
+ [
90
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
91
+ CenterCrop(image_size),
92
+ ToTensor(),
93
+ Normalize(
94
+ (0.48145466, 0.4578275, 0.40821073),
95
+ (0.26862954, 0.26130258, 0.27577711),
96
+ ),
97
+ ]
98
+ )
99
+
100
+ dataset = utils.CustomDataSet("photos/", transform=val_preprocess)
101
+
102
+ image_paths = utils.find_imageread_markdown_file(
103
+ query, model, dataset, tokenizer, image_features, n=2
104
+ )
105
+
106
+ st.image(image_paths)