max commited on
Commit
4c3ee55
1 Parent(s): 24a3787

initial commit

Browse files
Files changed (2) hide show
  1. app.py +115 -0
  2. clip_texts_1_fp16.pkl +3 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import gradio.components as gc
3
+ import gradio as gr
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import CLIPModel, CLIPProcessor
10
+ device = 'cpu'
11
+ torch.no_grad().__enter__()
12
+ torch.autocast('cuda').__enter__()
13
+
14
+ # %%
15
+
16
+ t = pd.read_pickle("clip_texts_1_fp16.pkl")
17
+ words = t.reset_index().word
18
+ wordsv = torch.tensor(t.values).to(device)
19
+
20
+ # %%
21
+
22
+ # %%
23
+ model_name = "openai/clip-vit-large-patch14"
24
+ mmm = CLIPModel.from_pretrained(model_name)
25
+ mmm.eval()
26
+ mmm.to(device)
27
+
28
+ processor = CLIPProcessor.from_pretrained(model_name)
29
+
30
+ # %%
31
+
32
+
33
+ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
34
+ """ helper function to spherically interpolate two arrays v1 v2 """
35
+ inputs_are_torch = False
36
+ if not isinstance(v0, np.ndarray):
37
+ inputs_are_torch = True
38
+ input_device = v0.device
39
+ v0 = v0.cpu().numpy()
40
+ v1 = v1.cpu().numpy()
41
+
42
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
43
+ if np.abs(dot) > DOT_THRESHOLD:
44
+ v2 = (1 - t) * v0 + t * v1
45
+ else:
46
+ theta_0 = np.arccos(dot)
47
+ sin_theta_0 = np.sin(theta_0)
48
+ theta_t = theta_0 * t
49
+ sin_theta_t = np.sin(theta_t)
50
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
51
+ s1 = sin_theta_t / sin_theta_0
52
+ v2 = s0 * v0 + s1 * v1
53
+
54
+ if inputs_are_torch:
55
+ v2 = torch.from_numpy(v2).to(input_device)
56
+
57
+ return v2
58
+
59
+
60
+ def query(text: str, img: Image.Image, limit: int, score_threshold: float, slerp_degree: float):
61
+ if text != '':
62
+ inp = processor(text=text, return_tensors='pt').to(device)
63
+ rout = mmm.get_text_features(**inp)
64
+ tout = rout.detach().cpu().numpy()[0]
65
+ out = tout
66
+
67
+ if img is not None:
68
+ inp = processor(images=[img], return_tensors="pt",).to(device)
69
+ rout = mmm.get_image_features(**inp)
70
+ iout = rout.detach().cpu().numpy()[0]
71
+ out = iout
72
+
73
+ if text != '' and img is not None:
74
+ out = slerp(slerp_degree, tout, iout)
75
+
76
+ if out is not None:
77
+ # calculate cosine similarity
78
+ scores = np.dot(out, wordsv.T)
79
+ # sort by score
80
+ topk = (
81
+ pd.concat(
82
+ [words, pd.Series(scores, name='score')],
83
+ axis=1
84
+ )
85
+ .sort_values('score', ascending=False)
86
+ .query(f'score > {score_threshold}')
87
+ .head(limit)
88
+ )
89
+
90
+ topwords = "\n".join(
91
+ f'{word}: {score:.2f} '
92
+ for _, word, score in topk.itertuples()
93
+ )
94
+
95
+ return topwords
96
+
97
+
98
+ searchtext = gc.Textbox(lines=2, placeholder="Search text")
99
+ searchimage = gc.Image(shape=(224, 224), label="Search image", type='pil')
100
+ inp_limit = gc.Slider(1, 50, 10, step=1, label='Limit')
101
+ score_threshold = gc.Slider(0, 30, 0, step=.5, label='Score threshold')
102
+ slerp_degree = gc.Slider(
103
+ 0, 1, 0.5, step=.01, label='Slerp degree (if both text and image are provided)\nFinds a midpoint between image and text embeddings')
104
+
105
+
106
+ dsurl = 'https://www.kaggle.com/datasets/yk1598/479k-english-words'
107
+ gr.Interface(
108
+ query,
109
+ [searchtext, searchimage, inp_limit, score_threshold, slerp_degree],
110
+ [gc.Textbox(label='Top words')],
111
+ title="Initial Token Finder for Textual Inversion",
112
+ description=f"find the closest single token word for a given text and/or image.\nbased on {model_name}.\n\nData: {dsurl}",
113
+ analytics_enabled=False,
114
+ allow_flagging='never',
115
+ ).launch()
clip_texts_1_fp16.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbcc5d3d464979b764c0b8a69a58f28f5bf941bf10b3501b513d3b28fcb17876
3
+ size 39828901