archit11 commited on
Commit
fa811ea
1 Parent(s): b758188

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -12
app.py CHANGED
@@ -5,27 +5,31 @@ import json
5
  import requests
6
 
7
  import gradio as gr
8
- import spaces
9
  import torch
10
  import transformers
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
 
 
13
  DESCRIPTION = """\
14
  # Zero GPU Model Comparison Arena
15
  Select two different models from the dropdowns and see how they perform on the same input.
16
  """
17
 
18
- MAX_MAX_NEW_TOKENS = 1024
19
- DEFAULT_MAX_NEW_TOKENS = 256
 
20
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
21
 
 
22
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
 
 
24
  MODEL_OPTIONS = [
25
  "sarvamai/OpenHathi-7B-Hi-v0.1-Base",
26
  "TokenBender/Navarna_v0_1_OpenHermes_Hindi"
27
  ]
28
 
 
29
  models = {}
30
  tokenizers = {}
31
 
@@ -42,6 +46,7 @@ for model_id in MODEL_OPTIONS:
42
  if tokenizers[model_id].pad_token_id is None:
43
  tokenizers[model_id].pad_token_id = tokenizers[model_id].eos_token_id
44
 
 
45
  def log_comparison(model1_name: str, model2_name: str, question: str, answer1: str, answer2: str, winner: str = None):
46
  log_data = {
47
  "question": question,
@@ -60,13 +65,27 @@ def log_comparison(model1_name: str, model2_name: str, question: str, answer1: s
60
  except requests.RequestException as e:
61
  print(f"Error sending log to server: {e}")
62
 
63
- @spaces.GPU(duration=90)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def generate(
65
  model_id: str,
66
  message: str,
67
  chat_history: List[Tuple[str, str]],
68
  max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
69
- temperature: float = 0.7,
70
  top_p: float = 0.95,
71
  ) -> Iterator[str]:
72
  model = models[model_id]
@@ -99,6 +118,7 @@ def generate(
99
  outputs.append(text)
100
  yield "".join(outputs)
101
 
 
102
  def compare_models(
103
  model1_name: str,
104
  model2_name: str,
@@ -123,13 +143,13 @@ def compare_models(
123
 
124
  return chat_history1, chat_history2, chat_history1, chat_history2
125
 
126
-
127
-
128
  def vote_better(model1_name, model2_name, question, answer1, answer2, choice):
129
  winner = model1_name if choice == "Model 1" else model2_name
130
  log_comparison(model1_name, model2_name, question, answer1, answer2, winner)
131
  return f"You voted that {winner} performs better. This has been logged."
132
 
 
133
  with gr.Blocks(css="style.css") as demo:
134
  gr.Markdown(DESCRIPTION)
135
 
@@ -174,9 +194,6 @@ with gr.Blocks(css="style.css") as demo:
174
  outputs=[vote_output]
175
  )
176
 
 
177
  if __name__ == "__main__":
178
- # Start Flask server in a separate thread
179
-
180
-
181
- # Start Gradio app with public link
182
- demo.queue(max_size=3).launch(share=True)
 
5
  import requests
6
 
7
  import gradio as gr
 
8
  import torch
9
  import transformers
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
+ # Description for the Gradio Interface
13
  DESCRIPTION = """\
14
  # Zero GPU Model Comparison Arena
15
  Select two different models from the dropdowns and see how they perform on the same input.
16
  """
17
 
18
+ # Constants
19
+ MAX_MAX_NEW_TOKENS = 256
20
+ DEFAULT_MAX_NEW_TOKENS = 128
21
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
22
 
23
+ # Device configuration
24
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
 
26
+ # Model options
27
  MODEL_OPTIONS = [
28
  "sarvamai/OpenHathi-7B-Hi-v0.1-Base",
29
  "TokenBender/Navarna_v0_1_OpenHermes_Hindi"
30
  ]
31
 
32
+ # Load models and tokenizers
33
  models = {}
34
  tokenizers = {}
35
 
 
46
  if tokenizers[model_id].pad_token_id is None:
47
  tokenizers[model_id].pad_token_id = tokenizers[model_id].eos_token_id
48
 
49
+ # Function to log comparisons
50
  def log_comparison(model1_name: str, model2_name: str, question: str, answer1: str, answer2: str, winner: str = None):
51
  log_data = {
52
  "question": question,
 
65
  except requests.RequestException as e:
66
  print(f"Error sending log to server: {e}")
67
 
68
+ # Function to prepare input
69
+ def prepare_input(model_id: str, message: str, chat_history: List[Tuple[str, str]]):
70
+ tokenizer = tokenizers[model_id]
71
+ # Prepare inputs for the model
72
+ inputs = tokenizer(
73
+ [x[1] for x in chat_history] + [message],
74
+ return_tensors="pt",
75
+ truncation=True,
76
+ padding=True,
77
+ max_length=MAX_INPUT_TOKEN_LENGTH,
78
+ )
79
+ return inputs
80
+
81
+ # Function to generate responses from models
82
+ @spaces.GPU(duration=120)
83
  def generate(
84
  model_id: str,
85
  message: str,
86
  chat_history: List[Tuple[str, str]],
87
  max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
88
+ temperature: float = 0.4,
89
  top_p: float = 0.95,
90
  ) -> Iterator[str]:
91
  model = models[model_id]
 
118
  outputs.append(text)
119
  yield "".join(outputs)
120
 
121
+ # Function to compare two models
122
  def compare_models(
123
  model1_name: str,
124
  model2_name: str,
 
143
 
144
  return chat_history1, chat_history2, chat_history1, chat_history2
145
 
146
+ # Function to log the voting result
 
147
  def vote_better(model1_name, model2_name, question, answer1, answer2, choice):
148
  winner = model1_name if choice == "Model 1" else model2_name
149
  log_comparison(model1_name, model2_name, question, answer1, answer2, winner)
150
  return f"You voted that {winner} performs better. This has been logged."
151
 
152
+ # Gradio UI setup
153
  with gr.Blocks(css="style.css") as demo:
154
  gr.Markdown(DESCRIPTION)
155
 
 
194
  outputs=[vote_output]
195
  )
196
 
197
+ # Main function to run the Gradio app
198
  if __name__ == "__main__":
199
+ demo.queue(max_size=3).launch(share=True)