Ffftdtd5dtft commited on
Commit
ae48414
·
verified ·
1 Parent(s): 7d31abf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -65
app.py CHANGED
@@ -16,7 +16,7 @@ import time
16
  # Obtener las variables de entorno
17
  hf_token = os.getenv("HF_TOKEN")
18
  redis_host = os.getenv("REDIS_HOST")
19
- redis_port = os.getenv("REDIS_PORT")
20
  redis_password = os.getenv("REDIS_PASSWORD")
21
 
22
  HfFolder.save_token(hf_token)
@@ -65,88 +65,123 @@ def get_model_or_download(model_id, redis_key, loader_func):
65
  if model:
66
  print(f"Model loaded from Redis: {redis_key}")
67
  return model
68
- model = loader_func(model_id, torch_dtype=torch.float16)
69
- save_object_to_redis(redis_key, model)
70
- print(f"Model downloaded and saved to Redis: {redis_key}")
 
 
 
71
  return model
72
 
73
  def generate_image(prompt):
74
  redis_key = f"generated_image_{prompt}"
75
  image = load_object_from_redis(redis_key)
76
  if not image:
77
- image = text_to_image_pipeline(prompt).images[0]
78
- save_object_to_redis(redis_key, image)
 
 
 
 
79
  return image
80
 
81
  def edit_image_with_prompt(image, prompt, strength=0.75):
82
  redis_key = f"edited_image_{prompt}_{strength}"
83
  edited_image = load_object_from_redis(redis_key)
84
  if not edited_image:
85
- edited_image = img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
86
- save_object_to_redis(redis_key, edited_image)
 
 
 
 
87
  return edited_image
88
 
89
  def generate_song(prompt, duration=10):
90
  redis_key = f"generated_song_{prompt}_{duration}"
91
  song = load_object_from_redis(redis_key)
92
  if not song:
93
- song = music_gen.generate(prompt, duration=duration)
94
- save_object_to_redis(redis_key, song)
 
 
 
 
95
  return song
96
 
97
  def generate_text(prompt):
98
  redis_key = f"generated_text_{prompt}"
99
  text = load_object_from_redis(redis_key)
100
  if not text:
101
- text = text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"].strip()
102
- save_object_to_redis(redis_key, text)
 
 
 
 
103
  return text
104
 
105
  def generate_flux_image(prompt):
106
  redis_key = f"generated_flux_image_{prompt}"
107
  flux_image = load_object_from_redis(redis_key)
108
  if not flux_image:
109
- flux_image = flux_pipeline(
110
- prompt,
111
- guidance_scale=0.0,
112
- num_inference_steps=4,
113
- max_sequence_length=256,
114
- generator=torch.Generator("cpu").manual_seed(0)
115
- ).images[0]
116
- save_object_to_redis(redis_key, flux_image)
 
 
 
 
117
  return flux_image
118
 
119
  def generate_code(prompt):
120
  redis_key = f"generated_code_{prompt}"
121
  code = load_object_from_redis(redis_key)
122
  if not code:
123
- inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
124
- outputs = starcoder_model.generate(inputs)
125
- code = starcoder_tokenizer.decode(outputs[0])
126
- save_object_to_redis(redis_key, code)
 
 
 
 
127
  return code
128
 
129
  def generate_video(prompt):
130
  redis_key = f"generated_video_{prompt}"
131
  video = load_object_from_redis(redis_key)
132
  if not video:
133
- pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16)
134
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
135
- pipe.enable_model_cpu_offload()
136
- video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
137
- save_object_to_redis(redis_key, video)
 
 
 
 
138
  return video
139
 
140
  def test_model_meta_llama():
141
  redis_key = "meta_llama_test_response"
142
  response = load_object_from_redis(redis_key)
143
  if not response:
144
- messages = [
145
- {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
146
- {"role": "user", "content": "Who are you?"}
147
- ]
148
- response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
149
- save_object_to_redis(redis_key, response)
 
 
 
 
150
  return response
151
 
152
  def train_model(model, dataset, epochs, batch_size, learning_rate):
@@ -158,9 +193,12 @@ def train_model(model, dataset, epochs, batch_size, learning_rate):
158
  learning_rate=learning_rate,
159
  )
160
  trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
