Commit
·
42c8298
1
Parent(s):
d655f51
add fix and push ft
Browse files
app.py
CHANGED
@@ -47,7 +47,7 @@ def get_ir_evaluator(eval_ds):
|
|
47 |
|
48 |
|
49 |
@spaces.GPU(duration=3600)
|
50 |
-
def train(hf_token, dataset_id, model_id, num_epochs, dev):
|
51 |
|
52 |
ds = load_dataset(dataset_id, split="train", token=hf_token)
|
53 |
ds = ds.shuffle(seed=42)
|
@@ -110,6 +110,8 @@ def train(hf_token, dataset_id, model_id, num_epochs, dev):
|
|
110 |
print(ir_evaluator.primary_metric)
|
111 |
print(ft_metrics[ir_evaluator.primary_metric])
|
112 |
|
|
|
|
|
113 |
|
114 |
metrics = pd.DataFrame([base_metrics, ft_metrics]).T
|
115 |
print(metrics)
|
@@ -119,5 +121,5 @@ def train(hf_token, dataset_id, model_id, num_epochs, dev):
|
|
119 |
## logs to UI
|
120 |
# https://github.com/gradio-app/gradio/issues/2362#issuecomment-1424446778
|
121 |
|
122 |
-
demo = gr.Interface(fn=
|
123 |
demo.launch()
|
|
|
47 |
|
48 |
|
49 |
@spaces.GPU(duration=3600)
|
50 |
+
def train(hf_token, dataset_id, model_id, num_epochs, dev=True):
|
51 |
|
52 |
ds = load_dataset(dataset_id, split="train", token=hf_token)
|
53 |
ds = ds.shuffle(seed=42)
|
|
|
110 |
print(ir_evaluator.primary_metric)
|
111 |
print(ft_metrics[ir_evaluator.primary_metric])
|
112 |
|
113 |
+
if not dev: model.push_to_hub("fine-tuned-sentence-transformer", private=True, token=hf_token)
|
114 |
+
|
115 |
|
116 |
metrics = pd.DataFrame([base_metrics, ft_metrics]).T
|
117 |
print(metrics)
|
|
|
121 |
## logs to UI
|
122 |
# https://github.com/gradio-app/gradio/issues/2362#issuecomment-1424446778
|
123 |
|
124 |
+
demo = gr.Interface(fn=train, inputs=["text", "text", "text", "number", "bool"], outputs=["text"]) # "dataframe"
|
125 |
demo.launch()
|