File size: 8,153 Bytes
113dc2c
0dec378
 
27e4a6a
d4fba6d
4816388
 
2fc432b
4816388
6b3d1c3
d95dbe9
32fdddd
980ffaa
 
52a0784
6b3d1c3
27e4a6a
4816388
27e4a6a
6b3d1c3
 
 
 
 
 
 
 
f0f180b
 
 
 
 
6b3d1c3
f0f180b
 
 
 
 
1a52ee5
68ef0f8
f0f180b
 
 
d2a5152
f0f180b
c79e0ac
f0f180b
 
481dde5
6b3d1c3
 
 
 
 
 
 
 
 
 
980ffaa
6b3d1c3
 
 
 
 
 
 
 
 
980ffaa
6b3d1c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980ffaa
 
 
 
6b3d1c3
980ffaa
6b3d1c3
 
 
99a5876
f0f180b
8c1a558
 
 
f0f180b
8c1a558
 
754753e
8c1a558
 
 
6b3d1c3
 
4816388
8c1a558
 
 
4816388
8c1a558
 
 
 
 
9ce8f90
8c1a558
 
 
077426a
53635c2
 
980ffaa
 
 
 
6b3d1c3
 
980ffaa
6b3d1c3
 
 
 
53635c2
 
 
 
 
 
 
 
6b3d1c3
 
980ffaa
6b3d1c3
 
 
 
 
 
 
 
53635c2
6b3d1c3
8c1a558
6b3d1c3
 
 
 
 
 
 
 
 
 
8c1a558
 
6b3d1c3
8c1a558
6b3d1c3
 
8c1a558
6b3d1c3
8c1a558
 
6b3d1c3
 
 
 
 
 
 
 
99a5876
 
6b3d1c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import os
import numpy as np
import random
from pathlib import Path
from PIL import Image
import streamlit as st
from huggingface_hub import InferenceClient, AsyncInferenceClient
from gradio_client import Client, handle_file
import asyncio
from concurrent.futures import ThreadPoolExecutor

MAX_SEED = np.iinfo(np.int32).max
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_TOKEN_UPSCALER = os.environ.get("HF_TOKEN_UPSCALER")
client = AsyncInferenceClient()
llm_client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
DATA_PATH = Path("./data")
DATA_PATH.mkdir(exist_ok=True)

def run_async(func):
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    executor = ThreadPoolExecutor(max_workers=1)
    result = loop.run_in_executor(executor, func)
    return loop.run_until_complete(result)

async def generate_image(combined_prompt, model, width, height, scales, steps, seed):
    try:
        if seed == -1:
            seed = random.randint(0, MAX_SEED)
        seed = int(seed)
        image = await client.text_to_image(
            prompt=combined_prompt, height=height, width=width, guidance_scale=scales,
            num_inference_steps=steps, model=model
        )
        return image, seed
    except Exception as e:
        return f"Error al generar imagen: {e}", None

def get_upscale_finegrain(prompt, img_path, upscale_factor):
    try:
        client = Client("finegrain/finegrain-image-enhancer", hf_token=HF_TOKEN_UPSCALER)
        result = client.predict(
            input_image=handle_file(img_path), prompt=prompt, upscale_factor=upscale_factor
        )
        return result[1] if isinstance(result, list) and len(result) > 1 else None
    except Exception as e:
        return None

def save_prompt(prompt_text, seed):
    try:
        prompt_file_path = DATA_PATH / f"prompt_{seed}.txt"
        with open(prompt_file_path, "w") as prompt_file:
            prompt_file.write(prompt_text)
        return prompt_file_path
    except Exception as e:
        st.error(f"Error al guardar el prompt: {e}")
        return None

async def gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, process_enhancer):
    combined_prompt = prompt  
    if process_enhancer:
        improved_prompt = await improve_prompt(prompt)
        combined_prompt = f"{prompt} {improved_prompt}"
    
    if seed == -1:
        seed = random.randint(0, MAX_SEED)
    seed = int(seed)
    progress_bar = st.progress(0)
    image, seed = await generate_image(combined_prompt, basemodel, width, height, scales, steps, seed)
    progress_bar.progress(50)

    if isinstance(image, str) and image.startswith("Error"):
        progress_bar.empty()
        return [image, None, combined_prompt]

    image_path = save_image(image, seed)
    prompt_file_path = save_prompt(combined_prompt, seed)

    if process_upscale:
        upscale_image_path = get_upscale_finegrain(combined_prompt, image_path, upscale_factor)
        if upscale_image_path:
            upscale_image = Image.open(upscale_image_path)
            upscale_image.save(DATA_PATH / f"upscale_image_{seed}.jpg", format="JPEG")
            progress_bar.progress(100)
            image_path.unlink()  
            return [str(DATA_PATH / f"upscale_image_{seed}.jpg"), str(prompt_file_path)]
        else:
            progress_bar.empty()
            return [str(image_path), str(prompt_file_path)]
    else:
        progress_bar.progress(100)
        return [str(image_path), str(prompt_file_path)]

