Ngadou commited on
Commit
b0090a7
1 Parent(s): 17c3241

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -3,20 +3,32 @@ import time
3
  import openai
4
  import json
5
  import os
6
- from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
 
7
 
8
- asr_pipeline = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-large-robust-ft-libri-960h")
 
 
 
9
 
10
  openai.api_key = os.environ.get('OPENAI_KEY')
11
 
12
  def classify_audio(audio):
13
  # Transcribe the audio to text
14
- audio_transcript = asr_pipeline(audio)["text"]
15
- audio_transcript = audio_transcript.lower()
 
 
 
 
 
 
 
 
16
 
17
  messages = [
18
  {"role": "system", "content": "Is this chat a scam, spam or is safe? Only answer in JSON format with 'classification': '' as string and 'reasons': '' as the most plausible reasons why. The reason should be explaning to the potential victim why the conversation is probably a scam"},
19
- {"role": "user", "content": audio_transcript},
20
  ]
21
 
22
  # Call the OpenAI API to generate a response
 
3
  import openai
4
  import json
5
  import os
6
+ from transformers import pipeline
7
+ from transformers import AutoProcessor, AutoModelForCTC
8
 
9
+ processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-robust-ft-libri-960h")
10
+ model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-large-robust-ft-libri-960h")
11
+
12
+ # asr_pipeline = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-large-robust-ft-libri-960h")
13
 
14
  openai.api_key = os.environ.get('OPENAI_KEY')
15
 
16
  def classify_audio(audio):
17
  # Transcribe the audio to text
18
+ # audio_transcript = asr_pipeline(audio)["text"]
19
+ # audio_transcript = audio_transcript.lower()
20
+
21
+ input_values = processor(audio, return_tensors="pt", padding="longest").input_values
22
+ # retrieve logits
23
+ logits = model(input_values).logits
24
+
25
+ # take argmax and decode
26
+ predicted_ids = torch.argmax(logits, dim=-1)
27
+ transcription = processor.batch_decode(predicted_ids)
28
 
29
  messages = [
30
  {"role": "system", "content": "Is this chat a scam, spam or is safe? Only answer in JSON format with 'classification': '' as string and 'reasons': '' as the most plausible reasons why. The reason should be explaning to the potential victim why the conversation is probably a scam"},
31
+ {"role": "user", "content": transcription},
32
  ]
33
 
34
  # Call the OpenAI API to generate a response