Nechba commited on
Commit
78d298c
Β·
1 Parent(s): 46e1ee4

fisrt commit

Browse files
Files changed (6) hide show
  1. 135_back.jpg +0 -0
  2. 135_front.jpg +0 -0
  3. 141_back.jpg +0 -0
  4. 141_front.jpg +0 -0
  5. app.py +178 -0
  6. requirements.txt +11 -0
135_back.jpg ADDED
135_front.jpg ADDED
141_back.jpg ADDED
141_front.jpg ADDED
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import gradio as gr
4
+ import spaces
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+ import os
7
+ from threading import Thread
8
+
9
+ MODEL_LIST = ["THUDM/glm-4v-9b"]
10
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL_ID = os.environ.get("MODEL_ID")
12
+ MODEL_NAME = MODEL_ID.split("/")[-1]
13
+
14
+ TITLE = f'<br><center>πŸš€ Coin Generative Recognition</a></center>'
15
+
16
+ DESCRIPTION = f"""
17
+ <center>
18
+ <p>
19
+ A Space for Vision/Multimodal
20
+ <br>
21
+ <br>
22
+ ✨ Tips: Send messages or upload multiple IMAGES at a time.
23
+ <br>
24
+ ✨ Tips: Please increase MAX LENGTH when dealing with files.
25
+ <br>
26
+ πŸ€™ Supported Format: png, jpg, webp
27
+ <br>
28
+ πŸ™‡β€β™‚οΈ May be rebuilding from time to time.
29
+ </p>
30
+ </center>"""
31
+
32
+ CSS = """
33
+ h1 {
34
+ text-align: center;
35
+ display: block;
36
+ }
37
+ img {
38
+ max-width: 100%; /* Make sure images are not wider than their container */
39
+ height: auto; /* Maintain aspect ratio */
40
+ max-height: 300px; /* Limit the height of images */
41
+ }
42
+ """
43
+ import os
44
+ # Directory where the model and tokenizer will be saved
45
+
46
+
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ MODEL_ID,
49
+ torch_dtype=torch.bfloat16,
50
+ low_cpu_mem_usage=True,
51
+ trust_remote_code=True
52
+ ).to(0)
53
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
54
+ model.eval()
55
+
56
+
57
+ def merge_images(paths):
58
+ images = [Image.open(path).convert('RGB') for path in paths]
59
+ widths, heights = zip(*(i.size for i in images))
60
+ total_width = sum(widths)
61
+ max_height = max(heights)
62
+ new_im = Image.new('RGB', (total_width, max_height))
63
+ x_offset = 0
64
+ for im in images:
65
+ new_im.paste(im, (x_offset,0))
66
+ x_offset += im.width
67
+ return new_im
68
+
69
+ def mode_load(paths):
70
+ if all(path.lower().endswith(('png', 'jpg', 'jpeg', 'webp')) for path in paths):
71
+ content = merge_images(paths)
72
+ choice = "image"
73
+ return choice, content
74
+ else:
75
+ raise gr.Error("Unsupported file types. Please upload only images.")
76
+
77
+ @spaces.GPU()
78
+ def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
79
+ conversation = []
80
+ if message["files"]:
81
+ choice, contents = mode_load(message["files"])
82
+ conversation.append({"role": "user", "image": contents, "content": message['text']})
83
+ elif message["files"] and len(message["files"]) == 1:
84
+ content = Image.open( message["files"][-1]).convert('RGB')
85
+ choice = "image"
86
+ conversation.append({"role": "user", "image": content, "content": message['text']})
87
+ else:
88
+ raise gr.Error("Please upload one or more images.")
89
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
90
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
91
+ generate_kwargs = dict(
92
+ max_length=max_length,
93
+ streamer=streamer,
94
+ do_sample=True,
95
+ top_p=top_p,
96
+ top_k=top_k,
97
+ temperature=temperature,
98
+ repetition_penalty=penalty,
99
+ eos_token_id=[151329, 151336, 151338],
100
+ )
101
+ gen_kwargs = {**input_ids, **generate_kwargs}
102
+ with torch.no_grad():
103
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
104
+ thread.start()
105
+ buffer = ""
106
+ for new_text in streamer:
107
+ buffer += new_text
108
+ yield buffer
109
+
110
+ chatbot = gr.Chatbot(label="Chatbox", height=600, placeholder=DESCRIPTION)
111
+ chat_input = gr.MultimodalTextbox(
112
+ interactive=True,
113
+ placeholder="Enter message or upload images...",
114
+ show_label=False,
115
+ file_count="multiple",
116
+ )
117
+
118
+ EXAMPLES = [
119
+ [{"text": "Give me Country,Denomination andΒ year as json format.", "files": ["./135_back.jpg", "./135_front.jpg"]}],
120
+ [{"text": "Give me Country,Denomination andΒ year as json format.", "files": ["./141_back.jpg","./141_front.jpg"]}]
121
+ ]
122
+
123
+ with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
124
+ gr.HTML(TITLE)
125
+ gr.ChatInterface(
126
+ fn=stream_chat,
127
+ multimodal=True,
128
+ textbox=chat_input,
129
+ chatbot=chatbot,
130
+ fill_height=True,
131
+ additional_inputs_accordion=gr.Accordion(label="βš™οΈ Parameters", open=False, render=False),
132
+ additional_inputs=[
133
+ gr.Slider(
134
+ minimum=0,
135
+ maximum=1,
136
+ step=0.1,
137
+ value=0.8,
138
+ label="Temperature",
139
+ render=False,
140
+ ),
141
+ gr.Slider(
142
+ minimum=1024,
143
+ maximum=8192,
144
+ step=1,
145
+ value=4096,
146
+ label="Max Length",
147
+ render=False,
148
+ ),
149
+ gr.Slider(
150
+ minimum=0.0,
151
+ maximum=1.0,
152
+ step=0.1,
153
+ value=1.0,
154
+ label="top_p",
155
+ render=False,
156
+ ),
157
+ gr.Slider(
158
+ minimum=1,
159
+ maximum=20,
160
+ step=1,
161
+ value=10,
162
+ label="top_k",
163
+ render=False,
164
+ ),
165
+ gr.Slider(
166
+ minimum=0.0,
167
+ maximum=2.0,
168
+ step=0.1,
169
+ value=1.0,
170
+ label="Repetition penalty",
171
+ render=False,
172
+ ),
173
+ ],
174
+ ),
175
+ gr.Examples(EXAMPLES, [chat_input])
176
+
177
+ if __name__ == "__main__":
178
+ demo.queue(api_open=False).launch(show_api=False, share=False)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pillow
2
+ torch
3
+ torchvision
4
+ transformers
5
+ sentencepiece
6
+ opencv-python
7
+ accelerate
8
+ tiktoken
9
+ PyMuPDF
10
+ python-docx
11
+ python-pptx