ysharma HF staff commited on
Commit
9bdc545
1 Parent(s): fd28db2

Added streaming support

Browse files
Files changed (1) hide show
  1. app.py +69 -32
app.py CHANGED
@@ -1,63 +1,100 @@
1
  import gradio as gr
2
  import os
3
  import spaces
4
- from transformers import GemmaTokenizer, AutoModelForCausalLM
 
 
5
 
6
  # Set an environment variable
7
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Load the tokenizer and model
10
  tokenizer = GemmaTokenizer.from_pretrained("google/codegemma-7b-it")
11
  model = AutoModelForCausalLM.from_pretrained("google/codegemma-7b-it", device_map="auto")
12
 
 
13
  @spaces.GPU(duration=120)
14
- def codegemma(message: str, history: list, temperature: float, max_new_tokens: int) -> str:
 
 
 
 
15
  """
16
- Generate a response using the CodeGemma model.
17
-
18
  Args:
19
  message (str): The input message.
20
  history (list): The conversation history used by ChatInterface.
21
  temperature (float): The temperature for generating the response.
22
  max_new_tokens (int): The maximum number of new tokens to generate.
23
-
24
  Returns:
25
  str: The generated response.
26
  """
27
- input_ids = tokenizer(message, return_tensors="pt").to("cuda:0")
28
- outputs = model.generate(
29
- **input_ids,
30
- temperature=temperature,
 
 
 
31
  max_new_tokens=max_new_tokens,
 
 
32
  )
33
- response = tokenizer.decode(outputs[0])
34
- return response
35
 
 
 
36
 
37
- placeholder = """
38
- <div style="opacity: 0.65;">
39
- <img src="https://ysharma-dummy-chat-app.hf.space/file=/tmp/gradio/7dd7659cff2eab51f0f5336f378edfca01dd16fa/gemma_lockup_vertical_full-color_rgb.png" style="width:30%;">
40
- <br><b>CodeGemma-7B-IT Chatbot</b>
41
- </div>
42
- """
43
-
44
 
45
  # Gradio block
46
- chatbot=gr.Chatbot(placeholder=placeholder,)
 
47
  with gr.Blocks(fill_height=True) as demo:
48
- gr.Markdown("# CODEGEMMA-7b-IT")
49
- gr.ChatInterface(codegemma,
50
- chatbot=chatbot,
51
- fill_height=True,
52
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
53
- additional_inputs=[
54
- gr.Slider(0, 1, 0.95, label="Temperature", render=False),
55
- gr.Slider(128, 4096, 512, label="Max new tokens", render=False ),
56
- ],
57
- examples=[["Write a Python function to calculate the nth fibonacci number."]],
58
- cache_examples=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  )
60
 
61
-
62
  if __name__ == "__main__":
63
- demo.launch(debug=False)
 
1
  import gradio as gr
2
  import os
3
  import spaces
4
+ from transformers import AutoModelForCausalLM, GemmaTokenizer, TextIteratorStreamer
5
+ from threading import Thread
6
+
7
 
8
  # Set an environment variable
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
 
11
+ DESCRIPTION = """\
12
+ <h1><center> CodeGemma </center></h1>
13
+ This Space demonstrates model [CodeGemma-7b-it](https://huggingface.co/google/codegemma-7b-it) by Google. CodeGemma is a collection of lightweight open code models built on top of Gemma. Feel free to play with it, or duplicate to run privately!
14
+ 🔎 For more details about the CodeGemma release and how to use the models with `transformers`, take a look [at our blog post](https://huggingface.co/blog/codegemma).
15
+ """
16
+
17
+ PLACEHOLDER = """
18
+ <div style="opacity: 0.65;">
19
+ <img src="https://ysharma-dummy-chat-app.hf.space/file=/tmp/gradio/7dd7659cff2eab51f0f5336f378edfca01dd16fa/gemma_lockup_vertical_full-color_rgb.png" style="width:30%;">
20
+ <br><b>CodeGemma-7B-IT Chatbot</b>
21
+ </div>
22
+ """
23
+
24
+
25
  # Load the tokenizer and model
26
  tokenizer = GemmaTokenizer.from_pretrained("google/codegemma-7b-it")
27
  model = AutoModelForCausalLM.from_pretrained("google/codegemma-7b-it", device_map="auto")
28
 
29
+
30
  @spaces.GPU(duration=120)
31
+ def codegemma(message: str,
32
+ history: list,
33
+ temperature: float,
34
+ max_new_tokens: int
35
+ ) -> str:
36
  """
37
+ Generate a streaming response using the CodeGemma model.
 
38
  Args:
39
  message (str): The input message.
40
  history (list): The conversation history used by ChatInterface.
41
  temperature (float): The temperature for generating the response.
42
  max_new_tokens (int): The maximum number of new tokens to generate.
 
43
  Returns:
44
  str: The generated response.
45
  """
46
+ input_ids = tokenizer.encode(message, return_tensors="pt").to(model.device)
47
+
48
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
49
+
50
+ generate_kwargs = dict(
51
+ input_ids= input_ids,
52
+ streamer=streamer,
53
  max_new_tokens=max_new_tokens,
54
+ do_sample=True,
55
+ temperature=temperature,
56
  )
 
 
57
 
58
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
59
+ t.start()
60
 
61
+ outputs = []
62
+ for text in streamer:
63
+ outputs.append(text)
64
+ yield "".join(outputs)
65
+
 
 
66
 
67
  # Gradio block
68
+ chatbot=gr.Chatbot(placeholder=PLACEHOLDER,height=500)
69
+
70
  with gr.Blocks(fill_height=True) as demo:
71
+
72
+ gr.Markdown(DESCRIPTION)
73
+
74
+ gr.ChatInterface(
75
+ fn=codegemma,
76
+ chatbot=chatbot,
77
+ fill_height=True,
78
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
79
+ additional_inputs=[
80
+ gr.Slider(minimum=0,
81
+ maximum=1,
82
+ step=0.1,
83
+ value=0.95,
84
+ label="Temperature",
85
+ render=False),
86
+ gr.Slider(minimum=128,
87
+ maximum=4096,
88
+ step=1,
89
+ value=512,
90
+ label="Max new tokens",
91
+ render=False ),
92
+ ],
93
+ examples=[
94
+ ["Write a Python function to calculate the nth fibonacci number."]
95
+ ],
96
+ cache_examples=False,
97
  )
98
 
 
99
  if __name__ == "__main__":
100
+ demo.launch()