vtiw commited on
Commit
6d1e318
1 Parent(s): 509ee5f

split text to batches

Browse files
Files changed (1) hide show
  1. app.py +24 -3
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import gradio as gr
 
 
2
  from lang_list import (
3
  LANGUAGE_NAME_TO_CODE,
4
  T2TT_TARGET_LANGUAGE_NAMES,
@@ -15,6 +17,19 @@ processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium")
15
  # translated_text_from_text = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
16
  # print(translated_text_from_text)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def run_t2tt(file_uploader , input_text: str, source_language: str, target_language: str) -> (str, bytes):
20
  if file_uploader is not None:
@@ -22,9 +37,15 @@ def run_t2tt(file_uploader , input_text: str, source_language: str, target_langu
22
  input_text=file.read()
23
  source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
24
  target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
25
- text_inputs = processor(text = input_text, src_lang=source_language_code , return_tensors="pt")
26
- output_tokens = model.generate(**text_inputs, tgt_lang=target_language_code)
27
- output = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
 
 
 
 
 
 
28
  _output_name = "result.txt"
29
  open(_output_name, 'w').write(output)
30
  return str(output), _output_name
 
1
  import gradio as gr
2
+ import nltk
3
+ nltk.download('punkt')
4
  from lang_list import (
5
  LANGUAGE_NAME_TO_CODE,
6
  T2TT_TARGET_LANGUAGE_NAMES,
 
17
  # translated_text_from_text = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
18
  # print(translated_text_from_text)
19
 
20
+ def split_text_into_batches(text, max_tokens_per_batch):
21
+ sentences = nltk.sent_tokenize(text) # Tokenize text into sentences
22
+ batches = []
23
+ current_batch = ""
24
+ for sentence in sentences:
25
+ if len(current_batch) + len(sentence) + 1 <= max_tokens_per_batch: # Add 1 for space
26
+ current_batch += sentence + " " # Add sentence to current batch
27
+ else:
28
+ batches.append(current_batch.strip()) # Add current batch to batches list
29
+ current_batch = sentence + " " # Start a new batch with the current sentence
30
+ if current_batch:
31
+ batches.append(current_batch.strip()) # Add the last batch
32
+ return batches
33
 
34
  def run_t2tt(file_uploader , input_text: str, source_language: str, target_language: str) -> (str, bytes):
35
  if file_uploader is not None:
 
37
  input_text=file.read()
38
  source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
39
  target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
40
+ max_tokens_per_batch= 256
41
+ batches = split_text_into_batches(input_text, max_tokens_per_batch)
42
+ translated_text = ""
43
+ for batch in batches:
44
+ text_inputs = processor(text=batch, src_lang=source_language_code, return_tensors="pt")
45
+ output_tokens = model.generate(**text_inputs, tgt_lang=target_language_code)
46
+ translated_batch = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
47
+ translated_text += translated_batch + " "
48
+ output=translated_text.strip()
49
  _output_name = "result.txt"
50
  open(_output_name, 'w').write(output)
51
  return str(output), _output_name