Spaces:
Running
Running
Ffftdtd5dtft
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
from PIL import Image
|
6 |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
|
7 |
from diffusers.utils import export_to_video
|
8 |
-
from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer
|
9 |
from audiocraft.models import musicgen
|
10 |
import gradio as gr
|
11 |
from huggingface_hub import snapshot_download, HfApi, HfFolder
|
@@ -84,6 +84,7 @@ def get_model_or_download(model_id, redis_key, loader_func):
|
|
84 |
save_object_to_redis(redis_key, model)
|
85 |
model_bytes = pickle.dumps(model)
|
86 |
upload_to_gcs(gcs_bucket_name, redis_key, model_bytes)
|
|
|
87 |
except Exception as e:
|
88 |
print(f"Failed to load or save model: {e}")
|
89 |
return None
|
@@ -113,7 +114,7 @@ def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
|
|
113 |
try:
|
114 |
image = Image.open(io.BytesIO(image_bytes))
|
115 |
with tqdm(total=1, desc="Editing image") as pbar:
|
116 |
-
edited_image = img2img_pipeline(prompt=prompt,
|
117 |
pbar.update(1)
|
118 |
buffered = io.BytesIO()
|
119 |
edited_image.save(buffered, format="JPEG")
|
@@ -131,7 +132,7 @@ def generate_song(prompt, duration=10):
|
|
131 |
if not song_bytes:
|
132 |
try:
|
133 |
with tqdm(total=1, desc="Generating song") as pbar:
|
134 |
-
song = music_gen
|
135 |
pbar.update(1)
|
136 |
song_bytes = song[0].getvalue()
|
137 |
save_object_to_redis(redis_key, song_bytes)
|
@@ -166,7 +167,7 @@ def generate_flux_image(prompt):
|
|
166 |
prompt,
|
167 |
guidance_scale=0.0,
|
168 |
num_inference_steps=4,
|
169 |
-
|
170 |
generator=torch.Generator("cpu").manual_seed(0)
|
171 |
).images[0]
|
172 |
pbar.update(1)
|
@@ -197,24 +198,6 @@ def generate_code(prompt):
|
|
197 |
return None
|
198 |
return code
|
199 |
|
200 |
-
def generate_video(prompt):
|
201 |
-
redis_key = f"generated_video:{prompt}"
|
202 |
-
video = load_object_from_redis(redis_key)
|
203 |
-
if not video:
|
204 |
-
try:
|
205 |
-
with tqdm(total=1, desc="Generating video") as pbar:
|
206 |
-
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16)
|
207 |
-
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
208 |
-
pipe.enable_model_cpu_offload()
|
209 |
-
video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
|
210 |
-
pbar.update(1)
|
211 |
-
save_object_to_redis(redis_key, video)
|
212 |
-
upload_to_gcs(gcs_bucket_name, redis_key, video.encode())
|
213 |
-
except Exception as e:
|
214 |
-
print(f"Failed to generate video: {e}")
|
215 |
-
return None
|
216 |
-
return video
|
217 |
-
|
218 |
def test_model_meta_llama():
|
219 |
redis_key = "meta_llama_test_response"
|
220 |
response = load_object_from_redis(redis_key)
|
@@ -234,55 +217,15 @@ def test_model_meta_llama():
|
|
234 |
return None
|
235 |
return response
|
236 |
|
237 |
-
def train_model(model, dataset, epochs, batch_size, learning_rate):
|
238 |
-
output_dir = io.BytesIO()
|
239 |
-
training_args = TrainingArguments(
|
240 |
-
output_dir=output_dir,
|
241 |
-
num_train_epochs=epochs,
|
242 |
-
per_device_train_batch_size=batch_size,
|
243 |
-
learning_rate=learning_rate,
|
244 |
-
)
|
245 |
-
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
246 |
-
try:
|
247 |
-
with tqdm(total=epochs, desc="Training model") as pbar:
|
248 |
-
trainer.train()
|
249 |
-
pbar.update(epochs)
|
250 |
-
save_object_to_redis("trained_model", model)
|
251 |
-
save_object_to_redis("training_results", output_dir.getvalue())
|
252 |
-
upload_to_gcs(gcs_bucket_name, "trained_model", pickle.dumps(model))
|
253 |
-
upload_to_gcs(gcs_bucket_name, "training_results", output_dir.getvalue())
|
254 |
-
except Exception as e:
|
255 |
-
print(f"Failed to train model: {e}")
|
256 |
-
|
257 |
-
def run_task(task_queue):
|
258 |
-
while True:
|
259 |
-
task = task_queue.get()
|
260 |
-
if task is None:
|
261 |
-
break
|
262 |
-
func, args, kwargs = task
|
263 |
-
try:
|
264 |
-
func(*args, **kwargs)
|
265 |
-
except Exception as e:
|
266 |
-
print(f"Failed to run task: {e}")
|
267 |
-
|
268 |
-
task_queue = multiprocessing.Queue()
|
269 |
-
num_processes = multiprocessing.cpu_count()
|
270 |
-
|
271 |
-
processes = []
|
272 |
-
for _ in range(num_processes):
|
273 |
-
p = multiprocessing.Process(target=run_task, args=(task_queue,))
|
274 |
-
p.start()
|
275 |
-
processes.append(p)
|
276 |
-
|
277 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
278 |
|
279 |
text_to_image_pipeline = get_model_or_download("stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained)
|
280 |
img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained)
|
281 |
flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
|
282 |
text_gen_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b")
|
283 |
-
music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained('melody')
|
284 |
meta_llama_pipeline = get_model_or_download("meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline)
|
285 |
-
starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder")
|
286 |
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
|
287 |
|
288 |
gen_image_tab = gr.Interface(fn=generate_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate Image")
|
@@ -298,9 +241,4 @@ app = gr.TabbedInterface(
|
|
298 |
["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Generate Code", "Test Meta-Llama"]
|
299 |
)
|
300 |
|
301 |
-
app.launch(share=True)
|
302 |
-
|
303 |
-
for _ in range(num_processes):
|
304 |
-
task_queue.put(None)
|
305 |
-
for p in processes:
|
306 |
-
p.join()
|
|
|
5 |
from PIL import Image
|
6 |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
|
7 |
from diffusers.utils import export_to_video
|
8 |
+
from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer
|
9 |
from audiocraft.models import musicgen
|
10 |
import gradio as gr
|
11 |
from huggingface_hub import snapshot_download, HfApi, HfFolder
|
|
|
84 |
save_object_to_redis(redis_key, model)
|
85 |
model_bytes = pickle.dumps(model)
|
86 |
upload_to_gcs(gcs_bucket_name, redis_key, model_bytes)
|
87 |
+
return model
|
88 |
except Exception as e:
|
89 |
print(f"Failed to load or save model: {e}")
|
90 |
return None
|
|
|
114 |
try:
|
115 |
image = Image.open(io.BytesIO(image_bytes))
|
116 |
with tqdm(total=1, desc="Editing image") as pbar:
|
117 |
+
edited_image = img2img_pipeline(prompt=prompt, image=image, strength=strength).images[0]
|
118 |
pbar.update(1)
|
119 |
buffered = io.BytesIO()
|
120 |
edited_image.save(buffered, format="JPEG")
|
|
|
132 |
if not song_bytes:
|
133 |
try:
|
134 |
with tqdm(total=1, desc="Generating song") as pbar:
|
135 |
+
song = music_gen(prompt, duration=duration)
|
136 |
pbar.update(1)
|
137 |
song_bytes = song[0].getvalue()
|
138 |
save_object_to_redis(redis_key, song_bytes)
|
|
|
167 |
prompt,
|
168 |
guidance_scale=0.0,
|
169 |
num_inference_steps=4,
|
170 |
+
max_length=256,
|
171 |
generator=torch.Generator("cpu").manual_seed(0)
|
172 |
).images[0]
|
173 |
pbar.update(1)
|
|
|
198 |
return None
|
199 |
return code
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
def test_model_meta_llama():
|
202 |
redis_key = "meta_llama_test_response"
|
203 |
response = load_object_from_redis(redis_key)
|
|
|
217 |
return None
|
218 |
return response
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
221 |
|
222 |
text_to_image_pipeline = get_model_or_download("stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained)
|
223 |
img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained)
|
224 |
flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
|
225 |
text_gen_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b")
|
226 |
+
music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained('melody').to(device)
|
227 |
meta_llama_pipeline = get_model_or_download("meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline)
|
228 |
+
starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder").to(device)
|
229 |
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
|
230 |
|
231 |
gen_image_tab = gr.Interface(fn=generate_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate Image")
|
|
|
241 |
["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Generate Code", "Test Meta-Llama"]
|
242 |
)
|
243 |
|
244 |
+
app.launch(share=True)
|
|
|
|
|
|
|
|
|
|