Jesse Karmani commited on
Commit
2661eab
·
1 Parent(s): a0c5750

Test some default inputs and outputs

Browse files
Files changed (1) hide show
  1. app.py +6 -29
app.py CHANGED
@@ -68,13 +68,9 @@ def get_dataloader(
68
 
69
 
70
  def main(
71
- sentences_path: Optional[str],
72
- sentences_dir: Optional[str],
73
- files_extension: str,
74
- output_path: str,
75
  source_lang: Optional[str],
76
  target_lang: Optional[str],
77
- starting_batch_size: int,
78
  model_name: str = "facebook/m2m100_1.2B",
79
  lora_weights_name_or_path: str = None,
80
  force_auto_device_map: bool = False,
@@ -93,6 +89,8 @@ def main(
93
  trust_remote_code: bool = False,
94
  ):
95
  accelerator = Accelerator()
 
 
96
 
97
  if force_auto_device_map and starting_batch_size >= 64:
98
  print(
@@ -102,16 +100,6 @@ def main(
102
  f"inference. You should consider using a smaller batch size, i.e '--starting_batch_size 8'"
103
  )
104
 
105
- if sentences_path is None and sentences_dir is None:
106
- raise ValueError(
107
- "You must specify either --sentences_path or --sentences_dir. Use --help for more details."
108
- )
109
-
110
- if sentences_path is not None and sentences_dir is not None:
111
- raise ValueError(
112
- "You must specify either --sentences_path or --sentences_dir, not both. Use --help for more details."
113
- )
114
-
115
  if precision is None:
116
  quantization = None
117
  dtype = None
@@ -346,20 +334,9 @@ def main(
346
  os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
347
  inference(sentences_path=sentences_path, output_path=output_path)
348
 
349
- if sentences_dir is not None:
350
- print(
351
- f"Translating all files in {sentences_dir}, with extension {files_extension}"
352
- )
353
- os.makedirs(os.path.abspath(output_path), exist_ok=True)
354
- for filename in glob.glob(
355
- os.path.join(
356
- sentences_dir, f"*.{files_extension}" if files_extension else "*"
357
- )
358
- ):
359
- output_filename = os.path.join(output_path, os.path.basename(filename))
360
- inference(sentences_path=filename, output_path=output_filename)
361
-
362
  print(f"Translation done.\n")
 
 
363
 
364
 
365
  # if __name__ == "__main__":
@@ -560,5 +537,5 @@ def main(
560
  # trust_remote_code=args.trust_remote_code,
561
  # )
562
 
563
- demo = gradio.Interface(fn=main, inputs="textbox", outputs="textbox")
564
  demo.launch(share=True)
 
68
 
69
 
70
  def main(
 
 
 
 
71
  source_lang: Optional[str],
72
  target_lang: Optional[str],
73
+ starting_batch_size: int = 8,
74
  model_name: str = "facebook/m2m100_1.2B",
75
  lora_weights_name_or_path: str = None,
76
  force_auto_device_map: bool = False,
 
89
  trust_remote_code: bool = False,
90
  ):
91
  accelerator = Accelerator()
92
+ sentences_path = "sample_text/en.txt"
93
+ output_path = "sample_text/en2es.translation.m2m100_12B.txt"
94
 
95
  if force_auto_device_map and starting_batch_size >= 64:
96
  print(
 
100
  f"inference. You should consider using a smaller batch size, i.e '--starting_batch_size 8'"
101
  )
102
 
 
 
 
 
 
 
 
 
 
 
103
  if precision is None:
104
  quantization = None
105
  dtype = None
 
334
  os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
335
  inference(sentences_path=sentences_path, output_path=output_path)
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  print(f"Translation done.\n")
338
+ with open(output_path, "r", encoding="utf-8") as f:
339
+ return f.read()
340
 
341
 
342
  # if __name__ == "__main__":
 
537
  # trust_remote_code=args.trust_remote_code,
538
  # )
539
 
540
+ demo = gradio.Interface(fn=main, inputs=["textbox", "textbox"], outputs="textbox")
541
  demo.launch(share=True)