161
- trainer.train()
162
- save_object_to_redis("trained_model", model)
163
- save_object_to_redis("training_results", output_dir.getvalue())
 
 
 
164
 
165
  def run_task(task_queue):
166
  while True:
@@ -168,7 +206,10 @@ def run_task(task_queue):
168
  if task is None:
169
  break
170
  func, args, kwargs = task
171
- func(*args, **kwargs)
 
 
 
172
 
173
  task_queue = multiprocessing.Queue()
174
  num_processes = multiprocessing.cpu_count()
@@ -179,33 +220,16 @@ for _ in range(num_processes):
179
  p.start()
180
  processes.append(p)
181
 
182
- device = "cuda" if torch.cuda.is_available() else "cpu"
183
  text_to_image_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "text_to_image_model", StableDiffusionPipeline.from_pretrained).to(device)
184
- img2img_pipeline = get_model_or_download("runwayml/stable-diffusion-inpainting", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained).to(device)
185
- flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
186
- flux_pipeline.enable_model_cpu_offload()
187
- music_gen = load_object_from_redis("music_gen") or MusicGen.get_pretrained('melody', use_auth_token=hf_token)
188
- save_object_to_redis("music_gen", music_gen)
189
- text_gen_pipeline = load_object_from_redis("text_gen_pipeline") or transformers_pipeline(
190
- "text-generation",
191
- model="google/gemini-2-2b-it",
192
- model_kwargs={"torch_dtype": torch.bfloat16},
193
- device=device,
194
- use_auth_token=hf_token,
195
- )
196
- save_object_to_redis("text_gen_pipeline", text_gen_pipeline)
197
- starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b", use_auth_token=hf_token)
198
- starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder2-15b", device_map="auto", torch_dtype=torch.bfloat16, use_auth_token=hf_token)
199
- meta_llama_pipeline = transformers_pipeline(
200
- "text-generation",
201
- model="meta-llama/Meta-Llama-3.1-8B-Instruct",
202
- model_kwargs={"torch_dtype": torch.bfloat16},
203
- device_map="auto",
204
- use_auth_token=hf_token
205
- )
206
-
207
- gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Images")
208
- edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Images")
209
  generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
210
  generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
211
  generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")
 
16
  # Obtener las variables de entorno
17
  hf_token = os.getenv("HF_TOKEN")
18
  redis_host = os.getenv("REDIS_HOST")
19
+ redis_port = int(os.getenv("REDIS_PORT", 6379)) # Valor predeterminado si no se proporciona
20
  redis_password = os.getenv("REDIS_PASSWORD")
21
 
22
  HfFolder.save_token(hf_token)
 
65
  if model:
66
  print(f"Model loaded from Redis: {redis_key}")
67
  return model
68
+ try:
69
+ model = loader_func(model_id, torch_dtype=torch.float16)
70
+ save_object_to_redis(redis_key, model)
71
+ print(f"Model downloaded and saved to Redis: {redis_key}")
72
+ except Exception as e:
73
+ print(f"Failed to load or save model: {e}")
74
  return model
75
 
76
  def generate_image(prompt):
77
  redis_key = f"generated_image_{prompt}"
78
  image = load_object_from_redis(redis_key)
79
  if not image:
80
+ try:
81
+ image = text_to_image_pipeline(prompt).images[0]
82
+ save_object_to_redis(redis_key, image)
83
+ except Exception as e:
84
+ print(f"Failed to generate image: {e}")
85
+ return None
86
  return image
87
 
88
  def edit_image_with_prompt(image, prompt, strength=0.75):
89
  redis_key = f"edited_image_{prompt}_{strength}"
90
  edited_image = load_object_from_redis(redis_key)
91
  if not edited_image:
92
+ try:
93
+ edited_image = img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
94
+ save_object_to_redis(redis_key, edited_image)
95
+ except Exception as e:
96
+ print(f"Failed to edit image: {e}")
97
+ return None
98
  return edited_image
99
 
100
  def generate_song(prompt, duration=10):
101
  redis_key = f"generated_song_{prompt}_{duration}"
102
  song = load_object_from_redis(redis_key)
103
  if not song:
104
+ try:
105
+ song = music_gen.generate(prompt, duration=duration)
106
+ save_object_to_redis(redis_key, song)
107
+ except Exception as e:
108
+ print(f"Failed to generate song: {e}")
109
+ return None
110
  return song