async def improve_prompt(prompt):
    try:
        instruction_en = "With this idea, describe in English a detailed txt2img prompt in 500 characters at most, add illumination, atmosphere, cinematic elements, and characters..."
        instruction_es = "Con esta idea, describe en espa帽ol un prompt detallado de txt2img en un m谩ximo de 500 caracteres, a帽adiendo iluminaci贸n, atm贸sfera, elementos cinematogr谩ficos y personajes..."
        formatted_prompt = f"{prompt}: {instruction_en} {instruction_es}"
        response = llm_client.text_generation(formatted_prompt, max_new_tokens=300)
        improved_text = response['generated_text'].strip() if 'generated_text' in response else response.strip()
        return improved_text[:300] if len(improved_text) > 300 else improved_text
    except Exception as e:
        return f"Error mejorando el prompt: {e}"

def save_image(image, seed):
    try:
        image_path = DATA_PATH / f"image_{seed}.jpg"
        image.save(image_path, format="JPEG")
        return image_path
    except Exception as e:
        st.error(f"Error al guardar la imagen: {e}")
        return None

def get_storage():
    files = [file for file in DATA_PATH.glob("*.jpg") if file.is_file()]
    files.sort(key=lambda x: x.stat().st_mtime, reverse=True)
    usage = sum([file.stat().st_size for file in files])
    return [str(file.resolve()) for file in files], f"Uso total: {usage/(1024.0 ** 3):.3f}GB"

def get_prompts():
    prompt_files = [file for file in DATA_PATH.glob("*.txt") if file.is_file()]
    return {file.stem.replace("prompt_", ""): file for file in prompt_files}

def delete_image(image_path):
    try:
        if Path(image_path).exists():
            Path(image_path).unlink()
            st.success(f"Imagen {image_path} borrada.")
        else:
            st.error("El archivo de imagen no existe.")
    except Exception as e:
        st.error(f"Error al borrar la imagen: {e}")

def main():
    st.set_page_config(layout="wide")
    st.title("FLUX with prompt enhancer and upscaler")

    prompt = st.sidebar.text_input("Descripci贸n de la imagen", max_chars=200)
    process_enhancer = st.sidebar.checkbox("Mejorar Prompt", value=True)
    basemodel = st.sidebar.selectbox("Modelo Base", ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-DEV"])
    format_option = st.sidebar.selectbox("Formato", ["9:16", "16:9"])
    process_upscale = st.sidebar.checkbox("Procesar Escalador", value=True)
    upscale_factor = st.sidebar.selectbox("Factor de Escala", [2, 4, 8], index=0)
    scales = st.sidebar.slider("Escalado", 1, 20, 10)
    steps = st.sidebar.slider("Pasos", 1, 100, 20)
    seed = st.sidebar.number_input("Semilla", value=-1)

    if format_option == "9:16":
        width = 720
        height = 1280
    else:
        width = 1280
        height = 720

    if st.sidebar.button("Generar Imagen"):
        with st.spinner("Mejorando y generando imagen..."):
            result = asyncio.run(gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, process_enhancer))
            image_paths = result[0]
            prompt_file = result[1]

        st.write(f"Image paths: {image_paths}")

        if image_paths:
            if Path(image_paths).exists():
                st.image(image_paths, caption="Imagen Generada")
            else:
                st.error("El archivo de imagen no existe.")

            if prompt_file and Path(prompt_file).exists():
                prompt_text = Path(prompt_file).read_text()
                st.write(f"Prompt utilizado: {prompt_text}")
            else:
                st.write("El archivo del prompt no est谩 disponible.")

    files, usage = get_storage()
    st.text(usage)
    cols = st.columns(6)
    prompts = get_prompts()

    for idx, file in enumerate(files):
        with cols[idx % 6]:
            image = Image.open(file)
            prompt_file = prompts.get(Path(file).stem.replace("image_", ""), None)
            prompt_text = Path(prompt_file).read_text() if prompt_file else "No disponible"
            
            st.image(image, caption=f"Imagen {idx+1}")
            st.write(f"Prompt: {prompt_text}")
            
            if st.button(f"Borrar Imagen {idx+1}", key=f"delete_{idx}"):
                try:
                    os.remove(file)  
                    if prompt_file:
                        os.remove(prompt_file)
                    st.success(f"Imagen {idx+1} y su prompt fueron borrados.")
                except Exception as e:
                    st.error(f"Error al borrar la imagen o prompt: {e}")

if __name__ == "__main__":
    main()