toandev commited on
Commit
abdf424
·
verified ·
1 Parent(s): 1d66770

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -61
app.py CHANGED
@@ -1,109 +1,134 @@
 
 
 
 
 
 
 
 
1
  from transformers import (
2
- MllamaForConditionalGeneration,
3
  AutoProcessor,
 
4
  TextIteratorStreamer,
5
  )
6
- from PIL import Image
7
- import requests
8
- import torch
9
- from threading import Thread
10
- import gradio as gr
11
- from gradio import FileData
12
- import time
13
- import spaces
14
 
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- ckpt = "toandev/Viet-Receipt-Llama-3.2-11B-Vision-Instruct"
 
 
 
17
  model = MllamaForConditionalGeneration.from_pretrained(
18
- ckpt, torch_dtype=torch.bfloat16
19
- ).to(device)
20
- processor = AutoProcessor.from_pretrained(ckpt)
21
 
22
 
23
- @spaces.GPU
24
- def bot_streaming(message, history, max_new_tokens=250):
 
25
 
26
- txt = message["text"]
27
- ext_buffer = f"{txt}"
28
 
 
 
 
29
  messages = []
30
  images = []
31
 
32
  for i, msg in enumerate(history):
33
  if isinstance(msg[0], tuple):
34
- messages.append(
35
- {
36
- "role": "user",
37
- "content": [
38
- {"type": "text", "text": history[i + 1][0]},
39
- {"type": "image"},
40
- ],
41
- }
42
- )
43
- messages.append(
44
- {
45
- "role": "assistant",
46
- "content": [{"type": "text", "text": history[i + 1][1]}],
47
- }
48
  )
49
  images.append(Image.open(msg[0][0]).convert("RGB"))
50
  elif isinstance(history[i - 1], tuple) and isinstance(msg[0], str):
51
- # messages are already handled
52
- pass
53
- elif isinstance(history[i - 1][0], str) and isinstance(
54
- msg[0], str
55
- ): # text only turn
56
- messages.append(
57
- {"role": "user", "content": [{"type": "text", "text": msg[0]}]}
58
- )
59
- messages.append(
60
- {"role": "assistant", "content": [{"type": "text", "text": msg[1]}]}
61
  )
62
 
63
- # add current message
64
- if len(message["files"]) == 1:
65
 
66
- if isinstance(message["files"][0], str): # examples
67
- image = Image.open(message["files"][0]).convert("RGB")
68
- else: # regular input
69
- image = Image.open(message["files"][0]["path"]).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  images.append(image)
71
  messages.append(
72
  {
73
  "role": "user",
74
- "content": [{"type": "text", "text": txt}, {"type": "image"}],
75
  }
76
  )
77
  else:
78
- messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})
79
 
 
80
  texts = processor.apply_chat_template(messages, add_generation_prompt=True)
 
 
 
 
 
81
 
82
- if images == []:
83
- inputs = processor(text=texts, return_tensors="pt").to(device)
84
- else:
85
- inputs = processor(text=texts, images=images, return_tensors="pt").to(device)
86
  streamer = TextIteratorStreamer(
87
  processor, skip_special_tokens=True, skip_prompt=True
88
  )
89
-
90
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
91
- generated_text = ""
92
 
93
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
94
  thread.start()
95
- buffer = ""
96
 
 
97
  for new_text in streamer:
98
  buffer += new_text
99
- generated_text_without_prompt = buffer
100
  time.sleep(0.01)
101
  yield buffer
102
 
103
 
