zwx00 commited on
Commit
a10adb4
1 Parent(s): ba99fa4

flask serve

Browse files
Files changed (1) hide show
  1. serve.py +176 -0
serve.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from compel import Compel, ReturnedEmbeddingsType
2
+ import logging
3
+ from abc import ABC
4
+
5
+ import diffusers
6
+ import torch
7
+ from diffusers import StableDiffusionXLPipeline
8
+
9
+ import numpy as np
10
+ import threading
11
+
12
+ import base64
13
+ from io import BytesIO
14
+ from PIL import Image
15
+ import numpy as np
16
+ import uuid
17
+ from tempfile import TemporaryFile
18
+ from google.cloud import storage
19
+ import sys
20
+ from flask import Flask, request, jsonify
21
+
22
+ logger = logging.getLogger(__name__)
23
+ logger.info("Diffusers version %s", diffusers.__version__)
24
+
25
+ class DiffusersHandler(ABC):
26
+ """
27
+ Diffusers handler class for text to image generation.
28
+ """
29
+
30
+ def __init__(self):
31
+ self.initialized = False
32
+
33
+ def initialize(self, properties):
34
+ """In this initialize function, the Stable Diffusion model is loaded and
35
+ initialized here.
36
+ Args:
37
+ ctx (context): It is a JSON Object containing information
38
+ pertaining to the model artefacts parameters.
39
+ """
40
+
41
+ logger.info("Loading diffusion model")
42
+ logger.info("I'm totally new and updated")
43
+
44
+
45
+ device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
46
+
47
+ print("my device is " + device_str)
48
+ self.device = torch.device(device_str)
49
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
50
+ sys.argv[1],
51
+ torch_dtype=torch.float16,
52
+ use_safetensors=True,
53
+ )
54
+
55
+ logger.info("moving model to device: %s", device_str)
56
+ self.pipe.to(self.device)
57
+
58
+ logger.info(self.device)
59
+ logger.info("Diffusion model from path %s loaded successfully")
60
+
61
+ self.initialized = True
62
+
63
+ def preprocess(self, raw_requests):
64
+ """Basic text preprocessing, of the user's prompt.
65
+ Args:
66
+ requests (str): The Input data in the form of text is passed on to the preprocess
67
+ function.
68
+ Returns:
69
+ list : The preprocess function returns a list of prompts.
70
+ """
71
+ logger.info("Received requests: '%s'", raw_requests)
72
+ self.working = True
73
+
74
+ processed_request = {
75
+ "prompt": raw_requests[0]["prompt"],
76
+ "negative_prompt": raw_requests[0].get("negative_prompt"),
77
+ "width": raw_requests[0].get("width"),
78
+ "height": raw_requests[0].get("height"),
79
+ "num_inference_steps": raw_requests[0].get("num_inference_steps", 30),
80
+ "guidance_scale": raw_requests[0].get("guidance_scale", 7.5),
81
+ }
82
+
83
+ logger.info("Processed request: '%s'", processed_request)
84
+ return processed_request
85
+
86
+
87
+ def inference(self, request):
88
+ """Generates the image relevant to the received text.
89
+ Args:
90
+ inputs (list): List of Text from the pre-process function is passed here
91
+ Returns:
92
+ list : It returns a list of the generate images for the input text
93
+ """
94
+
95
+ # Handling inference for sequence_classification.
96
+ compel = Compel(tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] , text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
97
+
98
+ self.prompt = request.pop("prompt")
99
+ conditioning, pooled = compel(self.prompt)
100
+
101
+ # Handling inference for sequence_classification.
102
+ inferences = self.pipe(
103
+ prompt_embeds=conditioning,
104
+ pooled_prompt_embeds=pooled,
105
+ **request
106
+ ).images
107
+
108
+ logger.info("Generated image: '%s'", inferences)
109
+ return inferences
110
+
111
+ def postprocess(self, inference_outputs):
112
+ """Post Process Function converts the generated image into Torchserve readable format.
113
+ Args:
114
+ inference_outputs (list): It contains the generated image of the input text.
115
+ Returns:
116
+ (list): Returns a list of the images.
117
+ """
118
+ bucket_name = "outputs-storage-prod"
119
+ client = storage.Client()
120
+ self.working = False
121
+ bucket = client.get_bucket(bucket_name)
122
+ outputs = []
123
+ for image in inference_outputs:
124
+ image_name = str(uuid.uuid4())
125
+
126
+ blob = bucket.blob(image_name + '.png')
127
+
128
+ with TemporaryFile() as tmp:
129
+ image.save(tmp, format="png")
130
+ tmp.seek(0)
131
+ blob.upload_from_file(tmp, content_type='image/png')
132
+
133
+ # generate txt file with the image name and the prompt inside
134
+ # blob = bucket.blob(image_name + '.txt')
135
+ # blob.upload_from_string(self.prompt)
136
+
137
+ outputs.append('https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png')
138
+ return outputs
139
+
140
+
141
+ app = Flask(__name__)
142
+
143
+ # Initialize the handler on startup
144
+ gpu_count = torch.cuda.device_count()
145
+ if gpu_count == 0:
146
+ raise ValueError("No GPUs available!")
147
+
148
+ handlers = [DiffusersHandler() for i in range(gpu_count)]
149
+ for i in range(gpu_count):
150
+ handlers[i].initialize({"gpu_id": i})
151
+
152
+ handler_lock = threading.Lock()
153
+ handler_index = 0
154
+
155
+ @app.route('/generate', methods=['POST'])
156
+ def generate_image():
157
+ global handler_index
158
+ try:
159
+ # Extract raw requests from HTTP POST body
160
+ raw_requests = request.json
161
+
162
+ with handler_lock:
163
+ selected_handler = handlers[handler_index]
164
+ handler_index = (handler_index + 1) % gpu_count # Rotate to the next handler
165
+
166
+ processed_request = selected_handler.preprocess([raw_requests])
167
+ inferences = selected_handler.inference(processed_request)
168
+ outputs = selected_handler.postprocess(inferences)
169
+
170
+ return jsonify({"image_urls": outputs})
171
+ except Exception as e:
172
+ logger.error("Error during image generation: %s", str(e))
173
+ return jsonify({"error": "Failed to generate image", "details": str(e)}), 500
174
+
175
+ if __name__ == '__main__':
176
+ app.run(host='0.0.0.0', port=3000, threaded=True)