vithacocf commited on
Commit
76a5fff
·
verified ·
1 Parent(s): 4be1568

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -6
app.py CHANGED
@@ -1,11 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForVision2Seq
3
- from PIL import Image
4
  import torch
 
5
 
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
  torch.cuda.empty_cache()
8
 
 
9
  model_id = "prithivMLmods/Camel-Doc-OCR-062825"
10
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
11
  model = AutoModelForVision2Seq.from_pretrained(
@@ -14,18 +67,46 @@ model = AutoModelForVision2Seq.from_pretrained(
14
  trust_remote_code=True
15
  ).to(device)
16
 
 
 
 
 
 
 
 
 
 
 
 
17
  def predict(image, prompt=None):
18
- image = image.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Cực kỳ quan trọng: text="" bắt buộc phải có
 
 
 
 
 
 
 
21
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
22
- # In debug để kiểm tra input_ids
23
- print(">>> input_ids shape:", inputs.input_ids.shape)
24
  generated_ids = model.generate(
25
  **inputs,
26
  max_new_tokens=512,
27
  do_sample=False,
28
- use_cache=False, # ✅ Thêm dòng này để fix lỗi cache_position
29
  eos_token_id=processor.tokenizer.eos_token_id,
30
  pad_token_id=processor.tokenizer.pad_token_id
31
  )
@@ -33,6 +114,7 @@ def predict(image, prompt=None):
33
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
  return result
35
 
 
36
  demo = gr.Interface(
37
  fn=predict,
38
  inputs=[
 
1
+ # Code anh Thang
2
+ # import gradio as gr
3
+ # from transformers import AutoProcessor, AutoModelForVision2Seq
4
+ # from PIL import Image
5
+ # import torch
6
+
7
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ # torch.cuda.empty_cache()
9
+
10
+ # model_id = "prithivMLmods/Camel-Doc-OCR-062825"
11
+ # processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
12
+ # model = AutoModelForVision2Seq.from_pretrained(
13
+ # model_id,
14
+ # torch_dtype=torch.float16 if device == "cuda" else torch.float32,
15
+ # trust_remote_code=True
16
+ # ).to(device)
17
+
18
+ # def predict(image, prompt=None):
19
+ # image = image.convert("RGB")
20
+
21
+ # # Cực kỳ quan trọng: text="" bắt buộc phải có
22
+ # inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
23
+ # # In debug để kiểm tra input_ids
24
+ # print(">>> input_ids shape:", inputs.input_ids.shape)
25
+ # generated_ids = model.generate(
26
+ # **inputs,
27
+ # max_new_tokens=512,
28
+ # do_sample=False,
29
+ # use_cache=False, # ✅ Thêm dòng này để fix lỗi cache_position
30
+ # eos_token_id=processor.tokenizer.eos_token_id,
31
+ # pad_token_id=processor.tokenizer.pad_token_id
32
+ # )
33
+
34
+ # result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
35
+ # return result
36
+
37
+ # demo = gr.Interface(
38
+ # fn=predict,
39
+ # inputs=[
40
+ # gr.Image(type="pil", label="Tải ảnh tài liệu lên"),
41
+ # gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn")
42
+ # ],
43
+ # outputs="text",
44
+ # title="Camel-Doc OCR - Trích xuất văn bản từ ảnh"
45
+ # )
46
+
47
+ # if __name__ == "__main__":
48
+ # demo.launch()
49
+
50
+ # Code fix
51
  import gradio as gr
52
  from transformers import AutoProcessor, AutoModelForVision2Seq
53
+ from PIL import Image, UnidentifiedImageError
54
  import torch
55
+ import os
56
 
57
+ # Cấu hình thiết bị
58
  device = "cuda" if torch.cuda.is_available() else "cpu"
59
  torch.cuda.empty_cache()
60
 
61
+ # Load mô hình
62
  model_id = "prithivMLmods/Camel-Doc-OCR-062825"
63
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
64
  model = AutoModelForVision2Seq.from_pretrained(
 
67
  trust_remote_code=True
68
  ).to(device)
69
 
70
+ # Hỗ trợ định dạng ảnh
71
+ def is_supported_image(image):
72
+ return isinstance(image, Image.Image)
73
+
74
+ # Chuyển PNG sang JPG
75
+ def convert_png_to_jpg(image):
76
+ converted = Image.new("RGB", image.size, (255, 255, 255))
77
+ converted.paste(image)
78
+ return converted
79
+
80
+ # Hàm chính
81
  def predict(image, prompt=None):
82
+ # Kiểm tra ảnh hợp lệ
83
+ if not is_supported_image(image):
84
+ return "Không hỗ trợ định dạng file này. Vui lòng tải ảnh đúng."
85
+
86
+ # Prompt rỗng
87
+ if prompt is None or prompt.strip() == "":
88
+ return "Vui lòng nhập prompt để trích xuất dữ liệu từ ảnh."
89
+
90
+ try:
91
+ # Nếu ảnh là PNG có alpha, convert sang RGB
92
+ if image.mode == "RGBA" or image.mode == "LA":
93
+ image = convert_png_to_jpg(image)
94
 
95
+ image = image.convert("RGB")
96
+
97
+ except UnidentifiedImageError:
98
+ return "Không thể đọc ảnh. Vui lòng kiểm tra lại định dạng hoặc ảnh bị lỗi."
99
+ except Exception as e:
100
+ return f"Lỗi khi xử lý ảnh: {str(e)}"
101
+
102
+ # Inference
103
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
104
+
 
105
  generated_ids = model.generate(
106
  **inputs,
107
  max_new_tokens=512,
108
  do_sample=False,
109
+ use_cache=False, # fix cache_position
110
  eos_token_id=processor.tokenizer.eos_token_id,
111
  pad_token_id=processor.tokenizer.pad_token_id
112
  )
 
114
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
115
  return result
116
 
117
+
118
  demo = gr.Interface(
119
  fn=predict,
120
  inputs=[