seanbenhur commited on
Commit
00ca6f9
1 Parent(s): a94eec7
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -19,11 +19,16 @@ def post_process(text):
19
  pass
20
  return text
21
  def predict(image, max_length=64, num_beams=4):
 
22
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
23
  pixel_values = pixel_values.to(device)
24
  with torch.no_grad():
25
- text = model.generate(pixel_values.unsqueeze(0).cpu())
26
- text = tokenizer.decode(text.replace('<|endoftext|>', '').split('\n')[0],'\n\n\n')
 
 
 
 
27
  # output_ids = model.generate(
28
  # pixel_values,
29
  # max_length=max_length,
@@ -33,10 +38,10 @@ def predict(image, max_length=64, num_beams=4):
33
 
34
  #preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
35
  #pred = post_process(preds[0])
36
- return text
37
 
38
  model_path = "team-indain-image-caption/hindi-image-captioning"
39
- device = torch.device("cuda:0" if torch.cuda.is_available() else"cpu")
40
  # Load model.
41
  model = VisionEncoderDecoderModel.from_pretrained(model_path)
42
  model.to(device)
 
19
  pass
20
  return text
21
  def predict(image, max_length=64, num_beams=4):
22
+ image = image.convert('RGB')
23
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
24
  pixel_values = pixel_values.to(device)
25
  with torch.no_grad():
26
+ text = tokenizer.decode(model.generate(pixel_values.cpu())[0])
27
+ text = text.replace('<|endoftext|>', '').split('\n')
28
+ #[0],'\n\n\n'
29
+ #text[0]
30
+ #text = model.generate(pixel_values.cpu())
31
+ #text = tokenizer.decode(text.replace('<|endoftext|>', '').split('\n')[0],'\n\n\n')
32
  # output_ids = model.generate(
33
  # pixel_values,
34
  # max_length=max_length,
 
38
 
39
  #preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
40
  #pred = post_process(preds[0])
41
+ return text[0]
42
 
43
  model_path = "team-indain-image-caption/hindi-image-captioning"
44
+ device = "cpu"
45
  # Load model.
46
  model = VisionEncoderDecoderModel.from_pretrained(model_path)
47
  model.to(device)