voidDescriptor commited on
Commit
041e17c
1 Parent(s): b1c60d2

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -314
app.py DELETED
@@ -1,314 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import time
5
- import bitsandbytes as bnb
6
-
7
- print(f"bitsandbytes version: {bnb.__version__}")
8
- print(f"CUDA is available: {torch.cuda.is_available()}")
9
- print(f"CUDA device count: {torch.cuda.device_count()}")
10
- if torch.cuda.is_available():
11
- print(f"Current CUDA device: {torch.cuda.current_device()}")
12
- print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
13
-
14
- class ConversationManager:
15
- def __init__(self):
16
- self.models = {}
17
- self.conversation = []
18
- self.delay = 3
19
- self.is_paused = False
20
- self.current_model = None
21
- self.initial_prompt = ""
22
- self.task_complete = False
23
-
24
- def load_model(self, model_name):
25
- if not model_name:
26
- print("Error: Empty model name provided")
27
- return None
28
-
29
- if model_name in self.models:
30
- return self.models[model_name]
31
-
32
- try:
33
- print(f"Attempting to load model: {model_name}")
34
- tokenizer = AutoTokenizer.from_pretrained(model_name)
35
- # Try to load the model with 8-bit quantization
36
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
37
- except RuntimeError as e:
38
- print(f"8-bit quantization not available, falling back to full precision: {e}")
39
- if torch.cuda.is_available():
40
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
41
- else:
42
- model = AutoModelForCausalLM.from_pretrained(model_name)
43
- except Exception as e:
44
- print(f"Failed to load model {model_name}: {e}")
45
- print(f"Error type: {type(e).__name__}")
46
- print(f"Error details: {str(e)}")
47
- return None
48
-
49
- self.models[model_name] = (model, tokenizer)
50
- print(f"Successfully loaded model: {model_name}")
51
- return self.models[model_name]
52
- except Exception as e:
53
- print(f"Failed to load model {model_name}: {e}")
54
- print(f"Error type: {type(e).__name__}")
55
- print(f"Error details: {str(e)}")
56
- return None
57
-
58
- def generate_response(self, model_name, prompt):
59
- model, tokenizer = self.load_model(model_name)
60
-
61
- formatted_prompt = f"Human: {prompt.strip()}\n\nAssistant:"
62
-
63
- inputs = tokenizer(formatted_prompt, return_tensors="pt", max_length=1024, truncation=True)
64
- with torch.no_grad():
65
- outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True)
66
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
67
-
68
- def add_to_conversation(self, model_name, response):
69
- self.conversation.append((model_name, response))
70
- if "task complete?" in response.lower():
71
- self.task_complete = True
72
-
73
- def get_conversation_history(self):
74
- return "\n".join([f"{model}: {msg}" for model, msg in self.conversation])
75
-
76
- def clear_conversation(self):
77
- self.conversation = []
78
- self.initial_prompt = ""
79
- self.models = {}
80
- self.current_model = None
81
- self.task_complete = False
82
-
83
- def rewind_conversation(self, steps):
84
- self.conversation = self.conversation[:-steps]
85
- self.task_complete = False
86
-
87
- def rewind_and_insert(self, steps, inserted_response):
88
- if steps > 0:
89
- self.conversation = self.conversation[:-steps]
90
- if inserted_response.strip():
91
- last_model = self.conversation[-1][0] if self.conversation else "User"
92
- next_model = "Model 1" if last_model == "Model 2" or last_model == "User" else "Model 2"
93
- self.conversation.append((next_model, inserted_response))
94
- self.current_model = last_model
95
- self.task_complete = False
96
-
97
- manager = ConversationManager()
98
-
99
- def get_model(dropdown, custom):
100
- return custom if custom and custom.strip() else dropdown
101
-
102
- def chat(model1, model2, user_input, history, inserted_response=""):
103
- try:
104
- print(f"Starting chat with models: {model1}, {model2}")
105
- print(f"User input: {user_input}")
106
-
107
- model1 = get_model(model1, model1_custom.value)
108
- model2 = get_model(model2, model2_custom.value)
109
-
110
- print(f"Selected models: {model1}, {model2}")
111
-
112
- if not manager.load_model(model1) or not manager.load_model(model2):
113
- return "Error: Failed to load one or both models. Please check the model names and try again.", ""
114
-
115
- if not manager.conversation:
116
- manager.initial_prompt = user_input
117
- manager.clear_conversation()
118
- manager.add_to_conversation("User", user_input)
119
-
120
- models = [model1, model2]
121
- current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1
122
-
123
- while not manager.task_complete:
124
- if manager.is_paused:
125
- yield history, "Conversation paused."
126
- return
127
-
128
- model = models[current_model_index]
129
- manager.current_model = model
130
-
131
- if inserted_response and current_model_index == 0:
132
- response = inserted_response
133
- inserted_response = ""
134
- else:
135
- conversation_history = manager.get_conversation_history()
136
- prompt = f"{conversation_history}\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'"
137
- response = manager.generate_response(model, prompt)
138
-
139
- manager.add_to_conversation(model, response)
140
- history = manager.get_conversation_history()
141
-
142
- for i in range(manager.delay, 0, -1):
143
- yield history, f"{model} is writing... {i}"
144
- time.sleep(1)
145
-
146
- yield history, ""
147
-
148
- if manager.task_complete:
149
- yield history, "Models believe the task is complete. Are you satisfied with the result? (Yes/No)"
150
- return
151
-
152
- current_model_index = (current_model_index + 1) % 2
153
-
154
- return history, "Conversation completed."
155
- except Exception as e:
156
- print(f"Error in chat function: {str(e)}")
157
- print(f"Error type: {type(e).__name__}")
158
- print(f"Error details: {str(e)}")
159
- return f"An error occurred: {str(e)}", ""
160
-
161
- def user_satisfaction(satisfied, history):
162
- if satisfied.lower() == 'yes':
163
- return history, "Task completed successfully."
164
- else:
165
- manager.task_complete = False
166
- return history, "Continuing the conversation..."
167
-
168
- def pause_conversation():
169
- manager.is_paused = True
170
- return "Conversation paused. Press Resume to continue."
171
-
172
- def resume_conversation():
173
- manager.is_paused = False
174
- return "Conversation resumed."
175
-
176
- def edit_response(edited_text):
177
- if manager.conversation:
178
- manager.conversation[-1] = (manager.current_model, edited_text)
179
- manager.task_complete = False
180
- return manager.get_conversation_history()
181
-
182
- def restart_conversation(model1, model2, user_input):
183
- manager.clear_conversation()
184
- return chat(model1, model2, user_input, "")
185
-
186
- def rewind_and_insert(steps, inserted_response, history):
187
- manager.rewind_and_insert(int(steps), inserted_response)
188
- return manager.get_conversation_history(), ""
189
-
190
- open_source_models = [
191
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
192
- "bigcode/starcoder2-15b",
193
- "bigcode/starcoder2-3b",
194
- "tiiuae/falcon-7b",
195
- "EleutherAI/gpt-neox-20b",
196
- "google/flan-ul2",
197
- "stabilityai/stablelm-zephyr-3b",
198
- "HuggingFaceH4/zephyr-7b-beta",
199
- "microsoft/phi-2",
200
- "google/gemma-7b-it",
201
- "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
202
- "mosaicml/mpt-7b-chat",
203
- "databricks/dolly-v2-12b",
204
- "thebloke/Wizard-Vicuna-13B-Uncensored-HF",
205
- "bigscience/bloom-560m"
206
- ]
207
-
208
- with gr.Blocks() as demo:
209
- gr.Markdown("# ConversAI Playground")
210
-
211
- with gr.Row():
212
- with gr.Column(scale=1):
213
- model1_dropdown = gr.Dropdown(choices=open_source_models, label="Model 1")
214
- model1_custom = gr.Textbox(label="Custom Model 1")
215
- with gr.Column(scale=1):
216
- model2_dropdown = gr.Dropdown(choices=open_source_models, label="Model 2")
217
- model2_custom = gr.Textbox(label="Custom Model 2")
218
-
219
- user_input = gr.Textbox(label="Initial prompt", lines=2)
220
- chat_history = gr.Textbox(label="Conversation", lines=20)
221
- current_response = gr.Textbox(label="Current model response", lines=3)
222
-
223
- with gr.Row():
224
- pause_btn = gr.Button("Pause")
225
- edit_btn = gr.Button("Edit")
226
- rewind_btn = gr.Button("Rewind")
227
- resume_btn = gr.Button("Resume")
228
- restart_btn = gr.Button("Restart")
229
- clear_btn = gr.Button("Clear")
230
-
231
- with gr.Row():
232
- rewind_steps = gr.Slider(0, 10, 1, label="Steps to rewind")
233
- inserted_response = gr.Textbox(label="Insert response after rewind", lines=2)
234
-
235
- delay_slider = gr.Slider(0, 10, 3, label="Response Delay (seconds)")
236
-
237
- user_satisfaction_input = gr.Textbox(label="Are you satisfied with the result? (Yes/No)", visible=False)
238
-
239
- gr.Markdown("""
240
- ## Button Descriptions
241
- - **Pause**: Temporarily stops the conversation. The current model will finish its response.
242
- - **Edit**: Allows you to modify the last response in the conversation.
243
- - **Rewind**: Removes the specified number of last responses from the conversation.
244
- - **Resume**: Continues the conversation from where it was paused.
245
- - **Restart**: Begins a new conversation with the same or different models, keeping the initial prompt.
246
- - **Clear**: Resets everything, including loaded models, conversation history, and initial prompt.
247
- """)
248
-
249
- def on_chat_update(history, response):
250
- if response and "Models believe the task is complete" in response:
251
- return gr.update(visible=True), gr.update(visible=False)
252
- return gr.update(visible=False), gr.update(visible=True)
253
-
254
- start_btn = gr.Button("Start Conversation")
255
- chat_output = start_btn.click(
256
- chat,
257
- inputs=[
258
- model1_dropdown,
259
- model2_dropdown,
260
- user_input,
261
- chat_history
262
- ],
263
- outputs=[chat_history, current_response]
264
- )
265
-
266
- chat_output.then(
267
- on_chat_update,
268
- inputs=[chat_history, current_response],
269
- outputs=[user_satisfaction_input, start_btn]
270
- )
271
-
272
- user_satisfaction_input.submit(
273
- user_satisfaction,
274
- inputs=[user_satisfaction_input, chat_history],
275
- outputs=[chat_history, current_response]
276
- ).then(
277
- chat,
278
- inputs=[
279
- model1_dropdown,
280
- model2_dropdown,
281
- user_input,
282
- chat_history
283
- ],
284
- outputs=[chat_history, current_response]
285
- )
286
-
287
- pause_btn.click(pause_conversation, outputs=[current_response])
288
- resume_btn.click(
289
- chat,
290
- inputs=[
291
- model1_dropdown,
292
- model2_dropdown,
293
- user_input,
294
- chat_history,
295
- inserted_response
296
- ],
297
- outputs=[chat_history, current_response]
298
- )
299
- edit_btn.click(edit_response, inputs=[current_response], outputs=[chat_history])
300
- rewind_btn.click(rewind_and_insert, inputs=[rewind_steps, inserted_response, chat_history], outputs=[chat_history, current_response])
301
- restart_btn.click(
302
- restart_conversation,
303
- inputs=[
304
- model1_dropdown,
305
- model2_dropdown,
306
- user_input
307
- ],
308
- outputs=[chat_history, current_response]
309
- )
310
- clear_btn.click(manager.clear_conversation, outputs=[chat_history, current_response, user_input])
311
- delay_slider.change(lambda x: setattr(manager, 'delay', x), inputs=[delay_slider])
312
-
313
- if __name__ == "__main__":
314
- demo.launch()