Spaces:
Runtime error
Runtime error
seanbenhur
commited on
Commit
•
00ca6f9
1
Parent(s):
a94eec7
fix bugs
Browse files
app.py
CHANGED
@@ -19,11 +19,16 @@ def post_process(text):
|
|
19 |
pass
|
20 |
return text
|
21 |
def predict(image, max_length=64, num_beams=4):
|
|
|
22 |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
23 |
pixel_values = pixel_values.to(device)
|
24 |
with torch.no_grad():
|
25 |
-
text = model.generate(pixel_values.
|
26 |
-
text =
|
|
|
|
|
|
|
|
|
27 |
# output_ids = model.generate(
|
28 |
# pixel_values,
|
29 |
# max_length=max_length,
|
@@ -33,10 +38,10 @@ def predict(image, max_length=64, num_beams=4):
|
|
33 |
|
34 |
#preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
35 |
#pred = post_process(preds[0])
|
36 |
-
return text
|
37 |
|
38 |
model_path = "team-indain-image-caption/hindi-image-captioning"
|
39 |
-
device =
|
40 |
# Load model.
|
41 |
model = VisionEncoderDecoderModel.from_pretrained(model_path)
|
42 |
model.to(device)
|
|
|
19 |
pass
|
20 |
return text
|
21 |
def predict(image, max_length=64, num_beams=4):
|
22 |
+
image = image.convert('RGB')
|
23 |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
24 |
pixel_values = pixel_values.to(device)
|
25 |
with torch.no_grad():
|
26 |
+
text = tokenizer.decode(model.generate(pixel_values.cpu())[0])
|
27 |
+
text = text.replace('<|endoftext|>', '').split('\n')
|
28 |
+
#[0],'\n\n\n'
|
29 |
+
#text[0]
|
30 |
+
#text = model.generate(pixel_values.cpu())
|
31 |
+
#text = tokenizer.decode(text.replace('<|endoftext|>', '').split('\n')[0],'\n\n\n')
|
32 |
# output_ids = model.generate(
|
33 |
# pixel_values,
|
34 |
# max_length=max_length,
|
|
|
38 |
|
39 |
#preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
40 |
#pred = post_process(preds[0])
|
41 |
+
return text[0]
|
42 |
|
43 |
model_path = "team-indain-image-caption/hindi-image-captioning"
|
44 |
+
device = "cpu"
|
45 |
# Load model.
|
46 |
model = VisionEncoderDecoderModel.from_pretrained(model_path)
|
47 |
model.to(device)
|