Ffftdtd5dtft commited on
Commit
fa62195
·
verified ·
1 Parent(s): d7f7784

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -70
app.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  from PIL import Image
6
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
7
  from diffusers.utils import export_to_video
8
- from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
9
  from audiocraft.models import musicgen
10
  import gradio as gr
11
  from huggingface_hub import snapshot_download, HfApi, HfFolder
@@ -84,6 +84,7 @@ def get_model_or_download(model_id, redis_key, loader_func):
84
  save_object_to_redis(redis_key, model)
85
  model_bytes = pickle.dumps(model)
86
  upload_to_gcs(gcs_bucket_name, redis_key, model_bytes)
 
87
  except Exception as e:
88
  print(f"Failed to load or save model: {e}")
89
  return None
@@ -113,7 +114,7 @@ def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
113
  try:
114
  image = Image.open(io.BytesIO(image_bytes))
115
  with tqdm(total=1, desc="Editing image") as pbar:
116
- edited_image = img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
117
  pbar.update(1)
118
  buffered = io.BytesIO()
119
  edited_image.save(buffered, format="JPEG")
@@ -131,7 +132,7 @@ def generate_song(prompt, duration=10):
131
  if not song_bytes:
132
  try:
133
  with tqdm(total=1, desc="Generating song") as pbar:
134
- song = music_gen.generate([prompt], duration=[duration])
135
  pbar.update(1)
136
  song_bytes = song[0].getvalue()
137
  save_object_to_redis(redis_key, song_bytes)
@@ -166,7 +167,7 @@ def generate_flux_image(prompt):
166
  prompt,
167
  guidance_scale=0.0,
168
  num_inference_steps=4,
169
- max_sequence_length=256,
170
  generator=torch.Generator("cpu").manual_seed(0)
171
  ).images[0]
172
  pbar.update(1)
@@ -197,24 +198,6 @@ def generate_code(prompt):
197
  return None
198
  return code
199
 
200
- def generate_video(prompt):
201
- redis_key = f"generated_video:{prompt}"
202
- video = load_object_from_redis(redis_key)
203
- if not video:
204
- try:
205
- with tqdm(total=1, desc="Generating video") as pbar:
206
- pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16)
207
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
208
- pipe.enable_model_cpu_offload()
209
- video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
210
- pbar.update(1)
211
- save_object_to_redis(redis_key, video)
212
- upload_to_gcs(gcs_bucket_name, redis_key, video.encode())
213
- except Exception as e:
214
- print(f"Failed to generate video: {e}")
215
- return None
216
- return video
217
-
218
  def test_model_meta_llama():
219
  redis_key = "meta_llama_test_response"
220
  response = load_object_from_redis(redis_key)
@@ -234,55 +217,15 @@ def test_model_meta_llama():
234
  return None
235
  return response
236
 
237
- def train_model(model, dataset, epochs, batch_size, learning_rate):
238
- output_dir = io.BytesIO()
239
- training_args = TrainingArguments(
240
- output_dir=output_dir,
241
- num_train_epochs=epochs,
242
- per_device_train_batch_size=batch_size,
243
- learning_rate=learning_rate,
244
- )
245
- trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
246
- try:
247
- with tqdm(total=epochs, desc="Training model") as pbar:
248
- trainer.train()
249
- pbar.update(epochs)
250
- save_object_to_redis("trained_model", model)
251
- save_object_to_redis("training_results", output_dir.getvalue())
252
- upload_to_gcs(gcs_bucket_name, "trained_model", pickle.dumps(model))
253
- upload_to_gcs(gcs_bucket_name, "training_results", output_dir.getvalue())
254
- except Exception as e:
255
- print(f"Failed to train model: {e}")
256
-
257
- def run_task(task_queue):
258
- while True:
259
- task = task_queue.get()
260
- if task is None:
261
- break
262
- func, args, kwargs = task
263
- try:
264
- func(*args, **kwargs)
265
- except Exception as e:
266
- print(f"Failed to run task: {e}")
267
-
268
- task_queue = multiprocessing.Queue()
269
- num_processes = multiprocessing.cpu_count()
270
-
271
- processes = []
272
- for _ in range(num_processes):
273
- p = multiprocessing.Process(target=run_task, args=(task_queue,))
274
- p.start()
275
- processes.append(p)
276
-
277
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
278
 
279
  text_to_image_pipeline = get_model_or_download("stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained)
280
  img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained)
281
  flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
282
  text_gen_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b")
283
- music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained('melody')
284
  meta_llama_pipeline = get_model_or_download("meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline)
285
- starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder")
286
  starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
287
 
288
  gen_image_tab = gr.Interface(fn=generate_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate Image")
@@ -298,9 +241,4 @@ app = gr.TabbedInterface(
298
  ["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Generate Code", "Test Meta-Llama"]
299
  )
300
 
301
- app.launch(share=True)
302
-
303
- for _ in range(num_processes):
304
- task_queue.put(None)
305
- for p in processes:
306
- p.join()
 
5
  from PIL import Image
6
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
7
  from diffusers.utils import export_to_video
8
+ from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer
9
  from audiocraft.models import musicgen
10
  import gradio as gr
11
  from huggingface_hub import snapshot_download, HfApi, HfFolder
 
84
  save_object_to_redis(redis_key, model)
85
  model_bytes = pickle.dumps(model)
86
  upload_to_gcs(gcs_bucket_name, redis_key, model_bytes)
87
+ return model
88
  except Exception as e:
89
  print(f"Failed to load or save model: {e}")
90
  return None
 
114
  try:
115
  image = Image.open(io.BytesIO(image_bytes))
116
  with tqdm(total=1, desc="Editing image") as pbar:
117
+ edited_image = img2img_pipeline(prompt=prompt, image=image, strength=strength).images[0]
118
  pbar.update(1)
119
  buffered = io.BytesIO()
120
  edited_image.save(buffered, format="JPEG")
 
132
  if not song_bytes:
133
  try:
134
  with tqdm(total=1, desc="Generating song") as pbar:
135
+ song = music_gen(prompt, duration=duration)
136
  pbar.update(1)
137
  song_bytes = song[0].getvalue()
138
  save_object_to_redis(redis_key, song_bytes)
 
167
  prompt,
168
  guidance_scale=0.0,
169
  num_inference_steps=4,
170
+ max_length=256,
171
  generator=torch.Generator("cpu").manual_seed(0)
172
  ).images[0]
173
  pbar.update(1)
 
198
  return None
199
  return code
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  def test_model_meta_llama():
202
  redis_key = "meta_llama_test_response"
203
  response = load_object_from_redis(redis_key)
 
217
  return None
218
  return response
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
221
 
222
  text_to_image_pipeline = get_model_or_download("stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained)
223
  img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained)
224
  flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
225
  text_gen_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b")
226
+ music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained('melody').to(device)
227
  meta_llama_pipeline = get_model_or_download("meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline)
228
+ starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder").to(device)
229
  starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
230
 
231
  gen_image_tab = gr.Interface(fn=generate_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate Image")
 
241
  ["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Generate Code", "Test Meta-Llama"]
242
  )
243
 
244
+ app.launch(share=True)