voidDescriptor commited on
Commit
193dc4a
1 Parent(s): 0f5e907

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -301
app.py DELETED
@@ -1,301 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import time
5
- import os
6
-
7
- class ConversationManager:
8
- def __init__(self):
9
- self.models = {}
10
- self.conversation = []
11
- self.delay = 3
12
- self.is_paused = False
13
- self.current_model = None
14
- self.initial_prompt = ""
15
- self.task_complete = False # New attribute for task completion
16
-
17
- def load_model(self, model_name):
18
- if model_name in self.models:
19
- return self.models[model_name]
20
-
21
- try:
22
- print(f"Attempting to load model: {model_name}")
23
- tokenizer = AutoTokenizer.from_pretrained(model_name)
24
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
25
- self.models[model_name] = (model, tokenizer)
26
- print(f"Successfully loaded model: {model_name}")
27
- return self.models[model_name]
28
- except Exception as e:
29
- print(f"Failed to load model {model_name}: {e}")
30
- print(f"Error type: {type(e).__name__}")
31
- print(f"Error details: {str(e)}")
32
- return None
33
-
34
- def generate_response(self, model_name, prompt):
35
- model, tokenizer = self.load_model(model_name)
36
-
37
- # Format the prompt based on the model
38
- if "llama" in model_name.lower():
39
- formatted_prompt = self.format_llama2_prompt(prompt)
40
- else:
41
- formatted_prompt = self.format_general_prompt(prompt)
42
-
43
- inputs = tokenizer(formatted_prompt, return_tensors="pt", max_length=1024, truncation=True)
44
- with torch.no_grad():
45
- outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True)
46
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
47
-
48
- def format_llama2_prompt(self, prompt):
49
- B_INST, E_INST = "[INST]", "[/INST]"
50
- B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
51
- system_prompt = "You are a helpful AI assistant. Please provide a concise and relevant response."
52
-
53
- formatted_prompt = f"{B_INST} {B_SYS}{system_prompt}{E_SYS}{prompt.strip()} {E_INST}"
54
- return formatted_prompt
55
-
56
- def format_general_prompt(self, prompt):
57
- # A general format that might work for other models
58
- return f"Human: {prompt.strip()}\n\nAssistant:"
59
- def add_to_conversation(self, model_name, response):
60
- self.conversation.append((model_name, response))
61
- if "task complete?" in response.lower(): # Check for task completion marker
62
- self.task_complete = True
63
-
64
- def get_conversation_history(self):
65
- return "\n".join([f"{model}: {msg}" for model, msg in self.conversation])
66
-
67
- def clear_conversation(self):
68
- self.conversation = []
69
- self.initial_prompt = ""
70
- self.models = {}
71
- self.current_model = None
72
- self.task_complete = False # Reset task completion status
73
-
74
- def rewind_conversation(self, steps):
75
- self.conversation = self.conversation[:-steps]
76
- self.task_complete = False # Reset task completion status after rewinding
77
-
78
- def rewind_and_insert(self, steps, inserted_response):
79
- if steps > 0:
80
- self.conversation = self.conversation[:-steps]
81
- if inserted_response.strip():
82
- last_model = self.conversation[-1][0] if self.conversation else "User"
83
- next_model = "Model 1" if last_model == "Model 2" or last_model == "User" else "Model 2"
84
- self.conversation.append((next_model, inserted_response))
85
- self.current_model = last_model
86
- self.task_complete = False # Reset task completion status after rewinding and inserting
87
-
88
- manager = ConversationManager()
89
-
90
- def get_model(dropdown, custom):
91
- model = custom if custom.strip() else dropdown
92
- return (model, model) # Return a tuple (label, value)
93
-
94
- def chat(model1, model2, user_input, history, inserted_response=""):
95
- try:
96
- model1 = get_model(model1, model1_custom.value)[0]
97
- model2 = get_model(model2, model2_custom.value)[0]
98
-
99
- if not manager.load_model(model1) or not manager.load_model(model2):
100
- return "Error: Failed to load one or both models. Please check the model names and try again.", ""
101
-
102
- if not manager.conversation:
103
- manager.initial_prompt = user_input
104
- manager.clear_conversation()
105
- manager.add_to_conversation("User", user_input)
106
-
107
- models = [model1, model2]
108
- current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1
109
-
110
- while not manager.task_complete:
111
- if manager.is_paused:
112
- yield history, "Conversation paused."
113
- return
114
-
115
- model = models[current_model_index]
116
- manager.current_model = model
117
-
118
- if inserted_response and current_model_index == 0:
119
- response = inserted_response
120
- inserted_response = ""
121
- else:
122
- conversation_history = manager.get_conversation_history()
123
- prompt = f"{conversation_history}\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'"
124
- response = manager.generate_response(model, prompt)
125
-
126
- manager.add_to_conversation(model, response)
127
- history = manager.get_conversation_history()
128
-
129
- for i in range(manager.delay, 0, -1):
130
- yield history, f"{model} is writing... {i}"
131
- time.sleep(1)
132
-
133
- yield history, ""
134
-
135
- if manager.task_complete:
136
- yield history, "Models believe the task is complete. Are you satisfied with the result? (Yes/No)"
137
- return
138
-
139
- current_model_index = (current_model_index + 1) % 2
140
-
141
- return history, "Conversation completed."
142
- except Exception as e:
143
- print(f"Error in chat function: {str(e)}")
144
- return f"An error occurred: {str(e)}", ""
145
-
146
- def user_satisfaction(satisfied, history):
147
- if satisfied.lower() == 'yes':
148
- return history, "Task completed successfully."
149
- else:
150
- manager.task_complete = False
151
- return history, "Continuing the conversation..."
152
-
153
- def pause_conversation():
154
- manager.is_paused = True
155
- return "Conversation paused. Press Resume to continue."
156
-
157
- def resume_conversation():
158
- manager.is_paused = False
159
- return "Conversation resumed."
160
-
161
- def edit_response(edited_text):
162
- if manager.conversation:
163
- manager.conversation[-1] = (manager.current_model, edited_text)
164
- manager.task_complete = False # Reset task completion status after editing
165
- return manager.get_conversation_history()
166
-
167
- def restart_conversation(model1, model2, user_input):
168
- manager.clear_conversation()
169
- return chat(model1, model2, user_input, "")
170
-
171
- def rewind_and_insert(steps, inserted_response, history):
172
- manager.rewind_and_insert(int(steps), inserted_response)
173
- return manager.get_conversation_history(), ""
174
-
175
- # This list should be populated with the exact model names when available
176
- open_source_models = [
177
- "meta-llama/Llama-2-7b-chat-hf",
178
- "meta-llama/Llama-2-13b-chat-hf",
179
- "meta-llama/Llama-2-70b-chat-hf",
180
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
181
- "bigcode/starcoder2-15b",
182
- "bigcode/starcoder2-3b",
183
- "tiiuae/falcon-7b",
184
- "tiiuae/falcon-40b",
185
- "EleutherAI/gpt-neox-20b",
186
- "google/flan-ul2",
187
- "stabilityai/stablelm-zephyr-3b",
188
- "HuggingFaceH4/zephyr-7b-beta",
189
- "microsoft/phi-2",
190
- "google/gemma-7b-it"
191
- ]
192
-
193
- with gr.Blocks() as demo:
194
- gr.Markdown("# ConversAI Playground")
195
-
196
- with gr.Row():
197
- with gr.Column(scale=1):
198
- model1_dropdown = gr.Dropdown(choices=open_source_models, label="Model 1")
199
- model1_custom = gr.Textbox(label="Custom Model 1")
200
- with gr.Column(scale=1):
201
- model2_dropdown = gr.Dropdown(choices=open_source_models, label="Model 2")
202
- model2_custom = gr.Textbox(label="Custom Model 2")
203
-
204
- user_input = gr.Textbox(label="Initial prompt", lines=2)
205
- chat_history = gr.Textbox(label="Conversation", lines=20)
206
- current_response = gr.Textbox(label="Current model response", lines=3)
207
-
208
- with gr.Row():
209
- pause_btn = gr.Button("Pause")
210
- edit_btn = gr.Button("Edit")
211
- rewind_btn = gr.Button("Rewind")
212
- resume_btn = gr.Button("Resume")
213
- restart_btn = gr.Button("Restart")
214
- clear_btn = gr.Button("Clear")
215
-
216
- with gr.Row():
217
- rewind_steps = gr.Slider(0, 10, 1, label="Steps to rewind")
218
- inserted_response = gr.Textbox(label="Insert response after rewind", lines=2)
219
-
220
- delay_slider = gr.Slider(0, 10, 3, label="Response Delay (seconds)")
221
-
222
- user_satisfaction_input = gr.Textbox(label="Are you satisfied with the result? (Yes/No)", visible=False)
223
-
224
-
225
- gr.Markdown("""
226
- ## Button Descriptions
227
- - **Pause**: Temporarily stops the conversation. The current model will finish its response.
228
- - **Edit**: Allows you to modify the last response in the conversation.
229
- - **Rewind**: Removes the specified number of last responses from the conversation.
230
- - **Resume**: Continues the conversation from where it was paused.
231
- - **Restart**: Begins a new conversation with the same or different models, keeping the initial prompt.
232
- - **Clear**: Resets everything, including loaded models, conversation history, and initial prompt.
233
- """)
234
-
235
- def on_chat_update(history, response):
236
- if "Models believe the task is complete" in response:
237
- return gr.update(visible=True), gr.update(visible=False)
238
- return gr.update(visible=False), gr.update(visible=True)
239
-
240
- start_btn = gr.Button("Start Conversation")
241
- chat_output = start_btn.click(
242
- chat,
243
- inputs=[
244
- model1_dropdown,
245
- model2_dropdown,
246
- user_input,
247
- chat_history
248
- ],
249
- outputs=[chat_history, current_response]
250
- )
251
-
252
- chat_output.then(
253
- on_chat_update,
254
- inputs=[chat_history, current_response],
255
- outputs=[user_satisfaction_input, start_btn]
256
- )
257
-
258
- user_satisfaction_input.submit(
259
- user_satisfaction,
260
- inputs=[user_satisfaction_input, chat_history],
261
- outputs=[chat_history, current_response]
262
- ).then(
263
- chat,
264
- inputs=[
265
- model1_dropdown,
266
- model2_dropdown,
267
- user_input,
268
- chat_history
269
- ],
270
- outputs=[chat_history, current_response]
271
- )
272
-
273
- pause_btn.click(pause_conversation, outputs=[current_response])
274
- resume_btn.click(
275
- chat,
276
- inputs=[
277
- model1_dropdown,
278
- model2_dropdown,
279
- user_input,
280
- chat_history,
281
- inserted_response
282
- ],
283
- outputs=[chat_history, current_response]
284
- )
285
- edit_btn.click(edit_response, inputs=[current_response], outputs=[chat_history])
286
- rewind_btn.click(rewind_and_insert, inputs=[rewind_steps, inserted_response, chat_history], outputs=[chat_history, current_response])
287
- restart_btn.click(
288
- restart_conversation,
289
- inputs=[
290
- model1_dropdown,
291
- model2_dropdown,
292
- user_input
293
- ],
294
- outputs=[chat_history, current_response]
295
- )
296
- clear_btn.click(manager.clear_conversation, outputs=[chat_history, current_response, user_input])
297
- delay_slider.change(lambda x: setattr(manager, 'delay', x), inputs=[delay_slider])
298
-
299
- if __name__ == "__main__":
300
- demo.launch()
301
-