Staticaliza commited on
Commit
44bd5ae
1 Parent(s): b09a184

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -66
app.py CHANGED
@@ -6,16 +6,6 @@ import re
6
 
7
  API_TOKEN = os.environ.get("API_TOKEN")
8
 
9
- KEY = os.environ.get("KEY")
10
-
11
- SPECIAL_SYMBOLS = ["⠀", "⠀"] # ["‹", "›"] ['"', '"']
12
-
13
- DEFAULT_INPUT = "User: Hi!"
14
- DEFAULT_WRAP = "Statical: %s"
15
- DEFAULT_INSTRUCTION = "Statical is a helpful chatbot who is communicating with people."
16
-
17
- DEFAULT_STOPS = '["⠀", "⠀"]' # '["‹", "›"]' '[\"\\\"\"]'
18
-
19
  API_ENDPOINTS = {
20
  "Falcon": "tiiuae/falcon-180B-chat",
21
  "Llama": "meta-llama/Llama-2-70b-chat-hf",
@@ -33,35 +23,11 @@ for model_name, model_endpoint in API_ENDPOINTS.items():
33
  CHOICES.append(model_name)
34
  CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })
35
 
36
- def format(instruction, history, input, wrap):
37
- sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
38
- wrapped_input = wrap % ("")
39
- formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}\n{sy_l}{message[1]}{sy_r}\n" for message in history)
40
- formatted_input = f"{sy_l}INSTRUCTIONS: {instruction}{sy_r}\n{formatted_history}{sy_l}{input}{sy_r}\n{sy_l}"
41
- return f"{formatted_input}{wrapped_input}", formatted_input
42
-
43
- def predict(access_key, instruction, history, input, wrap, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
44
-
45
- if (access_key != KEY):
46
- print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
47
- return ("[UNAUTHORIZED ACCESS]", input, []);
48
-
49
- instruction = instruction or DEFAULT_INSTRUCTION
50
- history = history or []
51
- input = input or ""
52
- wrap = wrap or ""
53
- stop_seqs = stop_seqs or DEFAULT_STOPS
54
-
55
  stops = json.loads(stop_seqs)
56
-
57
- formatted_input, formatted_input_base = format(instruction, history, input, wrap)
58
-
59
- print(seed)
60
- print(formatted_input)
61
- print(model)
62
-
63
  response = CLIENTS[model].text_generation(
64
- formatted_input,
65
  temperature = temperature,
66
  max_new_tokens = max_tokens,
67
  top_p = top_p,
@@ -75,27 +41,7 @@ def predict(access_key, instruction, history, input, wrap, model, temperature, t
75
  return_full_text = False
76
  )
77
 
78
- sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
79
- result = wrap % (response)
80
-
81
- for stop in stops:
82
- result = result.split(stop, 1)[0]
83
- for symbol in stops:
84
- result = result.replace(symbol, '')
85
-
86
- history = history + [[input, result]]
87
-
88
- print(f"---\nUSER: {input}\nBOT: {result}\n---")
89
-
90
- return (result, input, history)
91
-
92
- def clear_history():
93
- print(">>> HISTORY CLEARED!")
94
- return []
95
-
96
- def maintain_cloud():
97
- print(">>> SPACE MAINTAINED!")
98
- return ("SUCCESS!", "SUCCESS!")
99
 
100
  with gr.Blocks() as demo:
101
  with gr.Row(variant = "panel"):
@@ -103,14 +49,9 @@ with gr.Blocks() as demo:
103
 
104
  with gr.Row():
105
  with gr.Column():
106
- history = gr.Chatbot(label = "History", elem_id = "chatbot")
107
  input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2)
108
- wrap = gr.Textbox(label = "Wrap", value = DEFAULT_WRAP, lines = 1)
109
- instruction = gr.Textbox(label = "Instruction", value = DEFAULT_INSTRUCTION, lines = 4)
110
- access_key = gr.Textbox(label = "Access Key", lines = 1)
111
  run = gr.Button("▶")
112
  clear = gr.Button("🗑️")
113
- cloud = gr.Button("☁️")
114
 
115
  with gr.Column():
116
  model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
@@ -119,15 +60,14 @@ with gr.Blocks() as demo:
119
  top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
120
  rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
121
  max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
122
- stop_seqs = gr.Textbox( value = DEFAULT_STOPS, interactive = True, label = "Stop Sequences ( JSON Array / 4 Max )" )
123
  seed = gr.Slider( minimum = 0, maximum = 9007199254740991, value = 42, step = 1, interactive = True, label = "Seed" )
124
 
125
  with gr.Row():
126
  with gr.Column():
127
  output = gr.Textbox(label = "Output", value = "", lines = 50)
128
 
129
- run.click(predict, inputs = [access_key, instruction, history, input, wrap, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history], queue = False)
130
  clear.click(clear_history, [], history, queue = False)
131
- cloud.click(maintain_cloud, inputs = [], outputs = [input, output], queue = False)
132
 
133
  demo.launch(show_api = True)
 
6
 
7
  API_TOKEN = os.environ.get("API_TOKEN")
8
 
 
 
 
 
 
 
 
 
 
 
9
  API_ENDPOINTS = {
10
  "Falcon": "tiiuae/falcon-180B-chat",
11
  "Llama": "meta-llama/Llama-2-70b-chat-hf",
 
23
  CHOICES.append(model_name)
24
  CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })
25
 
26
+ def predict(input, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  stops = json.loads(stop_seqs)
28
+
 
 
 
 
 
 
29
  response = CLIENTS[model].text_generation(
30
+ input,
31
  temperature = temperature,
32
  max_new_tokens = max_tokens,
33
  top_p = top_p,
 
41
  return_full_text = False
42
  )
43
 
44
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  with gr.Blocks() as demo:
47
  with gr.Row(variant = "panel"):
 
49
 
50
  with gr.Row():
51
  with gr.Column():
 
52
  input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2)
 
 
 
53
  run = gr.Button("▶")
54
  clear = gr.Button("🗑️")
 
55
 
56
  with gr.Column():
57
  model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
 
60
  top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
61
  rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
62
  max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
63
+ stop_seqs = gr.Textbox( value = "", interactive = True, label = "Stop Sequences ( JSON Array / 4 Max )" )
64
  seed = gr.Slider( minimum = 0, maximum = 9007199254740991, value = 42, step = 1, interactive = True, label = "Seed" )
65
 
66
  with gr.Row():
67
  with gr.Column():
68
  output = gr.Textbox(label = "Output", value = "", lines = 50)
69
 
70
+ run.click(predict, inputs = [input, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output], queue = False)
71
  clear.click(clear_history, [], history, queue = False)
 
72
 
73
  demo.launch(show_api = True)