Jayem-11 commited on
Commit
07f39f7
1 Parent(s): 15c02f3
Files changed (2) hide show
  1. main.py +80 -4
  2. requirements.txt +7 -1
main.py CHANGED
@@ -1,7 +1,83 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
5
- @app.get("/hello")
6
- def hello():
7
- return {"Hello": "World"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile
2
+ from moviepy.editor import *
3
+ from transformers import AutoTokenizer , AutoModelForSeq2SeqLM , pipeline
4
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
5
+ from transformers import WhisperFeatureExtractor, WhisperTokenizer
6
+ import librosa
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
 
12
  app = FastAPI()
13
 
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ device
16
+
17
+ feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
18
+ tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Swahili", task="transcribe")
19
+ processor = WhisperProcessor.from_pretrained("Jayem-11/whisper-small-swahili-3")
20
+ asr_model = WhisperForConditionalGeneration.from_pretrained('Jayem-11/whisper-small-swahili-3')
21
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language="sw", task="transcribe")
22
+
23
+ t5_tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
24
+ summary_model = (AutoModelForSeq2SeqLM.from_pretrained("Jayem-11/mt5-summarize-sw"))
25
+
26
+
27
+ @app.get("/")
28
+ async def read_root():
29
+ return {"Successful"}
30
+
31
+
32
+
33
+ def extract_and_resample_audio(file):
34
+
35
+ with open('vid.mp4', 'wb') as f:
36
+
37
+ f.write(file)
38
+
39
+ video = VideoFileClip("vid.mp4")
40
+
41
+ # Extract audio from the video
42
+ audio = video.audio
43
+
44
+ # Save the audio to a temporary file
45
+ audio.write_audiofile("temp_audio.wav")
46
+
47
+ # Load the temporary audio file
48
+ audio_data, sr = librosa.load("temp_audio.wav")
49
+
50
+ # Resample the audio to 16000Hz
51
+ audio_resampled = librosa.resample(audio_data, orig_sr = sr, target_sr=16000)
52
+ print("Done resampling")
53
+
54
+ return audio_resampled
55
+
56
+ @app.post("/predict")
57
+ async def predict(file: UploadFile):
58
+ audio_resampled = extract_and_resample_audio(await file.read())
59
+
60
+
61
+ input_feats = feature_extractor(audio_resampled, sampling_rate = 16000).input_features[0]
62
+
63
+
64
+ input_feats = np.expand_dims(input_feats, axis=0)
65
+
66
+
67
+ input_feats = torch.from_numpy(input_feats)
68
+
69
+
70
+ output = asr_model.generate(input_features=input_feats.to(device),max_new_tokens=255,).cpu().numpy()
71
+
72
+
73
+ sample_text = tokenizer.batch_decode(output, skip_special_tokens=True)
74
+
75
+
76
+ summarizer = pipeline("summarization", model=summary_model, tokenizer=t5_tokenizer)
77
+
78
+ summary = summarizer(
79
+ sample_text,
80
+ max_length=215,
81
+ )
82
+
83
+ return {'summary': summary}
requirements.txt CHANGED
@@ -1,2 +1,8 @@
1
  fastapi
2
- uvicorn
 
 
 
 
 
 
 
1
  fastapi
2
+ uvicorn
3
+ transformers
4
+ moviepy
5
+ librosa
6
+ numpy
7
+ torch
8
+ python-multipart