Jiayi-Pan commited on
Commit
7bf6be3
1 Parent(s): 928c42f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoModelForCausalLM,
3
+ AutoTokenizer,
4
+ __version__,
5
+ GenerationConfig,
6
+ )
7
+ from PIL import Image
8
+ import gradio as gr
9
+ import argparse
10
+ import tempfile
11
+
12
+ import os
13
+ from PIL import Image
14
+ import json
15
+ from tqdm import tqdm
16
+ import easyocr
17
+
18
+ assert (
19
+ __version__ == "4.32.0"
20
+ ), "Please use transformers version 4.32.0, pip install transformers==4.32.0"
21
+
22
+ reader = easyocr.Reader(
23
+ ["en"]
24
+ ) # this needs to run only once to load the model into memory
25
+
26
+
27
+ def get_easy_text(img_file):
28
+ out = reader.readtext(img_file, detail=0, paragraph=True)
29
+ if isinstance(out, list):
30
+ return "\n".join(out)
31
+ return out
32
+
33
+ model_name = "DigitalAgent/Captioner"
34
+ model = (
35
+ AutoModelForCausalLM.from_pretrained(
36
+ model_name, device_map="cuda", trust_remote_code=True
37
+ )
38
+ .eval()
39
+ .half()
40
+ )
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
42
+ generation_config = GenerationConfig.from_dict(
43
+ {
44
+ "chat_format": "chatml",
45
+ "do_sample": True,
46
+ "eos_token_id": 151643,
47
+ "max_new_tokens": 2048,
48
+ "max_window_size": 6144,
49
+ "pad_token_id": 151643,
50
+ "repetition_penalty": 1.2,
51
+ "top_k": 0,
52
+ "top_p": 0.3,
53
+ "transformers_version": "4.31.0",
54
+ }
55
+ )
56
+
57
+
58
+ def generate(image: Image):
59
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=True) as tmp:
60
+ image.save(tmp.name)
61
+ ocr_result = get_easy_text(tmp.name)
62
+ text = f"Please describe the screenshot above in details.\nOCR Result:\n{ocr_result}"
63
+ history = []
64
+ input_data = [{"image": tmp.name}, {"text": text}]
65
+ query = tokenizer.from_list_format(input_data)
66
+ response, _ = model.chat(
67
+ tokenizer, query=query, history=history, generation_config=generation_config
68
+ )
69
+ return response
70
+
71
+
72
+ def main(port, share):
73
+ demo = gr.Interface(
74
+ fn=generate, inputs=[gr.Image(type="pil")], outputs="text", concurrency_limit=1
75
+ )
76
+ demo.queue().launch(server_port=port, share=share)
77
+
78
+
79
+ if __name__ == "__main__":
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument("--port", type=int)
82
+ parser.add_argument("--share", action="store_true", default=False)
83
+ args = parser.parse_args()
84
+ main(args.port, args.share)