Singularity666 commited on
Commit
beed73c
·
verified ·
1 Parent(s): a906161

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -102
main.py CHANGED
@@ -1,117 +1,50 @@
 
 
1
  import os
2
- import json
3
- import shutil
4
- from pathlib import Path
5
  import torch
6
- import gradio as gr
7
- from diffusers import StableDiffusionPipeline, DDIMScheduler
8
- from transformers import CLIPTextModel, CLIPTokenizer
9
- from PIL import Image
10
  from torch import autocast
11
-
12
- # Define necessary paths and variables
13
- MODEL_NAME = "runwayml/stable-diffusion-v1-5"
14
- OUTPUT_DIR = "/output_model"
15
- INSTANCE_PROMPT = "photo of {identifier} person"
16
- CLASS_PROMPT = "photo of a person"
17
- SEED = 1337
18
- RESOLUTION = 512
19
- TRAIN_BATCH_SIZE = 1
20
- LEARNING_RATE = 1e-6
21
- MAX_TRAIN_STEPS = 800
22
- GUIDANCE_SCALE = 8.0
23
- NUM_INFERENCE_STEPS = 50
24
-
25
- # Function to fine-tune the model
26
- def fine_tune_model(instance_data_dir, identifier):
27
- # Set up paths
28
- instance_prompt = INSTANCE_PROMPT.format(identifier=identifier)
29
- concepts_list = [
30
- {
31
- "instance_prompt": instance_prompt,
32
- "class_prompt": CLASS_PROMPT,
33
- "instance_data_dir": instance_data_dir,
34
- "class_data_dir": "/sample_data/person" # Placeholder for regularization images
35
- }
36
- ]
37
-
38
- # Save concepts_list.json
39
- with open("concepts_list.json", "w") as f:
40
- json.dump(concepts_list, f, indent=4)
41
-
42
- # Run the training script
43
- os.system(f"""
44
- python3 train_dreambooth.py \
45
- --pretrained_model_name_or_path={MODEL_NAME} \
46
- --output_dir={OUTPUT_DIR} \
47
- --revision="fp16" \
48
- --with_prior_preservation --prior_loss_weight=1.0 \
49
- --seed={SEED} \
50
- --resolution={RESOLUTION} \
51
- --train_batch_size={TRAIN_BATCH_SIZE} \
52
- --train_text_encoder \
53
- --mixed_precision="fp16" \
54
- --use_8bit_adam \
55
- --gradient_accumulation_steps=1 \
56
- --learning_rate={LEARNING_RATE} \
57
- --max_train_steps={MAX_TRAIN_STEPS} \
58
- --save_sample_prompt="{instance_prompt}" \
59
- --concepts_list="concepts_list.json"
60
- """)
61
-
62
- # Function for inference
63
- def generate_images(prompt, negative_prompt, num_samples, model_path, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):
64
  pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
65
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
66
  pipe.enable_xformers_memory_efficient_attention()
67
- g_cuda = torch.Generator(device='cuda').manual_seed(SEED)
68
 
 
69
  with torch.autocast("cuda"), torch.inference_mode():
70
  images = pipe(
71
- prompt,
72
- height=height,
73
- width=width,
74
  negative_prompt=negative_prompt,
75
- num_images_per_prompt=num_samples,
76
- num_inference_steps=num_inference_steps,
77
- guidance_scale=guidance_scale,
78
- generator=g_cuda
79
  ).images
80
-
81
- return images
82
-
83
- # Gradio UI function
84
- def inference_ui(identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale):
85
- model_path = OUTPUT_DIR
86
- prompt = INSTANCE_PROMPT.format(identifier=identifier) + ", " + prompt
87
- images = generate_images(prompt, negative_prompt, num_samples, model_path, height, width, num_inference_steps, guidance_scale)
88
  return images
89
 
90
- # Define Gradio interface
91
- def create_gradio_ui():
92
- with gr.Blocks() as demo:
93
- with gr.Row():
94
- with gr.Column():
95
- identifier = gr.Textbox(label="Identifier", placeholder="Enter a unique identifier")
96
- image_upload = gr.File(label="Upload Images", file_count="multiple", type="file")
97
- finetune_button = gr.Button(value="Fine-Tune Model")
98
- finetune_output = gr.Textbox(label="Fine-Tuning Output")
99
-
100
- with gr.Column():
101
- prompt = gr.Textbox(label="Prompt", value="photo of {identifier} person in a marriage hall")
102
- negative_prompt = gr.Textbox(label="Negative Prompt", value="")
103
- num_samples = gr.Number(label="Number of Samples", value=4)
104
- guidance_scale = gr.Number(label="Guidance Scale", value=8)
105
- height = gr.Number(label="Height", value=512)
106
- width = gr.Number(label="Width", value=512)
107
- num_inference_steps = gr.Slider(label="Steps", value=50)
108
- generate_button = gr.Button(value="Generate Images")
109
- gallery = gr.Gallery()
110
-
111
- finetune_button.click(finetune_model, inputs=[image_upload, identifier], outputs=finetune_output)
112
- generate_button.click(inference_ui, inputs=[identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)
113
-
114
- demo.launch()
115
 
116
  if __name__ == "__main__":
117
- create_gradio_ui()
 
 
1
+ # main.py
2
+
3
  import os
 
 
 
4
  import torch
 
 
 
 
5
  from torch import autocast
6
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
7
+ from huggingface_hub import HfApi
8
+ from app import launch_gradio_app
9
+ from dreambooth import train_dreambooth
10
+
11
+ def fine_tune_model(instance_images, class_images, instance_prompt, class_prompt, num_train_steps=800):
12
+ model_name = "runwayml/stable-diffusion-v1-5"
13
+ output_dir = "dreambooth_model"
14
+
15
+ train_dreambooth(
16
+ pretrained_model_name_or_path=model_name,
17
+ instance_data_dir=instance_images,
18
+ class_data_dir=class_images,
19
+ output_dir=output_dir,
20
+ instance_prompt=instance_prompt,
21
+ class_prompt=class_prompt,
22
+ num_train_steps=num_train_steps
23
+ )
24
+
25
+ return output_dir
26
+
27
+ def load_model(model_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
29
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
30
  pipe.enable_xformers_memory_efficient_attention()
31
+ return pipe
32
 
33
+ def generate_images(pipe, prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):
34
  with torch.autocast("cuda"), torch.inference_mode():
35
  images = pipe(
36
+ prompt, height=int(height), width=int(width),
 
 
37
  negative_prompt=negative_prompt,
38
+ num_images_per_prompt=int(num_samples),
39
+ num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
40
+ generator=torch.Generator(device='cuda')
 
41
  ).images
 
 
 
 
 
 
 
 
42
  return images
43
 
44
+ def push_to_huggingface(model_path, repo_name):
45
+ api = HfApi()
46
+ api.upload_folder(folder_path=model_path, repo_id=repo_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
+ repo_name = "your-huggingface-username/dreambooth-app"
50
+ launch_gradio_app(fine_tune_model, load_model, generate_images, push_to_huggingface, repo_name)