Ffftdtd5dtft commited on
Commit
8f22092
·
verified ·
1 Parent(s): fa62195

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +449 -28
app.py CHANGED
@@ -3,9 +3,21 @@ import redis
3
  import pickle
4
  import torch
5
  from PIL import Image
6
- from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
 
 
 
 
 
 
7
  from diffusers.utils import export_to_video
8
- from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
9
  from audiocraft.models import musicgen
10
  import gradio as gr
11
  from huggingface_hub import snapshot_download, HfApi, HfFolder
@@ -27,24 +39,37 @@ HfFolder.save_token(hf_token)
27
 
28
  storage_client = storage.Client.from_service_account_info(gcs_credentials)
29
 
 
30
  def connect_to_redis():
31
  while True:
32
  try:
33
- redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password)
 
 
34
  redis_client.ping()
35
  return redis_client
36
- except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError) as e:
 
 
 
 
37
  print(f"Connection to Redis failed: {e}. Retrying in 1 second...")
38
  time.sleep(1)
39
 
 
40
  def reconnect_if_needed(redis_client):
41
  try:
42
  redis_client.ping()
43
- except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError):
 
 
 
 
44
  print("Reconnecting to Redis...")
45
  return connect_to_redis()
46
  return redis_client
47
 
 
48
  def load_object_from_redis(key):
49
  redis_client = connect_to_redis()
50
  redis_client = reconnect_if_needed(redis_client)
@@ -55,6 +80,7 @@ def load_object_from_redis(key):
55
  print(f"Failed to load object from Redis: {e}")
56
  return None
57
 
 
58
  def save_object_to_redis(key, obj):
59
  redis_client = connect_to_redis()
60
  redis_client = reconnect_if_needed(redis_client)
@@ -63,16 +89,19 @@ def save_object_to_redis(key, obj):
63
  except redis.exceptions.RedisError as e:
64
  print(f"Failed to save object to Redis: {e}")
65
 
 
66
  def upload_to_gcs(bucket_name, blob_name, data):
67
  bucket = storage_client.bucket(bucket_name)
68
  blob = bucket.blob(blob_name)
69
  blob.upload_from_string(data)
70
 
 
71
  def download_from_gcs(bucket_name, blob_name):
72
  bucket = storage_client.bucket(bucket_name)
73
  blob = bucket.blob(blob_name)
74
  return blob.download_as_bytes()
75
 
 
76
  def get_model_or_download(model_id, redis_key, loader_func):
77
  model = load_object_from_redis(redis_key)
78
  if model:
@@ -89,6 +118,7 @@ def get_model_or_download(model_id, redis_key, loader_func):
89
  print(f"Failed to load or save model: {e}")
90
  return None
91
 
 
92
  def generate_image(prompt):
93
  redis_key = f"generated_image:{prompt}"
94
  image_bytes = load_object_from_redis(redis_key)
@@ -107,6 +137,7 @@ def generate_image(prompt):
107
  return None
108
  return image_bytes
109
 
 
110
  def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
111
  redis_key = f"edited_image:{prompt}:{strength}"
112
  edited_image_bytes = load_object_from_redis(redis_key)
@@ -114,7 +145,9 @@ def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
114
  try:
115
  image = Image.open(io.BytesIO(image_bytes))
116
  with tqdm(total=1, desc="Editing image") as pbar:
117
- edited_image = img2img_pipeline(prompt=prompt, image=image, strength=strength).images[0]
 
 
118
  pbar.update(1)
119
  buffered = io.BytesIO()
120
  edited_image.save(buffered, format="JPEG")
@@ -126,6 +159,7 @@ def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
126
  return None
127
  return edited_image_bytes
128
 
 
129
  def generate_song(prompt, duration=10):
130
  redis_key = f"generated_song:{prompt}:{duration}"
131
  song_bytes = load_object_from_redis(redis_key)
@@ -142,13 +176,16 @@ def generate_song(prompt, duration=10):
142
  return None
143
  return song_bytes
