Singularity666 commited on
Commit
751d17d
0 Parent(s):

Duplicate from Singularity666/RadiXGPT_

Browse files
Files changed (9) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +140 -0
  4. gitignore.txt +7 -0
  5. main.py +335 -0
  6. requirements.txt +14 -0
  7. saved_text_embeddings.pt +3 -0
  8. testing_df.csv +0 -0
  9. weights.pt +3 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: RadiXGPT
3
+ emoji: 🏢
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ sdk_version: 1.19.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: bigscience-openrail-m
11
+ duplicated_from: Singularity666/RadiXGPT_
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import pandas as pd
4
+ import torch
5
+ from PIL import Image
6
+ import numpy as np
7
+ from main import predict_caption, CLIPModel, get_text_embeddings
8
+ import openai
9
+ import base64
10
+ from docx import Document
11
+ from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
12
+ from io import BytesIO
13
+ import re
14
+
15
+ openai.api_key = "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"
16
+
17
+ st.markdown(
18
+ """
19
+ <style>
20
+ body {
21
+ background-color: transparent;
22
+ }
23
+ .container {
24
+ display: flex;
25
+ justify-content: center;
26
+ align-items: center;
27
+ background-color: rgba(255, 255, 255, 0.7);
28
+ border-radius: 15px;
29
+ padding: 20px;
30
+ }
31
+ .stApp {
32
+ background-color: transparent;
33
+ }
34
+ .stText, .stMarkdown, .stTextInput>label, .stButton>button>span {
35
+ color: #1c1c1c !important; /* Set the dark text color for text elements */
36
+ }
37
+ .stButton>button>span {
38
+ color: initial !important; /* Reset the text color for the 'Generate Caption' button */
39
+ }
40
+ .stMarkdown h1, .stMarkdown h2 {
41
+ color: #ff6b81 !important; /* Set the text color of h1 and h2 elements to soft red-pink */
42
+ font-weight: bold; /* Set the font weight to bold */
43
+ border: 2px solid #ff6b81; /* Add a bold border around the headers */
44
+ padding: 10px; /* Add padding to the headers */
45
+ border-radius: 5px; /* Add border-radius to the headers */
46
+ }
47
+ </style>
48
+ """,
49
+ unsafe_allow_html=True,
50
+ )
51
+
52
+ device = torch.device("cpu")
53
+
54
+ testing_df = pd.read_csv("testing_df.csv")
55
+ model = CLIPModel().to(device)
56
+ model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu')))
57
+ text_embeddings = torch.load('saved_text_embeddings.pt', map_location=device)
58
+
59
+ def download_link(content, filename, link_text):
60
+ b64 = base64.b64encode(content).decode()
61
+ href = f'<a href="data:application/octet-stream;base64,{b64}" download="{filename}">{link_text}</a>'
62
+ return href
63
+
64
+ def show_predicted_caption(image, top_k=8):
65
+ matches = predict_caption(
66
+ image, model, text_embeddings, testing_df["caption"]
67
+ )[:top_k]
68
+ cleaned_matches = [re.sub(r'\s\(ROCO_\d+\)', '', match) for match in matches] # Add this line to clean the matches
69
+ return cleaned_matches # Return the cleaned_matches instead of matches
70
+
71
+ def generate_radiology_report(prompt):
72
+ response = openai.Completion.create(
73
+ engine="text-davinci-003",
74
+ prompt=prompt,
75
+ max_tokens=800,
76
+ n=1,
77
+ stop=None,
78
+ temperature=1,
79
+ )
80
+ report = response.choices[0].text.strip()
81
+ # Remove reference string from the report
82
+ report = re.sub(r'\(ROCO_\d+\)', '', report).strip()
83
+ return report
84
+
85
+
86
+ def save_as_docx(text, filename):
87
+ document = Document()
88
+ document.add_paragraph(text)
89
+ with BytesIO() as output:
90
+ document.save(output)
91
+ output.seek(0)
92
+ return output.getvalue()
93
+
94
+ st.title("RadiXGPT: An Evolution of machine doctors towards Radiology")
95
+
96
+ # Collect user's personal information
97
+ st.subheader("Personal Information")
98
+ first_name = st.text_input("First Name")
99
+ last_name = st.text_input("Last Name")
100
+ age = st.number_input("Age", min_value=0, max_value=120, value=25, step=1)
101
+ gender = st.selectbox("Gender", ["Male", "Female", "Other"])
102
+
103
+ st.write("Upload Scan to get Radiological Report:")
104
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
105
+ if uploaded_file is not None:
106
+ image = Image.open(uploaded_file)
107
+ st.image(image, caption="Uploaded Image", use_column_width=True)
108
+ st.write("")
109
+
110
+ if st.button("Generate Caption"):
111
+ with st.spinner("Generating caption..."):
112
+ image_np = np.array(image)
113
+ caption = show_predicted_caption(image_np)[0]
114
+
115
+ st.success(f"Caption: {caption}")
116
+
117
+ # Generate the radiology report
118
+ radiology_report = generate_radiology_report(f"Write Complete Radiology Report for this with clinical info, subjective, Assessment, Finding, Impressions, Conclusion and more in proper order : {caption}")
119
+
120
+ # Add personal information to the radiology report
121
+ radiology_report_with_personal_info = f"Patient Name: {first_name} {last_name}\nAge: {age}\nGender: {gender}\n\n{radiology_report}"
122
+
123
+ st.header("Radiology Report")
124
+ st.write(radiology_report_with_personal_info)
125
+ st.markdown(download_link(save_as_docx(radiology_report_with_personal_info, "radiology_report.docx"), "radiology_report.docx", "Download Report as DOCX"), unsafe_allow_html=True)
126
+
127
+ feedback_options = ["Satisfied", "Not Satisfied"]
128
+ selected_feedback = st.radio("Please provide feedback on the generated report:", feedback_options)
129
+
130
+ if selected_feedback == "Not Satisfied":
131
+ if st.button("Regenerate Report"):
132
+ with st.spinner("Regenerating report..."):
133
+ alternative_caption = get_alternative_caption(image_np, model, text_embeddings, testing_df["caption"])
134
+ regenerated_radiology_report = generate_radiology_report(f"Write Complete Radiology Report for this with clinical info, subjective, Assessment, Finding, Impressions, Conclusion and more in proper order : {alternative_caption}")
135
+
136
+ regenerated_radiology_report_with_personal_info = f"Patient Name: {first_name} {last_name}\nAge: {age}\nGender: {gender}\n\n{regenerated_radiology_report}"
137
+
138
+ st.header("Regenerated Radiology Report")
139
+ st.write(regenerated_radiology_report_with_personal_info)
140
+ st.markdown(download_link(save_as_docx(regenerated_radiology_report_with_personal_info, "regenerated_radiology_report.docx"), "regenerated_radiology_report.docx", "Download Regenerated Report as DOCX"), unsafe_allow_html=True)
gitignore.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ *~
6
+ .env
7
+
main.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from tqdm.autonotebook import tqdm
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
5
+ import albumentations as A
6
+ import cv2
7
+ import timm
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ device = torch.device("cpu")
12
+
13
+ class CFG:
14
+ debug = False
15
+ image_path = '/content/content/new_images_v5'
16
+ captions_path = '/content/content/all_data/new_caption.csv'
17
+ batch_size = 12
18
+ num_workers = 2
19
+ head_lr = 1e-3
20
+ image_encoder_lr = 1e-4
21
+ text_encoder_lr = 1e-5
22
+ weight_decay = 1e-3
23
+ patience = 1
24
+ factor = 0.8
25
+ epochs = 2
26
+ saved_model_clinical = '/content/content/new_weights.pt'
27
+ trained_model = 'clinical_bert_weights.pt'
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ model_name = 'resnet50'
31
+ image_embedding = 2048
32
+ text_encoder_model = "distilbert-base-uncased"
33
+ clinical_encoder_model = "emilyalsentzer/Bio_ClinicalBERT"
34
+ text_embedding = 768
35
+ text_tokenizer = "distilbert-base-uncased"
36
+ max_length = 200
37
+
38
+ pretrained = True # for both image encoder and text encoder
39
+ trainable = True # for both image encoder and text encoder
40
+ temperature = 1.0
41
+
42
+ # image size
43
+ size = 224
44
+
45
+ # for projection head; used for both image and text encoders
46
+ num_projection_layers = 1
47
+ projection_dim = 256
48
+ dropout = 0.1
49
+
50
+
51
+ def build_loaders(dataframe, tokenizer, mode):
52
+ transforms = get_transforms(mode=mode)
53
+ dataset = CLIPDataset(
54
+ dataframe["image"].values,
55
+ dataframe["caption"].values,
56
+ tokenizer=tokenizer,
57
+ transforms=transforms,
58
+ )
59
+
60
+ dataloader = torch.utils.data.DataLoader(
61
+ dataset,
62
+ batch_size=CFG.batch_size,
63
+ num_workers=CFG.num_workers,
64
+ shuffle=True if mode == "train" else False,
65
+ )
66
+ return dataloader
67
+
68
+
69
+
70
+ class AvgMeter:
71
+ def __init__(self, name="Metric"):
72
+ self.name = name
73
+ self.reset()
74
+
75
+ def reset(self):
76
+ self.avg, self.sum, self.count = [0] * 3
77
+
78
+ def update(self, val, count=1):
79
+ self.count += count
80
+ self.sum += val * count
81
+ self.avg = self.sum / self.count
82
+
83
+ def __repr__(self):
84
+ text = f"{self.name}: {self.avg:.4f}"
85
+ return text
86
+
87
+ def get_lr(optimizer):
88
+ for param_group in optimizer.param_groups:
89
+ return param_group["lr"]
90
+
91
+
92
+ # Custom dataset object. Will tokenize text and apply transforms to images before yielding them.
93
+
94
+ class CLIPDataset(torch.utils.data.Dataset):
95
+ def __init__(self, image_filenames, captions, tokenizer, transforms):
96
+ """
97
+ image_filenames and cpations must have the same length; so, if there are
98
+ multiple captions for each image, the image_filenames must have repetitive
99
+ file names
100
+ """
101
+
102
+ self.image_filenames = image_filenames
103
+ self.captions = list(captions)
104
+ self.skippedImgCount = 0
105
+ self.encoded_captions = tokenizer(
106
+ list(captions), padding=True, truncation=True, max_length=CFG.max_length
107
+ )
108
+ self.transforms = transforms
109
+
110
+ def __getitem__(self, idx):
111
+ item = {
112
+ key: torch.tensor(values[idx])
113
+ for key, values in self.encoded_captions.items()
114
+ }
115
+
116
+ image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
117
+ if image is None:
118
+ # Skip the current example and move to the next one
119
+ self.skippedImgCount += 1
120
+ return self.__getitem__((idx + 1) % len(self))
121
+
122
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
123
+ image = self.transforms(image=image)['image']
124
+ item['image'] = torch.tensor(image).permute(2, 0, 1).float()
125
+ item['caption'] = self.captions[idx]
126
+
127
+ return item
128
+
129
+ def __len__(self):
130
+ return len(self.captions)
131
+
132
+
133
+ def get_transforms(mode="train"):
134
+ if mode == "train":
135
+ return A.Compose(
136
+ [
137
+ A.Resize(CFG.size, CFG.size, always_apply=True),
138
+ A.Normalize(max_pixel_value=255.0, always_apply=True),
139
+ ]
140
+ )
141
+ else:
142
+ return A.Compose(
143
+ [
144
+ A.Resize(CFG.size, CFG.size, always_apply=True),
145
+ A.Normalize(max_pixel_value=255.0, always_apply=True),
146
+ ]
147
+ )
148
+
149
+
150
+ class ImageEncoder(nn.Module):
151
+ """
152
+ Encode images to a fixed size vector
153
+ """
154
+
155
+ def __init__(
156
+ self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
157
+ ):
158
+ super().__init__()
159
+ self.model = timm.create_model(
160
+ model_name, pretrained, num_classes=0, global_pool="avg"
161
+ )
162
+ for p in self.model.parameters():
163
+ p.requires_grad = trainable
164
+
165
+ def forward(self, x):
166
+ return self.model(x)
167
+
168
+ class TextEncoder(nn.Module):
169
+ def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
170
+ super().__init__()
171
+ if pretrained:
172
+ # self.model = DistilBertModel.from_pretrained(model_name)
173
+
174
+ # Use Bio-ClinicalBERT
175
+ self.model = AutoModel.from_pretrained(CFG.clinical_encoder_model)
176
+
177
+ else:
178
+ self.model = DistilBertModel(config=DistilBertConfig())
179
+
180
+ for p in self.model.parameters():
181
+ p.requires_grad = trainable
182
+
183
+ # we are using the CLS token hidden representation as the sentence's embedding
184
+ self.target_token_idx = 0
185
+
186
+ def forward(self, input_ids, attention_mask):
187
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)
188
+ last_hidden_state = output.last_hidden_state
189
+ return last_hidden_state[:, self.target_token_idx, :]
190
+
191
+
192
+ # Get both image and text encodings into a same size matrix
193
+ class ProjectionHead(nn.Module):
194
+ def __init__(
195
+ self,
196
+ embedding_dim,
197
+ projection_dim=CFG.projection_dim,
198
+ dropout=CFG.dropout
199
+ ):
200
+ super().__init__()
201
+ self.projection = nn.Linear(embedding_dim, projection_dim)
202
+ self.gelu = nn.GELU()
203
+ self.fc = nn.Linear(projection_dim, projection_dim)
204
+ self.dropout = nn.Dropout(dropout)
205
+ self.layer_norm = nn.LayerNorm(projection_dim)
206
+
207
+ def forward(self, x):
208
+ projected = self.projection(x)
209
+ x = self.gelu(projected)
210
+ x = self.fc(x)
211
+ x = self.dropout(x)
212
+ x = x + projected
213
+ x = self.layer_norm(x)
214
+ return x
215
+
216
+
217
+ class CLIPModel(nn.Module):
218
+ def __init__(
219
+ self,
220
+ temperature=CFG.temperature,
221
+ image_embedding=CFG.image_embedding,
222
+ text_embedding=CFG.text_embedding,
223
+ ):
224
+ super().__init__()
225
+ self.image_encoder = ImageEncoder()
226
+ self.text_encoder = TextEncoder()
227
+ self.image_projection = ProjectionHead(embedding_dim=image_embedding)
228
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding)
229
+ self.temperature = temperature
230
+
231
+ def forward(self, batch):
232
+ # Getting Image and Text Features
233
+ image_features = self.image_encoder(batch["image"])
234
+ text_features = self.text_encoder(
235
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
236
+ )
237
+ # Getting Image and Text Embeddings (with same dimension)
238
+ image_embeddings = self.image_projection(image_features)
239
+ text_embeddings = self.text_projection(text_features)
240
+
241
+ # Calculating the Loss
242
+ logits = (text_embeddings @ image_embeddings.T) / self.temperature
243
+ images_similarity = image_embeddings @ image_embeddings.T
244
+ texts_similarity = text_embeddings @ text_embeddings.T
245
+ targets = F.softmax(
246
+ (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
247
+ )
248
+ texts_loss = cross_entropy(logits, targets, reduction='none')
249
+ images_loss = cross_entropy(logits.T, targets.T, reduction='none')
250
+ loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
251
+ return loss.mean()
252
+ def cross_entropy(preds, targets, reduction='none'):
253
+ log_softmax = nn.LogSoftmax(dim=-1)
254
+ loss = (-targets * log_softmax(preds)).sum(1)
255
+ if reduction == "none":
256
+ return loss
257
+ elif reduction == "mean":
258
+ return loss.mean()
259
+
260
+
261
+
262
+
263
+
264
+
265
+
266
+
267
+
268
+
269
+
270
+
271
+
272
+
273
+
274
+
275
+
276
+
277
+ # INFERENCE CODE
278
+ def get_image_embeddings(image):
279
+ # preprocess the image
280
+ if image is None:
281
+ print("Image not found!")
282
+ return None
283
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
284
+ image = get_transforms("valid")(image=image)['image']
285
+ image = image.reshape(3, 224, 224)
286
+ model = CLIPModel().to(device)
287
+ model.load_state_dict(torch.load('weights.pt', map_location=device))
288
+ model.eval()
289
+
290
+ with torch.no_grad():
291
+ image_tensor = torch.from_numpy(image)
292
+ image_features = model.image_encoder(image_tensor.unsqueeze(0).to(device))
293
+ image_embeddings = model.image_projection(image_features)
294
+ image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
295
+
296
+ return image_embeddings
297
+
298
+
299
+ def predict_caption(image, model, text_embeddings, captions, n=2):
300
+ # get the image embeddings
301
+ image_embeddings = get_image_embeddings(image)
302
+ if image_embeddings is None:
303
+ return None
304
+
305
+ # normalize the embeddings
306
+ image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
307
+ text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
308
+ # calculate the dot product of image and text embeddings
309
+ dot_similarity = image_embeddings_n @ text_embeddings_n.T
310
+
311
+ # get the top n matches
312
+ values, indices = torch.topk(dot_similarity.squeeze(0), n)
313
+ indices = indices.cpu().numpy().tolist()
314
+ matches = [captions[idx] for idx in indices]
315
+
316
+ return matches
317
+
318
+ def get_text_embeddings(valid_df):
319
+ tokenizer = AutoTokenizer.from_pretrained(CFG.clinical_encoder_model)
320
+ valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
321
+
322
+ model = CLIPModel().to(device)
323
+ model.load_state_dict(torch.load("weights.pt", map_location=device))
324
+ model.eval()
325
+
326
+ valid_text_embeddings = []
327
+ with torch.no_grad():
328
+ for batch in tqdm(valid_loader):
329
+ text_features = model.text_encoder(
330
+ input_ids=batch["input_ids"].to(device), attention_mask=batch["attention_mask"].to(device)
331
+ )
332
+ text_embeddings = model.text_projection(text_features)
333
+ valid_text_embeddings.append(text_embeddings)
334
+
335
+ return model, torch.cat(valid_text_embeddings)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ opencv-python==4.5.4.60
3
+ transformers
4
+ albumentations
5
+ timm
6
+ tqdm
7
+ openai
8
+ streamlit==0.84.0
9
+ pandas
10
+ torch
11
+ Pillow
12
+ numpy
13
+ huggingface-hub
14
+ python-docx
saved_text_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2da707e595ccab006a2159d26d55469c7f015ea4a4dbc645154972fa96a17cc5
3
+ size 7538453
testing_df.csv ADDED
The diff for this file is too large to render. See raw diff
 
weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7be01315d364ccdce0cc260f8b43c7381e3629e670017e5aaa9bcc6ca172eb34
3
+ size 531111517