basakerdogan commited on
Commit
86a1044
1 Parent(s): 4787649

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -239
main.py DELETED
@@ -1,239 +0,0 @@
1
- import os
2
- from typing import Iterator
3
- import gradio as gr
4
- from text_generation import Client
5
-
6
- model_id = 'basakerdogan/Cyber-Jarvis-4Bit'
7
-
8
- API_URL = "https://api-inference.huggingface.co/models/" + model_id
9
- HF_TOKEN = os.environ.get('HF_READ_TOKEN', False)
10
-
11
- client = Client(
12
- API_URL,
13
- headers={'Authorization': f"Bearer {HF_TOKEN}"}
14
- )
15
- EOS_STRING = "</s>"
16
- EOT_STRING = "<EOT>"
17
-
18
- def get_prompt(message, chat_history, system_prompt):
19
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
20
-
21
- do_strip = False
22
- for user_input, response in chat_history:
23
- user_input = user_input.strip() if do_strip else user_input
24
- do_strip = True
25
- texts.append(f"{user_input} [/INST] {response.strip()} </s><s>[INST] ")
26
- message = message.strip() if do_strip else message
27
- texts.append(f"{message} [/INST]")
28
- return ''.join(texts)
29
-
30
- def run(message, chat_history, system_prompt, max_new_tokens=1024, temperature=0.1, top_p=0.9, top_k=50):
31
- prompt = get_prompt(message, chat_history, system_prompt)
32
-
33
- generate_kwargs = dict(
34
- max_new_tokens=max_new_tokens,
35
- do_sample=True,
36
- top_p=top_p,
37
- top_k=top_k,
38
- temperature=temperature
39
- )
40
- stream = client.generate_stream(prompt, **generate_kwargs)
41
- output = ''
42
- for response in stream:
43
- if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
44
- return output
45
- else:
46
- output += response.token.text
47
- yield output
48
- return output
49
-
50
- DEFAULT_SYSTEM_PROMPT = """
51
- >This Space features Cyber-Jarvis, an AI model inspired by Iron Man's J.A.R.V.I.S. Cyber-Jarvis is designed with cybersecurity in mind, drawing inspiration from the iconic AI assistant while focusing on security
52
- """
53
- MAX_MAX_NEW_TOKENS = 4096
54
- DEFAULT_MAX_NEW_TOKENS = 256
55
- MAX_INPUT_TOKEN_LENGTH = 4000
56
-
57
- DESCRIPTION = "Cyber-Jarvis AI"
58
-
59
- def clear_and_save_textbox(message): return '', message
60
-
61
- def display_input(message, history=[]):
62
- history.append((message, ''))
63
- return history
64
-
65
- def delete_prev_fn(history=[]):
66
- try:
67
- message, _ = history.pop()
68
- except IndexError:
69
- message = ''
70
- return history, message or ''
71
-
72
- def generate(message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k):
73
- if max_new_tokens > MAX_MAX_NEW_TOKENS:
74
- raise ValueError
75
-
76
- history = history_with_input[:-1]
77
- generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
78
- try:
79
- first_response = next(generator)
80
- yield history + [(message, first_response)]
81
- except StopIteration:
82
- yield history + [(message, '')]
83
- for response in generator:
84
- yield history + [(message, response)]
85
-
86
- def process_example(message):
87
- generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
88
- for x in generator:
89
- pass
90
- return '', x
91
-
92
- def check_input_token_length(message, chat_history, system_prompt):
93
- input_token_length = len(message) + len(chat_history)
94
- if input_token_length > MAX_INPUT_TOKEN_LENGTH:
95
- raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
96
-
97
- with gr.Blocks(theme='Taithrah/Minimal') as demo:
98
- gr.Markdown(DESCRIPTION)
99
-
100
-
101
-
102
- with gr.Group():
103
- chatbot = gr.Chatbot(label='RickyAI based on Mistral-7B-Instruct-v0.1')
104
- with gr.Row():
105
- textbox = gr.Textbox(
106
- container=False,
107
- show_label=False,
108
- placeholder='Hi, Ricky',
109
- scale=10
110
- )
111
- submit_button = gr.Button('Submit', variant='primary', scale=1, min_width=0)
112
-
113
- with gr.Row():
114
- retry_button = gr.Button('Retry', variant='secondary')
115
- undo_button = gr.Button('Undo', variant='secondary')
116
- clear_button = gr.Button('Clear', variant='secondary')
117
-
118
- saved_input = gr.State()
119
-
120
- with gr.Accordion(label='Advanced options', open=False):
121
- system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False)
122
- max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
123
- temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
124
- top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
125
- top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=10)
126
-
127
- textbox.submit(
128
- fn=clear_and_save_textbox,
129
- inputs=textbox,
130
- outputs=[textbox, saved_input],
131
- api_name=False,
132
- queue=False,
133
- ).then(
134
- fn=display_input,
135
- inputs=[saved_input, chatbot],
136
- outputs=chatbot,
137
- api_name=False,
138
- queue=False,
139
- ).then(
140
- fn=check_input_token_length,
141
- inputs=[saved_input, chatbot, system_prompt],
142
- api_name=False,
143
- queue=False,
144
- ).success(
145
- fn=generate,
146
- inputs=[
147
- saved_input,
148
- chatbot,
149
- system_prompt,
150
- max_new_tokens,
151
- temperature,
152
- top_p,
153
- top_k,
154
- ],
155
- outputs=chatbot,
156
- api_name=False,
157
- )
158
-
159
- button_event_preprocess = submit_button.click(
160
- fn=clear_and_save_textbox,
161
- inputs=textbox,
162
- outputs=[textbox, saved_input],
163
- api_name=False,
164
- queue=False,
165
- ).then(
166
- fn=display_input,
167
- inputs=[saved_input, chatbot],
168
- outputs=chatbot,
169
- api_name=False,
170
- queue=False,
171
- ).then(
172
- fn=check_input_token_length,
173
- inputs=[saved_input, chatbot, system_prompt],
174
- api_name=False,
175
- queue=False,
176
- ).success(
177
- fn=generate,
178
- inputs=[
179
- saved_input,
180
- chatbot,
181
- system_prompt,
182
- max_new_tokens,
183
- temperature,
184
- top_p,
185
- top_k,
186
- ],
187
- outputs=chatbot,
188
- api_name=False,
189
- )
190
-
191
- retry_button.click(
192
- fn=delete_prev_fn,
193
- inputs=chatbot,
194
- outputs=[chatbot, saved_input],
195
- api_name=False,
196
- queue=False,
197
- ).then(
198
- fn=display_input,
199
- inputs=[saved_input, chatbot],
200
- outputs=chatbot,
201
- api_name=False,
202
- queue=False,
203
- ).then(
204
- fn=generate,
205
- inputs=[
206
- saved_input,
207
- chatbot,
208
- system_prompt,
209
- max_new_tokens,
210
- temperature,
211
- top_p,
212
- top_k,
213
- ],
214
- outputs=chatbot,
215
- api_name=False,
216
- )
217
-
218
- undo_button.click(
219
- fn=delete_prev_fn,
220
- inputs=chatbot,
221
- outputs=[chatbot, saved_input],
222
- api_name=False,
223
- queue=False,
224
- ).then(
225
- fn=lambda x: x,
226
- inputs=[saved_input],
227
- outputs=textbox,
228
- api_name=False,
229
- queue=False,
230
- )
231
-
232
- clear_button.click(
233
- fn=lambda: ([], ''),
234
- outputs=[chatbot, saved_input],
235
- queue=False,
236
- api_name=False,
237
- )
238
-
239
- demo.queue(max_size=32).launch(show_api=False)