DongfuJiang commited on
Commit
147b3b6
·
2 Parent(s): c5c1fe7 82a954e

Merge branch 'main' of https://huggingface.co/spaces/TIGER-Lab/GenAI-Arena

Browse files
model/model_manager.py CHANGED
@@ -18,6 +18,7 @@ class ModelManager:
18
  self.model_vg_list = VIDEO_GENERATION_MODELS
19
  self.excluding_model_list = MUSEUM_UNSUPPORTED_MODELS
20
  self.desired_model_list = DESIRED_APPEAR_MODEL
 
21
  self.loaded_models = {}
22
 
23
  def load_model_pipe(self, model_name):
@@ -28,23 +29,27 @@ class ModelManager:
28
  pipe = self.loaded_models[model_name]
29
  return pipe
30
 
31
- @spaces.GPU(duration=20)
32
- def NSFW_filter(self, prompt):
33
  model_id = "meta-llama/Meta-Llama-Guard-2-8B"
34
- device = "cuda"
35
  dtype = torch.bfloat16
36
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_GUARD'])
37
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device, token=os.environ['HF_GUARD'])
 
 
 
38
  chat = [{"role": "user", "content": prompt}]
39
- input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
40
- output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
 
41
  prompt_len = input_ids.shape[-1]
42
- result = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
43
  return result
44
 
45
  @spaces.GPU(duration=120)
46
  def generate_image_ig(self, prompt, model_name):
47
  if self.NSFW_filter(prompt) == 'safe':
 
48
  pipe = self.load_model_pipe(model_name)
49
  result = pipe(prompt=prompt)
50
  else:
@@ -53,6 +58,7 @@ class ModelManager:
53
 
54
  def generate_image_ig_api(self, prompt, model_name):
55
  if self.NSFW_filter(prompt) == 'safe':
 
56
  pipe = self.load_model_pipe(model_name)
57
  result = pipe(prompt=prompt)
58
  else:
@@ -119,11 +125,11 @@ class ModelManager:
119
 
120
  @spaces.GPU(duration=200)
121
  def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
122
- if self.NSFW_filter(" ".join([textbox_source, textbox_target, textbox_instruct])) == 'safe':
123
- pipe = self.load_model_pipe(model_name)
124
- result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
125
- else:
126
- result = ''
127
  return result
128
 
129
  def generate_image_ie_museum(self, model_name):
@@ -187,19 +193,19 @@ class ModelManager:
187
 
188
  @spaces.GPU(duration=150)
189
  def generate_video_vg(self, prompt, model_name):
190
- if self.NSFW_filter(prompt) == 'safe':
191
- pipe = self.load_model_pipe(model_name)
192
- result = pipe(prompt=prompt)
193
- else:
194
- result = ''
195
  return result
196
 
197
  def generate_video_vg_api(self, prompt, model_name):
198
- if self.NSFW_filter(prompt) == 'safe':
199
- pipe = self.load_model_pipe(model_name)
200
- result = pipe(prompt=prompt)
201
- else:
202
- result = ''
203
  return result
204
 
205
  def generate_video_vg_museum(self, model_name):
 
18
  self.model_vg_list = VIDEO_GENERATION_MODELS
19
  self.excluding_model_list = MUSEUM_UNSUPPORTED_MODELS
20
  self.desired_model_list = DESIRED_APPEAR_MODEL
21
+ self.load_guard()
22
  self.loaded_models = {}
23
 
24
  def load_model_pipe(self, model_name):
 
29
  pipe = self.loaded_models[model_name]
30
  return pipe
31
 
32
+ def load_guard(self):
 
33
  model_id = "meta-llama/Meta-Llama-Guard-2-8B"
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
  dtype = torch.bfloat16
36
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_GUARD'])
37
+ self.guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device, token=os.environ['HF_GUARD'])
38
+
39
+ @spaces.GPU(duration=30)
40
+ def NSFW_filter(self, prompt):
41
  chat = [{"role": "user", "content": prompt}]
42
+ input_ids = self.tokenizer.apply_chat_template(chat, return_tensors="pt").to('cuda')
43
+ self.guard.cuda()
44
+ output = self.guard.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
45
  prompt_len = input_ids.shape[-1]
46
+ result = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
47
  return result
48
 
49
  @spaces.GPU(duration=120)
50
  def generate_image_ig(self, prompt, model_name):
51
  if self.NSFW_filter(prompt) == 'safe':
52
+ print('The prompt is safe')
53
  pipe = self.load_model_pipe(model_name)
54
  result = pipe(prompt=prompt)
55
  else:
 
58
 
59
  def generate_image_ig_api(self, prompt, model_name):
60
  if self.NSFW_filter(prompt) == 'safe':
61
+ print('The prompt is safe')
62
  pipe = self.load_model_pipe(model_name)
63
  result = pipe(prompt=prompt)
64
  else:
 
125
 
126
  @spaces.GPU(duration=200)
127
  def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
128
+ # if self.NSFW_filter(" ".join([textbox_source, textbox_target, textbox_instruct])) == 'safe':
129
+ pipe = self.load_model_pipe(model_name)
130
+ result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
131
+ # else:
132
+ # result = ''
133
  return result
