DSDUDEd commited on
Commit
0516e0c
·
verified ·
1 Parent(s): 58b06a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -33
app.py CHANGED
@@ -1,48 +1,107 @@
1
  import gradio as gr
 
2
  from diffusers import DiffusionPipeline
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import torch
 
 
 
 
 
5
 
6
  # ------------------------------
7
  # Unified Super Model Class
8
  # ------------------------------
9
  class CASS3Beta:
10
  def __init__(self):
11
- # Load image models
12
- self.image_pipes = {
13
- "Lucy": DiffusionPipeline.from_pretrained("decart-ai/Lucy-Edit-Dev").to("cuda"),
14
- "Wan2.2": DiffusionPipeline.from_pretrained("Wan-AI/Wan2.2-Animate-14B").to("cuda"),
15
- "OpenJourney": DiffusionPipeline.from_pretrained("prompthero/openjourney").to("cuda"),
16
- "StableXL": DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0").to("cuda"),
17
- "Wan2.1": DiffusionPipeline.from_pretrained("samuelchristlie/Wan2.1-VACE-1.3B-GGUF").to("cuda")
18
- }
19
 
20
- # Load text models
21
- self.tokenizer_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
22
- self.model_qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B").to("cuda")
23
- self.model_isaac = AutoModelForCausalLM.from_pretrained(
24
- "PerceptronAI/Isaac-0.1", trust_remote_code=True, torch_dtype=torch.float16
25
- ).to("cuda")
26
 
27
- # Generate image from a single model
28
- def generate_image(self, prompt, model_name="StableXL"):
29
- pipe = self.image_pipes[model_name]
30
- return pipe(prompt).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Generate text from a single model
33
- def generate_text(self, prompt, model_name="Qwen"):
34
  if model_name == "Qwen":
35
- inputs = self.tokenizer_qwen(prompt, return_tensors="pt").to(self.model_qwen.device)
36
- outputs = self.model_qwen.generate(**inputs, max_new_tokens=50)
37
- return self.tokenizer_qwen.decode(outputs[0], skip_special_tokens=True)
 
 
38
  elif model_name == "Isaac":
39
- inputs = self.model_isaac.prepare_inputs_for_generation(prompt)
40
- outputs = self.model_isaac.generate(**inputs, max_new_tokens=50)
41
- return self.model_isaac.decode(outputs)
 
 
 
 
 
42
 
