Initial
Browse files- .env +8 -0
- .gitignore +4 -0
- Service.py +51 -0
- Training.py +24 -0
- handler.py +41 -0
- requirements.txt +9 -0
- setup.py +56 -0
- training.sh +6 -0
.env
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SERVICE_ENDPOINT=http://0.0.0.0/api/
|
2 |
+
|
3 |
+
MODEL_NAME=runwayml/stable-diffusion-v1-5
|
4 |
+
|
5 |
+
APP_DIR=/Users/akin/Desktop/txt2img-consumer/
|
6 |
+
INSTANCE_DIR=/Users/akin/Desktop/txt2img-consumer/instance/
|
7 |
+
OUTPUT_DIR=/content/model
|
8 |
+
CLASS_DIR=/content/data/person
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
instance
|
3 |
+
instance.zip
|
4 |
+
__pycache__
|
Service.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os, shutil
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import zipfile
|
5 |
+
|
6 |
+
|
7 |
+
class Service:
|
8 |
+
def __init__(self, orderId, hash):
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
self.service_endpoint = os.getenv('SERVICE_ENDPOINT')
|
12 |
+
self.app_dir = os.getenv('APP_DIR')
|
13 |
+
self.instance_dir = os.getenv('INSTANCE_DIR')
|
14 |
+
self.archive_filename = self.app_dir + 'instance.zip'
|
15 |
+
|
16 |
+
self.headers = {
|
17 |
+
'Content-Type': 'application/json',
|
18 |
+
'Authorization': hash,
|
19 |
+
}
|
20 |
+
|
21 |
+
self.params = {
|
22 |
+
'orderId': orderId,
|
23 |
+
}
|
24 |
+
|
25 |
+
# Download images
|
26 |
+
def download(self):
|
27 |
+
self.scarify()
|
28 |
+
|
29 |
+
# Get images
|
30 |
+
response = requests.post(self.service_endpoint + 'download', headers=self.headers, json=self.params, verify=False)
|
31 |
+
with open(self.archive_filename, 'wb') as f:
|
32 |
+
f.write(response.content)
|
33 |
+
|
34 |
+
# Extract
|
35 |
+
with zipfile.ZipFile(self.archive_filename, "r") as zip_ref:
|
36 |
+
zip_ref.extractall(self.instance_dir)
|
37 |
+
|
38 |
+
print("Downloaded images")
|
39 |
+
|
40 |
+
# Get data
|
41 |
+
def data(self):
|
42 |
+
response = requests.post(self.service_endpoint + 'data', headers=self.headers, json=self.params, verify=False)
|
43 |
+
return response.json()
|
44 |
+
|
45 |
+
# Remove old files
|
46 |
+
def scarify(self):
|
47 |
+
if os.path.exists(self.instance_dir):
|
48 |
+
shutil.rmtree(self.instance_dir)
|
49 |
+
|
50 |
+
if os.path.exists(self.archive_filename):
|
51 |
+
os.remove(self.archive_filename)
|
Training.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess, os
|
2 |
+
|
3 |
+
|
4 |
+
class Training:
|
5 |
+
def __init__(self):
|
6 |
+
self.app_path = os.getenv('APP_DIR')
|
7 |
+
self.model_name = os.getenv('MODEL_NAME')
|
8 |
+
self.instance_dir = os.getenv('INSTANCE_DIR')
|
9 |
+
self.output_dir = os.getenv('OUTPUT_DIR')
|
10 |
+
self.class_dir = os.getenv('CLASS_DIR')
|
11 |
+
|
12 |
+
self.start()
|
13 |
+
|
14 |
+
def start(self):
|
15 |
+
args = [
|
16 |
+
self.app_path + 'training.sh',
|
17 |
+
self.model_name,
|
18 |
+
self.instance_dir,
|
19 |
+
self.output_dir,
|
20 |
+
self.class_dir
|
21 |
+
]
|
22 |
+
subprocess.run(['chmod', '+x', self.app_path + 'training.sh'], cwd=self.app_path)
|
23 |
+
process = subprocess.Popen(args, cwd=self.app_path)
|
24 |
+
process.wait()
|
handler.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import pika
|
4 |
+
import json
|
5 |
+
import phpserialize
|
6 |
+
from Service import Service
|
7 |
+
from Training import Training
|
8 |
+
|
9 |
+
|
10 |
+
class EndpointHandler:
|
11 |
+
def __init__(self):
|
12 |
+
credentials = pika.PlainCredentials('admin', '123456')
|
13 |
+
self.connection = pika.BlockingConnection(pika.ConnectionParameters('0.0.0.0', 5672, '/', credentials))
|
14 |
+
self.channel = self.connection.channel()
|
15 |
+
self.channel.queue_declare(queue='avator', durable=True)
|
16 |
+
self.consume()
|
17 |
+
|
18 |
+
def consume(self):
|
19 |
+
print(' [*] Waiting for messages. To exit press CTRL+C')
|
20 |
+
|
21 |
+
def callback(ch, method, properties, body):
|
22 |
+
payload = json.loads(body)
|
23 |
+
service = Service(orderId=payload['orderId'], hash=payload['hash'])
|
24 |
+
|
25 |
+
# Download images
|
26 |
+
service.download()
|
27 |
+
|
28 |
+
# Get data
|
29 |
+
data = service.data()
|
30 |
+
|
31 |
+
# Start training
|
32 |
+
Training()
|
33 |
+
|
34 |
+
# ch.basic_ack(delivery_tag=method.delivery_tag)
|
35 |
+
|
36 |
+
self.channel.basic_qos(prefetch_count=1)
|
37 |
+
self.channel.basic_consume(queue='avator', on_message_callback=callback)
|
38 |
+
self.channel.start_consuming()
|
39 |
+
|
40 |
+
|
41 |
+
worker = EndpointHandler()
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pika
|
2 |
+
phpserialize
|
3 |
+
torchvision
|
4 |
+
tensorboard
|
5 |
+
modelcards
|
6 |
+
wget
|
7 |
+
bitsandbytes
|
8 |
+
xformers
|
9 |
+
natsort
|
setup.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from distutils.command.build import build as _build
|
2 |
+
import setuptools
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
print("123")
|
6 |
+
raise None
|
7 |
+
|
8 |
+
|
9 |
+
CUSTOM_COMMANDS = [
|
10 |
+
#['apt-get', 'install', 'wget'],
|
11 |
+
#['python', '-c', "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('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')"],
|
12 |
+
['wget', '-q', "https://github.com/ShivamShrirao/diffusers/raw/main/examples/dreambooth/train_dreambooth.py"],
|
13 |
+
['wget', '-q', "https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py"],
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
class Build(_build):
|
18 |
+
sub_commands = _build.sub_commands + [('CustomCommands', None)]
|
19 |
+
|
20 |
+
|
21 |
+
class CustomCommands(setuptools.Command):
|
22 |
+
def initialize_options(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def finalize_options(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
def RunCustomCommand(self, command_list):
|
29 |
+
print('Running command: %s' % command_list)
|
30 |
+
p = subprocess.Popen(
|
31 |
+
command_list,
|
32 |
+
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
|
33 |
+
)
|
34 |
+
|
35 |
+
stdout_data, _ = p.communicate()
|
36 |
+
print( 'Command output: %s' % stdout_data)
|
37 |
+
|
38 |
+
if p.returncode != 0:
|
39 |
+
raise RuntimeError(
|
40 |
+
'Command %s failed: exit code: %s' % (command_list, p.returncode))
|
41 |
+
|
42 |
+
def run(self):
|
43 |
+
for command in CUSTOM_COMMANDS:
|
44 |
+
self.RunCustomCommand(command)
|
45 |
+
|
46 |
+
|
47 |
+
setup(
|
48 |
+
name="Txt2ImgConsumer",
|
49 |
+
version="1.0",
|
50 |
+
description="",
|
51 |
+
packages=find_packages(),
|
52 |
+
cmdclass={
|
53 |
+
'build': Build,
|
54 |
+
'CustomCommands': CustomCommands,
|
55 |
+
}
|
56 |
+
)
|
training.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
model_name=$1
|
4 |
+
instance_dir=$2
|
5 |
+
output_dir=$3
|
6 |
+
class_dir=$4
|