davanstrien HF staff commited on
Commit
1b8ac52
1 Parent(s): 5a41550

Refactor code and add type annotations

Browse files
Files changed (1) hide show
  1. main.py +12 -6
main.py CHANGED
@@ -1,20 +1,20 @@
1
  import os
2
  import random
3
  from datetime import timedelta
4
- from pathlib import Path
5
  from statistics import mean
6
  from typing import Any, Iterator, Union
7
-
8
  import fasttext
9
  from cashews import cache
10
  from dotenv import load_dotenv
11
- from fastapi import FastAPI
12
  from httpx import AsyncClient, Client, Timeout
13
  from huggingface_hub import hf_hub_download
14
- from huggingface_hub.utils import logging
15
  from iso639 import Lang
16
  from starlette.responses import RedirectResponse
17
  from toolz import concat, groupby, valmap
 
18
 
19
  cache.setup("mem://")
20
 
@@ -130,6 +130,8 @@ async def get_random_rows(
130
 
131
 
132
  def load_model(repo_id: str) -> fasttext.FastText._FastText:
 
 
133
  Path("code/models").mkdir(parents=True, exist_ok=True)
134
  model_path = hf_hub_download(
135
  repo_id,
@@ -237,14 +239,18 @@ def predict_rows(
237
  def root():
238
  return RedirectResponse(url="/docs")
239
 
 
 
240
 
241
  @app.get("/predict_dataset_language/{hub_id:path}")
242
  @cache(ttl=timedelta(minutes=10))
243
  async def predict_language(
244
- hub_id: str,
245
  config: str | None = None,
246
  split: str | None = None,
247
- max_request_calls: int = 10,
 
 
248
  number_of_rows: int = 1000,
249
  ) -> dict[Any, Any] | None:
250
  is_valid = datasets_server_valid_rows(hub_id)
 
1
  import os
2
  import random
3
  from datetime import timedelta
4
+
5
  from statistics import mean
6
  from typing import Any, Iterator, Union
7
+ from typing import Annotated
8
  import fasttext
9
  from cashews import cache
10
  from dotenv import load_dotenv
11
+ from fastapi import FastAPI, Path
12
  from httpx import AsyncClient, Client, Timeout
13
  from huggingface_hub import hf_hub_download
 
14
  from iso639 import Lang
15
  from starlette.responses import RedirectResponse
16
  from toolz import concat, groupby, valmap
17
+ import logging
18
 
19
  cache.setup("mem://")
20
 
 
130
 
131
 
132
  def load_model(repo_id: str) -> fasttext.FastText._FastText:
133
+ from pathlib import Path
134
+
135
  Path("code/models").mkdir(parents=True, exist_ok=True)
136
  model_path = hf_hub_download(
137
  repo_id,
 
239
  def root():
240
  return RedirectResponse(url="/docs")
241
 
242
+ # item_id: Annotated[int, Path(title="The ID of the item to get", ge=1)], q: str
243
+
244
 
245
  @app.get("/predict_dataset_language/{hub_id:path}")
246
  @cache(ttl=timedelta(minutes=10))
247
  async def predict_language(
248
+ hub_id: Annotated[str, Path(title="The hub id of the dataset to predict")],
249
  config: str | None = None,
250
  split: str | None = None,
251
+ max_request_calls: Annotated[
252
+ int, Path(title="Max number of requests to datasets server", gt=0, le=20)
253
+ ] = 10,
254
  number_of_rows: int = 1000,
255
  ) -> dict[Any, Any] | None:
256
  is_valid = datasets_server_valid_rows(hub_id)