Shriharsh commited on
Commit
69bb03a
1 Parent(s): 0660ab2

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +78 -0
  2. requirements.txt.txt +1 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+
4
+ client = InferenceClient(
5
+ "google/gemma-7b-it"
6
+ )
7
+
8
+ def format_prompt(message, history):
9
+ prompt = ""
10
+ if history:
11
+ #<start_of_turn>userWhat is recession?<end_of_turn><start_of_turn>model
12
+ for user_prompt, bot_response in history:
13
+ prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
14
+ prompt += f"<start_of_turn>model{bot_response}"
15
+ prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
16
+ return prompt
17
+
18
+
19
+ def chat_inf(system_prompt,prompt,history,seed,temp,tokens,top_p,rep_p):
20
+ #token max=8192
21
+ if not history:
22
+ history = []
23
+ hist_len=0
24
+ if history:
25
+ hist_len=len(history)
26
+ print(hist_len)
27
+
28
+ generate_kwargs = dict(
29
+ temperature=temp,
30
+ max_new_tokens=tokens,
31
+ top_p=top_p,
32
+ repetition_penalty=rep_p,
33
+ do_sample=True,
34
+ seed=seed,
35
+ )
36
+ #formatted_prompt=prompt
37
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
38
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
39
+ output = ""
40
+
41
+ for response in stream:
42
+ output += response.token.text
43
+ yield [(prompt,output)]
44
+ history.append((prompt,output))
45
+ yield history
46
+
47
+
48
+ with gr.Blocks() as app:
49
+ gr.HTML("""<center><h1 style='font-size:xx-large;'>Google Gemma 7B Chat</h1></center>""")
50
+ chat_b = gr.Chatbot(height=450, layout = "bubble")
51
+
52
+ with gr.Group():
53
+ with gr.Row():
54
+ with gr.Column(scale=3):
55
+ inp = gr.Textbox(label="Prompt")
56
+ sys_inp = gr.Textbox(label="System Prompt (optional)")
57
+ with gr.Row():
58
+ with gr.Column(scale=2):
59
+ btn = gr.Button("Chat")
60
+ with gr.Column(scale=1):
61
+ with gr.Group():
62
+ stop_btn=gr.Button("Stop")
63
+ clear_btn=gr.Button("Clear")
64
+
65
+ with gr.Column(scale=1):
66
+ with gr.Group():
67
+ rand = gr.Checkbox(label="Random Seed", value=True)
68
+ seed=gr.Slider(label="Seed", minimum=1, maximum=1111111111111111,step=1, value=rand_val)
69
+ tokens = gr.Slider(label="Max new tokens",value=6400,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
70
+ temp=gr.Slider(label="Temperature",step=0.01, minimum=0.01, maximum=1.0, value=0.9)
71
+ top_p=gr.Slider(label="Top-P",step=0.01, minimum=0.01, maximum=1.0, value=0.9)
72
+ rep_p=gr.Slider(label="Repetition Penalty",step=0.1, minimum=0.1, maximum=2.0, value=1.0)
73
+
74
+ chat_sub=inp.submit(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,seed,temp,tokens,top_p,rep_p],chat_b)
75
+ go=btn.click(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,seed,temp,tokens,top_p,rep_p],chat_b)
76
+ stop_btn.click(None,None,None,cancels=[go,im_go,chat_sub])
77
+ clear_btn.click(clear_fn,None,[chat_b])
78
+ app.queue(default_concurrency_limit=10).launch()
requirements.txt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ huggingface_hub