hiyouga commited on
Commit
d6a8ce7
1 Parent(s): 9a8b9e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -44
app.py CHANGED
@@ -1,16 +1,40 @@
1
  from threading import Thread
2
- from typing import Dict
3
 
4
  import gradio as gr
5
  import spaces
6
- import torch
7
- from PIL import Image
8
- from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer
9
 
10
 
11
- TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.2</center></h1>"
12
 
13
- DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/BUAADreamer/PaliGemma-3B-Chat-v0.2' target='_blank'>our model page</a> for details.</center></h3>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  CSS = """
16
  .duplicate-button {
@@ -22,57 +46,32 @@ CSS = """
22
  """
23
 
24
 
25
- model_id = "BUAADreamer/PaliGemma-3B-Chat-v0.2"
26
- tokenizer = AutoTokenizer.from_pretrained(model_id)
27
- processor = AutoProcessor.from_pretrained(model_id)
28
- model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
29
 
30
 
31
  @spaces.GPU
32
- def stream_chat(message: Dict[str, str], history: list):
33
- # Turn 1:
34
- # {'text': 'what is this', 'files': ['image-xxx.jpg']}
35
- # []
36
-
37
- # Turn 2:
38
- # {'text': 'continue?', 'files': []}
39
- # [[('image-xxx.jpg',), None], ['what is this', 'a image.']]
40
-
41
- image_path = None
42
- if len(message["files"]) != 0:
43
- image_path = message["files"][0]
44
-
45
- if len(history) != 0 and isinstance(history[0][0], tuple):
46
- image_path = history[0][0][0]
47
- history = history[1:]
48
-
49
- if image_path is not None:
50
- image = Image.open(image_path).convert("RGB")
51
- else:
52
- image = Image.new("RGB", (100, 100), (255, 255, 255))
53
-
54
- pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]
55
-
56
- conversation = []
57
  for prompt, answer in history:
58
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
59
 
60
- conversation.append({"role": "user", "content": message["text"]})
61
-
62
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
63
- image_token_id = tokenizer.convert_tokens_to_ids("<image>")
64
- image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
65
- input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)
66
 
 
 
 
67
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
68
 
69
  generate_kwargs = dict(
70
  input_ids=input_ids,
71
- pixel_values=pixel_values,
72
  streamer=streamer,
73
- max_new_tokens=256,
 
74
  do_sample=True,
75
  )
 
 
76
 
77
  t = Thread(target=model.generate, kwargs=generate_kwargs)
78
  t.start()
@@ -91,9 +90,40 @@ with gr.Blocks(css=CSS) as demo:
91
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
92
  gr.ChatInterface(
93
  fn=stream_chat,
94
- multimodal=True,
95
  chatbot=chatbot,
96
  fill_height=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  cache_examples=False,
98
  )
99
 
 
1
  from threading import Thread
 
2
 
3
  import gradio as gr
4
  import spaces
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
6
 
7
 
8
+ TITLE = "<h1><center>Chat with Gemma-2-9B-Chinese-Chat</center></h1>"
9
 
10
+ DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/shenzhi-wang/Gemma-2-9B-Chinese-Chat' target='_blank'>our model page</a> for details.</center></h3>"
11
+
12
+ DEFAULT_SYSTEM = "You are a helpful assistant."
13
+
14
+ TOOL_EXAMPLE = '''You have access to the following tools:
15
+ ```python
16
+ def generate_password(length: int, include_symbols: Optional[bool]):
17
+ """
18
+ Generate a random password.
19
+
20
+ Args:
21
+ length (int): The length of the password
22
+ include_symbols (Optional[bool]): Include symbols in the password
23
+ """
24
+ pass
25
+ ```
26
+
27
+ Write "Action:" followed by a list of actions in JSON that you want to call, e.g.
28
+ Action:
29
+ ```json
30
+ [
31
+ {
32
+ "name": "tool name (one of [generate_password])",
33
+ "arguments": "the input to the tool"
34
+ }
35
+ ]
36
+ ```
37
+ '''
38
 
39
  CSS = """
40
  .duplicate-button {
 
46
  """
47
 
48
 
49
+ tokenizer = AutoTokenizer.from_pretrained("shenzhi-wang/Gemma-2-9B-Chinese-Chat")
50
+ model = AutoModelForCausalLM.from_pretrained("shenzhi-wang/Gemma-2-9B-Chinese-Chat", device_map="auto")
 
 
51
 
52
 
53
  @spaces.GPU
54
+ def stream_chat(message: str, history: list, system: str, temperature: float, max_new_tokens: int):
55
+ conversation = [{"role": "system", "content": system or DEFAULT_SYSTEM}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  for prompt, answer in history:
57
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
58
 
59
+ conversation.append({"role": "user", "content": message})
 
 
 
 
 
60
 
61
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(
62
+ model.device
63
+ )
64
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
65
 
66
  generate_kwargs = dict(
67
  input_ids=input_ids,
 
68
  streamer=streamer,
69
+ max_new_tokens=max_new_tokens,
70
+ temperature=temperature,
71
  do_sample=True,
72
  )
73
+ if temperature == 0:
74
+ generate_kwargs["do_sample"] = False
75
 
76
  t = Thread(target=model.generate, kwargs=generate_kwargs)
77
  t.start()
 
90
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
91
  gr.ChatInterface(
92
  fn=stream_chat,
 
93
  chatbot=chatbot,
94
  fill_height=True,
95
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
96
+ additional_inputs=[
97
+ gr.Text(
98
+ value="",
99
+ label="System",
100
+ render=False,
101
+ ),
102
+ gr.Slider(
103
+ minimum=0,
104
+ maximum=1,
105
+ step=0.1,
106
+ value=0.8,
107
+ label="Temperature",
108
+ render=False,
109
+ ),
110
+ gr.Slider(
111
+ minimum=128,
112
+ maximum=4096,
113
+ step=1,
114
+ value=1024,
115
+ label="Max new tokens",
116
+ render=False,
117
+ ),
118
+ ],
119
+ examples=[
120
+ ["我的蓝牙耳机坏了,我该去看牙科还是耳鼻喉科?", ""],
121
+ ["7年前,妈妈年龄是儿子的6倍,儿子今年12岁,妈妈今年多少岁?", ""],
122
+ ["我的笔记本找不到了。", "扮演诸葛亮和我对话。"],
123
+ ["我想要一个新的密码,长度为8位,包含特殊符号。", TOOL_EXAMPLE],
124
+ ["How are you today?", "You are Taylor Swift, use beautiful lyrics to answer questions."],
125
+ ["用C++实现KMP算法,并加上中文注释", ""],
126
+ ],
127
  cache_examples=False,
128
  )
129