khulnasoft commited on
Commit
26ec8ac
1 Parent(s): 1092afe

Update models_server.py

Browse files
Files changed (1) hide show
  1. models_server.py +101 -254
models_server.py CHANGED
@@ -1,258 +1,105 @@
1
- start = time.time()
2
-
3
- pipe = pipes[model_id]["model"]
4
-
5
- if "device" in pipes[model_id]:
6
- try:
7
- pipe.to(pipes[model_id]["device"])
8
- except:
9
- pipe.device = torch.device(pipes[model_id]["device"])
10
- pipe.model.to(pipes[model_id]["device"])
11
-
12
- result = None
13
- try:
14
- # text to video
15
- if model_id == "damo-vilab/text-to-video-ms-1.7b":
16
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
17
- # pipe.enable_model_cpu_offload()
18
- prompt = data["text"]
19
- video_frames = pipe(prompt, num_inference_steps=50, num_frames=40).frames
20
- file_name = str(uuid.uuid4())[:4]
21
- video_path = export_to_video(video_frames, f"public/videos/{file_name}.mp4")
22
-
23
- new_file_name = str(uuid.uuid4())[:4]
24
- os.system(f"ffmpeg -i {video_path} -vcodec libx264 public/videos/{new_file_name}.mp4")
25
-
26
- if os.path.exists(f"public/videos/{new_file_name}.mp4"):
27
- result = {"path": f"/videos/{new_file_name}.mp4"}
28
- else:
29
- result = {"path": f"/videos/{file_name}.mp4"}
30
-
31
- # controlnet
32
- if model_id.startswith("lllyasviel/sd-controlnet-"):
33
- pipe.controlnet.to('cpu')
34
- pipe.controlnet = pipes[model_id]["control"].to(pipes[model_id]["device"])
35
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
36
- control_image = load_image(data["img_url"])
37
- # generator = torch.manual_seed(66)
38
- out_image: Image = pipe(data["text"], num_inference_steps=20, image=control_image).images[0]
39
- file_name = str(uuid.uuid4())[:4]
40
- out_image.save(f"public/images/{file_name}.png")
41
- result = {"path": f"/images/{file_name}.png"}
42
-
43
- if model_id.endswith("-control"):
44
- image = load_image(data["img_url"])
45
- if "scribble" in model_id:
46
- control = pipe(image, scribble = True)
47
- elif "canny" in model_id:
48
- control = pipe(image, low_threshold=100, high_threshold=200)
49
- else:
50
- control = pipe(image)
51
- file_name = str(uuid.uuid4())[:4]
52
- control.save(f"public/images/{file_name}.png")
53
- result = {"path": f"/images/{file_name}.png"}
54
-
55
- # image to image
56
- if model_id == "lambdalabs/sd-image-variations-diffusers":
57
- im = load_image(data["img_url"])
58
- file_name = str(uuid.uuid4())[:4]
59
- with open(f"public/images/{file_name}.png", "wb") as f:
60
- f.write(data)
61
- tform = transforms.Compose([
62
- transforms.ToTensor(),
63
- transforms.Resize(
64
- (224, 224),
65
- interpolation=transforms.InterpolationMode.BICUBIC,
66
- antialias=False,
67
- ),
68
- transforms.Normalize(
69
- [0.48145466, 0.4578275, 0.40821073],
70
- [0.26862954, 0.26130258, 0.27577711]),
71
- ])
72
- inp = tform(im).to(pipes[model_id]["device"]).unsqueeze(0)
73
- out = pipe(inp, guidance_scale=3)
74
- out["images"][0].save(f"public/images/{file_name}.jpg")
75
- result = {"path": f"/images/{file_name}.jpg"}
76
-
77
- # image to text
78
- if model_id == "Salesforce/blip-image-captioning-large":
79
- raw_image = load_image(data["img_url"]).convert('RGB')
80
- text = data["text"]
81
- inputs = pipes[model_id]["processor"](raw_image, return_tensors="pt").to(pipes[model_id]["device"])
82
- out = pipe.generate(**inputs)
83
- caption = pipes[model_id]["processor"].decode(out[0], skip_special_tokens=True)
84
- result = {"generated text": caption}
85
- if model_id == "ydshieh/vit-gpt2-coco-en":
86
- img_url = data["img_url"]
87
- generated_text = pipe(img_url)[0]['generated_text']
88
- result = {"generated text": generated_text}
89
- if model_id == "nlpconnect/vit-gpt2-image-captioning":
90
- image = load_image(data["img_url"]).convert("RGB")
91
- pixel_values = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").pixel_values
92
- pixel_values = pixel_values.to(pipes[model_id]["device"])
93
- generated_ids = pipe.generate(pixel_values, **{"max_length": 200, "num_beams": 1})
94
- generated_text = pipes[model_id]["tokenizer"].batch_decode(generated_ids, skip_special_tokens=True)[0]
95
- result = {"generated text": generated_text}
96
- # image to text: OCR
97
- if model_id == "microsoft/trocr-base-printed" or model_id == "microsoft/trocr-base-handwritten":
98
- image = load_image(data["img_url"]).convert("RGB")
99
- pixel_values = pipes[model_id]["processor"](image, return_tensors="pt").pixel_values
100
- pixel_values = pixel_values.to(pipes[model_id]["device"])
101
- generated_ids = pipe.generate(pixel_values)
102
- generated_text = pipes[model_id]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0]
103
- result = {"generated text": generated_text}
104
-
105
- # text to image
106
- if model_id == "runwayml/stable-diffusion-v1-5":
107
- file_name = str(uuid.uuid4())[:4]
108
- text = data["text"]
109
- out = pipe(prompt=text)
110
- out["images"][0].save(f"public/images/{file_name}.jpg")
111
- result = {"path": f"/images/{file_name}.jpg"}
112
-
113
- # object detection
114
- if model_id == "google/owlvit-base-patch32" or model_id == "facebook/detr-resnet-101":
115
- img_url = data["img_url"]
116
- open_types = ["cat", "couch", "person", "car", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird"]
117
- result = pipe(img_url, candidate_labels=open_types)
118
-
119
- # VQA
120
- if model_id == "dandelin/vilt-b32-finetuned-vqa":
121
- question = data["text"]
122
- img_url = data["img_url"]
123
- result = pipe(question=question, image=img_url)
124
-
125
- #DQA
126
- if model_id == "impira/layoutlm-document-qa":
127
- question = data["text"]
128
- img_url = data["img_url"]
129
- result = pipe(img_url, question)
130
-
131
- # depth-estimation
132
- if model_id == "Intel/dpt-large":
133
- output = pipe(data["img_url"])
134
- image = output['depth']
135
- name = str(uuid.uuid4())[:4]
136
- image.save(f"public/images/{name}.jpg")
137
- result = {"path": f"/images/{name}.jpg"}
138
 
139
- if model_id == "Intel/dpt-hybrid-midas" and model_id == "Intel/dpt-large":
140
- image = load_image(data["img_url"])
141
- inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt")
142
- with torch.no_grad():
143
- outputs = pipe(**inputs)
144
- predicted_depth = outputs.predicted_depth
145
- prediction = torch.nn.functional.interpolate(
146
- predicted_depth.unsqueeze(1),
147
- size=image.size[::-1],
148
- mode="bicubic",
149
- align_corners=False,
150
- )
151
- output = prediction.squeeze().cpu().numpy()
152
- formatted = (output * 255 / np.max(output)).astype("uint8")
153
- image = Image.fromarray(formatted)
154
- name = str(uuid.uuid4())[:4]
155
- image.save(f"public/images/{name}.jpg")
156
- result = {"path": f"/images/{name}.jpg"}
157
-
158
- # TTS
159
- if model_id == "espnet/kan-bayashi_ljspeech_vits":
160
- text = data["text"]
161
- wav = pipe(text)["wav"]
162
- name = str(uuid.uuid4())[:4]
163
- sf.write(f"public/audios/{name}.wav", wav.cpu().numpy(), pipe.fs, "PCM_16")
164
- result = {"path": f"/audios/{name}.wav"}
165
-
166
- if model_id == "microsoft/speecht5_tts":
167
- text = data["text"]
168
- inputs = pipes[model_id]["processor"](text=text, return_tensors="pt")
169
- embeddings_dataset = pipes[model_id]["embeddings_dataset"]
170
- speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(pipes[model_id]["device"])
171
- pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
172
- speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
173
- name = str(uuid.uuid4())[:4]
174
- sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
175
- result = {"path": f"/audios/{name}.wav"}
176
 
177
- # ASR
178
- if model_id == "openai/whisper-base" or model_id == "microsoft/speecht5_asr":
179
- audio_url = data["audio_url"]
180
- result = { "text": pipe(audio_url)["text"]}
181
 
182
- # audio to audio
183
- if model_id == "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k":
184
- audio_url = data["audio_url"]
185
- wav, sr = torchaudio.load(audio_url)
186
- with torch.no_grad():
187
- result_wav = pipe(wav.to(pipes[model_id]["device"]))
188
- name = str(uuid.uuid4())[:4]
189
- sf.write(f"public/audios/{name}.wav", result_wav.cpu().squeeze().numpy(), sr)
190
- result = {"path": f"/audios/{name}.wav"}
191
-
192
- if model_id == "microsoft/speecht5_vc":
193
- audio_url = data["audio_url"]
194
- wav, sr = torchaudio.load(audio_url)
195
- inputs = pipes[model_id]["processor"](audio=wav, sampling_rate=sr, return_tensors="pt")
196
- embeddings_dataset = pipes[model_id]["embeddings_dataset"]
197
- speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
198
- pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
199
- speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
200
- name = str(uuid.uuid4())[:4]
201
- sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
202
- result = {"path": f"/audios/{name}.wav"}
203
 
204
- # segmentation
205
- if model_id == "facebook/detr-resnet-50-panoptic":
206
- result = []
207
- segments = pipe(data["img_url"])
208
- image = load_image(data["img_url"])
209
-
210
- colors = []
211
- for i in range(len(segments)):
212
- colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 50))
213
-
214
- for segment in segments:
215
- mask = segment["mask"]
216
- mask = mask.convert('L')
217
- layer = Image.new('RGBA', mask.size, colors[i])
218
- image.paste(layer, (0, 0), mask)
219
- name = str(uuid.uuid4())[:4]
220
- image.save(f"public/images/{name}.jpg")
221
- result = {"path": f"/images/{name}.jpg"}
222
-
223
- if model_id == "facebook/maskformer-swin-base-coco" or model_id == "facebook/maskformer-swin-large-ade":
224
- image = load_image(data["img_url"])
225
- inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").to(pipes[model_id]["device"])
226
- outputs = pipe(**inputs)
227
- result = pipes[model_id]["feature_extractor"].post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
228
- predicted_panoptic_map = result["segmentation"].cpu().numpy()
229
- predicted_panoptic_map = Image.fromarray(predicted_panoptic_map.astype(np.uint8))
230
- name = str(uuid.uuid4())[:4]
231
- predicted_panoptic_map.save(f"public/images/{name}.jpg")
232
- result = {"path": f"/images/{name}.jpg"}
233
-
234
- except Exception as e:
235
- print(e)
236
- traceback.print_exc()
237
- result = {"error": {"message": "Error when running the model inference."}}
238
-
239
- if "device" in pipes[model_id]:
240
- try:
241
- pipe.to("cpu")
242
- torch.cuda.empty_cache()
243
- except:
244
- pipe.device = torch.device("cpu")
245
- pipe.model.to("cpu")
246
- torch.cuda.empty_cache()
247
-
248
- pipes[model_id]["using"] = False
249
-
250
- if result is None:
251
- result = {"error": {"message": "model not found"}}
252
-
253
- end = time.time()
254
- during = end - start
255
- print(f"[ complete {model_id} ] {during}s")
256
- print(f"[ result {model_id} ] {result}")
257
-
258
- return result
 
