Omartificial-Intelligence-Space commited on
Commit
891a967
·
verified ·
1 Parent(s): 515e29d

update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -18,33 +18,30 @@ def evaluate_model(model_id, num_questions):
18
  model = SentenceTransformer(model_id, device=device)
19
  matryoshka_dimensions = [768, 512, 256, 128, 64]
20
 
21
- # Prepare datasets
22
  datasets_info = [
23
  {
24
  "name": "Financial",
25
  "dataset_id": "Omartificial-Intelligence-Space/Arabic-finanical-rag-embedding-dataset",
26
- "split": "train",
27
- "size": 7000,
28
  "columns": ("question", "context"),
29
  "sample_size": num_questions
30
  },
31
  {
32
  "name": "MLQA",
33
  "dataset_id": "google/xtreme",
34
- "subset": "MLQA.ar.ar",
35
  "split": "validation",
36
- "size": 500,
37
  "columns": ("question", "context"),
38
  "sample_size": num_questions
39
  },
40
  {
41
  "name": "ARCD",
42
  "dataset_id": "hsseinmz/arcd",
43
- "split": "train",
44
- "size": None,
45
  "columns": ("question", "context"),
46
  "sample_size": num_questions,
47
- "last_rows": True # Take the last n rows
48
  }
49
  ]
50
 
@@ -58,13 +55,13 @@ def evaluate_model(model_id, num_questions):
58
  else:
59
  dataset = load_dataset(dataset_info["dataset_id"], split=dataset_info["split"])
60
 
61
- # Take the last n rows if specified
62
  if dataset_info.get("last_rows"):
63
- dataset = dataset.select(range(len(dataset) - dataset_info["sample_size"], len(dataset)))
64
  else:
65
- dataset = dataset.select(range(min(dataset_info["sample_size"], len(dataset))))
66
 
67
- # Rename columns
68
  dataset = dataset.rename_column(dataset_info["columns"][0], "anchor")
69
  dataset = dataset.rename_column(dataset_info["columns"][1], "positive")
70
 
 
18
  model = SentenceTransformer(model_id, device=device)
19
  matryoshka_dimensions = [768, 512, 256, 128, 64]
20
 
21
+ # Prepare datasets (only load the necessary split and limit to num_questions)
22
  datasets_info = [
23
  {
24
  "name": "Financial",
25
  "dataset_id": "Omartificial-Intelligence-Space/Arabic-finanical-rag-embedding-dataset",
26
+ "split": "train", # Only train split
 
27
  "columns": ("question", "context"),
28
  "sample_size": num_questions
29
  },
30
  {
31
  "name": "MLQA",
32
  "dataset_id": "google/xtreme",
33
+ "subset": "MLQA.ar.ar", # Validation split only
34
  "split": "validation",
 
35
  "columns": ("question", "context"),
36
  "sample_size": num_questions
37
  },
38
  {
39
  "name": "ARCD",
40
  "dataset_id": "hsseinmz/arcd",
41
+ "split": "train", # Only train split
 
42
  "columns": ("question", "context"),
43
  "sample_size": num_questions,
44
+ "last_rows": True # Take the last num_questions rows
45
  }
46
  ]
47
 
 
55
  else:
56
  dataset = load_dataset(dataset_info["dataset_id"], split=dataset_info["split"])
57
 
58
+ # Limit the number of samples to num_questions (500 max)
59
  if dataset_info.get("last_rows"):
60
+ dataset = dataset.select(range(len(dataset) - dataset_info["sample_size"], len(dataset))) # Take last n rows
61
  else:
62
+ dataset = dataset.select(range(min(dataset_info["sample_size"], len(dataset)))) # Take first n rows
63
 
64
+ # Rename columns to 'anchor' and 'positive'
65
  dataset = dataset.rename_column(dataset_info["columns"][0], "anchor")
66
  dataset = dataset.rename_column(dataset_info["columns"][1], "positive")
67