134
 
135
  def generate_image_ie_museum(self, model_name):
 
193
 
194
  @spaces.GPU(duration=150)
195
  def generate_video_vg(self, prompt, model_name):
196
+ # if self.NSFW_filter(prompt) == 'safe':
197
+ pipe = self.load_model_pipe(model_name)
198
+ result = pipe(prompt=prompt)
199
+ # else:
200
+ # result = ''
201
  return result
202
 
203
  def generate_video_vg_api(self, prompt, model_name):
204
+ # if self.NSFW_filter(prompt) == 'safe':
205
+ pipe = self.load_model_pipe(model_name)
206
+ result = pipe(prompt=prompt)
207
+ # else:
208
+ # result = ''
209
  return result
210
 
211
  def generate_video_vg_museum(self, model_name):
model/model_registry.py CHANGED
@@ -285,6 +285,20 @@ register_model_info(
285
  "https://github.com/hpcaitech/Open-Sora",
286
  "A community-driven opensource implementation of Sora.",
287
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  register_model_info(
290
  ["videogenhub_T2VTurbo_generation"],
 
285
  "https://github.com/hpcaitech/Open-Sora",
286
  "A community-driven opensource implementation of Sora.",
287
  )
288
+
289
+ register_model_info(
290
+ ["videogenhub_OpenSora12_generation"],
291
+ "OpenSora v1.2",
292
+ "https://github.com/hpcaitech/Open-Sora",
293
+ "A community-driven opensource implementation of Sora. v1.2",
294
+ )
295
+
296
+ register_model_info(
297
+ ["videogenhub_CogVideoX_generation"],
298
+ "CogVideoX",
299
+ "https://github.com/THUDM/CogVideo",
300
+ "Text-to-Video Diffusion Models with An Expert Transformer.",
301
+ )
302
 
303
  register_model_info(
304
  ["videogenhub_T2VTurbo_generation"],
model/models/__init__.py CHANGED
@@ -19,7 +19,7 @@ VIDEO_GENERATION_MODELS = ['fal_AnimateDiff_text2video',
19
  'fal_AnimateDiffTurbo_text2video',
20
  'videogenhub_LaVie_generation',
21
  'videogenhub_VideoCrafter2_generation',
22
- 'videogenhub_ModelScope_generation',
23
  'videogenhub_OpenSora_generation', 'videogenhub_T2VTurbo_generation','fal_StableVideoDiffusion_text2video']
24
  MUSEUM_UNSUPPORTED_MODELS = ['videogenhub_OpenSoraPlan_generation']
25
  DESIRED_APPEAR_MODEL = ['videogenhub_T2VTurbo_generation','fal_StableVideoDiffusion_text2video']
 
19
  'fal_AnimateDiffTurbo_text2video',
20
  'videogenhub_LaVie_generation',
21
  'videogenhub_VideoCrafter2_generation',
22
+ 'videogenhub_ModelScope_generation', 'videogenhub_CogVideoX_generation', 'videogenhub_OpenSora12_generation',
23
  'videogenhub_OpenSora_generation', 'videogenhub_T2VTurbo_generation','fal_StableVideoDiffusion_text2video']
24
  MUSEUM_UNSUPPORTED_MODELS = ['videogenhub_OpenSoraPlan_generation']
25
  DESIRED_APPEAR_MODEL = ['videogenhub_T2VTurbo_generation','fal_StableVideoDiffusion_text2video']
requirements.txt CHANGED
@@ -7,7 +7,7 @@ h5py
7
  xformers~=0.0.20
8
  numpy>=1.23.5
9
  pandas<2.0.0
10
- peft
11
  torch==2.2
12
  torchvision
13
  torchaudio
@@ -28,7 +28,6 @@ setuptools>=59.5.0
28
  transformers
29
  torchmetrics>=0.6.0
30
  lpips
31
- dreamsim
32
  image-reward
33
  kornia>=0.6
34
  diffusers>=0.18.0
@@ -49,7 +48,7 @@ statsmodels
49
  plotly
50
  git+https://github.com/TIGER-AI-Lab/ImagenHub.git#egg=imagen-hub
51
  fal_client
52
- -e git+https://github.com/TIGER-AI-Lab/VideoGenHub.git@arena#egg=videogen-hub
53
  open_clip_torch
54
  decord
55
  huggingface_hub
 
7
  xformers~=0.0.20
8
  numpy>=1.23.5
9
  pandas<2.0.0
10
+ peft>=0.12
11
  torch==2.2
12
  torchvision
13
  torchaudio
 
28
  transformers
29
  torchmetrics>=0.6.0
30
  lpips
 
31
  image-reward
32
  kornia>=0.6
33
  diffusers>=0.18.0
 
48
  plotly
49
  git+https://github.com/TIGER-AI-Lab/ImagenHub.git#egg=imagen-hub
50
  fal_client
51
+ git+https://github.com/TIGER-AI-Lab/VideoGenHub.git@arena#egg=videogen-hub
52
  open_clip_torch
53
  decord
54
  huggingface_hub