iclalcetin commited on
Commit
fc25067
·
verified ·
1 Parent(s): eeda621

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -27
app.py CHANGED
@@ -1,25 +1,26 @@
1
- import os
2
  import random
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
 
6
- hf_token = os.getenv("HF_TOKEN")
7
- client = InferenceClient("google/gemma-7b", use_auth_token=hf_token)
8
 
9
  def format_prompt(message, history):
10
  prompt = ""
11
  if history:
12
  for user_prompt, bot_response in history:
13
  prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
14
- prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
15
  prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
16
  return prompt
17
 
18
- def generate(prompt, history, temperature=0.7, max_new_tokens=1024, top_p=0.90, repetition_penalty=1.0):
19
- # Adjust parameters as necessary
20
  temperature = float(temperature)
 
 
21
  top_p = float(top_p)
22
-
23
  if not history:
24
  history = []
25
 
@@ -36,34 +37,67 @@ def generate(prompt, history, temperature=0.7, max_new_tokens=1024, top_p=0.90,
36
 
37
  formatted_prompt = format_prompt(prompt, history)
38
 
39
- stream = client(text=formatted_prompt, parameters=generate_kwargs, wait_for_model=True)
40
  output = ""
41
 
42
- for response in stream["generated_text"]:
43
- output += response
44
  yield output
45
  history.append((prompt, output))
46
  return output
47
 
48
- # Setup Gradio Interface
49
- chatbot_ui = gr.Chatbot()
50
-
51
- def chat_interface(prompt, temperature=0.7, max_new_tokens=160, top_p=0.90, repetition_penalty=1.0):
52
- history = [] # Initialize or fetch existing history
53
- return generate(prompt, history, temperature, max_new_tokens, top_p, repetition_penalty)
54
 
55
- iface = gr.Interface(fn=chat_interface,
56
- inputs=[gr.Textbox(label="Your Message"),
57
- gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.7, step=0.01),
58
- gr.Slider(label="Max new tokens", minimum=1, maximum=512, value=160),
59
- gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, value=0.90),
60
- gr.Slider(label="Repetition Penalty", minimum=0.1, maximum=2.0, value=1.0)],
61
- outputs=chatbot_ui,
62
- live=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
 
 
 
 
 
 
64
 
65
- with gr.Blocks() as app:
66
  gr.HTML("<center><h1>Chat with GEMMA 7B</h1></center>")
67
  iface.render()
68
-
69
- app.launch()
 
 
1
  import random
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
4
 
5
+
6
+ client = InferenceClient("google/gemma-7b")
7
 
8
  def format_prompt(message, history):
9
  prompt = ""
10
  if history:
11
  for user_prompt, bot_response in history:
12
  prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
13
+ prompt += f"<start_of_turn>model{bot_response}"
14
  prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
15
  return prompt
16
 
17
+
18
+ def generate(prompt, history, temperature=0.7, max_new_tokens=1024, top_p=0.90, repetition_penalty=0.9):
19
  temperature = float(temperature)
20
+ if temperature < 1e-2:
21
+ temperature = 1e-2
22
  top_p = float(top_p)
23
+
24
  if not history:
25
  history = []
26
 
 
37
 
38
  formatted_prompt = format_prompt(prompt, history)
39
 
40
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
41
  output = ""
42
 
43
+ for response in stream:
44
+ output += response.token.text
45
  yield output
46
  history.append((prompt, output))
47
  return output
48
 
49
+
50
+ mychatbot = gr.Chatbot(
51
+ avatar_images=["./user.png", "./botgm.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,)
 
 
 
52
 
53
+ additional_inputs=[
54
+ gr.Slider(
55
+ label="Temperature",
56
+ value=0.7,
57
+ minimum=0.0,
58
+ maximum=1.0,
59
+ step=0.01,
60
+ interactive=True,
61
+ info="Higher values generate more diverse outputs",
62
+ ),
63
+ gr.Slider(
64
+ label="Max new tokens",
65
+ value=6400,
66
+ minimum=0,
67
+ maximum=8000,
68
+ step=64,
69
+ interactive=True,
70
+ info="The maximum numbers of new tokens",
71
+ ),
72
+ gr.Slider(
73
+ label="Top-p",
74
+ value=0.90,
75
+ minimum=0.0,
76
+ maximum=1,
77
+ step=0.01,
78
+ interactive=True,
79
+ info="Higher values sample more low-probability tokens",
80
+ ),
81
+ gr.Slider(
82
+ label="Repetition penalty",
83
+ value=1.0,
84
+ minimum=0.1,
85
+ maximum=2.0,
86
+ step=0.1,
87
+ interactive=True,
88
+ info="Penalize repeated tokens",
89
+ )
90
+ ]
91
 
92
+ iface = gr.ChatInterface(fn=generate,
93
+ chatbot=mychatbot,
94
+ additional_inputs=additional_inputs,
95
+ retry_btn=None,
96
+ undo_btn=None
97
+ )
98
 
99
+ with gr.Blocks() as demo:
100
  gr.HTML("<center><h1>Chat with GEMMA 7B</h1></center>")
101
  iface.render()
102
+
103
+ demo.queue().launch(show_api=False)