Spaces:
Runtime error
Runtime error
Commit
·
4045aa3
1
Parent(s):
dc7d2f7
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,11 @@ from pathlib import Path
|
|
5 |
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
|
6 |
# Pattern to ignore all the text after 2 or more full stops
|
7 |
regex_pattern = "[.]{2,}"
|
|
|
|
|
|
|
|
|
|
|
8 |
def post_process(text):
|
9 |
try:
|
10 |
text = text.strip()
|
@@ -17,15 +22,18 @@ def predict(image, max_length=64, num_beams=4):
|
|
17 |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
18 |
pixel_values = pixel_values.to(device)
|
19 |
with torch.no_grad():
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
29 |
|
30 |
model_path = "team-indain-image-caption/hindi-image-captioning"
|
31 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
5 |
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
|
6 |
# Pattern to ignore all the text after 2 or more full stops
|
7 |
regex_pattern = "[.]{2,}"
|
8 |
+
#sample = val_dataset[800]
|
9 |
+
#model = model.cuda()
|
10 |
+
#print(tokenizer.decode(model.generate(sample['pixel_values'].unsqueeze(0).cuda())[0]).replace('<|endoftext|>', '').split('\n')[0],'\n\n\n')
|
11 |
+
|
12 |
+
|
13 |
def post_process(text):
|
14 |
try:
|
15 |
text = text.strip()
|
|
|
22 |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
23 |
pixel_values = pixel_values.to(device)
|
24 |
with torch.no_grad():
|
25 |
+
text = model.generate(pixel_values.unsqueeze(0).cuda())
|
26 |
+
text = tokenizer.decode(text.replace('<|endoftext|>', '').split('\n')[0],'\n\n\n')
|
27 |
+
# output_ids = model.generate(
|
28 |
+
# pixel_values,
|
29 |
+
# max_length=max_length,
|
30 |
+
# num_beams=num_beams,
|
31 |
+
# return_dict_in_generate=True,
|
32 |
+
#).sequences
|
33 |
+
|
34 |
+
#preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
35 |
+
#pred = post_process(preds[0])
|
36 |
+
return text
|
37 |
|
38 |
model_path = "team-indain-image-caption/hindi-image-captioning"
|
39 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|