111
 
112
  def generate_text(prompt):
113
  redis_key = f"generated_text_{prompt}"
114
  text = load_object_from_redis(redis_key)
115
  if not text:
116
+ try:
117
+ text = text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"].strip()
118
+ save_object_to_redis(redis_key, text)
119
+ except Exception as e:
120
+ print(f"Failed to generate text: {e}")
121
+ return None
122
  return text
123
 
124
  def generate_flux_image(prompt):
125
  redis_key = f"generated_flux_image_{prompt}"
126
  flux_image = load_object_from_redis(redis_key)
127
  if not flux_image:
128
+ try:
129
+ flux_image = flux_pipeline(
130
+ prompt,
131
+ guidance_scale=0.0,
132
+ num_inference_steps=4,
133
+ max_sequence_length=256,
134
+ generator=torch.Generator("cpu").manual_seed(0)
135
+ ).images[0]
136
+ save_object_to_redis(redis_key, flux_image)
137
+ except Exception as e:
138
+ print(f"Failed to generate flux image: {e}")
139
+ return None
140
  return flux_image
141
 
142
  def generate_code(prompt):
143
  redis_key = f"generated_code_{prompt}"
144
  code = load_object_from_redis(redis_key)
145
  if not code:
146
+ try:
147
+ inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
148
+ outputs = starcoder_model.generate(inputs)
149
+ code = starcoder_tokenizer.decode(outputs[0])
150
+ save_object_to_redis(redis_key, code)
151
+ except Exception as e:
152
+ print(f"Failed to generate code: {e}")
153
+ return None
154
  return code
155
 
156
  def generate_video(prompt):
157
  redis_key = f"generated_video_{prompt}"
158
  video = load_object_from_redis(redis_key)
159
  if not video:
160
+ try:
161
+ pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16)
162
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
163
+ pipe.enable_model_cpu_offload()
164
+ video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
165
+ save_object_to_redis(redis_key, video)
166
+ except Exception as e:
167
+ print(f"Failed to generate video: {e}")
168
+ return None
169
  return video
170
 
171
  def test_model_meta_llama():
172
  redis_key = "meta_llama_test_response"
173
  response = load_object_from_redis(redis_key)
174
  if not response:
175
+ try:
176
+ messages = [
177
+ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
178
+ {"role": "user", "content": "Who are you?"}
179
+ ]
180
+ response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
181
+ save_object_to_redis(redis_key, response)
182
+ except Exception as e:
183
+ print(f"Failed to test Meta-Llama: {e}")
184
+ return None
185
  return response
186
 
187
  def train_model(model, dataset, epochs, batch_size, learning_rate):
 
193
  learning_rate=learning_rate,
194
  )
195
  trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
196
+ try:
197
+ trainer.train()
198
+ save_object_to_redis("trained_model", model)
199
+ save_object_to_redis("training_results", output_dir.getvalue())
200
+ except Exception as e:
201
+ print(f"Failed to train model: {e}")
202
 
203
  def run_task(task_queue):
204
  while True:
 
206
  if task is None:
207
  break
208
  func, args, kwargs = task
209
+ try:
210
+ func(*args, **kwargs)
211
+ except Exception as e:
212
+ print(f"Failed to run task: {e}")
213
 
214
  task_queue = multiprocessing.Queue()
215
  num_processes = multiprocessing.cpu_count()
 
220
  p.start()
221
  processes.append(p)
222
 
223
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
224
  text_to_image_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "text_to_image_model", StableDiffusionPipeline.from_pretrained).to(device)
225
+ img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained).to(device)
226
+ flux_pipeline = get_model_or_download("CompVis/stable-diffusion-flux", "flux_model", FluxPipeline.from_pretrained).to(device)
227
+ text_gen_pipeline = transformers_pipeline("text-generation", model="bigcode/starcoder", tokenizer="bigcode/starcoder", device=0)
228
+ music_gen = load_object_from_redis("music_gen") or MusicGen.from_pretrained('melody')
229
+ meta_llama_pipeline = get_model_or_download("meta/meta-llama-7b", "meta_llama_model", transformers_pipeline)
230
+
231
+ gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Image")
232
+ edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
234
  generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
235
  generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")