Aekanun commited on
Commit
21cb9fc
1 Parent(s): a88db73
Files changed (1) hide show
  1. app.py +97 -60
app.py CHANGED
@@ -1,75 +1,112 @@
1
  import os
 
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoProcessor
 
4
  from PIL import Image
5
  import gradio as gr
6
 
7
- # Login to Hugging Face Hub
8
- from huggingface_hub import login
9
- token = os.environ.get('HUGGING_FACE_HUB_TOKEN')
10
- if token:
11
- login(token=token)
12
 
13
- def load_model():
14
- base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
15
- hub_model_path = "Aekanun/thai-handwriting-llm"
16
-
17
- processor = AutoProcessor.from_pretrained(base_model_path, token=token)
18
-
19
- # โหลดโมเดลด้วย CausalLM ตาม task_type ใน adapter_config
20
- model = AutoModelForCausalLM.from_pretrained(
21
- hub_model_path,
22
- trust_remote_code=True,
23
- target_modules=["q_proj", "v_proj"], # จาก adapter_config
24
- token=token
25
- )
26
-
27
- return model, processor
28
 
29
- model, processor = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- def process_image(image):
32
  if image is None:
33
  return "กรุณาอัพโหลดรูปภาพ"
34
-
35
- if not isinstance(image, Image.Image):
36
- image = Image.fromarray(image)
37
 
38
- if image.mode != "RGB":
39
- image = image.convert("RGB")
 
 
 
 
40
 
41
- prompt = "Transcribe the Thai handwritten text from the provided image.\nOnly return the transcription in Thai language."
42
-
43
- messages = [
44
- {
45
- "role": "user",
46
- "content": [
47
- {"type": "text", "text": prompt},
48
- {"type": "image", "image": image}
49
- ],
50
- }
51
- ]
52
 
53
- text = processor.apply_chat_template(messages, tokenize=False)
54
- inputs = processor(text=text, images=image, return_tensors="pt")
55
-
56
- with torch.no_grad():
57
- outputs = model.generate(
58
- **inputs,
59
- max_new_tokens=256,
60
- do_sample=False,
61
- pad_token_id=processor.tokenizer.pad_token_id
62
- )
63
-
64
- transcription = processor.decode(outputs[0], skip_special_tokens=True)
65
- return transcription.strip()
66
 
67
- demo = gr.Interface(
68
- fn=process_image,
69
- inputs=gr.Image(type="pil"),
70
- outputs="text",
71
- title="Thai Handwriting OCR",
72
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- if __name__ == "__main__":
75
- demo.launch()
 
 
 
1
  import os
2
+ import warnings
3
  import torch
4
+ import gc
5
+ from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
6
  from PIL import Image
7
  import gradio as gr
8
 
9
+ warnings.filterwarnings('ignore')
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 
 
 
11
 
12
+ # Global variables
13
+ model = None
14
+ processor = None
15
+
16
+ if torch.cuda.is_available():
17
+ torch.cuda.empty_cache()
18
+ gc.collect()
19
+ print("เคลียร์ CUDA cache เรียบร้อยแล้ว")
 
 
 
 
 
 
 
20
 
21
+ def load_model_and_processor():
22
+ """โหลดโมเดลและ processor"""
23
+ global model, processor
24
+ print("กำลังโหลดโมเดลและ processor...")
25
+
26
+ try:
27
+ base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
28
+ hub_model_path = "Aekanun/thai-handwriting-llm"
29
+
30
+ # ตั้งค่า BitsAndBytes แบบเดียวกับต้นฉบับ
31
+ bnb_config = BitsAndBytesConfig(
32
+ load_in_4bit=True,
33
+ bnb_4bit_use_double_quant=True,
34
+ bnb_4bit_quant_type="nf4",
35
+ bnb_4bit_compute_dtype=torch.bfloat16
36
+ )
37
+
38
+ # โหลด processor แบบเดียวกับต้นฉบับ (ไม่มี token)
39
+ processor = AutoProcessor.from_pretrained(base_model_path)
40
+
41
+ # โหลดโมเดลจาก Hub แบบเดียวกับต้นฉบับ
42
+ print("กำลังโหลดโมเดลจาก Hub...")
43
+ model = AutoModelForVision2Seq.from_pretrained(
44
+ hub_model_path,
45
+ device_map="auto",
46
+ torch_dtype=torch.bfloat16,
47
+ quantization_config=bnb_config,
48
+ trust_remote_code=True
49
+ )
50
+ print("โหลดโมเดลจาก Hub สำเร็จ!")
51
+
52
+ return True
53
+ except Exception as e:
54
+ print(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}")
55
+ return False
56
 
57
+ def process_handwriting(image):
58
  if image is None:
59
  return "กรุณาอัพโหลดรูปภาพ"
 
 
 
60
 
61
+ try:
62
+ if not isinstance(image, Image.Image):
63
+ image = Image.fromarray(image)
64
+
65
+ if image.mode != "RGB":
66
+ image = image.convert("RGB")
67
 
68
+ prompt = """Transcribe the Thai handwritten text from the provided image.
69
+ Only return the transcription in Thai language."""
 
 
 
 
 
 
 
 
 
70
 
71
+ messages = [
72
+ {
73
+ "role": "user",
74
+ "content": [
75
+ {"type": "text", "text": prompt},
76
+ {"type": "image", "image": image}
77
+ ],
78
+ }
79
+ ]
 
 
 
 
80
 
81
+ text = processor.apply_chat_template(messages, tokenize=False)
82
+ inputs = processor(text=text, images=image, return_tensors="pt")
83
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
84
+
85
+ with torch.no_grad():
86
+ outputs = model.generate(
87
+ **inputs,
88
+ max_new_tokens=256,
89
+ do_sample=False,
90
+ pad_token_id=processor.tokenizer.pad_token_id
91
+ )
92
+
93
+ transcription = processor.decode(outputs[0], skip_special_tokens=True)
94
+ return transcription.strip()
95
+
96
+ except Exception as e:
97
+ return f"เกิดข้อผิดพลาด: {str(e)}"
98
+
99
+ print("กำลังเริ่มต้นแอปพลิเคชัน...")
100
+ if load_model_and_processor():
101
+ demo = gr.Interface(
102
+ fn=process_handwriting,
103
+ inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"),
104
+ outputs=gr.Textbox(label="ข้อความที่แปลงได้"),
105
+ title="Thai Handwriting Recognition",
106
+ description="อัพโหลดรูปภาพลายมือเขียนภาษาไทยเพื่อแปลงเป็นข้อความ"
107
+ )
108
 
109
+ if __name__ == "__main__":
110
+ demo.launch()
111
+ else:
112
+ print("ไม่สามารถเริ่มต้นแอปพลิเคชันได้")