akinv commited on
Commit
3d0a1a1
1 Parent(s): a2e65f1
Files changed (8) hide show
  1. .env +8 -0
  2. .gitignore +4 -0
  3. Service.py +51 -0
  4. Training.py +24 -0
  5. handler.py +41 -0
  6. requirements.txt +9 -0
  7. setup.py +56 -0
  8. training.sh +6 -0
.env ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ SERVICE_ENDPOINT=http://0.0.0.0/api/
2
+
3
+ MODEL_NAME=runwayml/stable-diffusion-v1-5
4
+
5
+ APP_DIR=/Users/akin/Desktop/txt2img-consumer/
6
+ INSTANCE_DIR=/Users/akin/Desktop/txt2img-consumer/instance/
7
+ OUTPUT_DIR=/content/model
8
+ CLASS_DIR=/content/data/person
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .idea
2
+ instance
3
+ instance.zip
4
+ __pycache__
Service.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os, shutil
3
+ from dotenv import load_dotenv
4
+ import zipfile
5
+
6
+
7
+ class Service:
8
+ def __init__(self, orderId, hash):
9
+ load_dotenv()
10
+
11
+ self.service_endpoint = os.getenv('SERVICE_ENDPOINT')
12
+ self.app_dir = os.getenv('APP_DIR')
13
+ self.instance_dir = os.getenv('INSTANCE_DIR')
14
+ self.archive_filename = self.app_dir + 'instance.zip'
15
+
16
+ self.headers = {
17
+ 'Content-Type': 'application/json',
18
+ 'Authorization': hash,
19
+ }
20
+
21
+ self.params = {
22
+ 'orderId': orderId,
23
+ }
24
+
25
+ # Download images
26
+ def download(self):
27
+ self.scarify()
28
+
29
+ # Get images
30
+ response = requests.post(self.service_endpoint + 'download', headers=self.headers, json=self.params, verify=False)
31
+ with open(self.archive_filename, 'wb') as f:
32
+ f.write(response.content)
33
+
34
+ # Extract
35
+ with zipfile.ZipFile(self.archive_filename, "r") as zip_ref:
36
+ zip_ref.extractall(self.instance_dir)
37
+
38
+ print("Downloaded images")
39
+
40
+ # Get data
41
+ def data(self):
42
+ response = requests.post(self.service_endpoint + 'data', headers=self.headers, json=self.params, verify=False)
43
+ return response.json()
44
+
45
+ # Remove old files
46
+ def scarify(self):
47
+ if os.path.exists(self.instance_dir):
48
+ shutil.rmtree(self.instance_dir)
49
+
50
+ if os.path.exists(self.archive_filename):
51
+ os.remove(self.archive_filename)
Training.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess, os
2
+
3
+
4
+ class Training:
5
+ def __init__(self):
6
+ self.app_path = os.getenv('APP_DIR')
7
+ self.model_name = os.getenv('MODEL_NAME')
8
+ self.instance_dir = os.getenv('INSTANCE_DIR')
9
+ self.output_dir = os.getenv('OUTPUT_DIR')
10
+ self.class_dir = os.getenv('CLASS_DIR')
11
+
12
+ self.start()
13
+
14
+ def start(self):
15
+ args = [
16
+ self.app_path + 'training.sh',
17
+ self.model_name,
18
+ self.instance_dir,
19
+ self.output_dir,
20
+ self.class_dir
21
+ ]
22
+ subprocess.run(['chmod', '+x', self.app_path + 'training.sh'], cwd=self.app_path)
23
+ process = subprocess.Popen(args, cwd=self.app_path)
24
+ process.wait()
handler.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import pika
4
+ import json
5
+ import phpserialize
6
+ from Service import Service
7
+ from Training import Training
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(self):
12
+ credentials = pika.PlainCredentials('admin', '123456')
13
+ self.connection = pika.BlockingConnection(pika.ConnectionParameters('0.0.0.0', 5672, '/', credentials))
14
+ self.channel = self.connection.channel()
15
+ self.channel.queue_declare(queue='avator', durable=True)
16
+ self.consume()
17
+
18
+ def consume(self):
19
+ print(' [*] Waiting for messages. To exit press CTRL+C')
20
+
21
+ def callback(ch, method, properties, body):
22
+ payload = json.loads(body)
23
+ service = Service(orderId=payload['orderId'], hash=payload['hash'])
24
+
25
+ # Download images
26
+ service.download()
27
+
28
+ # Get data
29
+ data = service.data()
30
+
31
+ # Start training
32
+ Training()
33
+
34
+ # ch.basic_ack(delivery_tag=method.delivery_tag)
35
+
36
+ self.channel.basic_qos(prefetch_count=1)
37
+ self.channel.basic_consume(queue='avator', on_message_callback=callback)
38
+ self.channel.start_consuming()
39
+
40
+
41
+ worker = EndpointHandler()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ pika
2
+ phpserialize
3
+ torchvision
4
+ tensorboard
5
+ modelcards
6
+ wget
7
+ bitsandbytes
8
+ xformers
9
+ natsort
setup.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.command.build import build as _build
2
+ import setuptools
3
+ import subprocess
4
+
5
+ print("123")
6
+ raise None
7
+
8
+
9
+ CUSTOM_COMMANDS = [
10
+ #['apt-get', 'install', 'wget'],
11
+ #['python', '-c', "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('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')"],
12
+ ['wget', '-q', "https://github.com/ShivamShrirao/diffusers/raw/main/examples/dreambooth/train_dreambooth.py"],
13
+ ['wget', '-q', "https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py"],
14
+ ]
15
+
16
+
17
+ class Build(_build):
18
+ sub_commands = _build.sub_commands + [('CustomCommands', None)]
19
+
20
+
21
+ class CustomCommands(setuptools.Command):
22
+ def initialize_options(self):
23
+ pass
24
+
25
+ def finalize_options(self):
26
+ pass
27
+
28
+ def RunCustomCommand(self, command_list):
29
+ print('Running command: %s' % command_list)
30
+ p = subprocess.Popen(
31
+ command_list,
32
+ stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
33
+ )
34
+
35
+ stdout_data, _ = p.communicate()
36
+ print( 'Command output: %s' % stdout_data)
37
+
38
+ if p.returncode != 0:
39
+ raise RuntimeError(
40
+ 'Command %s failed: exit code: %s' % (command_list, p.returncode))
41
+
42
+ def run(self):
43
+ for command in CUSTOM_COMMANDS:
44
+ self.RunCustomCommand(command)
45
+
46
+
47
+ setup(
48
+ name="Txt2ImgConsumer",
49
+ version="1.0",
50
+ description="",
51
+ packages=find_packages(),
52
+ cmdclass={
53
+ 'build': Build,
54
+ 'CustomCommands': CustomCommands,
55
+ }
56
+ )
training.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ model_name=$1
4
+ instance_dir=$2
5
+ output_dir=$3
6
+ class_dir=$4