43
- # Generate outputs from all models at once
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def generate_all(self, prompt):
45
- images = {name: pipe(prompt).images[0] for name, pipe in self.image_pipes.items()}
 
 
46
  texts = {
47
  "Qwen": self.generate_text(prompt, "Qwen"),
48
  "Isaac": self.generate_text(prompt, "Isaac")
@@ -59,13 +118,10 @@ cass3 = CASS3Beta()
59
  # ------------------------------
60
  def run_cass3(prompt):
61
  images, texts = cass3.generate_all(prompt)
62
-
63
- # Return images in fixed order
64
  image_list = [images[name] for name in ["Lucy", "Wan2.2", "OpenJourney", "StableXL", "Wan2.1"]]
65
-
66
  # Combine text outputs
67
  text_output = f"Qwen:\n{texts['Qwen']}\n\nIsaac:\n{texts['Isaac']}"
68
-
69
  return image_list, text_output
70
 
71
  iface = gr.Interface(
 
1
  import gradio as gr
2
+ import torch
3
  from diffusers import DiffusionPipeline
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ # ------------------------------
7
+ # Utility: Device Detection
8
+ # ------------------------------
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
11
 
12
  # ------------------------------
13
  # Unified Super Model Class
14
  # ------------------------------
15
  class CASS3Beta:
16
  def __init__(self):
17
+ self.image_pipes = {}
18
+ self.text_models = {}
 
 
 
 
 
 
19
 
20
+ # Lazy-load image model
21
+ def load_image_model(self, model_name):
22
+ if model_name in self.image_pipes:
23
+ return self.image_pipes[model_name]
 
 
24
 
25
+ if model_name == "Lucy":
26
+ pipe = DiffusionPipeline.from_pretrained(
27
+ "decart-ai/Lucy-Edit-Dev",
28
+ trust_remote_code=True,
29
+ torch_dtype=TORCH_DTYPE
30
+ ).to(DEVICE)
31
+ elif model_name == "Wan2.2":
32
+ pipe = DiffusionPipeline.from_pretrained(
33
+ "Wan-AI/Wan2.2-Animate-14B",
34
+ trust_remote_code=True,
35
+ torch_dtype=TORCH_DTYPE
36
+ ).to(DEVICE)
37
+ elif model_name == "OpenJourney":
38
+ pipe = DiffusionPipeline.from_pretrained(
39
+ "prompthero/openjourney",
40
+ torch_dtype=TORCH_DTYPE
41
+ ).to(DEVICE)
42
+ elif model_name == "StableXL":
43
+ pipe = DiffusionPipeline.from_pretrained(
44
+ "stabilityai/stable-diffusion-xl-base-1.0",
45
+ torch_dtype=TORCH_DTYPE
46
+ ).to(DEVICE)
47
+ elif model_name == "Wan2.1":
48
+ pipe = DiffusionPipeline.from_pretrained(
49
+ "samuelchristlie/Wan2.1-VACE-1.3B-GGUF",
50
+ torch_dtype=TORCH_DTYPE
51
+ ).to(DEVICE)
52
+ else:
53
+ raise ValueError(f"Unknown image model: {model_name}")
54
+
55
+ self.image_pipes[model_name] = pipe
56
+ return pipe
57
+
58
+ # Lazy-load text model
59
+ def load_text_model(self, model_name):
60
+ if model_name in self.text_models:
61
+ return self.text_models[model_name]
62
 
 
 
63
  if model_name == "Qwen":
64
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ "Qwen/Qwen3-0.6B",
67
+ torch_dtype=TORCH_DTYPE
68
+ ).to(DEVICE)
69
  elif model_name == "Isaac":
70
+ tokenizer = None # Isaac handles tokenization internally
71
+ model = AutoModelForCausalLM.from_pretrained(
72
+ "PerceptronAI/Isaac-0.1",
73
+ trust_remote_code=True,
74
+ torch_dtype=TORCH_DTYPE
75
+ ).to(DEVICE)
76
+ else:
77
+ raise ValueError(f"Unknown text model: {model_name}")
78
 
79
+ self.text_models[model_name] = (tokenizer, model)
80
+ return tokenizer, model
81
+
82
+ # Generate a single image
83
+ def generate_image(self, prompt, model_name):
84
+ pipe = self.load_image_model(model_name)
85
+ return pipe(prompt).images[0]
86
+
87
+ # Generate text
88
+ def generate_text(self, prompt, model_name):
89
+ tokenizer, model = self.load_text_model(model_name)
90
+ if tokenizer:
91
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
92
+ outputs = model.generate(**inputs, max_new_tokens=50)
93
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
94
+ else:
95
+ # Isaac model
96
+ inputs = model.prepare_inputs_for_generation(prompt)
97
+ outputs = model.generate(**inputs, max_new_tokens=50)
98
+ return model.decode(outputs)
99
+
100
+ # Generate outputs from all models
101
  def generate_all(self, prompt):
102
+ image_names = ["Lucy", "Wan2.2", "OpenJourney", "StableXL", "Wan2.1"]
103
+ images = {name: self.generate_image(prompt, name) for name in image_names}
104
+
105
  texts = {
106
  "Qwen": self.generate_text(prompt, "Qwen"),
107
  "Isaac": self.generate_text(prompt, "Isaac")
 
118
  # ------------------------------
119
  def run_cass3(prompt):
120
  images, texts = cass3.generate_all(prompt)
121
+ # List of images in fixed order
 
122
  image_list = [images[name] for name in ["Lucy", "Wan2.2", "OpenJourney", "StableXL", "Wan2.1"]]
 
123
  # Combine text outputs
124
  text_output = f"Qwen:\n{texts['Qwen']}\n\nIsaac:\n{texts['Isaac']}"
 
125
  return image_list, text_output
126
 
127
  iface = gr.Interface(