What's the python code for utilizing this quantized model?

#35
by enexorb - opened

What's the python code for utilizing this quantized model?

Here's my current code:

import os
import re
import sys
import time
import torch
import secrets
import json
import random
from concurrent.futures import ThreadPoolExecutor

# Utility functions

def sanitize_filename(filename):
    '''Remove invalid file name characters.'''
    return re.sub(r'[\\/*?:"<>|]', '', filename)

def format_time(seconds):
    '''Helper function to format time into minutes and seconds.'''
    minutes = int(seconds // 60)
    seconds = seconds % 60
    return f'{minutes}m {seconds:.2f}s'

def read_prompts():
    '''Read all prompts for image generation.'''
    try:
        with open('prompt.txt', 'r') as file:
            prompts = [line.strip() for line in file if line.strip()]
            if not prompts:
                raise ValueError('Prompt file is empty.')
            return prompts
    except FileNotFoundError:
        print('Error: prompt.txt not found.')
        sys.exit(1)
    except ValueError as e:
        print(f'Error: {e}')
        sys.exit(1)

def read_config():
    '''Reads the config file (config.json) and returns the parameters.'''
    config_file = 'config.json'
    try:
        with open(config_file, 'r') as file:
            return json.load(file)
    except FileNotFoundError:
        print(f'Error: {config_file} not found.')
        sys.exit(1)
    except json.JSONDecodeError:
        print(f'Error: Failed to parse {config_file}.')
        sys.exit(1)

def append_to_log(log_file_path, image_data):
    '''Append image metadata to the log file.'''
    if os.path.exists(log_file_path):
        with open(log_file_path, 'r') as log_file:
            try:
                logs = json.load(log_file)
            except json.JSONDecodeError:
                logs = []
    else:
        logs = []

    logs.append(image_data)

    with open(log_file_path, 'w') as log_file:
        json.dump(logs, log_file, indent=4)

# Model configurations

def get_model_config(model_type):
    '''Return the model configuration based on the model type.'''
    model_configs = {
        'flux-dev': {
            'model': 'black-forest-labs/FLUX.1-dev',
            'lora_weights': None,
            'weight_name': None
        },
        'flux-dev-uncensored': {
            'model': 'black-forest-labs/FLUX.1-dev',
            'lora_weights': 'enhanceaiteam/Flux-uncensored',
            'weight_name': 'lora.safetensors'
        },
        'flux-schnell': {
            'model': 'black-forest-labs/FLUX.1-schnell',
            'lora_weights': None,
            'weight_name': None
        },
        'flux-schnell-realism': {
            'model': 'black-forest-labs/FLUX.1-schnell',
            'lora_weights': 'hugovntr/flux-schnell-realism',
            'weight_name': 'schnell-realism_v2.3.safetensors'
        },
        'sd-3-medium': {
            'model': 'stabilityai/stable-diffusion-3-medium-diffusers',
            'lora_weights': None,
            'weight_name': None
        },
        'sd-xl-base-1': {
            'model': 'stabilityai/stable-diffusion-xl-base-1.0',
            'lora_weights': None,
            'weight_name': None
        }
    }
    if model_type in model_configs:
        return model_configs[model_type]
    else:
        print(f'Error: Unknown model type "{model_type}"')
        sys.exit(1)

# Image generation functions

def generate_image_batch(start_index, batch_size, total_images, pipe, model_type, config):
    '''Generate a batch of images using the provided model pipeline and save them.'''
    prompts = read_prompts()
    batch_prompts = []

    for i in range(batch_size):
        index = start_index + i
        if index >= total_images:
            break
        prompt = prompts[index % len(prompts)]
        batch_prompts.append(prompt)

    if not batch_prompts:
        return

    generator = torch.Generator(device='cuda')

    min_guidance_scale = config['guidance_scale']['min']
    max_guidance_scale = config['guidance_scale']['max']
    guidance_scale = round(random.uniform(min_guidance_scale, max_guidance_scale) * 2) / 2

    seed = config.get('seed')
    current_seed = seed + start_index if seed is not None else 'random'
    if seed is not None:
        generator.manual_seed(current_seed)

    start_time = time.time()

    outputs = pipe(
        prompt=batch_prompts,
        guidance_scale=guidance_scale,
        height=config['height'],
        width=config['width'],
        num_inference_steps=config['num_inference_steps'],
        max_sequence_length=config['max_sequence_length'],
        generator=generator
    ).images

    end_time = time.time()
    save_generated_images(outputs, batch_prompts, start_index, total_images, current_seed, guidance_scale, config, model_type, end_time - start_time)

def save_generated_images(outputs, batch_prompts, start_index, total_images, current_seed, guidance_scale, config, model_type, generation_time):
    '''Save the generated images and log their metadata.'''
    output_dir = config['output_dir']
    current_date = time.strftime('%Y-%m-%d')
    dated_output_dir = os.path.join(output_dir, current_date)
    os.makedirs(dated_output_dir, exist_ok=True)

    for i, out in enumerate(outputs):
        index = start_index + i
        if index >= total_images:
            break

        random_str = secrets.token_hex(4)
        unsanitized_name = f'{'_'.join(batch_prompts[i].split()[:10])}_{index}_{random_str}'
        file_name = sanitize_filename(unsanitized_name) + '.png'
        full_path = os.path.join(dated_output_dir, file_name)
        out.save(full_path)

        print(f'Image {index + 1}/{total_images} saved as {file_name} in {format_time(generation_time)}')

        image_data = {
            'prompt': batch_prompts[i],
            'seed': current_seed,
            'guidance_scale': guidance_scale,
            'height': config['height'],
            'width': config['width'],
            'num_inference_steps': config['num_inference_steps'],
            'max_sequence_length': config['max_sequence_length'],
            'model_type': model_type,
            'image_file': file_name,
            'generation_time': format_time(generation_time)
        }

        log_file_path = os.path.join(dated_output_dir, 'generation_log.json')
        append_to_log(log_file_path, image_data)

# Main function

def main(num_images, use_multithreading, model_type):
    total_start_time = time.time()

    model_config = get_model_config(model_type)
    model = model_config['model']
    lora_weights = model_config['lora_weights']
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f'Beginning image generation.\nModel type: {model_type}\nDevice: {device}')

    if model_type == 'sd-3-medium':
        from diffusers import StableDiffusion3Pipeline
        pipe = StableDiffusion3Pipeline.from_pretrained(model, torch_dtype=torch.float16).to(device)
    elif model_type == 'sd-xl-base-1':
        from diffusers import DiffusionPipeline
        pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16, use_safetensors=True, variant='fp16')
    else:
        from diffusers import AutoPipelineForText2Image
        pipe = AutoPipelineForText2Image.from_pretrained(model, torch_dtype=torch.bfloat16).to(device)

    if lora_weights is not None:
        pipe.load_lora_weights(lora_weights, weight_name=model_config['weight_name'])

    pipe.enable_model_cpu_offload()
    pipe.enable_sequential_cpu_offload()
    pipe.safety_checker = None

    start_idx = 0

    while start_idx < num_images:
        config = read_config()
        batch_size = config.get('batch_size', 2)

        if use_multithreading:
            with ThreadPoolExecutor(max_workers=min(4, num_images // batch_size)) as executor:
                executor.map(lambda idx: generate_image_batch(idx, batch_size, num_images, pipe, model_type, config),
                             range(start_idx, num_images, batch_size))
        else:
            generate_image_batch(start_idx, batch_size, num_images, pipe, model_type, config)

        start_idx += batch_size

    total_end_time = time.time()
    total_generation_time = total_end_time - total_start_time
    print(f'\nTotal time for image generation: {format_time(total_generation_time)}')

# Entry point

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Image generation script.')
    parser.add_argument('--num_images', type=int, default=1, help='Number of images to generate.')
    parser.add_argument('--use_multithreading', type=bool, default=False, help='Use multi-threading for image generation.')
    parser.add_argument('--model_type', type=str, default='flux-dev', choices=['flux-dev', 'flux-dev-uncensored', 'flux-schnell', 'flux-schnell-realism', 'sd-3-medium', 'sd-xl-base-1'], help='The type of model to use.')
    args = parser.parse_args()

    main(args.num_images, args.use_multithreading, args.model_type)

Sign up or log in to comment