1
+ import argparse
2
+ import logging
3
+ import random
4
+ import uuid
5
+ import numpy as np
6
+ from transformers import pipeline
7
+ from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
8
+ from diffusers.utils import load_image, export_to_video
9
+ from transformers import (
10
+ SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5ForSpeechToSpeech,
11
+ BlipProcessor, BlipForConditionalGeneration, TrOCRProcessor, VisionEncoderDecoderModel,
12
+ ViTImageProcessor, AutoTokenizer, AutoImageProcessor, TimesformerForVideoClassification,
13
+ MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, DPTForDepthEstimation, DPTFeatureExtractor
14
+ )
15
+ from datasets import load_dataset
16
+ from PIL import Image
17
+ from torchvision import transforms
18
+ import torch
19
+ import torchaudio
20
+ from speechbrain.pretrained import WaveformEnhancement
21
+ import joblib
22
+ from huggingface_hub import hf_hub_url, cached_download
23
+ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDetector, MidasDetector
24
+ import warnings
25
+ import time
26
+ from espnet2.bin.tts_inference import Text2Speech
27
+ import soundfile as sf
28
+ from asteroid.models import BaseModel
29
+ import traceback
30
+ import os
31
+ import yaml
32
+
33
+ warnings.filterwarnings("ignore")
34
+
35
+ def setup_logger():
36
+ logger = logging.getLogger(__name__)
37
+ logger.setLevel(logging.INFO)
38
+ handler = logging.StreamHandler()
39
+ handler.setLevel(logging.INFO)
40
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
41
+ handler.setFormatter(formatter)
42
+ logger.addHandler(handler)
43
+ return logger
44
+
45
+ logger = setup_logger()
46
+
47
+ def load_config(config_path):
48
+ with open(config_path, "r") as file:
49
+ return yaml.load(file, Loader=yaml.FullLoader)
50
+
51
+ def parse_args():
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument("--config", type=str, default="config.yaml")
54
+ return parser.parse_args()
55
+
56
+ args = parse_args()
57
+
58
+ # Ensure the config is always set when not running as the main script
59
+ if __name__ != "__main__":
60
+ args.config = "config.gradio.yaml"
61
+
62
+ config = load_config(args.config)
63
+
64
+ local_deployment = config["local_deployment"]
65
+ if config["inference_mode"] == "huggingface":
66
+ local_deployment = "none"
67
+
68
+ PROXY = {"https": config["proxy"]} if config["proxy"] else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ start = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ local_models = "" # Changed to empty string
 
 
 
