animikhaich commited on
Commit
d50bd1e
·
1 Parent(s): 5978ae3

Incomplete Update

Browse files
engine/audio_generator.py CHANGED
@@ -1 +1,162 @@
1
- # TODO: Add from model server
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ warnings.simplefilter("ignore")
5
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
6
+ import torch
7
+ import numpy as np
8
+ from audiocraft.models import musicgen
9
+ from scipy.io.wavfile import write as wav_write
10
+
11
+ try:
12
+ from logger import logging
13
+ except:
14
+ import logging
15
+
16
+
17
+ class GenerateAudio:
18
+ def __init__(self, model="musicgen-stereo-small"):
19
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ self.model_name = self.get_model_name(model)
21
+ self.model = self.get_model(self.model_name, self.device)
22
+
23
+ @staticmethod
24
+ def get_model(model, device):
25
+ try:
26
+ model = musicgen.MusicGen.get_pretrained(model, device=device)
27
+ logging.info(f"Loaded model: {model}")
28
+ return model
29
+ except Exception as e:
30
+ logging.error(f"Failed to load model: {e}")
31
+ raise ValueError(f"Failed to load model: {e}")
32
+ return
33
+
34
+ @staticmethod
35
+ def get_model_name(model_name):
36
+ if model_name.startswith("facebook/"):
37
+ return model_name
38
+ return f"facebook/{model_name}"
39
+
40
+ def generate_audio(self, prompts, duration=30):
41
+ try:
42
+ self.model.set_generation_params(duration=duration)
43
+ result = self.model.generate(prompts, progress=False)
44
+ result = result.squeeze().cpu().numpy().T
45
+ sample_rate = self.model.sample_rate
46
+ logging.info(
47
+ f"Generated audio with shape: {result.shape}, sample rate: {sample_rate} Hz"
48
+ )
49
+ return sample_rate, result
50
+ except Exception as e:
51
+ logging.error(f"Failed to generate audio: {e}")
52
+ raise ValueError(f"Failed to generate audio: {e}")
53
+
54
+
55
+
56
+
57
+ # Parse command line arguments
58
+ parser = argparse.ArgumentParser(description="Music Generation Server")
59
+ parser.add_argument(
60
+ "--model", type=str, default="musicgen-stereo-small", help="Pretrained model name"
61
+ )
62
+ parser.add_argument(
63
+ "--device", type=str, default="cuda", help="Device to load the model on"
64
+ )
65
+ parser.add_argument(
66
+ "--duration", type=int, default=10, help="Duration of generated music in seconds"
67
+ )
68
+ parser.add_argument(
69
+ "--host", type=str, default="0.0.0.0", help="Host to run the server on"
70
+ )
71
+ parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
72
+
73
+ args = parser.parse_args()
74
+
75
+
76
+ # Initialize the FastAPI app
77
+ app = FastAPI()
78
+
79
+ # Build the model name based on the provided arguments
80
+ if args.model.startswith("facebook/"):
81
+ args.model_name = args.model
82
+ else:
83
+ args.model_name = f"facebook/{args.model}"
84
+
85
+
86
+ logging.info(f"Initializing Model Server with Settings: {args}")
87
+
88
+ # Load the model with the provided arguments
89
+ try:
90
+ musicgen_model = musicgen.MusicGen.get_pretrained(
91
+ args.model_name, device=args.device
92
+ )
93
+ model_loaded = True
94
+ logging.info(f"Model Loaded: {args.model_name}")
95
+ except Exception as e:
96
+ logging.error(f"Failed to load model: {e}")
97
+ musicgen_model = None
98
+ model_loaded = False
99
+
100
+
101
+ class MusicRequest(BaseModel):
102
+ prompts: List[str]
103
+ duration: Optional[int] = 10 # Default duration is 10 seconds if not provided
104
+
105
+
106
+ @app.get("/generate_music")
107
+ def generate_music(request: MusicRequest):
108
+
109
+ if not model_loaded:
110
+ raise HTTPException(status_code=500, detail="Model is not loaded.")
111
+
112
+ try:
113
+ logging.info(
114
+ f"Generating music with prompts: {request.prompts}, duration: {request.duration} seconds"
115
+ )
116
+
117
+ musicgen_model.set_generation_params(duration=request.duration)
118
+ result = musicgen_model.generate(request.prompts, progress=False)
119
+ result = result.squeeze().cpu().numpy().T
120
+
121
+ sample_rate = musicgen_model.sample_rate
122
+
123
+ logging.info(
124
+ f"Music generated with shape: {result.shape}, sample rate: {sample_rate} Hz"
125
+ )
126
+
127
+ buffer = io.BytesIO()
128
+ wav_write(buffer, sample_rate, result)
129
+ buffer.seek(0)
130
+ return StreamingResponse(buffer, media_type="audio/wav")
131
+ except Exception as e:
132
+ logging.error(f"Failed to generate music: {e}")
133
+ raise HTTPException(status_code=500, detail=str(e))
134
+
135
+
136
+ @app.get("/health")
137
+ def health_check():
138
+ cpu_usage = psutil.cpu_percent(interval=1)
139
+ ram_usage = psutil.virtual_memory().percent
140
+ stats = {
141
+ "server_running": True,
142
+ "model_loaded": model_loaded,
143
+ "cpu_usage_percent": cpu_usage,
144
+ "ram_usage_percent": ram_usage,
145
+ }
146
+ if args.device == "cuda" and torch.cuda.is_available():
147
+ gpu_memory_allocated = memory_allocated()
148
+ gpu_memory_reserved = memory_reserved()
149
+ stats.update(
150
+ {
151
+ "gpu_memory_allocated": gpu_memory_allocated,
152
+ "gpu_memory_reserved": gpu_memory_reserved,
153
+ }
154
+ )
155
+
156
+ logging.info(f"Health Check: {stats}")
157
+
158
+ return JSONResponse(content=stats)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ uvicorn.run("main:app", host=args.host, port=args.port, reload=False, workers=1)
engine/video_descriptor.py CHANGED
@@ -1,8 +1,7 @@
 
1
  from warnings import simplefilter
2
 
3
  simplefilter("ignore")
4
- import os
5
-
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
  import json
8
  import time
@@ -78,6 +77,9 @@ class DescribeVideo:
78
 
79
  return json.loads(cleaned_response.text.strip("```json\n"))
80
 
 
 
 
81
  def reset_safety_settings(self):
82
  logging.info("Resetting safety settings")
83
  self.is_safety_set = False
 
1
+ import os
2
  from warnings import simplefilter
3
 
4
  simplefilter("ignore")
 
 
5
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
6
  import json
7
  import time
 
77
 
78
  return json.loads(cleaned_response.text.strip("```json\n"))
79
 
80
+ def __call__(self, video_path):
81
+ return self.describe_video(video_path)
82
+
83
  def reset_safety_settings(self):
84
  logging.info("Resetting safety settings")
85
  self.is_safety_set = False