benjolo commited on
Commit
fa14146
1 Parent(s): 974359f

Update backend/main.py

Browse files
Files changed (1) hide show
  1. backend/main.py +23 -9
backend/main.py CHANGED
@@ -133,12 +133,12 @@ static_files = {
133
  },
134
  }
135
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
136
- processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True)
137
  #cache_dir="/.cache"
138
 
139
  # PM - hardcoding temporarily as my GPU doesnt have enough vram
140
  # model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to("cpu")
141
- model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True).to(device)
142
 
143
 
144
  bytes_data = bytearray()
@@ -148,6 +148,18 @@ vocoder_name = "vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocode
148
  clients = {}
149
  rooms = {}
150
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  def get_collection_users():
153
  return app.database["user_records"]
@@ -297,16 +309,18 @@ async def incoming_audio(sid, data, call_id):
297
  tgt_sid = next(id for id in rooms[call_id] if id != sid)
298
  tgt_lang = clients[tgt_sid].target_language
299
  # following example from https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/README.md#transformers-usage
300
- output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt")
301
- model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0]
302
- asr_text = processor.decode(model_output, skip_special_tokens=True)
 
303
  print(f"ASR TEXT = {asr_text}")
304
  # ASR TEXT => ORIGINAL TEXT
305
 
306
- t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt")
307
- print(f"FIRST TYPE = {type(output_tokens)}, SECOND TYPE = {type(t2t_tokens)}")
308
- translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
309
- translated_text = processor.decode(translated_data, skip_special_tokens=True)
 
310
  print(f"TRANSLATED TEXT = {translated_text}")
311
 
312
  # BO -> send translated_text to mongodb as caption record update based on call_id
 
133
  },
134
  }
135
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
136
+ # processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True)
137
  #cache_dir="/.cache"
138
 
139
  # PM - hardcoding temporarily as my GPU doesnt have enough vram
140
  # model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to("cpu")
141
+ # model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True).to(device)
142
 
143
 
144
  bytes_data = bytearray()
 
148
  clients = {}
149
  rooms = {}
150
 
151
+ import torch
152
+ from transformers import pipeline
153
+ translator = pipeline("automatic-speech-recognition",
154
+ "facebook/seamless-m4t-v2-large",
155
+ torch_dtype=torch.float32,
156
+ device="cpu")
157
+
158
+ converter = pipeline("translation",
159
+ "facebook/seamless-m4t-v2-large",
160
+ torch_dtype=torch.float32,
161
+ device="cpu")
162
+
163
 
164
  def get_collection_users():
165
  return app.database["user_records"]
 
309
  tgt_sid = next(id for id in rooms[call_id] if id != sid)
310
  tgt_lang = clients[tgt_sid].target_language
311
  # following example from https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/README.md#transformers-usage
312
+ # output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt")
313
+ # model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0]
314
+ # asr_text = processor.decode(model_output, skip_special_tokens=True)
315
+ asr_text = translator(resampled_audio, generate_kwargs={"tgt_lang": src_lang})['text']
316
  print(f"ASR TEXT = {asr_text}")
317
  # ASR TEXT => ORIGINAL TEXT
318
 
319
+ # t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt")
320
+ # print(f"FIRST TYPE = {type(output_tokens)}, SECOND TYPE = {type(t2t_tokens)}")
321
+ # translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
322
+ # translated_text = processor.decode(translated_data, skip_special_tokens=True)
323
+ translated_text = converter(asr_text, src_lang=src_lang, tgt_lang=tgt_lang)
324
  print(f"TRANSLATED TEXT = {translated_text}")
325
 
326
  # BO -> send translated_text to mongodb as caption record update based on call_id