Omnibus commited on
Commit
0651449
1 Parent(s): cd7ba0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -20
app.py CHANGED
@@ -8,33 +8,37 @@ models=[
8
  "google/gemma-7b",
9
  "google/gemma-7b-it",
10
  "google/gemma-2b",
11
- "google/gemma-2b-it"
12
- ]
13
- clients=[
14
- InferenceClient(models[0]),
15
- InferenceClient(models[1]),
16
- InferenceClient(models[2]),
17
- InferenceClient(models[3]),
18
  ]
 
19
 
20
  VERBOSE=False
21
 
22
- def format_prompt(message, history):
 
 
 
 
 
 
 
 
 
23
  prompt = ""
24
  if history:
25
- #<start_of_turn>userHow does the brain work?<end_of_turn><start_of_turn>model
26
  for user_prompt, bot_response in history:
27
- prompt += f"{user_prompt}\n"
28
- #print(prompt)
29
- prompt += f"{bot_response}\n"
30
- #print(prompt)
31
- prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
 
32
  return prompt
33
 
34
- def chat_inf(system_prompt,prompt,history,memory,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem):
35
  #token max=8192
36
  hist_len=0
37
- client=clients[int(client_choice)-1]
38
  if not history:
39
  history = []
40
  hist_len=0
@@ -58,8 +62,10 @@ def chat_inf(system_prompt,prompt,history,memory,client_choice,seed,temp,tokens,
58
  do_sample=True,
59
  seed=seed,
60
  )
61
- formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", memory[0-chat_mem:])
62
- #print("\n######### PROMPT "+str(len(formatted_prompt)))
 
 
63
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
64
  output = ""
65
  for response in stream:
@@ -68,6 +74,7 @@ def chat_inf(system_prompt,prompt,history,memory,client_choice,seed,temp,tokens,
68
  history.append((prompt,output))
69
  memory.append((prompt,output))
70
  yield history,memory
 
71
  if VERBOSE==True:
72
  print("\n######### HIST "+str(in_len))
73
  print("\n######### TOKENS "+str(tokens))
@@ -109,6 +116,8 @@ with gr.Blocks() as app:
109
  stop_btn=gr.Button("Stop")
110
  clear_btn=gr.Button("Clear")
111
  client_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
 
 
112
  with gr.Column(scale=1):
113
  with gr.Group():
114
  rand = gr.Checkbox(label="Random Seed", value=True)
@@ -131,9 +140,15 @@ with gr.Blocks() as app:
131
  theme=gr.Radio(label="Theme", choices=["light","dark"],value="light")
132
  chatblock=gr.Dropdown(label="Chatblocks",info="Choose specific blocks of chat",choices=[c for c in range(1,40)],multiselect=True)
133
 
 
 
 
 
134
  im_go=im_btn.click(get_screenshot,[chat_b,im_height,im_width,chatblock,theme,wait_time],img)
135
- chat_sub=inp.submit(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,memory,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem],[chat_b,memory])
136
- go=btn.click(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,memory,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem],[chat_b,memory])
 
 
137
  stop_btn.click(None,None,None,cancels=[go,im_go,chat_sub])
138
  clear_btn.click(clear_fn,None,[inp,sys_inp,chat_b,memory])
139
  app.queue(default_concurrency_limit=10).launch()
 
8
  "google/gemma-7b",
9
  "google/gemma-7b-it",
10
  "google/gemma-2b",
11
+ "google/gemma-2b-it",
 
 
 
 
 
 
12
  ]
13
+ client_z=[]
14
 
15
  VERBOSE=False
16
 
17
+ def load_models(inp):
18
+ if VERBOSE==True:
19
+ print(type(inp))
20
+ print(inp)
21
+ print(models[inp])
22
+ client_z.clear()
23
+ client_z.append(InferenceClient(models[inp]))
24
+ return gr.update(label=models[inp])
25
+
26
+ def format_prompt(message, history, cust_p):
27
  prompt = ""
28
  if history:
 
29
  for user_prompt, bot_response in history:
30
+ prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
31
+ prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
32
+ if VERBOSE==True:
33
+ print(prompt)
34
+ #prompt += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
35
+ prompt+=cust_p.replace("USER_INPUT",message)
36
  return prompt
37
 
38
+ def chat_inf(system_prompt,prompt,history,memory,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem,cust_p):
39
  #token max=8192
40
  hist_len=0
41
+ client=client_z[0]
42
  if not history:
43
  history = []
44
  hist_len=0
 
62
  do_sample=True,
63
  seed=seed,
64
  )
65
+ if system_prompt:
66
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", memory[0-chat_mem:],cust_p)
67
+ else:
68
+ formatted_prompt = format_prompt(prompt, memory[0-chat_mem:],cust_p)
69
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
70
  output = ""
71
  for response in stream:
 
74
  history.append((prompt,output))
75
  memory.append((prompt,output))
76
  yield history,memory
77
+
78
  if VERBOSE==True:
79
  print("\n######### HIST "+str(in_len))
80
  print("\n######### TOKENS "+str(tokens))
 
116
  stop_btn=gr.Button("Stop")
117
  clear_btn=gr.Button("Clear")
118
  client_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
119
+ with gr.Accordion("Prompt Format",open=False):
120
+ custom_prompt=gr.Textbox(label="Modify Prompt Format", info="For testing purposes. 'USER_INPUT' is where 'SYSTEM_PROMPT, PROMPT' will be placed", lines=5,value="<start_of_turn>userUSER_INPUT<end_of_turn><start_of_turn>model")
121
  with gr.Column(scale=1):
122
  with gr.Group():
123
  rand = gr.Checkbox(label="Random Seed", value=True)
 
140
  theme=gr.Radio(label="Theme", choices=["light","dark"],value="light")
141
  chatblock=gr.Dropdown(label="Chatblocks",info="Choose specific blocks of chat",choices=[c for c in range(1,40)],multiselect=True)
142
 
143
+
144
+ client_choice.change(load_models,client_choice,[chat_b])
145
+ app.load(load_models,client_choice,[chat_b])
146
+
147
  im_go=im_btn.click(get_screenshot,[chat_b,im_height,im_width,chatblock,theme,wait_time],img)
148
+
149
+ chat_sub=inp.submit(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,memory,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem,custom_prompt],[chat_b,memory])
150
+ go=btn.click(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,memory,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem,custom_prompt],[chat_b,memory])
151
+
152
  stop_btn.click(None,None,None,cancels=[go,im_go,chat_sub])
153
  clear_btn.click(clear_fn,None,[inp,sys_inp,chat_b,memory])
154
  app.queue(default_concurrency_limit=10).launch()