What's the python code for utilizing this quantized model?
#35
by
enexorb
- opened
What's the python code for utilizing this quantized model?
Here's my current code:
import os
import re
import sys
import time
import torch
import secrets
import json
import random
from concurrent.futures import ThreadPoolExecutor
# Utility functions
def sanitize_filename(filename):
'''Remove invalid file name characters.'''
return re.sub(r'[\\/*?:"<>|]', '', filename)
def format_time(seconds):
'''Helper function to format time into minutes and seconds.'''
minutes = int(seconds // 60)
seconds = seconds % 60
return f'{minutes}m {seconds:.2f}s'
def read_prompts():
'''Read all prompts for image generation.'''
try:
with open('prompt.txt', 'r') as file:
prompts = [line.strip() for line in file if line.strip()]
if not prompts:
raise ValueError('Prompt file is empty.')
return prompts
except FileNotFoundError:
print('Error: prompt.txt not found.')
sys.exit(1)
except ValueError as e:
print(f'Error: {e}')
sys.exit(1)
def read_config():
'''Reads the config file (config.json) and returns the parameters.'''
config_file = 'config.json'
try:
with open(config_file, 'r') as file:
return json.load(file)
except FileNotFoundError:
print(f'Error: {config_file} not found.')
sys.exit(1)
except json.JSONDecodeError:
print(f'Error: Failed to parse {config_file}.')
sys.exit(1)
def append_to_log(log_file_path, image_data):
'''Append image metadata to the log file.'''
if os.path.exists(log_file_path):
with open(log_file_path, 'r') as log_file:
try:
logs = json.load(log_file)
except json.JSONDecodeError:
logs = []
else:
logs = []
logs.append(image_data)
with open(log_file_path, 'w') as log_file:
json.dump(logs, log_file, indent=4)
# Model configurations
def get_model_config(model_type):
'''Return the model configuration based on the model type.'''
model_configs = {
'flux-dev': {
'model': 'black-forest-labs/FLUX.1-dev',
'lora_weights': None,
'weight_name': None
},
'flux-dev-uncensored': {
'model': 'black-forest-labs/FLUX.1-dev',
'lora_weights': 'enhanceaiteam/Flux-uncensored',
'weight_name': 'lora.safetensors'
},
'flux-schnell': {
'model': 'black-forest-labs/FLUX.1-schnell',
'lora_weights': None,
'weight_name': None
},
'flux-schnell-realism': {
'model': 'black-forest-labs/FLUX.1-schnell',
'lora_weights': 'hugovntr/flux-schnell-realism',
'weight_name': 'schnell-realism_v2.3.safetensors'
},
'sd-3-medium': {
'model': 'stabilityai/stable-diffusion-3-medium-diffusers',
'lora_weights': None,
'weight_name': None
},
'sd-xl-base-1': {
'model': 'stabilityai/stable-diffusion-xl-base-1.0',
'lora_weights': None,
'weight_name': None
}
}
if model_type in model_configs:
return model_configs[model_type]
else:
print(f'Error: Unknown model type "{model_type}"')
sys.exit(1)
# Image generation functions
def generate_image_batch(start_index, batch_size, total_images, pipe, model_type, config):
'''Generate a batch of images using the provided model pipeline and save them.'''
prompts = read_prompts()
batch_prompts = []
for i in range(batch_size):
index = start_index + i
if index >= total_images:
break
prompt = prompts[index % len(prompts)]
batch_prompts.append(prompt)
if not batch_prompts:
return
generator = torch.Generator(device='cuda')
min_guidance_scale = config['guidance_scale']['min']
max_guidance_scale = config['guidance_scale']['max']
guidance_scale = round(random.uniform(min_guidance_scale, max_guidance_scale) * 2) / 2
seed = config.get('seed')
current_seed = seed + start_index if seed is not None else 'random'
if seed is not None:
generator.manual_seed(current_seed)
start_time = time.time()
outputs = pipe(
prompt=batch_prompts,
guidance_scale=guidance_scale,
height=config['height'],
width=config['width'],
num_inference_steps=config['num_inference_steps'],
max_sequence_length=config['max_sequence_length'],
generator=generator
).images
end_time = time.time()
save_generated_images(outputs, batch_prompts, start_index, total_images, current_seed, guidance_scale, config, model_type, end_time - start_time)
def save_generated_images(outputs, batch_prompts, start_index, total_images, current_seed, guidance_scale, config, model_type, generation_time):
'''Save the generated images and log their metadata.'''
output_dir = config['output_dir']
current_date = time.strftime('%Y-%m-%d')
dated_output_dir = os.path.join(output_dir, current_date)
os.makedirs(dated_output_dir, exist_ok=True)
for i, out in enumerate(outputs):
index = start_index + i
if index >= total_images:
break
random_str = secrets.token_hex(4)
unsanitized_name = f'{'_'.join(batch_prompts[i].split()[:10])}_{index}_{random_str}'
file_name = sanitize_filename(unsanitized_name) + '.png'
full_path = os.path.join(dated_output_dir, file_name)
out.save(full_path)
print(f'Image {index + 1}/{total_images} saved as {file_name} in {format_time(generation_time)}')
image_data = {
'prompt': batch_prompts[i],
'seed': current_seed,
'guidance_scale': guidance_scale,
'height': config['height'],
'width': config['width'],
'num_inference_steps': config['num_inference_steps'],
'max_sequence_length': config['max_sequence_length'],
'model_type': model_type,
'image_file': file_name,
'generation_time': format_time(generation_time)
}
log_file_path = os.path.join(dated_output_dir, 'generation_log.json')
append_to_log(log_file_path, image_data)
# Main function
def main(num_images, use_multithreading, model_type):
total_start_time = time.time()
model_config = get_model_config(model_type)
model = model_config['model']
lora_weights = model_config['lora_weights']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Beginning image generation.\nModel type: {model_type}\nDevice: {device}')
if model_type == 'sd-3-medium':
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_pretrained(model, torch_dtype=torch.float16).to(device)
elif model_type == 'sd-xl-base-1':
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16, use_safetensors=True, variant='fp16')
else:
from diffusers import AutoPipelineForText2Image
pipe = AutoPipelineForText2Image.from_pretrained(model, torch_dtype=torch.bfloat16).to(device)
if lora_weights is not None:
pipe.load_lora_weights(lora_weights, weight_name=model_config['weight_name'])
pipe.enable_model_cpu_offload()
pipe.enable_sequential_cpu_offload()
pipe.safety_checker = None
start_idx = 0
while start_idx < num_images:
config = read_config()
batch_size = config.get('batch_size', 2)
if use_multithreading:
with ThreadPoolExecutor(max_workers=min(4, num_images // batch_size)) as executor:
executor.map(lambda idx: generate_image_batch(idx, batch_size, num_images, pipe, model_type, config),
range(start_idx, num_images, batch_size))
else:
generate_image_batch(start_idx, batch_size, num_images, pipe, model_type, config)
start_idx += batch_size
total_end_time = time.time()
total_generation_time = total_end_time - total_start_time
print(f'\nTotal time for image generation: {format_time(total_generation_time)}')
# Entry point
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Image generation script.')
parser.add_argument('--num_images', type=int, default=1, help='Number of images to generate.')
parser.add_argument('--use_multithreading', type=bool, default=False, help='Use multi-threading for image generation.')
parser.add_argument('--model_type', type=str, default='flux-dev', choices=['flux-dev', 'flux-dev-uncensored', 'flux-schnell', 'flux-schnell-realism', 'sd-3-medium', 'sd-xl-base-1'], help='The type of model to use.')
args = parser.parse_args()
main(args.num_images, args.use_multithreading, args.model_type)