farhananis005 commited on
Commit
ec0af28
·
verified ·
1 Parent(s): e32f84f

LLM finetuning demo

Browse files
Files changed (4) hide show
  1. app.py +206 -0
  2. backend.py +103 -0
  3. config.py +83 -0
  4. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ from threading import Thread
4
+ from queue import Queue
5
+
6
+ # Import our new modules
7
+ import config
8
+ import backend
9
+
10
+ # --- HELPER FUNCTIONS (Unchanged) ---
11
+ def get_random_question(domain):
12
+ data_conf = config.DATASET_CONFIG[domain]
13
+ dataset = data_conf["dataset"]
14
+
15
+ if not dataset:
16
+ return "Failed to load dataset.", "N/A"
17
+
18
+ random_index = random.randint(0, len(dataset) - 1)
19
+ sample = dataset[random_index]
20
+
21
+ if domain == "Math":
22
+ question = sample[data_conf["question_col"]]
23
+ answer = sample[data_conf["answer_col"]]
24
+ elif domain == "Bio":
25
+ instruction = sample[data_conf["instruction_col"]]
26
+ bio_input = sample[data_conf["input_col"]]
27
+ answer = sample[data_conf["answer_col"]]
28
+ if bio_input and bio_input.strip():
29
+ question = f"**Instruction:**\n{instruction}\n\n**Input:**\n{bio_input}"
30
+ else:
31
+ question = instruction
32
+
33
+ return question, answer
34
+
35
+ def update_domain_settings(domain):
36
+ models = list(config.ALL_MODELS[domain].keys())
37
+ def_base = next((m for m in models if "Base" in m), models[0])
38
+ def_ft = next((m for m in models if "Finetuned" in m), models[0])
39
+
40
+ q, a = get_random_question(domain)
41
+ return [
42
+ gr.Dropdown(choices=models, value=def_base),
43
+ gr.Dropdown(choices=models, value=def_ft),
44
+ gr.Textbox(value=q),
45
+ a,
46
+ gr.Markdown(visible=False)
47
+ ]
48
+
49
+ def load_next_question(domain):
50
+ q, a = get_random_question(domain)
51
+ return [gr.Textbox(value=q), a, gr.Markdown(visible=False, value="")]
52
+
53
+ def reveal_answer(hidden_answer):
54
+ return gr.Markdown(value=f"**Ground Truth Answer:**\n\n{hidden_answer}", visible=True)
55
+
56
+ # --- CORE LOGIC (REBUILT FOR TRUE PARALLEL STREAMING) ---
57
+
58
+ def stream_to_queue(model_id, prompt, lane, queue, key):
59
+ """
60
+ A worker function that runs in a thread.
61
+ It calls the streaming API and puts tokens into the queue.
62
+ """
63
+ try:
64
+ # call_modal_api is a generator
65
+ for token in backend.call_modal_api(model_id, prompt, lane):
66
+ queue.put((key, token))
67
+ except Exception as e:
68
+ queue.put((key, f"\n\nTHREAD ERROR: {e}"))
69
+ finally:
70
+ # When the stream is done, put a 'None' sentinel
71
+ queue.put((key, None))
72
+
73
+ def run_comparison(domain, question, model_1_name, model_2_name):
74
+ # 1. Get IDs
75
+ id_1 = config.ALL_MODELS[domain].get(model_1_name)
76
+ id_2 = config.ALL_MODELS[domain].get(model_2_name)
77
+
78
+ # 2. Ask the Smart Router
79
+ lane_for_m1, lane_for_m2 = backend.router.get_routing_plan(id_1, id_2)
80
+
81
+ # 3. Create the Queue and Threads
82
+ q = Queue()
83
+
84
+ Thread(
85
+ target=stream_to_queue,
86
+ args=(id_1, question, lane_for_m1, q, 'm1')
87
+ ).start()
88
+
89
+ Thread(
90
+ target=stream_to_queue,
91
+ args=(id_2, question, lane_for_m2, q, 'm2')
92
+ ).start()
93
+
94
+ # 4. Listen to the Queue
95
+ text1 = ""
96
+ text2 = ""
97
+ m1_done = False
98
+ m2_done = False
99
+
100
+ # Clear boxes and start
101
+ yield "", "", gr.Markdown(visible=False)
102
+
103
+ while not (m1_done and m2_done):
104
+ # Wait for the next token from *either* thread
105
+ try:
106
+ key, token = q.get()
107
+ except Exception as e:
108
+ # This should ideally not happen
109
+ print(f"Queue error: {e}")
110
+ continue
111
+
112
+ # Check for the 'None' sentinel
113
+ if token is None:
114
+ if key == 'm1':
115
+ m1_done = True
116
+ elif key == 'm2':
117
+ m2_done = True
118
+ else:
119
+ # Append the new token
120
+ if key == 'm1':
121
+ text1 += token
122
+ elif key == 'm2':
123
+ text2 += token
124
+
125
+ # Yield the updated full text
126
+ yield text1, text2, gr.Markdown(visible=False)
127
+
128
+
129
+ # --- UI BUILD (Unchanged) ---
130
+ initial_question, initial_answer = get_random_question("Math")
131
+
132
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
+ gr.Markdown(
134
+ """
135
+ # 🔬 LLM Finetuning Arena
136
+ ### Comparing Finetuned vs. Base Models on Specialized Tasks
137
+ """
138
+ )
139
+
140
+ hidden_answer_state = gr.State(value=initial_answer)
141
+
142
+ with gr.Row():
143
+ domain_radio = gr.Radio(
144
+ ["Math", "Bio"], label="1. Select Domain", value="Math"
145
+ )
146
+
147
+ with gr.Row():
148
+ question_box = gr.Textbox(
149
+ label="2. Question Prompt (Editable)",
150
+ value=initial_question, lines=5, scale=4
151
+ )
152
+ next_btn = gr.Button("Load Random Question 🔄", scale=1, min_width=100)
153
+
154
+ with gr.Row():
155
+ model_1_dd = gr.Dropdown(
156
+ label="3. Select Model 1 (Left)",
157
+ choices=list(config.ALL_MODELS["Math"].keys()),
158
+ value=next((m for m in config.ALL_MODELS["Math"] if "Base" in m))
159
+ )
160
+ model_2_dd = gr.Dropdown(
161
+ label="4. Select Model 2 (Right)",
162
+ choices=list(config.ALL_MODELS["Math"].keys()),
163
+ value=next((m for m in config.ALL_MODELS["Math"] if "Finetuned" in m))
164
+ )
165
+
166
+ with gr.Row():
167
+ run_btn = gr.Button("🚀 Run Comparison", variant="primary", scale=3)
168
+ show_answer_btn = gr.Button("Show Ground Truth Answer", scale=1)
169
+
170
+ answer_display_box = gr.Markdown(label="Ground Truth Answer", visible=False)
171
+
172
+ gr.Markdown("---")
173
+
174
+ with gr.Row():
175
+ output_1_box = gr.Markdown(label="Output: Model 1")
176
+ output_2_box = gr.Markdown(label="Output: Model 2")
177
+
178
+ # --- EVENTS (Unchanged) ---
179
+ domain_radio.change(
180
+ fn=update_domain_settings,
181
+ inputs=[domain_radio],
182
+ outputs=[model_1_dd, model_2_dd, question_box, hidden_answer_state, answer_display_box]
183
+ )
184
+
185
+ next_btn.click(
186
+ fn=load_next_question,
187
+ inputs=[domain_radio],
188
+ outputs=[question_box, hidden_answer_state, answer_display_box]
189
+ )
190
+
191
+ show_answer_btn.click(
192
+ fn=reveal_answer,
193
+ inputs=[hidden_answer_state],
194
+ outputs=[answer_display_box]
195
+ )
196
+
197
+ run_btn.click(
198
+ fn=run_comparison,
199
+ inputs=[domain_radio, question_box, model_1_dd, model_2_dd],
200
+ outputs=[output_1_box, output_2_box, answer_display_box]
201
+ )
202
+
203
+ if __name__ == "__main__":
204
+ if not config.MY_AUTH_TOKEN:
205
+ print("⚠️ WARNING: ARENA_AUTH_TOKEN is not set.")
206
+ demo.launch()
backend.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import config
3
+
4
+ class SmartRouter:
5
+ def __init__(self):
6
+ # Tracks what is currently loaded in the backend (Best Guess)
7
+ self.lane_state = {
8
+ "primary": None, # URL: ...-generate-primary.modal.run
9
+ "secondary": None # URL: ...-generate-secondary.modal.run
10
+ }
11
+
12
+ def get_routing_plan(self, model_left_id, model_right_id):
13
+ """
14
+ Decides which model goes to which lane to minimize cold starts.
15
+ Returns: (lane_for_left_model, lane_for_right_model)
16
+ """
17
+ primary_model = self.lane_state["primary"]
18
+ secondary_model = self.lane_state["secondary"]
19
+
20
+ # Score: 0 = Cache Hit (Good), 1 = Cache Miss (Bad)
21
+
22
+ # Option A: Straight (Left -> Primary, Right -> Secondary)
23
+ cost_straight = (0 if primary_model == model_left_id else 1) + \
24
+ (0 if secondary_model == model_right_id else 1)
25
+
26
+ # Option B: Swapped (Left -> Secondary, Right -> Primary)
27
+ cost_swapped = (0 if secondary_model == model_left_id else 1) + \
28
+ (0 if primary_model == model_right_id else 1)
29
+
30
+ if cost_swapped < cost_straight:
31
+ print(f"🔀 Smart Router: Swapping lanes to optimize cache!")
32
+ # Update state for next time
33
+ self.lane_state["secondary"] = model_left_id
34
+ self.lane_state["primary"] = model_right_id
35
+ return "secondary", "primary"
36
+ else:
37
+ print(f"⬇️ Smart Router: keeping straight lanes.")
38
+ # Update state for next time
39
+ self.lane_state["primary"] = model_left_id
40
+ self.lane_state["secondary"] = model_right_id
41
+ return "primary", "secondary"
42
+
43
+ # Create a global instance
44
+ router = SmartRouter()
45
+
46
+ # --- STEP 3: REWRITE call_modal_api FOR STREAMING ---
47
+ def call_modal_api(model_repo_id, prompt, lane):
48
+ """
49
+ Calls the Modal API on a specific lane and yields tokens as they arrive.
50
+ This is now a GENERATOR.
51
+ """
52
+ if not model_repo_id:
53
+ yield "Please select a model from the dropdown."
54
+ return # Stop the generator
55
+
56
+ if not config.MY_AUTH_TOKEN:
57
+ yield "Error: `ARENA_AUTH_TOKEN` is not set on the Gradio server."
58
+ return
59
+
60
+ # Construct the URL based on the lane
61
+ if lane == "primary":
62
+ endpoint = f"{config.MODAL_BASE_URL}-generate-primary.modal.run"
63
+ else:
64
+ endpoint = f"{config.MODAL_BASE_URL}-generate-secondary.modal.run"
65
+
66
+ print(f"🚀 Streaming from {model_repo_id} on [{lane.upper()}]...")
67
+
68
+ headers = {
69
+ "Content-Type": "application/json",
70
+ "Authorization": f"Bearer {config.MY_AUTH_TOKEN}"
71
+ }
72
+ payload = {"model_id": model_repo_id, "prompt": prompt}
73
+
74
+ try:
75
+ # stream=True is the magic.
76
+ response = requests.post(
77
+ endpoint,
78
+ json=payload,
79
+ timeout=300,
80
+ headers=headers,
81
+ stream=True
82
+ )
83
+ response.raise_for_status()
84
+
85
+ # Yield tokens as they arrive
86
+ for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
87
+ if chunk:
88
+ yield chunk
89
+
90
+ except requests.exceptions.RequestException as e:
91
+ if e.response and e.response.status_code == 401:
92
+ yield "Error: Authentication failed. The token is invalid."
93
+ elif e.response:
94
+ # Try to get error detail from the streaming API
95
+ try:
96
+ error_detail = e.response.json().get("detail", str(e))
97
+ yield f"API Error: {e.response.status_code} - {error_detail}"
98
+ except:
99
+ yield f"API Error: {e}"
100
+ else:
101
+ yield f"API Error: {e}"
102
+ except Exception as e:
103
+ yield f"An unexpected error occurred: {e}"
config.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from datasets import load_dataset
4
+
5
+ load_dotenv()
6
+
7
+ # --- CONFIGURATION ---
8
+ MODAL_BASE_URL = "https://mohdfanis--unsloth-model-arena-backend" # Base URL
9
+ MY_AUTH_TOKEN = os.environ.get("ARENA_AUTH_TOKEN")
10
+
11
+ # --- DATASETS ---
12
+ print("Loading Hugging Face datasets...")
13
+ try:
14
+ math_dataset = load_dataset("microsoft/orca-math-word-problems-200k", split="train")
15
+ bio_dataset = load_dataset("bio-nlp-umass/bioinstruct", split="train")
16
+ print("✅ Datasets loaded successfully.")
17
+ except Exception as e:
18
+ print(f"❌ Failed to load datasets: {e}")
19
+ math_dataset, bio_dataset = [], []
20
+
21
+ DATASET_CONFIG = {
22
+ "Math": {
23
+ "dataset": math_dataset,
24
+ "question_col": "question",
25
+ "answer_col": "answer"
26
+ },
27
+ "Bio": {
28
+ "dataset": bio_dataset,
29
+ "instruction_col": "instruction",
30
+ "input_col": "input",
31
+ "answer_col": "output"
32
+ }
33
+ }
34
+
35
+ # --- MODEL DEFINITIONS ---
36
+ BASE_MODELS = {
37
+ "Base Llama-3.1 8B Instruct": "unsloth/llama-3.1-8b-instruct-bnb-4bit",
38
+ "Base Llama-3 8B Instruct": "unsloth/llama-3-8b-instruct-bnb-4bit",
39
+ "Base Llama-2 7B Chat": "unsloth/llama-2-7b-chat-bnb-4bit",
40
+ "Base Mistral 7B Instruct": "unsloth/mistral-7b-v0.3-instruct-bnb-4bit",
41
+ "Base Qwen-2 7B Instruct": "unsloth/qwen2-7B-instruct-bnb-4bit",
42
+ "Base Gemma-2 9B Instruct": "unsloth/gemma-2-9b-it-bnb-4bit",
43
+ "Base Gemma 7B Instruct": "unsloth/gemma-7b-it-bnb-4bit",
44
+ }
45
+
46
+ FINETUNED_MATH = {
47
+ "Finetuned Llama-3.1 8B (e3) - MATH": "farhananis005/lora-llama-3.1-8b-Math-e3",
48
+ "Finetuned Llama-3.1 8B (e1) - MATH": "farhananis005/lora-llama-3.1-8b-Math-e1",
49
+ "Finetuned Llama-3 8B (e3) - MATH": "farhananis005/lora-llama-3-8b-Math-e3",
50
+ "Finetuned Llama-3 8B (e1) - MATH": "farhananis005/lora-llama-3-8b-Math-e1",
51
+ "Finetuned Llama-2 7B (e3) - MATH": "farhananis005/lora-llama-2-7b-Math-e3",
52
+ "Finetuned Llama-2 7B (e1) - MATH": "farhananis005/lora-llama-2-7b-Math-e1",
53
+ "Finetuned Mistral 7B (e3) - MATH": "farhananis005/lora-mistral-7b-v0.3-Math-e3",
54
+ "Finetuned Mistral 7B (e1) - MATH": "farhananis005/lora-mistral-7b-v0.3-Math-e1",
55
+ "Finetuned Qwen-2 7B (e3) - MATH": "farhananis005/lora-qwen-2-7b-Math-e3",
56
+ "Finetuned Qwen-2 7B (e1) - MATH": "farhananis005/lora-qwen-2-7b-Math-e1",
57
+ "Finetuned Gemma-2 9B (e3) - MATH": "farhananis005/lora-gemma-2-9b-Math-e3",
58
+ "Finetuned Gemma-2 9B (e1) - MATH": "farhananis005/lora-gemma-2-9b-Math-e1",
59
+ "Finetuned Gemma 7B (e3) - MATH": "farhananis005/lora-gemma-7b-Math-e3",
60
+ "Finetuned Gemma 7B (e1) - MATH": "farhananis005/lora-gemma-7b-Math-e1",
61
+ }
62
+
63
+ FINETUNED_BIO = {
64
+ "Finetuned Llama-3.1 8B (e3) - BIO": "farhananis005/lora-llama-3.1-8b-Bio-e3",
65
+ "Finetuned Llama-3.1 8B (e1) - BIO": "farhananis005/lora-llama-3.1-8b-Bio-e1",
66
+ "Finetuned Llama-3 8B (e3) - BIO": "farhananis005/lora-llama-3-8b-Bio-e3",
67
+ "Finetuned Llama-3 8B (e1) - BIO": "farhananis005/lora-llama-3-8b-Bio-e1",
68
+ "Finetuned Llama-2 7B (e3) - BIO": "farhananis005/lora-llama-2-7b-Bio-e3",
69
+ "Finetuned Llama-2 7B (e1) - BIO": "farhananis005/lora-llama-2-7b-Bio-e1",
70
+ "Finetuned Mistral 7B (e3) - BIO": "farhananis005/lora-mistral-7b-v0.3-Bio-e3",
71
+ "Finetuned Mistral 7B (e1) - BIO": "farhananis005/lora-mistral-7b-v0.3-Bio-e1",
72
+ "Finetuned Qwen-2 7B (e3) - BIO": "farhananis005/lora-qwen-2-7b-Bio-e3",
73
+ "Finetuned Qwen-2 7B (e1) - BIO": "farhananis005/lora-qwen-2-7b-Bio-e1",
74
+ "Finetuned Gemma-2 9B (e3) - BIO": "farhananis005/lora-gemma-2-9b-Bio-e3",
75
+ "Finetuned Gemma-2 9B (e1) - BIO": "farhananis005/lora-gemma-2-9b-Bio-e1",
76
+ "Finetuned Gemma 7B (e3) - BIO": "farhananis005/lora-gemma-7b-Bio-e3",
77
+ "Finetuned Gemma 7B (e1) - BIO": "farhananis005/lora-gemma-7b-Bio-e1",
78
+ }
79
+
80
+ ALL_MODELS = {
81
+ "Math": {"-- Select Math Model --": None, **BASE_MODELS, **FINETUNED_MATH},
82
+ "Bio": {"-- Select Bio Model --": None, **BASE_MODELS, **FINETUNED_BIO}
83
+ }
requirements.txt ADDED
Binary file (3.07 kB). View file