davanstrien HF staff commited on
Commit
253bbca
1 Parent(s): 2c17b5e

language detection app draft

Browse files
Files changed (1) hide show
  1. app.py +219 -0
app.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from httpx import Client
3
+ import random
4
+ import os
5
+ import fasttext
6
+ from huggingface_hub import hf_hub_download
7
+ from typing import Union
8
+ from typing import Iterator
9
+ from dotenv import load_dotenv
10
+ from toolz import groupby, valmap, concat
11
+ from statistics import mean
12
+ from httpx import Timeout
13
+ from huggingface_hub.utils import logging
14
+
15
+ logger = logging.get_logger(__name__)
16
+ load_dotenv()
17
+ HF_TOKEN = os.getenv("HF_TOKEN")
18
+
19
+
20
+ BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
21
+ DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID"
22
+ headers = {
23
+ "authorization": f"Bearer ${HF_TOKEN}",
24
+ }
25
+ timeout = Timeout(60, read=120)
26
+ client = Client(headers=headers, timeout=timeout)
27
+ # non exhaustive list of columns that might contain text which can be used for language detection
28
+ # we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
29
+ TARGET_COLUMN_NAMES = {
30
+ "text",
31
+ "input",
32
+ "tokens",
33
+ "prompt",
34
+ "instruction",
35
+ "sentence_1",
36
+ "question",
37
+ "sentence2",
38
+ "answer",
39
+ "sentence",
40
+ "response",
41
+ "context",
42
+ "query",
43
+ }
44
+
45
+
46
+ def datasets_server_valid_rows(hub_id: str):
47
+ resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}")
48
+ resp.raise_for_status()
49
+ return resp.json()["viewer"]
50
+
51
+
52
+ def get_first_config_and_split_name(hub_id: str):
53
+ resp = client.get(f"https://datasets-server.huggingface.co/splits?dataset={hub_id}")
54
+ resp.raise_for_status()
55
+ data = resp.json()
56
+ return data["splits"][0]["config"], data["splits"][0]["split"]
57
+
58
+
59
+ def get_dataset_info(hub_id: str, config: str | None = None):
60
+ if config is None:
61
+ config = get_first_config_and_split_name(hub_id)
62
+ if config is None:
63
+ return None
64
+ else:
65
+ config = config[0]
66
+ resp = client.get(
67
+ f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
68
+ )
69
+ resp.raise_for_status()
70
+ return resp.json()
71
+
72
+
73
+ def get_random_rows(
74
+ hub_id,
75
+ total_length,
76
+ number_of_rows,
77
+ max_request_calls,
78
+ config="default",
79
+ split="train",
80
+ ):
81
+ rows = []
82
+ rows_per_call = min(
83
+ number_of_rows // max_request_calls, total_length // max_request_calls
84
+ )
85
+ rows_per_call = min(rows_per_call, 100) # Ensure rows_per_call is not more than 100
86
+ for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
87
+ offset = random.randint(0, total_length - rows_per_call)
88
+ url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
89
+ response = client.get(url)
90
+
91
+ if response.status_code == 200:
92
+ data = response.json()
93
+ batch_rows = data.get("rows")
94
+ rows.extend(batch_rows)
95
+ else:
96
+ print(f"Failed to fetch data: {response.status_code}")
97
+ print(url)
98
+ if len(rows) >= number_of_rows:
99
+ break
100
+ return [row.get("row") for row in rows]
101
+
102
+
103
+ def load_model(repo_id: str) -> fasttext.FastText._FastText:
104
+ model_path = hf_hub_download(repo_id, filename="model.bin")
105
+ return fasttext.load_model(model_path)
106
+
107
+
108
+ # def predict_language_for_rows(rows: list[dict], target_column_names: list[str] | str):
109
+ # pass
110
+
111
+
112
+ def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
113
+ for row in rows:
114
+ if isinstance(row, str):
115
+ # split on lines and remove empty lines
116
+ line = row.split("\n")
117
+ for line in line:
118
+ if line:
119
+ yield line
120
+ elif isinstance(row, list):
121
+ try:
122
+ line = " ".join(row)
123
+ if len(line) < min_length:
124
+ continue
125
+ else:
126
+ yield line
127
+ except TypeError:
128
+ continue
129
+
130
+
131
+ FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
132
+
133
+ # model = load_model(DEFAULT_FAST_TEXT_MODEL)
134
+
135
+ model = fasttext.load_model(
136
+ hf_hub_download("facebook/fasttext-language-identification", "model.bin")
137
+ )
138
+
139
+
140
+ def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
141
+ predictions = model.predict(inputs, k=k)
142
+ return [
143
+ {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob}
144
+ for label, prob in zip(predictions[0], predictions[1])
145
+ ]
146
+
147
+
148
+ def get_label(x):
149
+ return x.get("label")
150
+
151
+
152
+ def get_mean_score(preds):
153
+ return mean([pred.get("score") for pred in preds])
154
+
155
+
156
+ def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
157
+ """Filter a dict to include items whose value is above `threshold_percent`"""
158
+ total = sum(counts_dict.values())
159
+ threshold = total * threshold_percent
160
+ return {k for k, v in counts_dict.items() if v >= threshold}
161
+
162
+
163
+ def predict_rows(rows, target_column, language_threshold_percent=0.2):
164
+ rows = (row.get(target_column) for row in rows)
165
+ rows = (row for row in rows if row is not None)
166
+ rows = list(yield_clean_rows(rows))
167
+ predictions = [model_predict(row) for row in rows]
168
+ predictions = [pred for pred in predictions if pred is not None]
169
+ predictions = list(concat(predictions))
170
+ predictions_by_lang = groupby(get_label, predictions)
171
+ langues_counts = valmap(len, predictions_by_lang)
172
+ keys_to_keep = filter_by_frequency(
173
+ langues_counts, threshold_percent=language_threshold_percent
174
+ )
175
+ filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
176
+ return {
177
+ "predictions": dict(valmap(get_mean_score, filtered_dict)),
178
+ "pred": predictions,
179
+ }
180
+
181
+
182
+ def predict_language(
183
+ hub_id: str,
184
+ config: str | None = None,
185
+ split: str | None = None,
186
+ max_request_calls: int = 10,
187
+ ):
188
+ is_valid = datasets_server_valid_rows(hub_id)
189
+ if not is_valid:
190
+ gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
191
+ if not config:
192
+ config, split = get_first_config_and_split_name(hub_id)
193
+ info = get_dataset_info(hub_id, config)
194
+ if info is None:
195
+ gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
196
+ if dataset_info := info.get("dataset_info"):
197
+ total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples")
198
+ features = dataset_info.get("features")
199
+ column_names = set(features.keys())
200
+ logger.info(f"Column names: {column_names}")
201
+ if not set(column_names).intersection(TARGET_COLUMN_NAMES):
202
+ raise gr.Error(
203
+ f"Dataset {hub_id} does not contain any of the target columns {TARGET_COLUMN_NAMES}"
204
+ )
205
+ for column in TARGET_COLUMN_NAMES:
206
+ if column in column_names:
207
+ target_column = column
208
+ logger.info(f"Using column {target_column} for language detection")
209
+ break
210
+ random_rows = get_random_rows(
211
+ hub_id, total_rows_for_split, 1000, max_request_calls, config, split
212
+ )
213
+ logger.info(f"Predicting language for {len(random_rows)} rows")
214
+ return predict_rows(random_rows, target_column)
215
+
216
+
217
+ interface = gr.Interface(predict_language, inputs="text", outputs="json")
218
+ interface.queue()
219
+ interface.launch()