import concurrent.futures import random import gradio as gr import requests import io, base64, json, os import spaces from PIL import Image from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, MUSEUM_UNSUPPORTED_MODELS, DESIRED_APPEAR_MODEL, load_pipeline from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum, draw_from_videogen_museum, draw2_from_videogen_museum from transformers import AutoTokenizer, AutoModelForCausalLM import torch class ModelManager: def __init__(self, enable_nsfw=True): self.model_ig_list = IMAGE_GENERATION_MODELS self.model_ie_list = IMAGE_EDITION_MODELS self.model_vg_list = VIDEO_GENERATION_MODELS self.excluding_model_list = MUSEUM_UNSUPPORTED_MODELS self.desired_model_list = DESIRED_APPEAR_MODEL self.enable_nsfw = enable_nsfw self.load_guard(enable_nsfw) self.loaded_models = {} def load_model_pipe(self, model_name): if not model_name in self.loaded_models: pipe = load_pipeline(model_name) self.loaded_models[model_name] = pipe else: pipe = self.loaded_models[model_name] return pipe def load_guard(self, enable_nsfw=True): model_id = "meta-llama/Llama-Guard-3-8B" device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if enable_nsfw: self.guard_tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_GUARD']) self.guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device, token=os.environ['HF_GUARD']) else: self.guard_tokenizer = None self.guard = None def NSFW_filter(self, prompt): chat = [{"role": "user", "content": prompt}] input_ids = self.guard_tokenizer.apply_chat_template(chat, return_tensors="pt").to('cuda') self.guard.cuda() if self.guard: @spaces.GPU(duration=30) def _generate(): return self.guard.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0) output = _generate() output = self.guard.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0) prompt_len = input_ids.shape[-1] result = self.guard_tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True) return result else: # guard is disabled return "safe" @spaces.GPU(duration=120) def generate_image_ig(self, prompt, model_name): if 'unsafe' not in self.NSFW_filter(prompt): print('The prompt is safe') pipe = self.load_model_pipe(model_name) result = pipe(prompt=prompt) else: result = '' return result def generate_image_ig_api(self, prompt, model_name): if 'unsafe' not in self.NSFW_filter(prompt): print('The prompt is safe') pipe = self.load_model_pipe(model_name) result = pipe(prompt=prompt) else: result = '' return result def generate_image_ig_museum(self, model_name): model_name = model_name.split('_')[1] result_list = draw_from_imagen_museum("t2i", model_name) image_link = result_list[0] prompt = result_list[1] return image_link, prompt def generate_image_ig_parallel_anony(self, prompt, model_A, model_B): # Using list comprehension to get the difference between two lists picking_list = [item for item in self.model_ig_list if item not in self.excluding_model_list] if model_A == "" and model_B == "": model_names = random.sample([model for model in picking_list], 2) else: model_names = [model_A, model_B] with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("imagenhub") else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names] results = [future.result() for future in futures] return results[0], results[1], model_names[0], model_names[1] def generate_image_ig_museum_parallel_anony(self, model_A, model_B): # Using list comprehension to get the difference between two lists picking_list = [item for item in self.model_ig_list if item not in self.excluding_model_list] if model_A == "" and model_B == "": model_names = random.sample([model for model in picking_list], 2) else: model_names = [model_A, model_B] with concurrent.futures.ThreadPoolExecutor() as executor: model_1 = model_names[0].split('_')[1] model_2 = model_names[1].split('_')[1] result_list = draw2_from_imagen_museum("t2i", model_1, model_2) image_links = result_list[0] prompt_list = result_list[1] return image_links[0], image_links[1], model_names[0], model_names[1], prompt_list[0] def generate_image_ig_parallel(self, prompt, model_A, model_B): model_names = [model_A, model_B] with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("imagenhub") else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names] results = [future.result() for future in futures] return results[0], results[1] def generate_image_ig_museum_parallel(self, model_A, model_B): with concurrent.futures.ThreadPoolExecutor() as executor: model_1 = model_A.split('_')[1] model_2 = model_B.split('_')[1] result_list = draw2_from_imagen_museum("t2i", model_1, model_2) image_links = result_list[0] prompt_list = result_list[1] return image_links[0], image_links[1], prompt_list[0] @spaces.GPU(duration=200) def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name): # if 'unsafe' not in self.NSFW_filter(" ".join([textbox_source, textbox_target, textbox_instruct])): pipe = self.load_model_pipe(model_name) result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct) # else: # result = '' return result def generate_image_ie_museum(self, model_name): model_name = model_name.split('_')[1] result_list = draw_from_imagen_museum("tie", model_name) image_links = result_list[0] prompt_list = result_list[1] # image_links = [src, model] # prompt_list = [source_caption, target_caption, instruction] return image_links[0], image_links[1], prompt_list[0], prompt_list[1], prompt_list[2] def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B): model_names = [model_A, model_B] with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, model) for model in model_names] results = [future.result() for future in futures] return results[0], results[1] def generate_image_ie_museum_parallel(self, model_A, model_B): model_names = [model_A, model_B] with concurrent.futures.ThreadPoolExecutor() as executor: model_1 = model_names[0].split('_')[1] model_2 = model_names[1].split('_')[1] result_list = draw2_from_imagen_museum("tie", model_1, model_2) image_links = result_list[0] prompt_list = result_list[1] # image_links = [src, model_A, model_B] # prompt_list = [source_caption, target_caption, instruction] return image_links[0], image_links[1], image_links[2], prompt_list[0], prompt_list[1], prompt_list[2] def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B): # Using list comprehension to get the difference between two lists picking_list = [item for item in self.model_ie_list if item not in self.excluding_model_list] if model_A == "" and model_B == "": model_names = random.sample([model for model in picking_list], 2) else: model_names = [model_A, model_B] with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, model) for model in model_names] results = [future.result() for future in futures] return results[0], results[1], model_names[0], model_names[1] def generate_image_ie_museum_parallel_anony(self, model_A, model_B): # Using list comprehension to get the difference between two lists picking_list = [item for item in self.model_ie_list if item not in self.excluding_model_list] if model_A == "" and model_B == "": model_names = random.sample([model for model in picking_list], 2) else: model_names = [model_A, model_B] with concurrent.futures.ThreadPoolExecutor() as executor: model_1 = model_names[0].split('_')[1] model_2 = model_names[1].split('_')[1] result_list = draw2_from_imagen_museum("tie", model_1, model_2) image_links = result_list[0] prompt_list = result_list[1] # image_links = [src, model_A, model_B] # prompt_list = [source_caption, target_caption, instruction] return image_links[0], image_links[1], image_links[2], prompt_list[0], prompt_list[1], prompt_list[2], model_names[0], model_names[1] @spaces.GPU(duration=150) def generate_video_vg(self, prompt, model_name): # if 'unsafe' not in self.NSFW_filter(prompt): pipe = self.load_model_pipe(model_name) result = pipe(prompt=prompt) # else: # result = '' return result def generate_video_vg_api(self, prompt, model_name): # if 'unsafe' not in self.NSFW_filter(prompt): pipe = self.load_model_pipe(model_name) result = pipe(prompt=prompt) # else: # result = '' return result def generate_video_vg_museum(self, model_name): model_name = model_name.split('_')[1] result_list = draw_from_videogen_museum("t2v", model_name) video_link = result_list[0] prompt = result_list[1] return video_link, prompt def generate_video_vg_parallel_anony(self, prompt, model_A, model_B): # Using list comprehension to get the difference between two lists picking_list = [item for item in self.model_vg_list if item not in self.excluding_model_list] if model_A == "" and model_B == "": model_names = random.sample([model for model in picking_list], 2) else: model_names = [model_A, model_B] with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(self.generate_video_vg, prompt, model) if model.startswith("videogenhub") else executor.submit(self.generate_video_vg_api, prompt, model) for model in model_names] results = [future.result() for future in futures] return results[0], results[1], model_names[0], model_names[1] def generate_video_vg_museum_parallel_anony(self, model_A, model_B): # Using list comprehension to get the difference between two lists picking_list = [item for item in self.model_vg_list if item not in self.excluding_model_list] #picking_list = [item for item in picking_list if item not in self.desired_model_list] if model_A == "" and model_B == "": model_names = random.sample([model for model in picking_list], 2) #override the random selection #model_names[random.choice([0, 1])] = random.choice(self.desired_model_list) else: model_names = [model_A, model_B] with concurrent.futures.ThreadPoolExecutor() as executor: model_1 = model_names[0].split('_')[1] model_2 = model_names[1].split('_')[1] result_list = draw2_from_videogen_museum("t2v", model_1, model_2) video_links = result_list[0] prompt_list = result_list[1] return video_links[0], video_links[1], model_names[0], model_names[1], prompt_list[0] def generate_video_vg_parallel(self, prompt, model_A, model_B): model_names = [model_A, model_B] with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(self.generate_video_vg, prompt, model) if model.startswith("videogenhub") else executor.submit(self.generate_video_vg_api, prompt, model) for model in model_names] results = [future.result() for future in futures] return results[0], results[1] def generate_video_vg_museum_parallel(self, model_A, model_B): model_names = [model_A, model_B] with concurrent.futures.ThreadPoolExecutor() as executor: model_1 = model_A.split('_')[1] model_2 = model_B.split('_')[1] result_list = draw2_from_videogen_museum("t2v", model_1, model_2) video_links = result_list[0] prompt_list = result_list[1] return video_links[0], video_links[1], prompt_list[0]