Salt commited on
Commit
e95f4b9
1 Parent(s): 5828964

Upload 3 files

Browse files
Files changed (3) hide show
  1. constants.py +109 -0
  2. pipelines.py +26 -0
  3. server.py +777 -0
constants.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Constants
2
+ # Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
3
+ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
4
+ # Also try: 'joeddav/distilbert-base-uncased-go-emotions-student'
5
+ DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
6
+ # Also try: 'Salesforce/blip-image-captioning-base'
7
+ DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
8
+ DEFAULT_KEYPHRASE_MODEL = "ml6team/keyphrase-extraction-distilbert-inspec"
9
+ DEFAULT_PROMPT_MODEL = "FredZhang7/anime-anything-promptgen-v2"
10
+ DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
11
+ DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
12
+ DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
13
+ DEFAULT_REMOTE_SD_PORT = 7860
14
+ SILERO_SAMPLES_PATH = "tts_samples"
15
+ SILERO_SAMPLE_TEXT = "The quick brown fox jumps over the lazy dog"
16
+ # ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
17
+ DEFAULT_SUMMARIZE_PARAMS = {
18
+ "temperature": 1.0,
19
+ "repetition_penalty": 1.0,
20
+ "max_length": 500,
21
+ "min_length": 200,
22
+ "length_penalty": 1.5,
23
+ "bad_words": [
24
+ "\n",
25
+ '"',
26
+ "*",
27
+ "[",
28
+ "]",
29
+ "{",
30
+ "}",
31
+ ":",
32
+ "(",
33
+ ")",
34
+ "<",
35
+ ">",
36
+ "Â",
37
+ "The text ends",
38
+ "The story ends",
39
+ "The text is",
40
+ "The story is",
41
+ ],
42
+ }
43
+
44
+ PROMPT_PREFIX = "best quality, absurdres, "
45
+ NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
46
+ error hands, bad hands, error fingers, bad fingers, missing fingers
47
+ error legs, bad legs, multiple legs, missing legs, error lighting,
48
+ error shadow, error reflection, text, error, extra digit, fewer digits,
49
+ cropped, worst quality, low quality, normal quality, jpeg artifacts,
50
+ signature, watermark, username, blurry"""
51
+
52
+
53
+ # list of key phrases to be looking for in text (unused for now)
54
+ INDICATOR_LIST = [
55
+ "female",
56
+ "girl",
57
+ "male",
58
+ "boy",
59
+ "woman",
60
+ "man",
61
+ "hair",
62
+ "eyes",
63
+ "skin",
64
+ "wears",
65
+ "appearance",
66
+ "costume",
67
+ "clothes",
68
+ "body",
69
+ "tall",
70
+ "short",
71
+ "chubby",
72
+ "thin",
73
+ "expression",
74
+ "angry",
75
+ "sad",
76
+ "blush",
77
+ "smile",
78
+ "happy",
79
+ "depressed",
80
+ "long",
81
+ "cold",
82
+ "breasts",
83
+ "chest",
84
+ "tail",
85
+ "ears",
86
+ "fur",
87
+ "race",
88
+ "species",
89
+ "wearing",
90
+ "shoes",
91
+ "boots",
92
+ "shirt",
93
+ "panties",
94
+ "bra",
95
+ "skirt",
96
+ "dress",
97
+ "kimono",
98
+ "wings",
99
+ "horns",
100
+ "pants",
101
+ "shorts",
102
+ "leggins",
103
+ "sandals",
104
+ "hat",
105
+ "glasses",
106
+ "sweater",
107
+ "hoodie",
108
+ "sweatshirt",
109
+ ]
pipelines.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoModelForTokenClassification,
3
+ AutoTokenizer,
4
+ TokenClassificationPipeline,
5
+ )
6
+ from transformers.pipelines import AggregationStrategy
7
+ import numpy as np
8
+
9
+
10
+ class KeyphraseExtractionPipeline(TokenClassificationPipeline):
11
+ def __init__(self, model, *args, **kwargs):
12
+ super().__init__(
13
+ model=AutoModelForTokenClassification.from_pretrained(model),
14
+ tokenizer=AutoTokenizer.from_pretrained(model),
15
+ *args,
16
+ **kwargs
17
+ )
18
+
19
+ def postprocess(self, model_outputs):
20
+ results = super().postprocess(
21
+ model_outputs=model_outputs,
22
+ aggregation_strategy=AggregationStrategy.SIMPLE
23
+ if self.model.config.model_type == "roberta"
24
+ else AggregationStrategy.FIRST,
25
+ )
26
+ return np.unique([result.get("word").strip() for result in results])
server.py ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from flask import (
3
+ Flask,
4
+ jsonify,
5
+ request,
6
+ render_template_string,
7
+ abort,
8
+ send_from_directory,
9
+ send_file,
10
+ )
11
+ from flask_cors import CORS
12
+ import markdown
13
+ import argparse
14
+ from transformers import AutoTokenizer, AutoProcessor, pipeline
15
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
16
+ from transformers import BlipForConditionalGeneration, GPT2Tokenizer
17
+ import unicodedata
18
+ import torch
19
+ import time
20
+ import os
21
+ import gc
22
+ from PIL import Image
23
+ import base64
24
+ from io import BytesIO
25
+ from random import randint
26
+ import webuiapi
27
+ import hashlib
28
+ from constants import *
29
+ from colorama import Fore, Style, init as colorama_init
30
+
31
+ colorama_init()
32
+
33
+
34
+ class SplitArgs(argparse.Action):
35
+ def __call__(self, parser, namespace, values, option_string=None):
36
+ setattr(
37
+ namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
38
+ )
39
+
40
+
41
+ # Script arguments
42
+ parser = argparse.ArgumentParser(
43
+ prog="TavernAI Extras", description="Web API for transformers models"
44
+ )
45
+ parser.add_argument(
46
+ "--port", type=int, help="Specify the port on which the application is hosted"
47
+ )
48
+ parser.add_argument(
49
+ "--listen", action="store_true", help="Host the app on the local network"
50
+ )
51
+ parser.add_argument(
52
+ "--share", action="store_true", help="Share the app on CloudFlare tunnel"
53
+ )
54
+ parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
55
+ parser.add_argument("--summarization-model", help="Load a custom summarization model")
56
+ parser.add_argument(
57
+ "--classification-model", help="Load a custom text classification model"
58
+ )
59
+ parser.add_argument("--captioning-model", help="Load a custom captioning model")
60
+ parser.add_argument(
61
+ "--keyphrase-model", help="Load a custom keyphrase extraction model"
62
+ )
63
+ parser.add_argument("--prompt-model", help="Load a custom prompt generation model")
64
+ parser.add_argument("--embedding-model", help="Load a custom text embedding model")
65
+
66
+ sd_group = parser.add_mutually_exclusive_group()
67
+
68
+ local_sd = sd_group.add_argument_group("sd-local")
69
+ local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
70
+ local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU")
71
+
72
+ remote_sd = sd_group.add_argument_group("sd-remote")
73
+ remote_sd.add_argument(
74
+ "--sd-remote", action="store_true", help="Use a remote backend for SD"
75
+ )
76
+ remote_sd.add_argument(
77
+ "--sd-remote-host", type=str, help="Specify the host of the remote SD backend"
78
+ )
79
+ remote_sd.add_argument(
80
+ "--sd-remote-port", type=int, help="Specify the port of the remote SD backend"
81
+ )
82
+ remote_sd.add_argument(
83
+ "--sd-remote-ssl", action="store_true", help="Use SSL for the remote SD backend"
84
+ )
85
+ remote_sd.add_argument(
86
+ "--sd-remote-auth",
87
+ type=str,
88
+ help="Specify the username:password for the remote SD backend (if required)",
89
+ )
90
+
91
+ parser.add_argument(
92
+ "--enable-modules",
93
+ action=SplitArgs,
94
+ default=[],
95
+ help="Override a list of enabled modules",
96
+ )
97
+
98
+ args = parser.parse_args()
99
+
100
+ port = args.port if args.port else 5100
101
+ host = "0.0.0.0" if args.listen else "localhost"
102
+ summarization_model = (
103
+ args.summarization_model
104
+ if args.summarization_model
105
+ else DEFAULT_SUMMARIZATION_MODEL
106
+ )
107
+ classification_model = (
108
+ args.classification_model
109
+ if args.classification_model
110
+ else DEFAULT_CLASSIFICATION_MODEL
111
+ )
112
+ captioning_model = (
113
+ args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
114
+ )
115
+ keyphrase_model = (
116
+ args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
117
+ )
118
+ prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
119
+ embedding_model = (
120
+ args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
121
+ )
122
+
123
+ sd_use_remote = False if args.sd_model else True
124
+ sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
125
+ sd_remote_host = args.sd_remote_host if args.sd_remote_host else DEFAULT_REMOTE_SD_HOST
126
+ sd_remote_port = args.sd_remote_port if args.sd_remote_port else DEFAULT_REMOTE_SD_PORT
127
+ sd_remote_ssl = args.sd_remote_ssl
128
+ sd_remote_auth = args.sd_remote_auth
129
+
130
+ modules = (
131
+ args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
132
+ )
133
+
134
+ if len(modules) == 0:
135
+ print(
136
+ f"{Fore.RED}{Style.BRIGHT}You did not select any modules to run! Choose them by adding an --enable-modules option"
137
+ )
138
+ print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
139
+
140
+ # Models init
141
+ device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu"
142
+ device = torch.device(device_string)
143
+ torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
144
+
145
+ if "caption" in modules:
146
+ print("Initializing an image captioning model...")
147
+ captioning_processor = AutoProcessor.from_pretrained(captioning_model)
148
+ if "blip" in captioning_model:
149
+ captioning_transformer = BlipForConditionalGeneration.from_pretrained(
150
+ captioning_model, torch_dtype=torch_dtype
151
+ ).to(device)
152
+ else:
153
+ captioning_transformer = AutoModelForCausalLM.from_pretrained(
154
+ captioning_model, torch_dtype=torch_dtype
155
+ ).to(device)
156
+
157
+ if "summarize" in modules:
158
+ print("Initializing a text summarization model...")
159
+ summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
160
+ summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
161
+ summarization_model, torch_dtype=torch_dtype
162
+ ).to(device)
163
+
164
+ if "classify" in modules:
165
+ print("Initializing a sentiment classification pipeline...")
166
+ classification_pipe = pipeline(
167
+ "text-classification",
168
+ model=classification_model,
169
+ top_k=None,
170
+ device=device,
171
+ torch_dtype=torch_dtype,
172
+ )
173
+
174
+ if "keywords" in modules:
175
+ print("Initializing a keyword extraction pipeline...")
176
+ import pipelines as pipelines
177
+
178
+ keyphrase_pipe = pipelines.KeyphraseExtractionPipeline(keyphrase_model)
179
+
180
+ if "prompt" in modules:
181
+ print("Initializing a prompt generator")
182
+ gpt_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
183
+ gpt_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
184
+ gpt_model = AutoModelForCausalLM.from_pretrained(prompt_model)
185
+ prompt_generator = pipeline(
186
+ "text-generation", model=gpt_model, tokenizer=gpt_tokenizer
187
+ )
188
+
189
+ if "sd" in modules and not sd_use_remote:
190
+ from diffusers import StableDiffusionPipeline
191
+ from diffusers import EulerAncestralDiscreteScheduler
192
+
193
+ print("Initializing Stable Diffusion pipeline")
194
+ sd_device_string = (
195
+ "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
196
+ )
197
+ sd_device = torch.device(sd_device_string)
198
+ sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
199
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
200
+ sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
201
+ ).to(sd_device)
202
+ sd_pipe.safety_checker = lambda images, clip_input: (images, False)
203
+ sd_pipe.enable_attention_slicing()
204
+ # pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
205
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
206
+ sd_pipe.scheduler.config
207
+ )
208
+ elif "sd" in modules and sd_use_remote:
209
+ print("Initializing Stable Diffusion connection")
210
+ try:
211
+ sd_remote = webuiapi.WebUIApi(
212
+ host=sd_remote_host, port=sd_remote_port, use_https=sd_remote_ssl
213
+ )
214
+ if sd_remote_auth:
215
+ username, password = sd_remote_auth.split(":")
216
+ sd_remote.set_auth(username, password)
217
+ sd_remote.util_wait_for_ready()
218
+ except Exception as e:
219
+ # remote sd from modules
220
+ print(
221
+ f"{Fore.RED}{Style.BRIGHT}Could not connect to remote SD backend at http{'s' if sd_remote_ssl else ''}://{sd_remote_host}:{sd_remote_port}! Disabling SD module...{Style.RESET_ALL}"
222
+ )
223
+ modules.remove("sd")
224
+
225
+ if "tts" in modules:
226
+ if not os.path.exists(SILERO_SAMPLES_PATH):
227
+ os.makedirs(SILERO_SAMPLES_PATH)
228
+ print("Initializing Silero TTS server")
229
+ from silero_api_server import tts
230
+
231
+ tts_service = tts.SileroTtsService(SILERO_SAMPLES_PATH)
232
+ if len(os.listdir(SILERO_SAMPLES_PATH)) == 0:
233
+ print("Generating Silero TTS samples...")
234
+ tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
235
+ tts_service.generate_samples()
236
+
237
+ if "chromadb" in modules:
238
+ print("Initializing ChromaDB")
239
+ import chromadb
240
+ import posthog
241
+ from chromadb.config import Settings
242
+ from sentence_transformers import SentenceTransformer
243
+
244
+ # disable chromadb telemetry
245
+ posthog.capture = lambda *args, **kwargs: None
246
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
247
+ chromadb_embedder = SentenceTransformer(embedding_model)
248
+ chromadb_embed_fn = chromadb_embedder.encode
249
+
250
+
251
+ # Flask init
252
+ app = Flask(__name__)
253
+ CORS(app) # allow cross-domain requests
254
+ app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
255
+
256
+
257
+ def require_module(name):
258
+ def wrapper(fn):
259
+ @wraps(fn)
260
+ def decorated_view(*args, **kwargs):
261
+ if name not in modules:
262
+ abort(403, "Module is disabled by config")
263
+ return fn(*args, **kwargs)
264
+
265
+ return decorated_view
266
+
267
+ return wrapper
268
+
269
+
270
+ # AI stuff
271
+ def classify_text(text: str) -> list:
272
+ output = classification_pipe(
273
+ text,
274
+ truncation=True,
275
+ max_length=classification_pipe.model.config.max_position_embeddings,
276
+ )[0]
277
+ return sorted(output, key=lambda x: x["score"], reverse=True)
278
+
279
+
280
+ def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
281
+ inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
282
+ device, torch_dtype
283
+ )
284
+ outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
285
+ caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
286
+ return caption
287
+
288
+
289
+ def summarize_chunks(text: str, params: dict) -> str:
290
+ try:
291
+ return summarize(text, params)
292
+ except IndexError:
293
+ print(
294
+ "Sequence length too large for model, cutting text in half and calling again"
295
+ )
296
+ new_params = params.copy()
297
+ new_params["max_length"] = new_params["max_length"] // 2
298
+ new_params["min_length"] = new_params["min_length"] // 2
299
+ return summarize_chunks(
300
+ text[: (len(text) // 2)], new_params
301
+ ) + summarize_chunks(text[(len(text) // 2) :], new_params)
302
+
303
+
304
+ def summarize(text: str, params: dict) -> str:
305
+ # Tokenize input
306
+ inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
307
+ token_count = len(inputs[0])
308
+
309
+ bad_words_ids = [
310
+ summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
311
+ for bad_word in params["bad_words"]
312
+ ]
313
+ summary_ids = summarization_transformer.generate(
314
+ inputs["input_ids"],
315
+ num_beams=2,
316
+ max_new_tokens=max(token_count, int(params["max_length"])),
317
+ min_new_tokens=min(token_count, int(params["min_length"])),
318
+ repetition_penalty=float(params["repetition_penalty"]),
319
+ temperature=float(params["temperature"]),
320
+ length_penalty=float(params["length_penalty"]),
321
+ bad_words_ids=bad_words_ids,
322
+ )
323
+ summary = summarization_tokenizer.batch_decode(
324
+ summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
325
+ )[0]
326
+ summary = normalize_string(summary)
327
+ return summary
328
+
329
+
330
+ def normalize_string(input: str) -> str:
331
+ output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
332
+ return output
333
+
334
+
335
+ def extract_keywords(text: str) -> list:
336
+ punctuation = "(){}[]\n\r<>"
337
+ trans = str.maketrans(punctuation, " " * len(punctuation))
338
+ text = text.translate(trans)
339
+ text = normalize_string(text)
340
+ return list(keyphrase_pipe(text))
341
+
342
+
343
+ def generate_prompt(keywords: list, length: int = 100, num: int = 4) -> str:
344
+ prompt = ", ".join(keywords)
345
+ outs = prompt_generator(
346
+ prompt,
347
+ max_length=length,
348
+ num_return_sequences=num,
349
+ do_sample=True,
350
+ repetition_penalty=1.2,
351
+ temperature=0.7,
352
+ top_k=4,
353
+ early_stopping=True,
354
+ )
355
+ return [out["generated_text"] for out in outs]
356
+
357
+
358
+ def generate_image(data: dict) -> Image:
359
+ prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
360
+
361
+ if sd_use_remote:
362
+ image = sd_remote.txt2img(
363
+ prompt=prompt,
364
+ negative_prompt=data["negative_prompt"],
365
+ sampler_name=data["sampler"],
366
+ steps=data["steps"],
367
+ cfg_scale=data["scale"],
368
+ width=data["width"],
369
+ height=data["height"],
370
+ restore_faces=data["restore_faces"],
371
+ enable_hr=data["enable_hr"],
372
+ save_images=True,
373
+ send_images=True,
374
+ do_not_save_grid=False,
375
+ do_not_save_samples=False,
376
+ ).image
377
+ else:
378
+ image = sd_pipe(
379
+ prompt=prompt,
380
+ negative_prompt=data["negative_prompt"],
381
+ num_inference_steps=data["steps"],
382
+ guidance_scale=data["scale"],
383
+ width=data["width"],
384
+ height=data["height"],
385
+ ).images[0]
386
+
387
+ image.save("./debug.png")
388
+ return image
389
+
390
+
391
+ def image_to_base64(image: Image, quality: int = 75) -> str:
392
+ buffered = BytesIO()
393
+ image.save(buffered, format="JPEG", quality=quality)
394
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
395
+ return img_str
396
+
397
+
398
+ @app.before_request
399
+ # Request time measuring
400
+ def before_request():
401
+ request.start_time = time.time()
402
+
403
+
404
+ @app.after_request
405
+ def after_request(response):
406
+ duration = time.time() - request.start_time
407
+ response.headers["X-Request-Duration"] = str(duration)
408
+ return response
409
+
410
+
411
+ @app.route("/", methods=["GET"])
412
+ def index():
413
+ with open("./README.md", "r", encoding="utf8") as f:
414
+ content = f.read()
415
+ return render_template_string(markdown.markdown(content, extensions=["tables"]))
416
+
417
+
418
+ @app.route("/api/extensions", methods=["GET"])
419
+ def get_extensions():
420
+ extensions = dict(
421
+ {
422
+ "extensions": [
423
+ {
424
+ "name": "not-supported",
425
+ "metadata": {
426
+ "display_name": """<span style="white-space:break-spaces;">Extensions serving using Extensions API is no longer supported. Please update the mod from: <a href="https://github.com/Cohee1207/SillyTavern">https://github.com/Cohee1207/SillyTavern</a></span>""",
427
+ "requires": [],
428
+ "assets": [],
429
+ },
430
+ }
431
+ ]
432
+ }
433
+ )
434
+ return jsonify(extensions)
435
+
436
+
437
+ @app.route("/api/caption", methods=["POST"])
438
+ @require_module("caption")
439
+ def api_caption():
440
+ data = request.get_json()
441
+
442
+ if "image" not in data or not isinstance(data["image"], str):
443
+ abort(400, '"image" is required')
444
+
445
+ image = Image.open(BytesIO(base64.b64decode(data["image"])))
446
+ image = image.convert("RGB")
447
+ image.thumbnail((512, 512))
448
+ caption = caption_image(image)
449
+ thumbnail = image_to_base64(image)
450
+ print("Caption:", caption, sep="\n")
451
+ gc.collect()
452
+ return jsonify({"caption": caption, "thumbnail": thumbnail})
453
+
454
+
455
+ @app.route("/api/summarize", methods=["POST"])
456
+ @require_module("summarize")
457
+ def api_summarize():
458
+ data = request.get_json()
459
+
460
+ if "text" not in data or not isinstance(data["text"], str):
461
+ abort(400, '"text" is required')
462
+
463
+ params = DEFAULT_SUMMARIZE_PARAMS.copy()
464
+
465
+ if "params" in data and isinstance(data["params"], dict):
466
+ params.update(data["params"])
467
+
468
+ print("Summary input:", data["text"], sep="\n")
469
+ summary = summarize_chunks(data["text"], params)
470
+ print("Summary output:", summary, sep="\n")
471
+ gc.collect()
472
+ return jsonify({"summary": summary})
473
+
474
+
475
+ @app.route("/api/classify", methods=["POST"])
476
+ @require_module("classify")
477
+ def api_classify():
478
+ data = request.get_json()
479
+
480
+ if "text" not in data or not isinstance(data["text"], str):
481
+ abort(400, '"text" is required')
482
+
483
+ print("Classification input:", data["text"], sep="\n")
484
+ classification = classify_text(data["text"])
485
+ print("Classification output:", classification, sep="\n")
486
+ gc.collect()
487
+ return jsonify({"classification": classification})
488
+
489
+
490
+ @app.route("/api/classify/labels", methods=["GET"])
491
+ @require_module("classify")
492
+ def api_classify_labels():
493
+ classification = classify_text("")
494
+ labels = [x["label"] for x in classification]
495
+ return jsonify({"labels": labels})
496
+
497
+
498
+ @app.route("/api/keywords", methods=["POST"])
499
+ @require_module("keywords")
500
+ def api_keywords():
501
+ data = request.get_json()
502
+
503
+ if "text" not in data or not isinstance(data["text"], str):
504
+ abort(400, '"text" is required')
505
+
506
+ print("Keywords input:", data["text"], sep="\n")
507
+ keywords = extract_keywords(data["text"])
508
+ print("Keywords output:", keywords, sep="\n")
509
+ return jsonify({"keywords": keywords})
510
+
511
+
512
+ @app.route("/api/prompt", methods=["POST"])
513
+ @require_module("prompt")
514
+ def api_prompt():
515
+ data = request.get_json()
516
+
517
+ if "text" not in data or not isinstance(data["text"], str):
518
+ abort(400, '"text" is required')
519
+
520
+ keywords = extract_keywords(data["text"])
521
+
522
+ if "name" in data and isinstance(data["name"], str):
523
+ keywords.insert(0, data["name"])
524
+
525
+ print("Prompt input:", data["text"], sep="\n")
526
+ prompts = generate_prompt(keywords)
527
+ print("Prompt output:", prompts, sep="\n")
528
+ return jsonify({"prompts": prompts})
529
+
530
+
531
+ @app.route("/api/image", methods=["POST"])
532
+ @require_module("sd")
533
+ def api_image():
534
+ required_fields = {
535
+ "prompt": str,
536
+ }
537
+
538
+ optional_fields = {
539
+ "steps": 30,
540
+ "scale": 6,
541
+ "sampler": "DDIM",
542
+ "width": 512,
543
+ "height": 512,
544
+ "restore_faces": False,
545
+ "enable_hr": False,
546
+ "prompt_prefix": PROMPT_PREFIX,
547
+ "negative_prompt": NEGATIVE_PROMPT,
548
+ }
549
+
550
+ data = request.get_json()
551
+
552
+ # Check required fields
553
+ for field, field_type in required_fields.items():
554
+ if field not in data or not isinstance(data[field], field_type):
555
+ abort(400, f'"{field}" is required')
556
+
557
+ # Set optional fields to default values if not provided
558
+ for field, default_value in optional_fields.items():
559
+ type_match = (
560
+ (int, float)
561
+ if isinstance(default_value, (int, float))
562
+ else type(default_value)
563
+ )
564
+ if field not in data or not isinstance(data[field], type_match):
565
+ data[field] = default_value
566
+
567
+ try:
568
+ print("SD inputs:", data, sep="\n")
569
+ image = generate_image(data)
570
+ base64image = image_to_base64(image, quality=90)
571
+ return jsonify({"image": base64image})
572
+ except RuntimeError as e:
573
+ abort(400, str(e))
574
+
575
+
576
+ @app.route("/api/image/model", methods=["POST"])
577
+ @require_module("sd")
578
+ def api_image_model_set():
579
+ data = request.get_json()
580
+
581
+ if not sd_use_remote:
582
+ abort(400, "Changing model for local sd is not supported.")
583
+ if "model" not in data or not isinstance(data["model"], str):
584
+ abort(400, '"model" is required')
585
+
586
+ old_model = sd_remote.util_get_current_model()
587
+ sd_remote.util_set_model(data["model"], find_closest=False)
588
+ # sd_remote.util_set_model(data['model'])
589
+ sd_remote.util_wait_for_ready()
590
+ new_model = sd_remote.util_get_current_model()
591
+
592
+ return jsonify({"previous_model": old_model, "current_model": new_model})
593
+
594
+
595
+ @app.route("/api/image/model", methods=["GET"])
596
+ @require_module("sd")
597
+ def api_image_model_get():
598
+ model = sd_model
599
+
600
+ if sd_use_remote:
601
+ model = sd_remote.util_get_current_model()
602
+
603
+ return jsonify({"model": model})
604
+
605
+
606
+ @app.route("/api/image/models", methods=["GET"])
607
+ @require_module("sd")
608
+ def api_image_models():
609
+ models = [sd_model]
610
+
611
+ if sd_use_remote:
612
+ models = sd_remote.util_get_model_names()
613
+
614
+ return jsonify({"models": models})
615
+
616
+
617
+ @app.route("/api/image/samplers", methods=["GET"])
618
+ @require_module("sd")
619
+ def api_image_samplers():
620
+ samplers = ["Euler a"]
621
+
622
+ if sd_use_remote:
623
+ samplers = [sampler["name"] for sampler in sd_remote.get_samplers()]
624
+
625
+ return jsonify({"samplers": samplers})
626
+
627
+
628
+ @app.route("/api/modules", methods=["GET"])
629
+ def get_modules():
630
+ return jsonify({"modules": modules})
631
+
632
+
633
+ @app.route("/api/tts/speakers", methods=["GET"])
634
+ def tts_speakers():
635
+ voices = [
636
+ {
637
+ "name": speaker,
638
+ "voice_id": speaker,
639
+ "preview_url": f"{str(request.url_root)}api/tts/sample/{speaker}",
640
+ }
641
+ for speaker in tts_service.get_speakers()
642
+ ]
643
+ return jsonify(voices)
644
+
645
+
646
+ @app.route("/api/tts/generate", methods=["POST"])
647
+ def tts_generate():
648
+ voice = request.get_json()
649
+ if "text" not in voice or not isinstance(voice["text"], str):
650
+ abort(400, '"text" is required')
651
+ if "speaker" not in voice or not isinstance(voice["speaker"], str):
652
+ abort(400, '"speaker" is required')
653
+ # Remove asterisks
654
+ voice["text"] = voice["text"].replace("*", "")
655
+ try:
656
+ audio = tts_service.generate(voice["speaker"], voice["text"])
657
+ return send_file(audio, mimetype="audio/x-wav")
658
+ except Exception as e:
659
+ print(e)
660
+ abort(500, voice["speaker"])
661
+
662
+
663
+ @app.route("/api/tts/sample/<speaker>", methods=["GET"])
664
+ def tts_play_sample(speaker: str):
665
+ return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
666
+
667
+
668
+ @app.route("/api/chromadb", methods=["POST"])
669
+ @require_module("chromadb")
670
+ def chromadb_add_messages():
671
+ data = request.get_json()
672
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
673
+ abort(400, '"chat_id" is required')
674
+ if "messages" not in data or not isinstance(data["messages"], list):
675
+ abort(400, '"messages" is required')
676
+
677
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
678
+ collection = chromadb_client.get_or_create_collection(
679
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
680
+ )
681
+
682
+ documents = [m["content"] for m in data["messages"]]
683
+ ids = [m["id"] for m in data["messages"]]
684
+ metadatas = [
685
+ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
686
+ for m in data["messages"]
687
+ ]
688
+
689
+ collection.upsert(
690
+ ids=ids,
691
+ documents=documents,
692
+ metadatas=metadatas,
693
+ )
694
+
695
+ return jsonify({"count": len(ids)})
696
+
697
+
698
+ @app.route("/api/chromadb/purge", methods=["POST"])
699
+ @require_module("chromadb")
700
+ def chromadb_purge():
701
+ data = request.get_json()
702
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
703
+ abort(400, '"chat_id" is required')
704
+
705
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
706
+ collection = chromadb_client.get_or_create_collection(
707
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
708
+ )
709
+
710
+ deleted = collection.delete()
711
+ print("ChromaDB embeddings deleted", len(deleted))
712
+ return 'Ok', 200
713
+
714
+
715
+ @app.route("/api/chromadb/query", methods=["POST"])
716
+ @require_module("chromadb")
717
+ def chromadb_query():
718
+ data = request.get_json()
719
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
720
+ abort(400, '"chat_id" is required')
721
+ if "query" not in data or not isinstance(data["query"], str):
722
+ abort(400, '"query" is required')
723
+
724
+ if "n_results" not in data or not isinstance(data["n_results"], int):
725
+ n_results = 1
726
+ else:
727
+ n_results = data["n_results"]
728
+
729
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
730
+ collection = chromadb_client.get_or_create_collection(
731
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
732
+ )
733
+
734
+ n_results = min(collection.count(), n_results)
735
+ query_result = collection.query(
736
+ query_texts=[data["query"]],
737
+ n_results=n_results,
738
+ )
739
+
740
+ documents = query_result["documents"][0]
741
+ ids = query_result["ids"][0]
742
+ metadatas = query_result["metadatas"][0]
743
+ distances = query_result["distances"][0]
744
+
745
+ messages = [
746
+ {
747
+ "id": ids[i],
748
+ "date": metadatas[i]["date"],
749
+ "role": metadatas[i]["role"],
750
+ "meta": metadatas[i]["meta"],
751
+ "content": documents[i],
752
+ "distance": distances[i],
753
+ }
754
+ for i in range(len(ids))
755
+ ]
756
+
757
+ return jsonify(messages)
758
+
759
+
760
+ if args.share:
761
+ from flask_cloudflared import _run_cloudflared
762
+ import inspect
763
+
764
+ sig = inspect.signature(_run_cloudflared)
765
+ sum = sum(
766
+ 1
767
+ for param in sig.parameters.values()
768
+ if param.kind == param.POSITIONAL_OR_KEYWORD
769
+ )
770
+ if sum > 1:
771
+ metrics_port = randint(8100, 9000)
772
+ cloudflare = _run_cloudflared(port, metrics_port)
773
+ else:
774
+ cloudflare = _run_cloudflared(port)
775
+ print("Running on", cloudflare)
776
+
777
+ app.run(host=host, port=port)