aifeifei798's picture
Update app.py
2c166e9 verified
raw
history blame
8.58 kB
import base64
import os
from mistralai import Mistral
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
from huggingface_hub import hf_hub_download
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
from openai import OpenAI
import config
api_key = os.getenv("MISTRAL_API_KEY")
client = Mistral(api_key=api_key)
client_more_ai = OpenAI(
base_url="https://api-inference.huggingface.co/v1/",
api_key=os.getenv('HF_TOKEN')
)
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype, vae=taef1).to(device)
pipe.load_lora_weights(hf_hub_download("aifeifei798/feifei-flux-lora-v1", "feifei.safetensors"), adapter_name = "feifei")
pipe.load_lora_weights(hf_hub_download("aifeifei798/feifei-flux-lora-v1", "FLUX-dev-lora-add_details.safetensors"), adapter_name = "FLUX-dev-lora-add_details")
pipe.load_lora_weights(hf_hub_download("aifeifei798/feifei-flux-lora-v1", "Shadow-Projection.safetensors"), adapter_name = "Shadow-Projection")
pipe.set_adapters(["feifei","FLUX-dev-lora-add_details","Shadow-Projection"], adapter_weights=[0.65,0.35,0.35])
pipe.fuse_lora(adapter_name=["feifei","FLUX-dev-lora-add_details","Shadow-Projection"], lora_scale=1.0)
pipe.unload_lora_weights()
torch.cuda.empty_cache()
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
css="""
#col-container {
width: auto;
height: 750px;
}
"""
@spaces.GPU()
def infer(prompt, quality_select, styles_Radio, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True), guidance_scale=3.5):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
if quality_select:
prompt += ", masterpiece, best quality, very aesthetic, absurdres"
if styles_Radio:
for style_name in styles_Radio:
for style in config.style_list:
if style["name"] == style_name:
prompt += style["prompt"].replace("{prompt}", "")
image = pipe(
prompt = prompt,
width = width,
height = height,
num_inference_steps = num_inference_steps,
generator = generator,
guidance_scale=guidance_scale,
output_type="pil",
).images[0]
return image, seed
def encode_image(image_path):
"""Encode the image to base64."""
try:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
except FileNotFoundError:
print(f"Error: The file {image_path} was not found.")
return None
except Exception as e: # Added general exception handling
print(f"Error: {e}")
return None
def predict(message, history, additional_dropdown):
message_text = message.get("text", "")
message_files = message.get("files", [])
if message_files:
# Getting the base64 string
message_file = message_files[0]
base64_image = encode_image(message_file)
if base64_image is None:
yield "Error: Failed to encode the image."
return
# Specify model
model = "pixtral-12b-2409"
# Define the messages for the chat
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": message_text},
{
"type": "image_url",
"image_url": f"data:image/jpeg;base64,{base64_image}",
},
],
}
]
partial_message = ""
for chunk in client.chat.stream(model=model, messages=messages):
if chunk.data.choices[0].delta.content is not None:
partial_message = partial_message + chunk.data.choices[0].delta.content
yield partial_message
else:
stream = client_more_ai.chat.completions.create(
model=additional_dropdown,
messages=[{"role": "user", "content": str(message_text)}],
temperature=0.5,
max_tokens=1024,
top_p=0.7,
stream=True
)
partial_message = ""
temp = ""
for chunk in stream:
if chunk.choices[0].delta.content is not None:
temp += chunk.choices[0].delta.content
yield temp
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Generator"):
prompt = gr.Text(
label="Prompt",
show_label=False,
placeholder="Enter your prompt",
max_lines = 12,
container=False
)
run_button = gr.Button("Run")
result = gr.Image(label="Result", show_label=False, interactive=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=4,
)
guidancescale = gr.Slider(
label="Guidance scale",
minimum=0,
maximum=10,
step=0.1,
value=3.5,
)
with gr.Tab("Styles"):
quality_select = gr.Checkbox(label="high quality")
styles_name = [style["name"] for style in config.style_list]
styles_Radio = gr.Dropdown(styles_name,label="Styles",multiselect=True)
with gr.Column(scale=3,elem_id="col-container"):
gr.ChatInterface(
predict,
type="messages",
multimodal=True,
additional_inputs =[gr.Dropdown(
["CohereForAI/c4ai-command-r-plus-08-2024",
"meta-llama/Meta-Llama-3.1-70B-Instruct",
"Qwen/Qwen2.5-72B-Instruct",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"NousResearch/Hermes-3-Llama-3.1-8B",
"mistralai/Mistral-Nemo-Instruct-2407",
"microsoft/Phi-3.5-mini-instruct"],
value="meta-llama/Meta-Llama-3.1-70B-Instruct",
show_label=False,
)]
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = [prompt, quality_select, styles_Radio, seed, randomize_seed, width, height, num_inference_steps, guidancescale],
outputs = [result, seed]
)
if __name__ == "__main__":
demo.queue().launch()