cointegrated commited on
Commit
9c3fd77
·
1 Parent(s): f0ff84c

fix HuggingFaceDatasetSaver

Browse files
Files changed (1) hide show
  1. app.py +40 -3
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import fasttext
5
  import os
6
  import urllib
 
7
  from transformers import MBartForConditionalGeneration, MBart50Tokenizer
8
 
9
 
@@ -12,10 +13,46 @@ MODEL_URL_MUL_MYV = 'slone/mbart-large-51-mul-myv-v1'
12
  MODEL_URL_LANGID = 'https://huggingface.co/slone/fastText-LID-323/resolve/main/lid.323.ftz'
13
  MODEL_PATH_LANGID = 'lid.323.ftz'
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  HF_TOKEN = os.getenv('HF_TOKEN')
16
- hf_writer = gr.HuggingFaceDatasetSaver(
17
- HF_TOKEN,
18
- dataset_name="slone/myv-translation-2022-demo-flags",
 
19
  private=True,
20
  )
21
 
 
4
  import fasttext
5
  import os
6
  import urllib
7
+ import huggingface_hub
8
  from transformers import MBartForConditionalGeneration, MBart50Tokenizer
9
 
10
 
 
13
  MODEL_URL_LANGID = 'https://huggingface.co/slone/fastText-LID-323/resolve/main/lid.323.ftz'
14
  MODEL_PATH_LANGID = 'lid.323.ftz'
15
 
16
+
17
+ class DSaverFixed(gr.HuggingFaceDatasetSaver):
18
+ def __init__(
19
+ self,
20
+ hf_token: str,
21
+ dataset_name: str,
22
+ organization=None,
23
+ private=False,
24
+ ):
25
+ self.hf_token = hf_token
26
+ self.dataset_name = dataset_name
27
+ self.organization_name = organization
28
+ self.dataset_private = private
29
+
30
+ def setup(self, components, flagging_dir):
31
+ path_to_dataset_repo = huggingface_hub.create_repo(
32
+ repo_id=self.dataset_name,
33
+ token=self.hf_token,
34
+ private=self.dataset_private,
35
+ repo_type="dataset",
36
+ exist_ok=True,
37
+ )
38
+ self.path_to_dataset_repo = path_to_dataset_repo
39
+ self.components = components
40
+ self.flagging_dir = flagging_dir
41
+ self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
42
+ self.repo = huggingface_hub.Repository(
43
+ local_dir=self.dataset_dir,
44
+ clone_from=path_to_dataset_repo,
45
+ use_auth_token=self.hf_token,
46
+ )
47
+ self.repo.git_pull(lfs=True)
48
+ self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
49
+
50
+
51
  HF_TOKEN = os.getenv('HF_TOKEN')
52
+ hf_writer = DSaverFixed(
53
+ hf_token=HF_TOKEN,
54
+ dataset_name="myv-translation-2022-demo-flags",
55
+ organization="slone",
56
  private=True,
57
  )
58