flux3 / app.py
salomonsky's picture
Update app.py
c05358b verified
raw
history blame
7.19 kB
import os
import numpy as np
import random
from pathlib import Path
from PIL import Image
import streamlit as st
from huggingface_hub import InferenceClient, AsyncInferenceClient
from gradio_client import Client, handle_file
import asyncio
MAX_SEED = np.iinfo(np.int32).max
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_TOKEN_UPSCALER = os.environ.get("HF_TOKEN_UPSCALER")
client = AsyncInferenceClient()
llm_client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
DATA_PATH = Path("./data")
DATA_PATH.mkdir(exist_ok=True)
def enable_lora(lora_add, basemodel):
return lora_add if lora_add else basemodel
async def generate_image(combined_prompt, model, width, height, scales, steps, seed):
try:
if seed == -1:
seed = random.randint(0, MAX_SEED)
seed = int(seed)
image = await client.text_to_image(
prompt=combined_prompt, height=height, width=width, guidance_scale=scales,
num_inference_steps=steps, model=model
)
return image, seed
except Exception as e:
return f"Error al generar imagen: {e}", None
def get_upscale_finegrain(prompt, img_path, upscale_factor):
try:
client = Client("finegrain/finegrain-image-enhancer", hf_token=HF_TOKEN_UPSCALER)
result = client.predict(
input_image=handle_file(img_path), prompt=prompt, negative_prompt="",
seed=42, upscale_factor=upscale_factor, controlnet_scale=0.6,
controlnet_decay=1, condition_scale=6, tile_width=112,
tile_height=144, denoise_strength=0.35, num_inference_steps=18,
solver="DDIM", api_name="/process"
)
return result[1] if isinstance(result, list) and len(result) > 1 else None
except Exception as e:
return None
async def gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, lora_model, process_lora):
model = enable_lora(lora_model, basemodel) if process_lora else basemodel
improved_prompt = await improve_prompt(prompt)
combined_prompt = f"{prompt} {improved_prompt}"
if seed == -1:
seed = random.randint(0, MAX_SEED)
seed = int(seed)
progress_bar = st.progress(0)
image, seed = await generate_image(combined_prompt, model, width, height, scales, steps, seed)
progress_bar.progress(50)
if isinstance(image, str) and image.startswith("Error"):
progress_bar.empty()
return [image, None, combined_prompt]
image_path = DATA_PATH / f"image_{seed}.jpg"
image.save(image_path, format="JPEG")
prompt_file_path = DATA_PATH / f"prompt_{seed}.txt"
with open(prompt_file_path, "w") as prompt_file:
prompt_file.write(combined_prompt)
if process_upscale:
upscale_image_path = get_upscale_finegrain(combined_prompt, image_path, upscale_factor)
if upscale_image_path:
upscale_image = Image.open(upscale_image_path)
upscale_image.save(DATA_PATH / f"upscale_image_{seed}.jpg", format="JPEG")
progress_bar.progress(100)
image_path.unlink()
return [str(DATA_PATH / f"upscale_image_{seed}.jpg"), str(prompt_file_path)]
else:
progress_bar.empty()
return [str(image_path), str(prompt_file_path)]
else:
progress_bar.progress(100)
return [str(image_path), str(prompt_file_path)]
async def improve_prompt(prompt):
try:
instruction = ("With this idea, describe in English a detailed txt2img prompt in a single paragraph of up to 200 characters maximum, developing atmosphere, characters, lighting, and cameras.")
formatted_prompt = f"{prompt}: {instruction}"
response = llm_client.text_generation(formatted_prompt, max_new_tokens=200)
improved_text = response['generated_text'].strip() if 'generated_text' in response else response.strip()
return improved_text
except Exception as e:
return f"Error mejorando el prompt: {e}"
def get_storage():
files = [{"name": str(file.resolve()), "size": file.stat().st_size,}
for file in DATA_PATH.glob("*.jpg")
if file.is_file()]
usage = sum([f['size'] for f in files])
return [file["name"] for file in files], f"Uso total: {usage/(1024.0 ** 3):.3f}GB"
def get_prompts():
prompt_files = [file for file in DATA_PATH.glob("*.txt") if file.is_file()]
return {file.stem.replace("prompt_", ""): file for file in prompt_files}
def run_gen():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
prompt_to_use = st.session_state.get('improved_prompt', prompt)
return loop.run_until_complete(gen(prompt_to_use, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, lora_model, process_lora))
st.set_page_config(layout="wide")
st.title("Generador de Imágenes FLUX y Escalador con IA")
prompt = st.sidebar.text_input("Descripción de la imagen")
basemodel = st.sidebar.selectbox("Modelo Base", ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-DEV"])
lora_model = st.sidebar.selectbox("LORA Realismo", ["Shakker-Labs/FLUX.1-dev-LoRA-add-details", "XLabs-AI/flux-RealismLora"])
format_option = st.sidebar.selectbox("Formato", ["9:16", "16:9"])
process_lora = st.sidebar.checkbox("Procesar LORA")
process_upscale = st.sidebar.checkbox("Procesar Escalador")
if format_option == "9:16":
width = st.sidebar.slider("Ancho", 512, 720, 720, step=8)
height = st.sidebar.slider("Alto", 912, 1280, 1280, step=8)
else:
width = st.sidebar.slider("Ancho", 512, 1280, 1280, step=8)
height = st.sidebar.slider("Alto", 512, 720, 720, step=8)
upscale_factor = st.sidebar.selectbox("Factor de Escala", [2, 4, 8], index=0)
scales = st.sidebar.slider("Escalado", 1, 20, 10)
steps = st.sidebar.slider("Pasos", 1, 100, 20)
seed = st.sidebar.number_input("Semilla", value=-1)
if st.sidebar.button("Mejorar prompt"):
improved_prompt = asyncio.run(improve_prompt(prompt))
st.session_state.improved_prompt = improved_prompt
st.write(f"{improved_prompt}")
if st.sidebar.button("Generar Imagen"):
with st.spinner("Generando imagen..."):
result = run_gen()
image_paths = result[0]
prompt_file = result[1]
st.write(f"Image paths: {image_paths}")
if image_paths:
if Path(image_paths).exists():
st.image(image_paths, caption="Imagen Generada")
else:
st.error("El archivo de imagen no existe.")
if prompt_file and Path(prompt_file).exists():
prompt_text = Path(prompt_file).read_text()
st.write(f"Prompt utilizado: {prompt_text}")
else:
st.write("El archivo del prompt no está disponible.")
files, usage = get_storage()
st.text(usage)
cols = st.columns(6)
prompts = get_prompts()
for idx, file in enumerate(files):
with cols[idx % 6]:
image = Image.open(file)
prompt_file = prompts.get(Path(file).stem.replace("image_", ""), None)
prompt_text = Path(prompt_file).read_text() if prompt_file else "No disponible"
st.image(image, caption=f"Imagen {idx+1}")
st.write(f"Prompt: {prompt_text}")