104
  demo = gr.ChatInterface(
105
  fn=bot_streaming,
106
- title="Multimodal Llama",
107
  textbox=gr.MultimodalTextbox(),
108
  additional_inputs=[
109
  gr.Slider(
@@ -115,10 +140,10 @@ demo = gr.ChatInterface(
115
  )
116
  ],
117
  cache_examples=False,
118
- description="Try Multimodal Llama by Meta with transformers in this demo. Upload an image, and start chatting about it, or simply try one of the examples below. To learn more about Llama Vision, visit [our blog post](https://huggingface.co/blog/llama32). ",
119
  stop_btn="Stop Generation",
120
  fill_height=True,
121
  multimodal=True,
122
  )
123
 
124
- demo.launch(debug=True)
 
 
1
+ import time
2
+ from threading import Thread
3
+ from typing import Dict, List
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from PIL import Image
9
  from transformers import (
 
10
  AutoProcessor,
11
+ MllamaForConditionalGeneration,
12
  TextIteratorStreamer,
13
  )
 
 
 
 
 
 
 
 
14
 
15
+ # Constants
16
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ CHECKPOINT = "toandev/Viet-Receipt-Llama-3.2-11B-Vision-Instruct"
18
+
19
+ # Model initialization
20
  model = MllamaForConditionalGeneration.from_pretrained(
21
+ CHECKPOINT, torch_dtype=torch.bfloat16
22
+ ).to(DEVICE)
23
+ processor = AutoProcessor.from_pretrained(CHECKPOINT)
24
 
25
 
26
+ def process_chat_history(history: List) -> tuple[List[Dict], List[Image.Image]]:
27
+ """
28
+ Process chat history to extract messages and images.
29
 
30
+ Args:
31
+ history: List of chat messages
32
 
33
+ Returns:
34
+ Tuple containing processed messages and images
35
+ """
36
  messages = []
37
  images = []
38
 
39
  for i, msg in enumerate(history):
40
  if isinstance(msg[0], tuple):
41
+ messages.extend(
42
+ [
43
+ {
44
+ "role": "user",
45
+ "content": [
46
+ {"type": "text", "text": history[i + 1][0]},
47
+ {"type": "image"},
48
+ ],
49
+ },
50
+ {
51
+ "role": "assistant",
52
+ "content": [{"type": "text", "text": history[i + 1][1]}],
53
+ },
54
+ ]
55
  )
56
  images.append(Image.open(msg[0][0]).convert("RGB"))
57
  elif isinstance(history[i - 1], tuple) and isinstance(msg[0], str):
58
+ continue
59
+ elif isinstance(history[i - 1][0], str) and isinstance(msg[0], str):
60
+ messages.extend(
61
+ [
62
+ {"role": "user", "content": [{"type": "text", "text": msg[0]}]},
63
+ {
64
+ "role": "assistant",
65
+ "content": [{"type": "text", "text": msg[1]}],
66
+ },
67
+ ]
68
  )
69
 
70
+ return messages, images
71
+
72
 
73
+ @spaces.GPU
74
+ def bot_streaming(message: Dict, history: List, max_new_tokens: int = 250) -> str:
75
+ """
76
+ Generate streaming responses for the chatbot.
77
+
78
+ Args:
79
+ message: Current message containing text and files
80
+ history: Chat history
81
+ max_new_tokens: Maximum number of tokens to generate
82
+
83
+ Yields:
84
+ Generated text buffer
85
+ """
86
+ text = message["text"]
87
+ messages, images = process_chat_history(history)
88
+
89
+ # Handle current message
90
+ if len(message["files"]) == 1:
91
+ image = (
92
+ Image.open(message["files"][0])
93
+ if isinstance(message["files"][0], str)
94
+ else Image.open(message["files"][0]["path"])
95
+ ).convert("RGB")
96
  images.append(image)
97
  messages.append(
98
  {
99
  "role": "user",
100
+ "content": [{"type": "text", "text": text}, {"type": "image"}],
101
  }
102
  )
103
  else:
104
+ messages.append({"role": "user", "content": [{"type": "text", "text": text}]})
105
 
106
+ # Process inputs
107
  texts = processor.apply_chat_template(messages, add_generation_prompt=True)
108
+ inputs = (
109
+ processor(text=texts, images=images, return_tensors="pt")
110
+ if images
111
+ else processor(text=texts, return_tensors="pt")
112
+ ).to(DEVICE)
113
 
114
+ # Setup streaming
 
 
 
115
  streamer = TextIteratorStreamer(
116
  processor, skip_special_tokens=True, skip_prompt=True
117
  )
 
118
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
 
119
 
120
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
121
  thread.start()
 
122
 
123
+ buffer = ""
124
  for new_text in streamer:
125
  buffer += new_text
 
126
  time.sleep(0.01)
127
  yield buffer
128
 
129
 
130
  demo = gr.ChatInterface(
131
  fn=bot_streaming,
 
132
  textbox=gr.MultimodalTextbox(),
133
  additional_inputs=[
134
  gr.Slider(
 
140
  )
141
  ],
142
  cache_examples=False,
 
143
  stop_btn="Stop Generation",
144
  fill_height=True,
145
  multimodal=True,
146
  )
147
 
148
+ if __name__ == "__main__":
149
+ demo.launch(debug=True)