GhAyoub commited on
Commit
ed51681
·
1 Parent(s): b717804

[OCR API] Reverted.

Browse files
Files changed (1) hide show
  1. main.py +10 -39
main.py CHANGED
@@ -2,76 +2,47 @@ from fastapi import FastAPI, Query
2
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
- import requests
6
- from PIL import Image
7
- from io import BytesIO
8
 
9
  app = FastAPI()
10
 
11
- # Load model and processor
12
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
13
-
14
- # Check for Metal GPU support on macOS
15
- device = "mps" if torch.backends.mps.is_available() else "cpu"
16
-
17
  processor = AutoProcessor.from_pretrained(
18
  checkpoint,
19
- min_pixels=256 * 28 * 28,
20
- max_pixels=1280 * 28 * 28
21
  )
22
-
23
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
24
  checkpoint,
25
- torch_dtype=torch.float16 if device == "mps" else torch.bfloat16, # Use float16 on Apple Metal
26
- device_map={"": 0} if device == "mps" else "cpu",
27
- attn_implementation="flash_attention_2", # If it supports Mac
28
  )
29
 
30
-
31
- # Function to load and resize images (reduces processing time)
32
- def load_and_resize_image(image_url):
33
- response = requests.get(image_url)
34
- image = Image.open(BytesIO(response.content)).convert("RGB")
35
- image = image.resize((512, 512)) # Resize to 512x512 to speed up processing
36
- return image
37
-
38
-
39
  @app.get("/")
40
  def read_root():
41
  return {"message": "API is live. Use the /predict endpoint."}
42
 
43
-
44
  @app.get("/predict")
45
  def predict(image_url: str = Query(...), prompt: str = Query(...)):
46
  messages = [
47
  {"role": "system", "content": "You are a helpful assistant with vision abilities."},
48
  {"role": "user", "content": [{"type": "image", "image": image_url}, {"type": "text", "text": prompt}]},
49
  ]
50
-
51
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
52
-
53
- # Process image
54
- image_inputs = [load_and_resize_image(image_url)]
55
- video_inputs = None
56
-
57
- # Process inputs
58
  inputs = processor(
59
  text=[text],
60
  images=image_inputs,
61
  videos=video_inputs,
62
  padding=True,
63
- truncation=True, # Ensures token limit
64
- max_length=512, # Prevents excessive memory usage
65
  return_tensors="pt",
66
- ).to(device)
67
-
68
- # Generate response
69
  with torch.no_grad():
70
- generated_ids = model.generate(**inputs, max_new_tokens=64) # Reduced for faster inference
71
-
72
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
73
  output_texts = processor.batch_decode(
74
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
75
  )
76
-
77
  return {"response": output_texts[0]}
 
2
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
  import torch
 
 
 
5
 
6
  app = FastAPI()
7
 
 
8
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
9
+ min_pixels = 256*28*28
10
+ max_pixels = 1280*28*28
 
 
11
  processor = AutoProcessor.from_pretrained(
12
  checkpoint,
13
+ min_pixels=min_pixels,
14
+ max_pixels=max_pixels
15
  )
 
16
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
  checkpoint,
18
+ torch_dtype=torch.bfloat16,
19
+ device_map="auto",
20
+ # attn_implementation="flash_attention_2",
21
  )
22
 
 
 
 
 
 
 
 
 
 
23
  @app.get("/")
24
  def read_root():
25
  return {"message": "API is live. Use the /predict endpoint."}
26
 
 
27
  @app.get("/predict")
28
  def predict(image_url: str = Query(...), prompt: str = Query(...)):
29
  messages = [
30
  {"role": "system", "content": "You are a helpful assistant with vision abilities."},
31
  {"role": "user", "content": [{"type": "image", "image": image_url}, {"type": "text", "text": prompt}]},
32
  ]
 
33
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
34
+ image_inputs, video_inputs = process_vision_info(messages)
 
 
 
 
 
35
  inputs = processor(
36
  text=[text],
37
  images=image_inputs,
38
  videos=video_inputs,
39
  padding=True,
 
 
40
  return_tensors="pt",
41
+ ).to(model.device)
 
 
42
  with torch.no_grad():
43
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
 
44
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
45
  output_texts = processor.batch_decode(
46
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
47
  )
 
48
  return {"response": output_texts[0]}