File size: 2,401 Bytes
ea54d45 b636aa5 ea54d45 b636aa5 906db1e f4b717d b636aa5 906db1e 7754b09 41181c9 ea54d45 fd8d501 f4b717d fd8d501 ea54d45 1269c65 7754b09 edd8452 b636aa5 7754b09 b636aa5 edd8452 fd8d501 b636aa5 1269c65 a44f1cf 41181c9 a44f1cf b636aa5 41181c9 b636aa5 7754b09 906db1e b636aa5 41181c9 b636aa5 1269c65 e728996 41181c9 e728996 b636aa5 e728996 f4b717d fd8d501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import os
import torch
from PIL import Image
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from io import BytesIO
import logging
class EndpointHandler:
def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
# Set memory limit
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# Access the environment variable
HF_TOKEN = os.getenv('HF_TOKEN')
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable is not set")
logging.basicConfig(level=logging.INFO)
logging.info("Using HF_TOKEN")
# Clear GPU memory
torch.cuda.empty_cache()
# Load model and pipeline
self.controlnet = FluxControlNetModel.from_pretrained(
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
)
self.pipe = FluxControlNetPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=self.controlnet,
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN
)
self.pipe.to("cuda")
self.pipe.enable_attention_slicing("auto")
self.pipe.enable_sequential_cpu_offload()
self.pipe.enable_memory_efficient_attention()
def preprocess(self, data):
image_file = data.get("control_image", None)
if not image_file:
raise ValueError("Missing control_image in input.")
image = Image.open(image_file)
return image.resize((512, 512)) # Resize to reduce memory usage
def postprocess(self, output):
buffer = BytesIO()
output.save(buffer, format="PNG")
buffer.seek(0)
return buffer
def inference(self, data):
control_image = self.preprocess(data)
torch.cuda.empty_cache()
output_image = self.pipe(
prompt=data.get("prompt", ""),
control_image=control_image,
controlnet_conditioning_scale=0.5,
num_inference_steps=10,
height=control_image.size[1],
width=control_image.size[0],
).images[0]
return self.postprocess(output_image)
if __name__ == "__main__":
data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
handler = EndpointHandler()
output = handler.inference(data)
print(output)
|