73
 
74
+ def load_pipes(local_deployment):
75
+ standard_pipes = {}
76
+ other_pipes = {}
77
+ controlnet_sd_pipes = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ if local_deployment in ["full"]:
80
+ other_pipes = {
81
+ "damo-vilab/text-to-video-ms-1.7b": {
82
+ "model": DiffusionPipeline.from_pretrained(f"{local_models}damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"),
83
+ "device": "cuda:0"
84
+ },
85
+ "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": {
86
+ "model": BaseModel.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"),
87
+ "device": "cuda:0"
88
+ },
89
+ "microsoft/speecht5_vc": {
90
+ "processor": SpeechT5Processor.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
91
+ "model": SpeechT5ForSpeechToSpeech.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
92
+ "vocoder": SpeechT5HifiGan.from_pretrained(f"{local_models}microsoft/speecht5_hifigan"),
93
+ "embeddings_dataset": load_dataset(f"{local_models}Matthijs/cmu-arctic-xvectors", split="validation"),
94
+ "device": "cuda:0"
95
+ },
96
+ "facebook/maskformer-swin-base-coco": {
97
+ "feature_extractor": MaskFormerFeatureExtractor.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
98
+ "model": MaskFormerForInstanceSegmentation.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
99
+ "device": "cuda:0"
100
+ },
101
+ "Intel/dpt-hybrid-midas": {
102
+ "model": DPTForDepthEstimation.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas", low_cpu_mem_usage=True),
103
+ "feature_extractor": DPTFeatureExtractor.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas"),
104
+ "device": "cuda:0"
105
+ }