vithacocf commited on
Commit
25db7d4
·
verified ·
1 Parent(s): 0fb18ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -45
app.py CHANGED
@@ -49,70 +49,60 @@
49
 
50
  # Code fix
51
  import gradio as gr
52
- # from transformers import AutoProcessor, AutoModelForVision2Seq
53
  from PIL import Image, UnidentifiedImageError
54
- # import torch
55
  import os
56
 
57
- # # Cấu hình thiết bị
58
- # device = "cuda" if torch.cuda.is_available() else "cpu"
59
- # torch.cuda.empty_cache()
60
-
61
- # # Load mô hình
62
- # model_id = "prithivMLmods/Camel-Doc-OCR-062825"
63
- # processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
64
- # model = AutoModelForVision2Seq.from_pretrained(
65
- # model_id,
66
- # torch_dtype=torch.float16 if device == "cuda" else torch.float32,
67
- # trust_remote_code=True
68
- # ).to(device)
69
 
70
- # Hỗ trợ định dạng ảnh
71
- def is_supported_image(image):
72
- return isinstance(image, Image.Image)
 
 
 
 
 
73
 
74
- # Chuyển PNG sang JPG
75
  def convert_png_to_jpg(image):
76
- converted = Image.new("RGB", image.size, (255, 255, 255))
77
- converted.paste(image)
78
- return converted
 
 
79
 
80
  # Hàm chính
81
- def predict(image_path, prompt=None):
82
- if not isinstance(image_path, str) or not os.path.exists(image_path):
83
- return "=Không tìm thấy ảnh. Vui lòng thử lại sau khi upload thành công."
84
 
85
  if prompt is None or prompt.strip() == "":
86
  return "=Vui lòng nhập prompt để trích xuất dữ liệu."
87
 
88
  try:
89
- image = Image.open(image_path).convert("RGB")
90
-
91
- if image.mode in ["RGBA", "LA"]:
92
- new_img = Image.new("RGB", image.size, (255, 255, 255))
93
- new_img.paste(image)
94
- image = new_img
 
 
 
 
 
 
 
95
 
96
  except UnidentifiedImageError:
97
  return "=Không thể đọc ảnh. Ảnh có thể bị hỏng hoặc sai định dạng."
98
  except Exception as e:
99
  return f"=Lỗi khi xử lý ảnh: {str(e)}"
100
 
101
- # inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
102
-
103
- # generated_ids = model.generate(
104
- # **inputs,
105
- # max_new_tokens=512,
106
- # do_sample=False,
107
- # use_cache=False,
108
- # eos_token_id=processor.tokenizer.eos_token_id,
109
- # pad_token_id=processor.tokenizer.pad_token_id
110
- # )
111
-
112
- # result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
113
- result = "aaa"
114
- return result
115
-
116
  demo = gr.Interface(
117
  fn=predict,
118
  inputs=[
 
49
 
50
  # Code fix
51
  import gradio as gr
52
+ from transformers import AutoProcessor, AutoModelForVision2Seq
53
  from PIL import Image, UnidentifiedImageError
54
+ import torch
55
  import os
56
 
57
+ # Cấu hình thiết bị
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
60
 
61
+ # Load hình
62
+ model_id = "prithivMLmods/Camel-Doc-OCR-062825"
63
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
64
+ model = AutoModelForVision2Seq.from_pretrained(
65
+ model_id,
66
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
67
+ trust_remote_code=True
68
+ ).to(device)
69
 
70
+ # Hàm xử ảnh (nếu có kênh alpha)
71
  def convert_png_to_jpg(image):
72
+ if image.mode in ["RGBA", "LA"]:
73
+ converted = Image.new("RGB", image.size, (255, 255, 255))
74
+ converted.paste(image, mask=image.split()[-1]) # Dùng alpha làm mask
75
+ return converted
76
+ return image.convert("RGB")
77
 
78
  # Hàm chính
79
+ def predict(image, prompt=None):
80
+ if image is None:
81
+ return "=Vui lòng tải lên ảnh hợp lệ."
82
 
83
  if prompt is None or prompt.strip() == "":
84
  return "=Vui lòng nhập prompt để trích xuất dữ liệu."
85
 
86
  try:
87
+ image = convert_png_to_jpg(image)
88
+
89
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
90
+ generated_ids = model.generate(
91
+ **inputs,
92
+ max_new_tokens=512,
93
+ do_sample=False,
94
+ use_cache=False,
95
+ eos_token_id=processor.tokenizer.eos_token_id,
96
+ pad_token_id=processor.tokenizer.pad_token_id
97
+ )
98
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
99
+ return result
100
 
101
  except UnidentifiedImageError:
102
  return "=Không thể đọc ảnh. Ảnh có thể bị hỏng hoặc sai định dạng."
103
  except Exception as e:
104
  return f"=Lỗi khi xử lý ảnh: {str(e)}"
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  demo = gr.Interface(
107
  fn=predict,
108
  inputs=[