archit11 commited on
Commit
56fb754
1 Parent(s): 6d5995f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator, List, Tuple
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # Zero GPU Model Comparison Arena
12
+ Compare two language models using Hugging Face's Zero GPU initiative.
13
+ Select two different models from the dropdowns and see how they perform on the same input.
14
+ """
15
+
16
+ MAX_MAX_NEW_TOKENS = 1024
17
+ DEFAULT_MAX_NEW_TOKENS = 256
18
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
19
+
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+
22
+ MODEL_OPTIONS = [
23
+ "google/gemma-2b-it",
24
+ "mistralai/Mistral-7B-Instruct-v0.2",
25
+ "meta-llama/Llama-2-7b-chat-hf",
26
+ "tiiuae/falcon-7b-instruct"
27
+ ]
28
+
29
+ models = {}
30
+ tokenizers = {}
31
+
32
+ for model_id in MODEL_OPTIONS:
33
+ tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
34
+ models[model_id] = AutoModelForCausalLM.from_pretrained(
35
+ model_id,
36
+ device_map="auto",
37
+ load_in_8bit=True,
38
+ )
39
+ models[model_id].eval()
40
+
41
+ @spaces.GPU(duration=90)
42
+ def generate(
43
+ model_id: str,
44
+ message: str,
45
+ chat_history: List[Tuple[str, str]],
46
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
47
+ temperature: float = 0.7,
48
+ top_p: float = 0.95,
49
+ ) -> Iterator[str]:
50
+ model = models[model_id]
51
+ tokenizer = tokenizers[model_id]
52
+
53
+ conversation = []
54
+ for user, assistant in chat_history:
55
+ conversation.extend([
56
+ {"role": "user", "content": user},
57
+ {"role": "assistant", "content": assistant},
58
+ ])
59
+ conversation.append({"role": "user", "content": message})
60
+
61
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
62
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
63
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
64
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
65
+ input_ids = input_ids.to(model.device)
66
+
67
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
68
+ generate_kwargs = dict(
69
+ input_ids=input_ids,
70
+ streamer=streamer,
71
+ max_new_tokens=max_new_tokens,
72
+ do_sample=True,
73
+ top_p=top_p,
74
+ temperature=temperature,
75
+ num_beams=1,
76
+ )
77
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
78
+ t.start()
79
+
80
+ outputs = []
81
+ for text in streamer:
82
+ outputs.append(text)
83
+ yield "".join(outputs)
84
+
85
+ def compare_models(
86
+ model1_name: str,
87
+ model2_name: str,
88
+ message: str,
89
+ chat_history1: List[Tuple[str, str]],
90
+ chat_history2: List[Tuple[str, str]],
91
+ max_new_tokens: int,
92
+ temperature: float,
93
+ top_p: float,
94
+ ) -> Tuple[str, str, List[Tuple[str, str]], List[Tuple[str, str]]]:
95
+ if model1_name == model2_name:
96
+ return "Error: Please select two different models.", "Error: Please select two different models.", chat_history1, chat_history2
97
+
98
+ output1 = "".join(list(generate(model1_name, message, chat_history1, max_new_tokens, temperature, top_p)))
99
+ output2 = "".join(list(generate(model2_name, message, chat_history2, max_new_tokens, temperature, top_p)))
100
+
101
+ chat_history1.append((message, output1))
102
+ chat_history2.append((message, output2))
103
+
104
+ log_results(model1_name, model2_name, message, output1, output2)
105
+
106
+ return output1, output2, chat_history1, chat_history2
107
+
108
+ def log_results(model1_name: str, model2_name: str, question: str, answer1: str, answer2: str, winner: str = None):
109
+ log_data = {
110
+ "question": question,
111
+ "model1": {"name": model1_name, "answer": answer1},
112
+ "model2": {"name": model2_name, "answer": answer2},
113
+ "winner": winner
114
+ }
115
+
116
+ # Here you would implement the actual logging logic, e.g., sending to a server or writing to a file
117
+ print("Logged:", log_data)
118
+
119
+ def vote_better(model1_name, model2_name, question, answer1, answer2, choice):
120
+ winner = model1_name if choice == "Model 1" else model2_name
121
+ log_results(model1_name, model2_name, question, answer1, answer2, winner)
122
+ return f"You voted that {winner} performs better. This has been logged."
123
+
124
+ with gr.Blocks(css="style.css") as demo:
125
+ gr.Markdown(DESCRIPTION)
126
+
127
+ with gr.Row():
128
+ with gr.Column():
129
+ model1_dropdown = gr.Dropdown(choices=MODEL_OPTIONS, label="Model 1", value=MODEL_OPTIONS[0])
130
+ chatbot1 = gr.Chatbot(label="Model 1 Output")
131
+ with gr.Column():
132
+ model2_dropdown = gr.Dropdown(choices=MODEL_OPTIONS, label="Model 2", value=MODEL_OPTIONS[1])
133
+ chatbot2 = gr.Chatbot(label="Model 2 Output")
134
+
135
+ text_input = gr.Textbox(label="Input Text", lines=3)
136
+
137
+ with gr.Row():
138
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS)
139
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7)
140
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, value=0.95)
141
+
142
+ compare_btn = gr.Button("Compare Models")
143
+
144
+ with gr.Row():
145
+ better1_btn = gr.Button("Model 1 is Better")
146
+ better2_btn = gr.Button("Model 2 is Better")
147
+
148
+ vote_output = gr.Textbox(label="Voting Result")
149
+
150
+ compare_btn.click(
151
+ compare_models,
152
+ inputs=[model1_dropdown, model2_dropdown, text_input, chatbot1, chatbot2, max_new_tokens, temperature, top_p],
153
+ outputs=[chatbot1, chatbot2, chatbot1, chatbot2]
154
+ )
155
+
156
+ better1_btn.click(
157
+ vote_better,
158
+ inputs=[model1_dropdown, model2_dropdown, text_input, chatbot1, chatbot2, gr.Textbox(value="Model 1", visible=False)],
159
+ outputs=[vote_output]
160
+ )
161
+
162
+ better2_btn.click(
163
+ vote_better,
164
+ inputs=[model1_dropdown, model2_dropdown, text_input, chatbot1, chatbot2, gr.Textbox(value="Model 2", visible=False)],
165
+ outputs=[vote_output]
166
+ )
167
+
168
+ if __name__ == "__main__":
169
+ demo.queue(max_size=10).launch()