kirp@umich.edu commited on
Commit
fa7ceb4
·
1 Parent(s): 629edc1

comment batch generating

Browse files
Files changed (2) hide show
  1. ocr.py +7 -7
  2. output.png +2 -2
ocr.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image, ImageDraw
5
  from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
6
 
7
  repo = "microsoft/kosmos-2.5"
8
- device = "cuda:1"
9
  dtype = torch.bfloat16
10
  model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype)
11
  processor = AutoProcessor.from_pretrained(repo)
@@ -22,12 +22,12 @@ raw_width, raw_height = image.size
22
  scale_height = raw_height / height
23
  scale_width = raw_width / width
24
 
25
- # bs > 1, batch decoding sample
26
- inputs = processor(text=[prompt, prompt], images=[image,image], return_tensors="pt")
27
- height, width = inputs.pop("height"), inputs.pop("width")
28
- raw_width, raw_height = image.size
29
- scale_height = raw_height / height[0]
30
- scale_width = raw_width / width[0]
31
 
32
  inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
33
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
 
5
  from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
6
 
7
  repo = "microsoft/kosmos-2.5"
8
+ device = "cuda:0"
9
  dtype = torch.bfloat16
10
  model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype)
11
  processor = AutoProcessor.from_pretrained(repo)
 
22
  scale_height = raw_height / height
23
  scale_width = raw_width / width
24
 
25
+ # bs > 1, batch generation
26
+ # inputs = processor(text=[prompt, prompt], images=[image,image], return_tensors="pt")
27
+ # height, width = inputs.pop("height"), inputs.pop("width")
28
+ # raw_width, raw_height = image.size
29
+ # scale_height = raw_height / height[0]
30
+ # scale_width = raw_width / width[0]
31
 
32
  inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
33
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
output.png CHANGED

Git LFS Details

  • SHA256: 410b17e2b48d588c7bd9317e924e69841c0b9670848fe0efa217389d74882d32
  • Pointer size: 132 Bytes
  • Size of remote file: 1.66 MB

Git LFS Details

  • SHA256: d95e7707ae64bea0e864438f3efdf5a501c46714afc3a909b84a29c3feca0b16
  • Pointer size: 132 Bytes
  • Size of remote file: 1.66 MB