seanbenhur commited on
Commit
4045aa3
1 Parent(s): dc7d2f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -5,6 +5,11 @@ from pathlib import Path
5
  from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
6
  # Pattern to ignore all the text after 2 or more full stops
7
  regex_pattern = "[.]{2,}"
 
 
 
 
 
8
  def post_process(text):
9
  try:
10
  text = text.strip()
@@ -17,15 +22,18 @@ def predict(image, max_length=64, num_beams=4):
17
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
18
  pixel_values = pixel_values.to(device)
19
  with torch.no_grad():
20
- output_ids = model.generate(
21
- pixel_values,
22
- max_length=max_length,
23
- num_beams=num_beams,
24
- return_dict_in_generate=True,
25
- ).sequences
26
- preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
27
- pred = post_process(preds[0])
28
- return pred
 
 
 
29
 
30
  model_path = "team-indain-image-caption/hindi-image-captioning"
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
5
  from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
6
  # Pattern to ignore all the text after 2 or more full stops
7
  regex_pattern = "[.]{2,}"
8
+ #sample = val_dataset[800]
9
+ #model = model.cuda()
10
+ #print(tokenizer.decode(model.generate(sample['pixel_values'].unsqueeze(0).cuda())[0]).replace('<|endoftext|>', '').split('\n')[0],'\n\n\n')
11
+
12
+
13
  def post_process(text):
14
  try:
15
  text = text.strip()
 
22
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
23
  pixel_values = pixel_values.to(device)
24
  with torch.no_grad():
25
+ text = model.generate(pixel_values.unsqueeze(0).cuda())
26
+ text = tokenizer.decode(text.replace('<|endoftext|>', '').split('\n')[0],'\n\n\n')
27
+ # output_ids = model.generate(
28
+ # pixel_values,
29
+ # max_length=max_length,
30
+ # num_beams=num_beams,
31
+ # return_dict_in_generate=True,
32
+ #).sequences
33
+
34
+ #preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
35
+ #pred = post_process(preds[0])
36
+ return text
37
 
38
  model_path = "team-indain-image-caption/hindi-image-captioning"
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")