144
 
 
145
  def generate_text(prompt):
146
  redis_key = f"generated_text:{prompt}"
147
  text = load_object_from_redis(redis_key)
148
  if not text:
149
  try:
150
  with tqdm(total=1, desc="Generating text") as pbar:
151
- text = text_gen_pipeline(prompt, max_new_tokens=256)[0]["generated_text"].strip()
 
 
152
  pbar.update(1)
153
  save_object_to_redis(redis_key, text)
154
  upload_to_gcs(gcs_bucket_name, redis_key, text.encode())
@@ -157,6 +194,7 @@ def generate_text(prompt):
157
  return None
158
  return text
159
 
 
160
  def generate_flux_image(prompt):
161
  redis_key = f"generated_flux_image:{prompt}"
162
  flux_image_bytes = load_object_from_redis(redis_key)
@@ -168,7 +206,7 @@ def generate_flux_image(prompt):
168
  guidance_scale=0.0,
169
  num_inference_steps=4,
170
  max_length=256,
171
- generator=torch.Generator("cpu").manual_seed(0)
172
  ).images[0]
173
  pbar.update(1)
174
  buffered = io.BytesIO()
@@ -181,13 +219,16 @@ def generate_flux_image(prompt):
181
  return None
182
  return flux_image_bytes
183
 
 
184
  def generate_code(prompt):
185
  redis_key = f"generated_code:{prompt}"
186
  code = load_object_from_redis(redis_key)
187
  if not code:
188
  try:
189
  with tqdm(total=1, desc="Generating code") as pbar:
190
- inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to(starcoder_model.device)
 
 
191
  outputs = starcoder_model.generate(inputs, max_new_tokens=256)
192
  code = starcoder_tokenizer.decode(outputs[0])
193
  pbar.update(1)
@@ -198,17 +239,23 @@ def generate_code(prompt):
198
  return None
199
  return code
200
 
 
201
  def test_model_meta_llama():
202
  redis_key = "meta_llama_test_response"
203
  response = load_object_from_redis(redis_key)
204
  if not response:
205
  try:
206
  messages = [
207
- {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
208
- {"role": "user", "content": "Who are you?"}
 
 
 
209
  ]
210
  with tqdm(total=1, desc="Testing Meta-Llama") as pbar:
211
- response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
 
 
212
  pbar.update(1)
213
  save_object_to_redis(redis_key, response)
214
  upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
@@ -217,28 +264,402 @@ def test_model_meta_llama():
217
  return None
218
  return response
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
221
 
222
- text_to_image_pipeline = get_model_or_download("stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained)
223
- img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained)
224
- flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
225
- text_gen_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b")
226
- music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained('melody').to(device)
227
- meta_llama_pipeline = get_model_or_download("meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline)
228
- starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
230
 
231
- gen_image_tab = gr.Interface(fn=generate_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate Image")
232
- edit_image_tab = gr.Interface(fn=edit_image_with_prompt, inputs=[gr.Image(type="pil", label="Image:"), gr.Textbox(label="Prompt:"), gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], outputs=gr.Image(type="pil"), title="Edit Image")
233
- generate_song_tab = gr.Interface(fn=generate_song, inputs=[gr.Textbox(label="Prompt:"), gr.Slider(5, 60, 10, step=1, label="Duration (s):")], outputs=gr.Audio(type="numpy"), title="Generate Songs")
234
- generate_text_tab = gr.Interface(fn=generate_text, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Textbox(label="Generated Text:"), title="Generate Text")
235
- generate_flux_image_tab = gr.Interface(fn=generate_flux_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate FLUX Images")
236
- generate_code_tab = gr.Interface(fn=generate_code, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Textbox(label="Generated Code:"), title="Generate Code")
237
- model_meta_llama_test_tab = gr.Interface(fn=test_model_meta_llama, inputs=None, outputs=gr.Textbox(label="Model Output:"), title="Test Meta-Llama")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  app = gr.TabbedInterface(
240
- [gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, generate_code_tab, model_meta_llama_test_tab],
241
- ["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Generate Code", "Test Meta-Llama"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  )
243
 
244
  app.launch(share=True)
 
3
  import pickle
4
  import torch
5
  from PIL import Image
6
+ from diffusers import (
7
+ StableDiffusionPipeline,
8
+ StableDiffusionImg2ImgPipeline,
9
+ FluxPipeline,
10
+ DiffusionPipeline,
11
+ DPMSolverMultistepScheduler,
12
+ )
13
  from diffusers.utils import export_to_video
14
+ from transformers import (
15
+ pipeline as transformers_pipeline,
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ GPT2Tokenizer,
19
+ GPT2Model,
20
+ )
21
  from audiocraft.models import musicgen
22
  import gradio as gr
23
  from huggingface_hub import snapshot_download, HfApi, HfFolder
 
39
 
40
  storage_client = storage.Client.from_service_account_info(gcs_credentials)
41
 
42
+
43
  def connect_to_redis():
44
  while True:
45
  try:
46
+ redis_client = redis.Redis(
47
+ host=redis_host, port=redis_port, password=redis_password
48
+ )
49
  redis_client.ping()
50
  return redis_client
51
+ except (
52
+ redis.exceptions.ConnectionError,
53
+ redis.exceptions.TimeoutError,
54
+ BrokenPipeError,
55
+ ) as e:
56
  print(f"Connection to Redis failed: {e}. Retrying in 1 second...")
57
  time.sleep(1)
58
 
59
+
60
  def reconnect_if_needed(redis_client):
61
  try:
62
  redis_client.ping()
63
+ except (
64
+ redis.exceptions.ConnectionError,
65
+ redis.exceptions.TimeoutError,
66
+ BrokenPipeError,
67
+ ):
68
  print("Reconnecting to Redis...")
69
  return connect_to_redis()
70
  return redis_client
71
 
72
+
73
  def load_object_from_redis(key):
74
  redis_client = connect_to_redis()
75
  redis_client = reconnect_if_needed(redis_client)
 
80
  print(f"Failed to load object from Redis: {e}")
81
  return None
82
 
83
+
84
  def save_object_to_redis(key, obj):
85
  redis_client = connect_to_redis()
86
  redis_client = reconnect_if_needed(redis_client)
 
89
  except redis.exceptions.RedisError as e:
90
  print(f"Failed to save object to Redis: {e}")
91
 
92
+
93
  def upload_to_gcs(bucket_name, blob_name, data):
94
  bucket = storage_client.bucket(bucket_name)
95
  blob = bucket.blob(blob_name)
96
  blob.upload_from_string(data)
97
 
98
+
99
  def download_from_gcs(bucket_name, blob_name):
100
  bucket = storage_client.bucket(bucket_name)
101
  blob = bucket.blob(blob_name)
102
  return blob.download_as_bytes()
103
 
104
+
105
  def get_model_or_download(model_id, redis_key, loader_func):
106
  model = load_object_from_redis(redis_key)
107
  if model:
 
118
  print(f"Failed to load or save model: {e}")
119
  return None
120
 
121
+
122
  def generate_image(prompt):
123
  redis_key = f"generated_image:{prompt}"
124
  image_bytes = load_object_from_redis(redis_key)
 
137
  return None
138
  return image_bytes
139
 
140
+
141
  def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
142
  redis_key = f"edited_image:{prompt}:{strength}"
143
  edited_image_bytes = load_object_from_redis(redis_key)
 
145
  try:
146
  image = Image.open(io.BytesIO(image_bytes))
147
  with tqdm(total=1, desc="Editing image") as pbar:
148
+ edited_image = img2img_pipeline(
149
+ prompt=prompt, image=image, strength=strength
150
+ ).images[0]
151
  pbar.update(1)
152
  buffered = io.BytesIO()
153
  edited_image.save(buffered, format="JPEG")
 
159
  return None
160
  return edited_image_bytes
161
 
162
+
163
  def generate_song(prompt, duration=10):
164
  redis_key = f"generated_song:{prompt}:{duration}"
165
  song_bytes = load_object_from_redis(redis_key)
 
176
  return None
177
  return song_bytes
178
 
179
+
180
  def generate_text(prompt):
181
  redis_key = f"generated_text:{prompt}"
182
  text = load_object_from_redis(redis_key)
183
  if not text:
184
  try:
185
  with tqdm(total=1, desc="Generating text") as pbar:
186
+ text = text_gen_pipeline(prompt, max_new_tokens=256)[0][
187
+ "generated_text"
188
+ ].strip()
189
  pbar.update(1)
190
  save_object_to_redis(redis_key, text)
191
  upload_to_gcs(gcs_bucket_name, redis_key, text.encode())
 
194
  return None
195
  return text
196
 
197
+
198
  def generate_flux_image(prompt):
199
  redis_key = f"generated_flux_image:{prompt}"
200
  flux_image_bytes = load_object_from_redis(redis_key)
 
206
  guidance_scale=0.0,
207
  num_inference_steps=4,
208
  max_length=256,
209
+ generator=torch.Generator("cpu").manual_seed(0),
210
  ).images[0]
211
  pbar.update(1)
212
  buffered = io.BytesIO()
 
219
  return None
220
  return flux_image_bytes
221
 
222
+
223
  def generate_code(prompt):
224
  redis_key = f"generated_code:{prompt}"
225
  code = load_object_from_redis(redis_key)
226
  if not code:
227
  try:
228
  with tqdm(total=1, desc="Generating code") as pbar:
229
+ inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to(
230
+ starcoder_model.device
231
+ )
232
  outputs = starcoder_model.generate(inputs, max_new_tokens=256)
233
  code = starcoder_tokenizer.decode(outputs[0])
234
  pbar.update(1)
 
239
  return None
240
  return code
241
 
242
+
243
  def test_model_meta_llama():
244
  redis_key = "meta_llama_test_response"
245
  response = load_object_from_redis(redis_key)
246
  if not response:
247
  try:
248
  messages = [
249
+ {
250
+ "role": "system",
251
+ "content": "You are a pirate chatbot who always responds in pirate speak!",
252
+ },
253
+ {"role": "user", "content": "Who are you?"},
254
  ]
255
  with tqdm(total=1, desc="Testing Meta-Llama") as pbar:
256
+ response = meta_llama_pipeline(messages, max_new_tokens=256)[0][
257
+ "generated_text"
258
+ ].strip()
259
  pbar.update(1)
260
  save_object_to_redis(redis_key, response)
261
  upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
 
264
  return None
265
  return response
266
 
267
+
268
+ def generate_image_sdxl(prompt):
269
+ redis_key = f"generated_image_sdxl:{prompt}"
270
+ image_bytes = load_object_from_redis(redis_key)
271
+ if not image_bytes:
272
+ try:
273
+ with tqdm(total=1, desc="Generating SDXL image") as pbar:
274
+ image = base(
275
+ prompt=prompt,
276
+ num_inference_steps=40,
277
+ denoising_end=0.8,
278
+ output_type="latent",
279
+ ).images
280
+ image = refiner(
281
+ prompt=prompt,
282
+ num_inference_steps=40,
283
+ denoising_start=0.8,
284
+ image=image,
285
+ ).images[0]
286
+ pbar.update(1)
287
+ buffered = io.BytesIO()
288
+ image.save(buffered, format="JPEG")
289
+ image_bytes = buffered.getvalue()
290
+ save_object_to_redis(redis_key, image_bytes)
291
+ upload_to_gcs(gcs_bucket_name, redis_key, image_bytes)
292
+ except Exception as e:
293
+ print(f"Failed to generate SDXL image: {e}")
294
+ return None
295
+ return image_bytes
296
+
297
+
298
+ def generate_musicgen_melody(prompt):
299
+ redis_key = f"generated_musicgen_melody:{prompt}"
300
+ song_bytes = load_object_from_redis(redis_key)
301
+ if not song_bytes:
302
+ try:
303
+ with tqdm(total=1, desc="Generating MusicGen melody") as pbar:
304
+ melody, sr = torchaudio.load("./assets/bach.mp3")
305
+ wav = music_gen_melody.generate_with_chroma(
306
+ [prompt], melody[None].expand(3, -1, -1), sr
307
+ )
308
+ pbar.update(1)
309
+ song_bytes = wav[0].getvalue()
310
+ save_object_to_redis(redis_key, song_bytes)
311
+ upload_to_gcs(gcs_bucket_name, redis_key, song_bytes)
312
+ except Exception as e:
313
+ print(f"Failed to generate MusicGen melody: {e}")
314
+ return None
315
+ return song_bytes
316
+
317
+
318
+ def generate_musicgen_large(prompt):
319
+ redis_key = f"generated_musicgen_large:{prompt}"
320
+ song_bytes = load_object_from_redis(redis_key)
321
+ if not song_bytes:
322
+ try:
323
+ with tqdm(total=1, desc="Generating MusicGen large") as pbar:
324
+ wav = music_gen_large.generate([prompt])
325
+ pbar.update(1)
326
+ song_bytes = wav[0].getvalue()
327
+ save_object_to_redis(redis_key, song_bytes)
328
+ upload_to_gcs(gcs_bucket_name, redis_key, song_bytes)
329
+ except Exception as e:
330
+ print(f"Failed to generate MusicGen large: {e}")
331
+ return None
332
+ return song_bytes
333
+
334
+
335
+ def transcribe_audio(audio_sample):
336
+ redis_key = f"transcribed_audio:{hash(audio_sample.tobytes())}"
337
+ text = load_object_from_redis(redis_key)
338
+ if not text:
339
+ try:
340
+ with tqdm(total=1, desc="Transcribing audio") as pbar:
341
+ text = whisper_pipeline(audio_sample.copy(), batch_size=8)["text"]
342
+ pbar.update(1)
343
+ save_object_to_redis(redis_key, text)
344
+ upload_to_gcs(gcs_bucket_name, redis_key, text.encode())
345
+ except Exception as e:
346
+ print(f"Failed to transcribe audio: {e}")
347
+ return None
348
+ return text
349
+
350
+
351
+ def generate_mistral_instruct(prompt):
352
+ redis_key = f"generated_mistral_instruct:{prompt}"
353
+ response = load_object_from_redis(redis_key)
354
+ if not response:
355
+ try:
356
+ conversation = [{"role": "user", "content": prompt}]
357
+ with tqdm(total=1, desc="Generating Mistral Instruct response") as pbar:
358
+ inputs = mistral_instruct_tokenizer.apply_chat_template(
359
+ conversation,
360
+ tools=tools,
361
+ add_generation_prompt=True,
362
+ return_dict=True,
363
+ return_tensors="pt",
364
+ )
365
+ inputs.to(mistral_instruct_model.device)
366
+ outputs = mistral_instruct_model.generate(
367
+ **inputs, max_new_tokens=1000
368
+ )
369
+ response = mistral_instruct_tokenizer.decode(
370
+ outputs[0], skip_special_tokens=True
371
+ )
372
+ pbar.update(1)
373
+ save_object_to_redis(redis_key, response)
374
+ upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
375
+ except Exception as e:
376
+ print(f"Failed to generate Mistral Instruct response: {e}")
377
+ return None
378
+ return response
379
+
380
+
381
+ def generate_mistral_nemo(prompt):
382
+ redis_key = f"generated_mistral_nemo:{prompt}"
383
+ response = load_object_from_redis(redis_key)
384
+ if not response:
385
+ try:
386
+ conversation = [{"role": "user", "content": prompt}]
387
+ with tqdm(total=1, desc="Generating Mistral Nemo response") as pbar:
388
+ inputs = mistral_nemo_tokenizer.apply_chat_template(
389
+ conversation,
390
+ tools=tools,
391
+ add_generation_prompt=True,
392
+ return_dict=True,
393
+ return_tensors="pt",
394
+ )
395
+ inputs.to(mistral_nemo_model.device)
396
+ outputs = mistral_nemo_model.generate(**inputs, max_new_tokens=1000)
397
+ response = mistral_nemo_tokenizer.decode(
398
+ outputs[0], skip_special_tokens=True
399
+ )
400
+ pbar.update(1)
401
+ save_object_to_redis(redis_key, response)
402
+ upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
403
+ except Exception as e:
404
+ print(f"Failed to generate Mistral Nemo response: {e}")
405
+ return None
406
+ return response
407
+
408
+
409
+ def generate_gpt2_xl(prompt):
410
+ redis_key = f"generated_gpt2_xl:{prompt}"
411
+ response = load_object_from_redis(redis_key)
412
+ if not response:
413
+ try:
414
+ with tqdm(total=1, desc="Generating GPT-2 XL response") as pbar:
415
+ inputs = gpt2_xl_tokenizer(prompt, return_tensors="pt")
416
+ outputs = gpt2_xl_model(**inputs)
417
+ response = gpt2_xl_tokenizer.decode(
418
+ outputs[0][0], skip_special_tokens=True
419
+ )
420
+ pbar.update(1)
421
+ save_object_to_redis(redis_key, response)
422
+ upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
423
+ except Exception as e:
424
+ print(f"Failed to generate GPT-2 XL response: {e}")
425
+ return None
426
+ return response
427
+
428
+
429
+ def answer_question_minicpm(image_bytes, question):
430
+ redis_key = f"minicpm_answer:{hash(image_bytes)}:{question}"
431
+ answer = load_object_from_redis(redis_key)
432
+ if not answer:
433
+ try:
434
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
435
+ with tqdm(total=1, desc="Answering question with MiniCPM") as pbar:
436
+ msgs = [{"role": "user", "content": [image, question]}]
437
+ answer = minicpm_model.chat(
438
+ image=None, msgs=msgs, tokenizer=minicpm_tokenizer
439
+ )
440
+ pbar.update(1)
441
+ save_object_to_redis(redis_key, answer)
442
+ upload_to_gcs(gcs_bucket_name, redis_key, answer.encode())
443
+ except Exception as e:
444
+ print(f"Failed to answer question with MiniCPM: {e}")
445
+ return None
446
+ return answer
447
+
448
+
449
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
450
 
451
+ text_to_image_pipeline = get_model_or_download(
452
+ "stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained
453
+ )
454
+ img2img_pipeline = get_model_or_download(
455
+ "CompVis/stable-diffusion-v1-4",
456
+ "img2img_model",
457
+ StableDiffusionImg2ImgPipeline.from_pretrained,
458
+ )
459
+ flux_pipeline = get_model_or_download(
460
+ "black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained
461
+ )
462
+ text_gen_pipeline = transformers_pipeline(
463
+ "text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b"
464
+ )
465
+ music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained(
466
+ "melody"
467
+ ).to(device)
468
+ meta_llama_pipeline = get_model_or_download(
469
+ "meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline
470
+ )
471
+ starcoder_model = AutoModelForCausalLM.from_pretrained(
472
+ "bigcode/starcoder"
473
+ ).to(device)
474
  starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
475
 
476
+ base = DiffusionPipeline.from_pretrained(
477
+ "stabilityai/stable-diffusion-xl-base-1.0",
478
+ torch_dtype=torch.float16,
479
+ variant="fp16",
480
+ use_safetensors=True,
481
+ ).to(device)
482
+ refiner = DiffusionPipeline.from_pretrained(
483
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
484
+ text_encoder_2=base.text_encoder_2,
485
+ vae=base.vae,
486
+ torch_dtype=torch.float16,
487
+ use_safetensors=True,
488
+ variant="fp16",
489
+ ).to(device)
490
+ music_gen_melody = musicgen.MusicGen.get_pretrained("melody").to(device)
491
+ music_gen_melody.set_generation_params(duration=8)
492
+ music_gen_large = musicgen.MusicGen.get_pretrained("large").to(device)
493
+ music_gen_large.set_generation_params(duration=8)
494
+ whisper_pipeline = transformers_pipeline(
495
+ "automatic-speech-recognition",
496
+ model="openai/whisper-small",
497
+ chunk_length_s=30,
498
+ device=device,
499
+ )
500
+ mistral_instruct_model = AutoModelForCausalLM.from_pretrained(
501
+ "mistralai/Mistral-Large-Instruct-2407",
502
+ torch_dtype=torch.bfloat16,
503
+ device_map="auto",
504
+ )
505
+ mistral_instruct_tokenizer = AutoTokenizer.from_pretrained(
506
+ "mistralai/Mistral-Large-Instruct-2407"
507
+ )
508
+ mistral_nemo_model = AutoModelForCausalLM.from_pretrained(
509
+ "mistralai/Mistral-Nemo-Instruct-2407",
510
+ torch_dtype=torch.bfloat16,
511
+ device_map="auto",
512
+ )
513
+ mistral_nemo_tokenizer = AutoTokenizer.from_pretrained(
514
+ "mistralai/Mistral-Nemo-Instruct-2407"
515
+ )
516
+ gpt2_xl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl")
517
+ gpt2_xl_model = GPT2Model.from_pretrained("gpt2-xl")
518
+ minicpm_model = AutoModel.from_pretrained(
519
+ "openbmb/MiniCPM-V-2_6",
520
+ trust_remote_code=True,
521
+ attn_implementation="sdpa",
522
+ torch_dtype=torch.bfloat16,
523
+ ).eval().cuda()
524
+ minicpm_tokenizer = AutoTokenizer.from_pretrained(
525
+ "openbmb/MiniCPM-V-2_6", trust_remote_code=True
526
+ )
527
+
528
+ tools = [] # Define any tools needed for Mistral models
529
+
530
+ gen_image_tab = gr.Interface(
531
+ fn=generate_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate Image"
532
+ )
533
+ edit_image_tab = gr.Interface(
534
+ fn=edit_image_with_prompt,
535
+ inputs=[
536
+ gr.Image(type="pil", label="Image:"),
537
+ gr.Textbox(label="Prompt:"),
538
+ gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:"),
539
+ ],
540
+ outputs=gr.Image(type="pil"),
541
+ title="Edit Image",
542
+ )
543
+ generate_song_tab = gr.Interface(
544
+ fn=generate_song,
545
+ inputs=[
546
+ gr.Textbox(label="Prompt:"),
547
+ gr.Slider(5, 60, 10, step=1, label="Duration (s):"),
548
+ ],
549
+ outputs=gr.Audio(type="numpy"),
550
+ title="Generate Songs",
551
+ )
552
+ generate_text_tab = gr.Interface(
553
+ fn=generate_text,
554
+ inputs=gr.Textbox(label="Prompt:"),
555
+ outputs=gr.Textbox(label="Generated Text:"),
556
+ title="Generate Text",
557
+ )
558
+ generate_flux_image_tab = gr.Interface(
559
+ fn=generate_flux_image,
560
+ inputs=gr.Textbox(label="Prompt:"),
561
+ outputs=gr.Image(type="pil"),
562
+ title="Generate FLUX Images",
563
+ )
564
+ generate_code_tab = gr.Interface(
565
+ fn=generate_code,
566
+ inputs=gr.Textbox(label="Prompt:"),
567
+ outputs=gr.Textbox(label="Generated Code:"),
568
+ title="Generate Code",
569
+ )
570
+ model_meta_llama_test_tab = gr.Interface(
571
+ fn=test_model_meta_llama,
572
+ inputs=None,
573
+ outputs=gr.Textbox(label="Model Output:"),
574
+ title="Test Meta-Llama",
575
+ )
576
+ generate_image_sdxl_tab = gr.Interface(
577
+ fn=generate_image_sdxl,
578
+ inputs=gr.Textbox(label="Prompt:"),
579
+ outputs=gr.Image(type="pil"),
580
+ title="Generate SDXL Image",
581
+ )
582
+ generate_musicgen_melody_tab = gr.Interface(
583
+ fn=generate_musicgen_melody,
584
+ inputs=gr.Textbox(label="Prompt:"),
585
+ outputs=gr.Audio(type="numpy"),
586
+ title="Generate MusicGen Melody",
587
+ )
588
+ generate_musicgen_large_tab = gr.Interface(
589
+ fn=generate_musicgen_large,
590
+ inputs=gr.Textbox(label="Prompt:"),
591
+ outputs=gr.Audio(type="numpy"),
592
+ title="Generate MusicGen Large",
593
+ )
594
+ transcribe_audio_tab = gr.Interface(
595
+ fn=transcribe_audio,
596
+ inputs=gr.Audio(type="numpy", label="Audio Sample:"),
597
+ outputs=gr.Textbox(label="Transcribed Text:"),
598
+ title="Transcribe Audio",
599
+ )
600
+ generate_mistral_instruct_tab = gr.Interface(
601
+ fn=generate_mistral_instruct,
602
+ inputs=gr.Textbox(label="Prompt:"),
603
+ outputs=gr.Textbox(label="Mistral Instruct Response:"),
604
+ title="Generate Mistral Instruct Response",
605
+ )
606
+ generate_mistral_nemo_tab = gr.Interface(
607
+ fn=generate_mistral_nemo,
608
+ inputs=gr.Textbox(label="Prompt:"),
609
+ outputs=gr.Textbox(label="Mistral Nemo Response:"),
610
+ title="Generate Mistral Nemo Response",
611
+ )
612
+ generate_gpt2_xl_tab = gr.Interface(
613
+ fn=generate_gpt2_xl,
614
+ inputs=gr.Textbox(label="Prompt:"),
615
+ outputs=gr.Textbox(label="GPT-2 XL Response:"),
616
+ title="Generate GPT-2 XL Response",
617
+ )
618
+ answer_question_minicpm_tab = gr.Interface(
619
+ fn=answer_question_minicpm,
620
+ inputs=[
621
+ gr.Image(type="pil", label="Image:"),
622
+ gr.Textbox(label="Question:"),
623
+ ],
624
+ outputs=gr.Textbox(label="MiniCPM Answer:"),
625
+ title="Answer Question with MiniCPM",
626
+ )
627
 
628
  app = gr.TabbedInterface(
629
+ [
630
+ gen_image_tab,
631
+ edit_image_tab,
632
+ generate_song_tab,
633
+ generate_text_tab,
634
+ generate_flux_image_tab,
635
+ generate_code_tab,
636
+ model_meta_llama_test_tab,
637
+ generate_image_sdxl_tab,
638
+ generate_musicgen_melody_tab,
639
+ generate_musicgen_large_tab,
640
+ transcribe_audio_tab,
641
+ generate_mistral_instruct_tab,
642
+ generate_mistral_nemo_tab,
643
+ generate_gpt2_xl_tab,
644
+ answer_question_minicpm_tab,
645
+ ],
646
+ [
647
+ "Generate Image",
648
+ "Edit Image",
649
+ "Generate Song",
650
+ "Generate Text",
651
+ "Generate FLUX Image",
652
+ "Generate Code",
653
+ "Test Meta-Llama",
654
+ "Generate SDXL Image",
655
+ "Generate MusicGen Melody",
656
+ "Generate MusicGen Large",
657
+ "Transcribe Audio",
658
+ "Generate Mistral Instruct Response",
659
+ "Generate Mistral Nemo Response",
660
+ "Generate GPT-2 XL Response",
661
+ "Answer Question with MiniCPM",
662
+ ],
663
  )
664
 
665
  app.launch(share=True)