davanstrien HF staff commited on
Commit
41869c7
1 Parent(s): f1bc1ad

Add fastapi.responses and starlette.responses imports

Browse files
Files changed (1) hide show
  1. main.py +28 -20
main.py CHANGED
@@ -3,7 +3,7 @@ import random
3
  from pathlib import Path
4
  from statistics import mean
5
  from typing import Any, Iterator, Union
6
-
7
  import fasttext
8
  from dotenv import load_dotenv
9
  from fastapi import FastAPI
@@ -11,6 +11,7 @@ from httpx import AsyncClient, Client, Timeout
11
  from huggingface_hub import hf_hub_download
12
  from huggingface_hub.utils import logging
13
  from toolz import concat, groupby, valmap
 
14
 
15
  app = FastAPI()
16
  logger = logging.get_logger(__name__)
@@ -19,16 +20,17 @@ HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
21
 
 
 
22
  BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
23
- DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID"
24
  headers = {
25
  "authorization": f"Bearer ${HF_TOKEN}",
26
  }
27
  timeout = Timeout(60, read=120)
28
  client = Client(headers=headers, timeout=timeout)
29
  async_client = AsyncClient(headers=headers, timeout=timeout)
30
- # non exhaustive list of columns that might contain text which can be used for language detection
31
- # we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
32
  TARGET_COLUMN_NAMES = {
33
  "text",
34
  "input",
@@ -116,10 +118,20 @@ async def get_random_rows(
116
 
117
 
118
  def load_model(repo_id: str) -> fasttext.FastText._FastText:
119
- model_path = hf_hub_download(repo_id, filename="model.bin")
 
 
 
 
 
 
 
120
  return fasttext.load_model(model_path)
121
 
122
 
 
 
 
123
  def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
124
  for row in rows:
125
  if isinstance(row, str):
@@ -139,21 +151,6 @@ def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterat
139
  continue
140
 
141
 
142
- FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
143
-
144
- # model = load_model(DEFAULT_FAST_TEXT_MODEL)
145
- Path("code/models").mkdir(parents=True, exist_ok=True)
146
- model = fasttext.load_model(
147
- hf_hub_download(
148
- "facebook/fasttext-language-identification",
149
- "model.bin",
150
- cache_dir="code/models",
151
- local_dir="code/models",
152
- local_dir_use_symlinks=False,
153
- )
154
- )
155
-
156
-
157
  def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
158
  predictions = model.predict(inputs, k=k)
159
  return [
@@ -196,6 +193,17 @@ def predict_rows(rows, target_column, language_threshold_percent=0.2):
196
  }
197
 
198
 
 
 
 
 
 
 
 
 
 
 
 
199
  @app.get("/predict_dataset_language/{hub_id}")
200
  async def predict_language(
201
  hub_id: str,
 
3
  from pathlib import Path
4
  from statistics import mean
5
  from typing import Any, Iterator, Union
6
+ from fastapi.responses import HTMLResponse
7
  import fasttext
8
  from dotenv import load_dotenv
9
  from fastapi import FastAPI
 
11
  from huggingface_hub import hf_hub_download
12
  from huggingface_hub.utils import logging
13
  from toolz import concat, groupby, valmap
14
+ from starlette.responses import RedirectResponse
15
 
16
  app = FastAPI()
17
  logger = logging.get_logger(__name__)
 
20
 
21
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
22
 
23
+ FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
24
+
25
  BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
26
+ DEFAULT_FAST_TEXT_MODEL = "facebook/fasttext-language-identification"
27
  headers = {
28
  "authorization": f"Bearer ${HF_TOKEN}",
29
  }
30
  timeout = Timeout(60, read=120)
31
  client = Client(headers=headers, timeout=timeout)
32
  async_client = AsyncClient(headers=headers, timeout=timeout)
33
+
 
34
  TARGET_COLUMN_NAMES = {
35
  "text",
36
  "input",
 
118
 
119
 
120
  def load_model(repo_id: str) -> fasttext.FastText._FastText:
121
+ Path("code/models").mkdir(parents=True, exist_ok=True)
122
+ model_path = hf_hub_download(
123
+ repo_id,
124
+ "model.bin",
125
+ cache_dir="code/models",
126
+ local_dir="code/models",
127
+ local_dir_use_symlinks=False,
128
+ )
129
  return fasttext.load_model(model_path)
130
 
131
 
132
+ model = load_model(DEFAULT_FAST_TEXT_MODEL)
133
+
134
+
135
  def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
136
  for row in rows:
137
  if isinstance(row, str):
 
151
  continue
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
155
  predictions = model.predict(inputs, k=k)
156
  return [
 
193
  }
194
 
195
 
196
+ # @app.get("/", response_class=HTMLResponse)
197
+ # async def read_index():
198
+ # html_content = Path("index.html").read_text()
199
+ # return HTMLResponse(content=html_content)
200
+
201
+
202
+ @app.get("/", include_in_schema=False)
203
+ def root():
204
+ return RedirectResponse(url="/docs")
205
+
206
+
207
  @app.get("/predict_dataset_language/{hub_id}")
208
  async def predict_language(
209
  hub_id: str,