Spaces:
Running
Running
import os | |
import numpy as np | |
import random | |
from pathlib import Path | |
from PIL import Image | |
import streamlit as st | |
from huggingface_hub import InferenceClient, AsyncInferenceClient | |
from gradio_client import Client, handle_file | |
import asyncio | |
MAX_SEED = np.iinfo(np.int32).max | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
HF_TOKEN_UPSCALER = os.environ.get("HF_TOKEN_UPSCALER") | |
client = AsyncInferenceClient() | |
llm_client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
DATA_PATH = Path("./data") | |
DATA_PATH.mkdir(exist_ok=True) | |
def enable_lora(lora_add, basemodel): | |
return lora_add if lora_add else basemodel | |
async def generate_image(combined_prompt, model, width, height, scales, steps, seed): | |
try: | |
if seed == -1: | |
seed = random.randint(0, MAX_SEED) | |
seed = int(seed) | |
image = await client.text_to_image( | |
prompt=combined_prompt, height=height, width=width, guidance_scale=scales, | |
num_inference_steps=steps, model=model | |
) | |
return image, seed | |
except Exception as e: | |
return f"Error al generar imagen: {e}", None | |
def get_upscale_finegrain(prompt, img_path, upscale_factor): | |
try: | |
client = Client("finegrain/finegrain-image-enhancer", hf_token=HF_TOKEN_UPSCALER) | |
result = client.predict( | |
input_image=handle_file(img_path), prompt=prompt, negative_prompt="", | |
seed=42, upscale_factor=upscale_factor, controlnet_scale=0.6, | |
controlnet_decay=1, condition_scale=6, tile_width=112, | |
tile_height=144, denoise_strength=0.35, num_inference_steps=18, | |
solver="DDIM", api_name="/process" | |
) | |
return result[1] if isinstance(result, list) and len(result) > 1 else None | |
except Exception as e: | |
st.error(f"Error en el escalado: {e}") | |
return None | |
async def gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, lora_model, process_lora): | |
model = enable_lora(lora_model, basemodel) if process_lora else basemodel | |
improved_prompt = await improve_prompt(prompt) | |
combined_prompt = f"{prompt} {improved_prompt}" | |
if seed == -1: | |
seed = random.randint(0, MAX_SEED) | |
seed = int(seed) | |
progress_bar = st.progress(0) | |
image, seed = await generate_image(combined_prompt, model, width, height, scales, steps, seed) | |
progress_bar.progress(50) | |
if isinstance(image, str) and image.startswith("Error"): | |
progress_bar.empty() | |
return [image, None, combined_prompt] | |
image_path = DATA_PATH / f"image_{seed}.jpg" | |
image.save(image_path, format="JPEG") | |
prompt_file_path = DATA_PATH / f"prompt_{seed}.txt" | |
with open(prompt_file_path, "w") as prompt_file: | |
prompt_file.write(combined_prompt) | |
if process_upscale: | |
upscale_image_path = get_upscale_finegrain(combined_prompt, image_path, upscale_factor) | |
if upscale_image_path: | |
upscale_image = Image.open(upscale_image_path) | |
upscale_image.save(DATA_PATH / f"upscale_image_{seed}.jpg", format="JPEG") | |
progress_bar.progress(100) | |
image_path.unlink() | |
return [str(DATA_PATH / f"upscale_image_{seed}.jpg"), str(prompt_file_path)] | |
else: | |
progress_bar.empty() | |
return [str(image_path), str(prompt_file_path)] | |
else: | |
progress_bar.progress(100) | |
return [str(image_path), str(prompt_file_path)] | |
async def improve_prompt(prompt): | |
try: | |
instruction = ("With this idea, describe in English a detailed txt2img prompt in a single paragraph of up to 200 characters maximum, developing atmosphere, characters, lighting, and cameras.") | |
formatted_prompt = f"{prompt}: {instruction}" | |
response = llm_client.text_generation(formatted_prompt, max_new_tokens=200) | |
improved_text = response['generated_text'].strip() if 'generated_text' in response else response.strip() | |
return improved_text | |
except Exception as e: | |
st.error(f"Error mejorando el prompt: {e}") | |
return "" | |
def get_storage(): | |
files = [{"name": str(file.resolve()), "size": file.stat().st_size,} | |
for file in DATA_PATH.glob("*.jpg") | |
if file.is_file()] | |
usage = sum([f['size'] for f in files]) | |
return [file["name"] for file in files], f"Uso total: {usage/(1024.0 ** 3):.3f}GB" | |
def get_prompts(): | |
prompt_files = [file for file in DATA_PATH.glob("*.txt") if file.is_file()] | |
return {file.stem.replace("prompt_", ""): file for file in prompt_files} | |
def run_gen(): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
prompt_to_use = st.session_state.get('improved_prompt', prompt) | |
result = loop.run_until_complete(gen(prompt_to_use, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, lora_model, process_lora)) | |
return result | |
st.set_page_config(layout="wide") | |
prompt = st.sidebar.text_input("Descripción de la imagen") | |
basemodel = st.sidebar.selectbox("Modelo Base", ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-DEV"]) | |
lora_model = st.sidebar.selectbox("LORA Realismo", ["Shakker-Labs/FLUX.1-dev-LoRA-add-details", "XLabs-AI/flux-RealismLora"]) | |
format_option = st.sidebar.selectbox("Formato", ["9:16", "16:9"]) | |
process_lora = st.sidebar.checkbox("Procesar LORA") | |
process_upscale = st.sidebar.checkbox("Procesar Escalador") | |
if format_option == "9:16": | |
width = st.sidebar.slider("Ancho", 512, 720, 720, step=8) | |
height = st.sidebar.slider("Alto", 912, 1280, 1280, step=8) | |
else: | |
width = st.sidebar.slider("Ancho", 512, 1280, 1280, step=8) | |
height = st.sidebar.slider("Alto", 512, 720, 720, step=8) | |
upscale_factor = st.sidebar.selectbox("Factor de Escala", [2, 4, 8], index=0) | |
scales = st.sidebar.slider("Escalado", 1, 20, 10) | |
steps = st.sidebar.slider("Pasos", 1, 100, 20) | |
seed = st.sidebar.number_input("Semilla", value=-1) | |
if st.sidebar.button("Mejorar prompt"): | |
improved_prompt = asyncio.run(improve_prompt(prompt)) | |
st.session_state.improved_prompt = improved_prompt | |
st.write(f"{improved_prompt}") | |
if st.sidebar.button("Generar Imagen"): | |
with st.spinner("Generando imagen..."): | |
result = run_gen() | |
image_paths = result[0] | |
prompt_file = result[1] | |
st.write(f"Image paths: {image_paths}") | |
if image_paths: | |
if Path(image_paths).exists(): | |
st.image(image_paths, caption="Imagen Generada") | |
else: | |
st.error("El archivo de imagen no existe.") | |
if prompt_file and Path(prompt_file).exists(): | |
prompt_text = Path(prompt_file).read_text() | |
st.write(f"Prompt utilizado: {prompt_text}") | |
else: | |
st.write("El archivo del prompt no está disponible.") | |
files, usage = get_storage() | |
st.text(usage) | |
cols = st.columns(6) | |
prompts = get_prompts() | |
for idx, file in enumerate(files): | |
with cols[idx % 6]: | |
image = Image.open(file) | |
prompt_file = prompts.get(Path(file).stem.replace("image_", ""), None) | |
prompt_text = Path(prompt_file).read_text() if prompt_file else "No disponible" | |
st.image(image, caption=f"Imagen {idx+1}") | |
st.write(f"Prompt: {prompt_text}") |