Spaces:
Build error
Build error
LRhinehart
commited on
Commit
•
5bd179e
1
Parent(s):
f37adf7
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- $characters/Assistant.yaml +4 -0
- $characters/Example.png +0 -0
- $characters/Example.yaml +17 -0
- $extensions/Training_PRO/README.md +92 -0
- $extensions/Training_PRO/custom_scheduler.py +433 -0
- $extensions/Training_PRO/matplotgraph.py +62 -0
- $extensions/Training_PRO/script.py +1376 -0
- $extensions/Training_PRO/train_utils.py +368 -0
- $extensions/character_bias/script.py +83 -0
- $extensions/coqui_tts/harvard_sentences.txt +720 -0
- $extensions/coqui_tts/languages.json +18 -0
- $extensions/coqui_tts/requirements.txt +1 -0
- $extensions/coqui_tts/script.py +239 -0
- $extensions/coqui_tts/style.css +8 -0
- $extensions/coqui_tts/voices/arnold.wav +0 -0
- $extensions/coqui_tts/voices/female_01.wav +0 -0
- $extensions/coqui_tts/voices/female_02.wav +0 -0
- $extensions/elevenlabs_tts/outputs/outputs-will-be-saved-here.txt +0 -0
- $extensions/elevenlabs_tts/requirements.txt +1 -0
- $extensions/elevenlabs_tts/script.py +197 -0
- $extensions/example/script.py +139 -0
- $extensions/gallery/__pycache__/script.cpython-311.pyc +0 -0
- $extensions/gallery/script.js +40 -0
- $extensions/gallery/script.py +129 -0
- $extensions/google_translate/requirements.txt +1 -0
- $extensions/google_translate/script.py +59 -0
- $extensions/long_replies/script.py +143 -0
- $extensions/multimodal/DOCS.md +85 -0
- $extensions/multimodal/README.md +91 -0
- $extensions/multimodal/abstract_pipeline.py +63 -0
- $extensions/multimodal/multimodal_embedder.py +178 -0
- $extensions/multimodal/pipeline_loader.py +52 -0
- $extensions/multimodal/pipelines/llava/README.md +9 -0
- $extensions/multimodal/pipelines/llava/llava.py +262 -0
- $extensions/multimodal/pipelines/llava/pipelines.py +48 -0
- $extensions/multimodal/pipelines/place-additional-pipelines-here.txt +0 -0
- $extensions/multimodal/script.py +113 -0
- $extensions/ngrok/README.md +69 -0
- $extensions/ngrok/requirements.txt +1 -0
- $extensions/ngrok/script.py +36 -0
- $extensions/openai/cache_embedding_model.py +11 -0
- $extensions/openai/completions.py +508 -0
- $extensions/openai/embeddings.py +98 -0
- $extensions/openai/errors.py +31 -0
- $extensions/openai/images.py +70 -0
- $extensions/openai/logits.py +11 -0
- $extensions/openai/models.py +80 -0
- $extensions/openai/moderations.py +69 -0
- $extensions/openai/requirements.txt +4 -0
- $extensions/openai/script.py +377 -0
$characters/Assistant.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: AI
|
2 |
+
greeting: How can I help you today?
|
3 |
+
context: |
|
4 |
+
The following is a conversation with an AI Large Language Model. The AI has been trained to answer questions, provide recommendations, and help with decision making. The AI follows user requests. The AI thinks outside the box.
|
$characters/Example.png
ADDED
$characters/Example.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Chiharu Yamada
|
2 |
+
greeting: |-
|
3 |
+
*Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air*
|
4 |
+
Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!
|
5 |
+
context: |-
|
6 |
+
Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.
|
7 |
+
|
8 |
+
{{user}}: So how did you get into computer engineering?
|
9 |
+
{{char}}: I've always loved tinkering with technology since I was a kid.
|
10 |
+
{{user}}: That's really impressive!
|
11 |
+
{{char}}: *She chuckles bashfully* Thanks!
|
12 |
+
{{user}}: So what do you do when you're not working on computers?
|
13 |
+
{{char}}: I love exploring, going out with friends, watching movies, and playing video games.
|
14 |
+
{{user}}: What's your favorite type of computer hardware to work with?
|
15 |
+
{{char}}: Motherboards, they're like puzzles and the backbone of any system.
|
16 |
+
{{user}}: That sounds great!
|
17 |
+
{{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job.
|
$extensions/Training_PRO/README.md
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Training_PRO
|
2 |
+
|
3 |
+
This is an expanded and reworked Training tab
|
4 |
+
Maintained by FP
|
5 |
+
|
6 |
+
[![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/Q5Q5MOB4M)
|
7 |
+
|
8 |
+
Repo home:
|
9 |
+
|
10 |
+
https://github.com/FartyPants/Training_PRO
|
11 |
+
|
12 |
+
In general the repo above is ahead of the extension included in text WebUi.
|
13 |
+
|
14 |
+
## News
|
15 |
+
|
16 |
+
- NEFtune: add noise to help with generalization
|
17 |
+
- Loss Graph in interface.
|
18 |
+
- Supports Mistral training
|
19 |
+
- some roundabout around pytorch and transformers version desync
|
20 |
+
|
21 |
+
![image](https://github.com/FartyPants/Training_PRO/assets/23346289/e389ec69-d7ad-4922-9ad9-865625997479)
|
22 |
+
|
23 |
+
## Features/Changes
|
24 |
+
|
25 |
+
- Chunking: precise raw text slicer (PRTS) uses sentence slicing and making sure things are clean on all ends
|
26 |
+
- overlap chunking - this special overlapping will make additional overlap block based on logical rules (aka no overlap block on hard cut)
|
27 |
+
- custom scheduler (follow the code to make your own) In LR Scheduler select FP_low_epoch_annealing - this scheduler will keep the LR constant for first epoch then use cosine for the rest - this part would be best to spawn into a new py file
|
28 |
+
- saves graph png file at the end with learning rate and loss per epoch
|
29 |
+
- adding EOS to each block or to hard cut only
|
30 |
+
- automatically lowers gradient accumulation if you go overboard and set gradient accumulation that will be higher than actual data - transformers would then throw error (or they used to, not sure if still true) but in any way, it will fix bad data
|
31 |
+
- turn BOS on and OFF
|
32 |
+
- target selector
|
33 |
+
- DEMENTOR LEARNING (experimental) Deep Memorization Enforcement Through Overlapping and Repetition. This is an experiment for long-text learning using low epochs (basically use 1 epoch with constant LR or 2 epochs with FP_low_epoch_annealing LR scheduler)
|
34 |
+
- Getting rid of micro batch size/batch size confusion. Now there is True Batch Size and Gradient accumulation slider, consisten with all the other training out there
|
35 |
+
- Ability to save Checkpoint during training with a button
|
36 |
+
- Ability to change Stop Loss during training
|
37 |
+
- different modes of checkpoint auto saving
|
38 |
+
- Function to Check Dataset and suggest parameters such as warmup and checkpoint save frequency before training
|
39 |
+
- Graph Training Loss in interface
|
40 |
+
- more custom schedulers
|
41 |
+
|
42 |
+
### Notes:
|
43 |
+
|
44 |
+
This uses it's own chunking code for raw text based on sentence splitting. This will avoid weird cuts in the chunks and each chunk should now start with sentence and end on some sentence. It works hand in hand with Hard Cut. A propper use is to structure your text into logical blocks (ideas) separated by three \n then use three \n in hard cut. This way each chunk will contain only one flow of ideas and not derail in the thoughts. And Overlapping code will create overlapped blocks on sentence basis too, but not cross hard cut, thus not cross different ideas either. Does it make any sense? No? Hmmmm...
|
45 |
+
|
46 |
+
### Custom schedulers
|
47 |
+
|
48 |
+
A bunch of custom (combination) schedulers are added to the LR schedule. These are based on my own experiments
|
49 |
+
|
50 |
+
**FP_low_epoch_annealing**
|
51 |
+
|
52 |
+
Uses constant LR (with warmup) for 1 epoch only. The rest of the epoch(s) is cosine annealing. So 10 epochs - 1 will be constant 9 will be nose dive down. However a typical usage would be 2 epochs (hence low epoch in name). 1st is constant, the second is annealing. Simple. I use it 90% of time.
|
53 |
+
|
54 |
+
**FP_half_time_annealing**
|
55 |
+
|
56 |
+
Like the low epoch, but now the total number of steps is divided by 2. First half is constant, second half is annealing. So 10 epochs - 5 will be constant, 5 will be cosine nose down.
|
57 |
+
|
58 |
+
**FP_raise_fall_creative**
|
59 |
+
|
60 |
+
This is a sine raise till half of the total steps then cosine fall the rest. (Or you may think of the curve as sine in its entirety. The most learning is done in the hump, in the middle. The warmup entry has no effect, since sine is automatically warm up.
|
61 |
+
The idea is to start very mildly as not to overfit with the first blocks of dataset. It seems to broaden the scope of the model making it less strict for tight dataset.
|
62 |
+
|
63 |
+
### Targets
|
64 |
+
|
65 |
+
Normal LORA is q, v and that's what you should use. You can use (q k v o) or (q k v) and it will give you a lot more trainable parameters. The benefit is that you can keep rank lower and still attain the same coherency as q v with high rank. Guanaco has been trained with QLORA and q k v o for example and they swear by it.
|
66 |
+
|
67 |
+
### DEMENTOR LEARNING (experimental) Deep Memorization Enforcement Through Overlapping and Repetition
|
68 |
+
|
69 |
+
This is and experimental chunking to train long-form text in low number of epochs (basically 1) with sliding repetition. The depth of learning directly depends on the cutoff_length. Increasing cutoff length will also increase number of blocks created from long-form text (which is contrary to normal training). It is based on my own wild experiments.
|
70 |
+
|
71 |
+
### Getting rid of batch size and micro batch size
|
72 |
+
|
73 |
+
Keeping consistency with everyone else.
|
74 |
+
|
75 |
+
Listen, There is only ONE batch size - the True batch size (called previously micro-batch size in WebUI) - this is how many blocks are processed at once (during a single step). It eats GPU, but it really helps with the quality training (in fact the ideal batch size would be the same as number of blocks - which is unrealistic) - so the idea is to cram as much True Batch Size before your GPU blows with OOM. On 24GB this is about 10 for 13b (loaded with 4-bit)
|
76 |
+
|
77 |
+
So no micro batch size - it is now called True Batch Size, because that's what it is.
|
78 |
+
|
79 |
+
The other thing is Gradient Accumulation - this is an emulation of the above Batch size - a virtual batch size, if you will. If your GPU can't handle real batch size then you may fake it using Gradient Accumulation. This will accumulate the gradients over so many steps defined here and then update the weights at the end without increase in GPU.
|
80 |
+
Gradient accumulation is like a virtual Batch size multiplier without the GPU penalty.
|
81 |
+
|
82 |
+
If your batch size is 4 and your gradient accumulation is 2 then it sort of behaves as if we have batch size 8. *Sort of* because Batch size of 4 and GA of 2 is NOT the same as batch size of 2 and GA of 4. (It produces different weights - hence it's not an equivalent). The idea is that if you don't have GPU - using GA to extend batch size is the next best thing (good enough) since you have no other choice.
|
83 |
+
|
84 |
+
If all you can afford is 1 batch size, then increasing GA will likely make the learning better in some range of GA (it's not always more is better).
|
85 |
+
|
86 |
+
However - GA is not some golden goose. As said, it isn't the same as batch size. In fact GA may worsen your learning as well.
|
87 |
+
|
88 |
+
I would suggest a series of experiment where you would put batch size as high as possible without OOM, set GA 1, then repeat training while increasing the GA (2, 4...), and see how the model changes. It's likely that it would follow some sort of curve where GA will seem to help before it will make it worse. Some people believe that if you can squeeze 6 BATCH Size, then you should not bother with GA at all... YMMW
|
89 |
+
|
90 |
+
High Batch Size vs High GA would also likely produce different results in terms of learning words vs style. How? Hmmmm... good question.
|
91 |
+
|
92 |
+
One optical "benefit" of GA is that the loss will fluctuate less (because of all the gradient accumulation, which works as a form of noise smoothing as well).
|
$extensions/Training_PRO/custom_scheduler.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
import math
|
5 |
+
from torch.optim.lr_scheduler import LambdaLR
|
6 |
+
|
7 |
+
from peft import (
|
8 |
+
PeftModel,
|
9 |
+
)
|
10 |
+
|
11 |
+
RED = "\033[91m"
|
12 |
+
YELLOW = "\033[93m"
|
13 |
+
GREEN = "\033[92m"
|
14 |
+
RESET = "\033[0m"
|
15 |
+
|
16 |
+
last_print_label = ''
|
17 |
+
|
18 |
+
custom_scheduler_params = {'trigger_loss': 0.0, 'ramp_down_ratio':1.0, 'current_loss': 0.0,'dynamic_scheduler_stop': False, 'calc_ramp_down_at_step': 0, 'calc_num_training_steps': 0}
|
19 |
+
|
20 |
+
|
21 |
+
def custom_scheduler_global_update(current_loss: float):
|
22 |
+
custom_scheduler_params.update({'current_loss': current_loss})
|
23 |
+
|
24 |
+
def custom_scheduler_global_setup(trigger_loss: float, ramp_down_ratio: float):
|
25 |
+
custom_scheduler_params.update({'trigger_loss': trigger_loss})
|
26 |
+
custom_scheduler_params.update({'ramp_down_ratio': ramp_down_ratio})
|
27 |
+
|
28 |
+
# calculates the total num steps after trigger
|
29 |
+
custom_scheduler_params.update({'calc_num_training_steps': 0})
|
30 |
+
#calculates steps when the ramp_down trigger occured
|
31 |
+
custom_scheduler_params.update({'calc_ramp_down_at_step': 0})
|
32 |
+
# triggers scheduler stopping after it reached calc_num_training_steps
|
33 |
+
custom_scheduler_params.update({'dynamic_scheduler_stop': False})
|
34 |
+
|
35 |
+
|
36 |
+
# hold constant to the half of epochs then cosine down to 0
|
37 |
+
def _get_fp_half_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
|
38 |
+
|
39 |
+
global last_print_label
|
40 |
+
print_label = ''
|
41 |
+
|
42 |
+
half_steps = num_training_steps//2
|
43 |
+
|
44 |
+
num_warmup_steps = min(num_warmup_steps,half_steps)
|
45 |
+
|
46 |
+
if current_step < num_warmup_steps:
|
47 |
+
print_label = 'Scheduler: Warmup'
|
48 |
+
elif current_step < half_steps:
|
49 |
+
print_label = 'Scheduler: Hold'
|
50 |
+
else:
|
51 |
+
print_label = 'Scheduler: Annealing'
|
52 |
+
|
53 |
+
if print_label != last_print_label:
|
54 |
+
print(print_label)
|
55 |
+
|
56 |
+
last_print_label = print_label
|
57 |
+
|
58 |
+
if current_step < num_warmup_steps:
|
59 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
60 |
+
|
61 |
+
if current_step < half_steps:
|
62 |
+
return 1.0
|
63 |
+
|
64 |
+
progress = float(current_step - half_steps) / float(max(1, num_training_steps - half_steps))
|
65 |
+
num_cycles = 0.5
|
66 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
67 |
+
|
68 |
+
|
69 |
+
# raise up in cosine, then fall back in cosine
|
70 |
+
def _get_fp_cosine_raise_and_fall_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
|
71 |
+
|
72 |
+
global last_print_label
|
73 |
+
print_label = ''
|
74 |
+
|
75 |
+
half_steps = num_training_steps//2
|
76 |
+
|
77 |
+
#num_warmup_steps = min(num_warmup_steps,half_steps)
|
78 |
+
|
79 |
+
if current_step < half_steps:
|
80 |
+
print_label = 'Scheduler: Raise'
|
81 |
+
else:
|
82 |
+
print_label = 'Scheduler: Fall'
|
83 |
+
|
84 |
+
if print_label != last_print_label:
|
85 |
+
print(print_label)
|
86 |
+
|
87 |
+
last_print_label = print_label
|
88 |
+
|
89 |
+
|
90 |
+
# linear
|
91 |
+
# return float(current_step) / float(max(1, num_warmup_steps))
|
92 |
+
|
93 |
+
progress = float(current_step - half_steps) / float(max(1, num_training_steps - half_steps))
|
94 |
+
num_cycles = 0.5
|
95 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
96 |
+
|
97 |
+
# constant to the first epochs then cosine down to 0 over the rest epochs
|
98 |
+
def _get_fp_cosine_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
|
99 |
+
|
100 |
+
global last_print_label
|
101 |
+
print_label = ''
|
102 |
+
|
103 |
+
num_warmup_steps = min(num_warmup_steps,num_firstepoch_steps)
|
104 |
+
|
105 |
+
if current_step < num_warmup_steps:
|
106 |
+
print_label = 'Scheduler: Warmup'
|
107 |
+
elif current_step < num_firstepoch_steps:
|
108 |
+
print_label = 'Scheduler: Hold'
|
109 |
+
else:
|
110 |
+
print_label = 'Scheduler: Annealing'
|
111 |
+
|
112 |
+
if print_label != last_print_label:
|
113 |
+
print(print_label)
|
114 |
+
|
115 |
+
last_print_label = print_label
|
116 |
+
|
117 |
+
if current_step < num_warmup_steps:
|
118 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
119 |
+
|
120 |
+
if current_step < num_firstepoch_steps:
|
121 |
+
return 1.0
|
122 |
+
|
123 |
+
progress = float(current_step - num_firstepoch_steps) / float(max(1, num_training_steps - num_firstepoch_steps))
|
124 |
+
num_cycles = 0.5
|
125 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
126 |
+
|
127 |
+
# halve lr each epoch
|
128 |
+
|
129 |
+
def _get_fp_cdrop_rate_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
|
130 |
+
|
131 |
+
global last_print_label
|
132 |
+
print_label = ''
|
133 |
+
|
134 |
+
num_warmup_steps = min(num_warmup_steps, num_firstepoch_steps)
|
135 |
+
|
136 |
+
current_epoch = (current_step // num_firstepoch_steps) + 1
|
137 |
+
|
138 |
+
|
139 |
+
if current_step < num_warmup_steps:
|
140 |
+
print_label = 'Scheduler: Warmup'
|
141 |
+
elif current_step < num_firstepoch_steps:
|
142 |
+
print_label = 'Scheduler: Hold'
|
143 |
+
else:
|
144 |
+
print_label = 'Scheduler: Drop Rate'
|
145 |
+
|
146 |
+
if print_label != last_print_label:
|
147 |
+
print(print_label)
|
148 |
+
|
149 |
+
last_print_label = print_label
|
150 |
+
|
151 |
+
if current_step < num_warmup_steps:
|
152 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
153 |
+
|
154 |
+
if current_step < num_firstepoch_steps:
|
155 |
+
return 1.0
|
156 |
+
|
157 |
+
# Compute the learning rate for the annealing phase
|
158 |
+
|
159 |
+
learning_rate = 1.0 / float(2 ** (current_epoch - 1))
|
160 |
+
|
161 |
+
return learning_rate
|
162 |
+
|
163 |
+
# epoch decay: 1/(1 + decay * epoch)
|
164 |
+
|
165 |
+
def custom_cosine_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
|
166 |
+
"""
|
167 |
+
Args:
|
168 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
169 |
+
The optimizer for which to schedule the learning rate.
|
170 |
+
num_warmup_steps (`int`):
|
171 |
+
The number of steps for the warmup phase.
|
172 |
+
num_training_steps (`int`):
|
173 |
+
The total number of training steps.
|
174 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
175 |
+
The index of the last epoch when resuming training.
|
176 |
+
|
177 |
+
Return:
|
178 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
179 |
+
"""
|
180 |
+
|
181 |
+
lr_lambda = partial(
|
182 |
+
_get_fp_cosine_schedule_with_warmup_lr_lambda,
|
183 |
+
num_warmup_steps=num_warmup_steps,
|
184 |
+
num_training_steps=num_training_steps,
|
185 |
+
num_firstepoch_steps = num_firstepoch_steps,
|
186 |
+
)
|
187 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
188 |
+
|
189 |
+
def custom_half_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
|
190 |
+
"""
|
191 |
+
Args:
|
192 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
193 |
+
The optimizer for which to schedule the learning rate.
|
194 |
+
num_warmup_steps (`int`):
|
195 |
+
The number of steps for the warmup phase.
|
196 |
+
num_training_steps (`int`):
|
197 |
+
The total number of training steps.
|
198 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
199 |
+
The index of the last epoch when resuming training.
|
200 |
+
|
201 |
+
Return:
|
202 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
203 |
+
"""
|
204 |
+
|
205 |
+
lr_lambda = partial(
|
206 |
+
_get_fp_half_schedule_with_warmup_lr_lambda,
|
207 |
+
num_warmup_steps=num_warmup_steps,
|
208 |
+
num_training_steps=num_training_steps,
|
209 |
+
num_firstepoch_steps = num_firstepoch_steps,
|
210 |
+
)
|
211 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
212 |
+
|
213 |
+
def custom_raise_fall_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
|
214 |
+
"""
|
215 |
+
Args:
|
216 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
217 |
+
The optimizer for which to schedule the learning rate.
|
218 |
+
num_warmup_steps (`int`):
|
219 |
+
The number of steps for the warmup phase.
|
220 |
+
num_training_steps (`int`):
|
221 |
+
The total number of training steps.
|
222 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
223 |
+
The index of the last epoch when resuming training.
|
224 |
+
|
225 |
+
Return:
|
226 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
227 |
+
"""
|
228 |
+
|
229 |
+
lr_lambda = partial(
|
230 |
+
_get_fp_cosine_raise_and_fall_lr_lambda,
|
231 |
+
num_warmup_steps=num_warmup_steps,
|
232 |
+
num_training_steps=num_training_steps,
|
233 |
+
num_firstepoch_steps = num_firstepoch_steps,
|
234 |
+
)
|
235 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
236 |
+
|
237 |
+
|
238 |
+
def neftune_forward(self, input: torch.Tensor):
|
239 |
+
"""
|
240 |
+
Implements the NEFTune forward pass for the model. Note this works only for
|
241 |
+
torch.nn.Embedding layers. This method is slightly adapted from the original source code
|
242 |
+
that can be found here: https://github.com/neelsjain/NEFTune
|
243 |
+
|
244 |
+
Args:
|
245 |
+
input (`torch.Tensor`):
|
246 |
+
The input tensor to the model.
|
247 |
+
noise_alpha (`float`):
|
248 |
+
The noise alpha value to use for the NEFTune forward pass.
|
249 |
+
"""
|
250 |
+
embeddings = torch.nn.functional.embedding(
|
251 |
+
input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
|
252 |
+
)
|
253 |
+
|
254 |
+
if self.training:
|
255 |
+
# Add noise to the embeddings
|
256 |
+
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
257 |
+
mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)
|
258 |
+
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)
|
259 |
+
|
260 |
+
return embeddings
|
261 |
+
|
262 |
+
|
263 |
+
class FPNEFtuneTrainer(transformers.Trainer):
|
264 |
+
def __init__(self,neftune_noise_alpha:float = 0.0, model = None, *args, **kwargs):
|
265 |
+
self.neftune_noise_alpha = neftune_noise_alpha
|
266 |
+
if self.neftune_noise_alpha > 0.0:
|
267 |
+
model = self._activate_neftune(model)
|
268 |
+
super().__init__(model = model, *args, **kwargs)
|
269 |
+
|
270 |
+
|
271 |
+
def _activate_neftune(self, model):
|
272 |
+
r"""
|
273 |
+
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
|
274 |
+
"""
|
275 |
+
print(f"Activating {RED}NEFtune{RESET} with scale: {self.neftune_noise_alpha}")
|
276 |
+
if isinstance(model, transformers.PreTrainedModel):
|
277 |
+
embeddings = model.get_input_embeddings()
|
278 |
+
elif isinstance(model, PeftModel):
|
279 |
+
embeddings = model.base_model.get_input_embeddings()
|
280 |
+
|
281 |
+
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
|
282 |
+
old_forward = embeddings.forward
|
283 |
+
|
284 |
+
# This hack seems to be needed to properly use a custom forward pass
|
285 |
+
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
286 |
+
bound_method = neftune_forward.__get__(embeddings, embeddings.__class__)
|
287 |
+
setattr(embeddings, "forward", bound_method)
|
288 |
+
|
289 |
+
# embeddings.forward = neftune_forward
|
290 |
+
embeddings._trl_old_forward = old_forward
|
291 |
+
|
292 |
+
return model
|
293 |
+
|
294 |
+
def train(self, *args, **kwargs):
|
295 |
+
output = super().train(*args, **kwargs)
|
296 |
+
|
297 |
+
# After training we make sure to retrieve back the original forward pass method
|
298 |
+
# for the embedding layer
|
299 |
+
if self.neftune_noise_alpha is not None:
|
300 |
+
|
301 |
+
if isinstance(self.model, transformers.PreTrainedModel):
|
302 |
+
embeddings = self.model.get_input_embeddings()
|
303 |
+
elif isinstance(self.model, PeftModel):
|
304 |
+
embeddings = self.model.base_model.get_input_embeddings()
|
305 |
+
|
306 |
+
if hasattr(embeddings, "_trl_old_forward"):
|
307 |
+
embeddings.forward = embeddings._trl_old_forward
|
308 |
+
del embeddings._trl_old_forward
|
309 |
+
del embeddings.neftune_noise_alpha
|
310 |
+
|
311 |
+
return output
|
312 |
+
|
313 |
+
|
314 |
+
class FPSchedulerTrainer(transformers.Trainer):
|
315 |
+
def __init__(self,neftune_noise_alpha:float = 0.0, model = None, *args, **kwargs):
|
316 |
+
self.neftune_noise_alpha = neftune_noise_alpha
|
317 |
+
if self.neftune_noise_alpha > 0.0:
|
318 |
+
model = self._activate_neftune(model)
|
319 |
+
super().__init__(model = model, *args, **kwargs)
|
320 |
+
|
321 |
+
|
322 |
+
def _activate_neftune(self, model):
|
323 |
+
r"""
|
324 |
+
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
|
325 |
+
"""
|
326 |
+
print(f"Activating {RED}NEFtune{RESET} with scale: {self.neftune_noise_alpha}")
|
327 |
+
if isinstance(model, transformers.PreTrainedModel):
|
328 |
+
embeddings = model.get_input_embeddings()
|
329 |
+
elif isinstance(model, PeftModel):
|
330 |
+
embeddings = model.base_model.get_input_embeddings()
|
331 |
+
|
332 |
+
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
|
333 |
+
old_forward = embeddings.forward
|
334 |
+
|
335 |
+
# This hack seems to be needed to properly use a custom forward pass
|
336 |
+
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
337 |
+
bound_method = neftune_forward.__get__(embeddings, embeddings.__class__)
|
338 |
+
setattr(embeddings, "forward", bound_method)
|
339 |
+
|
340 |
+
# embeddings.forward = neftune_forward
|
341 |
+
embeddings._trl_old_forward = old_forward
|
342 |
+
|
343 |
+
return model
|
344 |
+
|
345 |
+
def train(self, *args, **kwargs):
|
346 |
+
output = super().train(*args, **kwargs)
|
347 |
+
|
348 |
+
# After training we make sure to retrieve back the original forward pass method
|
349 |
+
# for the embedding layer
|
350 |
+
if self.neftune_noise_alpha is not None:
|
351 |
+
|
352 |
+
if isinstance(self.model, transformers.PreTrainedModel):
|
353 |
+
embeddings = self.model.get_input_embeddings()
|
354 |
+
elif isinstance(self.model, PeftModel):
|
355 |
+
embeddings = self.model.base_model.get_input_embeddings()
|
356 |
+
|
357 |
+
if hasattr(embeddings, "_trl_old_forward"):
|
358 |
+
embeddings.forward = embeddings._trl_old_forward
|
359 |
+
del embeddings._trl_old_forward
|
360 |
+
del embeddings.neftune_noise_alpha
|
361 |
+
|
362 |
+
return output
|
363 |
+
|
364 |
+
|
365 |
+
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
|
366 |
+
#Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument.
|
367 |
+
|
368 |
+
num_train_epochs = self.args.num_train_epochs
|
369 |
+
num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
|
370 |
+
num_firstepoch_steps = math.ceil(num_training_steps/num_train_epochs)
|
371 |
+
num_warmup_acc = num_warmup_steps*self.args.gradient_accumulation_steps
|
372 |
+
num_firstepoch_steps_acc = num_firstepoch_steps*self.args.gradient_accumulation_steps
|
373 |
+
num_training_steps_acc = num_training_steps*self.args.gradient_accumulation_steps
|
374 |
+
|
375 |
+
custom_scheduler_params.update({'dynamic_scheduler_stop': False})
|
376 |
+
|
377 |
+
print (f"Warm-up steps aligned to Gradient accumulation ({self.args.gradient_accumulation_steps}) = {num_warmup_acc} actual warmup steps")
|
378 |
+
if self.args.lr_scheduler_type == 'cosine':
|
379 |
+
|
380 |
+
num_warmup_acc_min = min(num_warmup_acc, num_firstepoch_steps_acc)
|
381 |
+
|
382 |
+
if num_warmup_acc>num_firstepoch_steps_acc:
|
383 |
+
print(f"\033[1;31;1mWARNING: The number of warmup steps is set too high! It will be clamped to 1 epoch, essentially going from warmup to annealing.\033[0;37;0m")
|
384 |
+
print (f"FP Scheduler Warmup: 0-[{num_warmup_acc_min}], Hold [{num_warmup_acc_min}]-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}")
|
385 |
+
else:
|
386 |
+
print (f"FP Scheduler Warmup: 0-{num_warmup_acc_min}, Hold {num_warmup_acc_min}-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}")
|
387 |
+
|
388 |
+
self.lr_scheduler = custom_cosine_scheduler_with_warmup(
|
389 |
+
optimizer=self.optimizer if optimizer is None else optimizer,
|
390 |
+
num_warmup_steps=num_warmup_steps,
|
391 |
+
num_training_steps=num_training_steps,
|
392 |
+
num_firstepoch_steps = num_firstepoch_steps,
|
393 |
+
)
|
394 |
+
self._created_lr_scheduler = True
|
395 |
+
return self.lr_scheduler
|
396 |
+
elif self.args.lr_scheduler_type == 'constant':
|
397 |
+
|
398 |
+
half_step_acc = num_training_steps_acc//2
|
399 |
+
num_warmup_acc_min = min(num_warmup_acc, half_step_acc)
|
400 |
+
|
401 |
+
if num_warmup_acc>half_step_acc:
|
402 |
+
print(f"\033[1;31;1mWARNING: The number of warmup steps is set too high! It will be clamped to half of all epochs, essentially going from warmup to annealing in the middle.\033[0;37;0m")
|
403 |
+
print (f"FP Scheduler Warmup: 0-[{num_warmup_acc_min}], Hold [{num_warmup_acc_min}]-{half_step_acc}, Annealing {half_step_acc}-{num_training_steps_acc}")
|
404 |
+
else:
|
405 |
+
print (f"FP Scheduler Warmup: 0-{num_warmup_acc_min}, Hold {num_warmup_acc_min}-{half_step_acc}, Annealing {half_step_acc}-{num_training_steps_acc}")
|
406 |
+
|
407 |
+
self.lr_scheduler = custom_half_scheduler_with_warmup(
|
408 |
+
optimizer=self.optimizer if optimizer is None else optimizer,
|
409 |
+
num_warmup_steps=num_warmup_steps,
|
410 |
+
num_training_steps=num_training_steps,
|
411 |
+
num_firstepoch_steps = num_firstepoch_steps,
|
412 |
+
)
|
413 |
+
self._created_lr_scheduler = True
|
414 |
+
return self.lr_scheduler
|
415 |
+
elif self.args.lr_scheduler_type == 'constant_with_warmup':
|
416 |
+
|
417 |
+
half_step_acc = num_training_steps_acc//2
|
418 |
+
|
419 |
+
if num_warmup_steps>0:
|
420 |
+
print(f"Warmup doesn't apply to this scheduler [Raise-Fall]")
|
421 |
+
|
422 |
+
print (f"Scheduler Raise: 0-{half_step_acc}, Fall {half_step_acc}-{num_training_steps_acc}")
|
423 |
+
|
424 |
+
self.lr_scheduler = custom_raise_fall_scheduler_with_warmup(
|
425 |
+
optimizer=self.optimizer if optimizer is None else optimizer,
|
426 |
+
num_warmup_steps=num_warmup_steps,
|
427 |
+
num_training_steps=num_training_steps,
|
428 |
+
num_firstepoch_steps = num_firstepoch_steps,
|
429 |
+
)
|
430 |
+
self._created_lr_scheduler = True
|
431 |
+
return self.lr_scheduler
|
432 |
+
else:
|
433 |
+
return super().create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
|
$extensions/Training_PRO/matplotgraph.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
def create_graph(lora_path, lora_name):
|
5 |
+
try:
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from matplotlib.ticker import ScalarFormatter
|
8 |
+
|
9 |
+
peft_model_path = f'{lora_path}/training_graph.json'
|
10 |
+
image_model_path = f'{lora_path}/training_graph.png'
|
11 |
+
# Check if the JSON file exists
|
12 |
+
if os.path.exists(peft_model_path):
|
13 |
+
# Load data from JSON file
|
14 |
+
with open(peft_model_path, 'r') as file:
|
15 |
+
data = json.load(file)
|
16 |
+
# Extract x, y1, and y2 values
|
17 |
+
x = [item['epoch'] for item in data]
|
18 |
+
y1 = [item['learning_rate'] for item in data]
|
19 |
+
y2 = [item['loss'] for item in data]
|
20 |
+
|
21 |
+
# Create the line chart
|
22 |
+
fig, ax1 = plt.subplots(figsize=(10, 6))
|
23 |
+
|
24 |
+
|
25 |
+
# Plot y1 (learning rate) on the first y-axis
|
26 |
+
ax1.plot(x, y1, 'b-', label='Learning Rate')
|
27 |
+
ax1.set_xlabel('Epoch')
|
28 |
+
ax1.set_ylabel('Learning Rate', color='b')
|
29 |
+
ax1.tick_params('y', colors='b')
|
30 |
+
|
31 |
+
# Create a second y-axis
|
32 |
+
ax2 = ax1.twinx()
|
33 |
+
|
34 |
+
# Plot y2 (loss) on the second y-axis
|
35 |
+
ax2.plot(x, y2, 'r-', label='Loss')
|
36 |
+
ax2.set_ylabel('Loss', color='r')
|
37 |
+
ax2.tick_params('y', colors='r')
|
38 |
+
|
39 |
+
# Set the y-axis formatter to display numbers in scientific notation
|
40 |
+
ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
|
41 |
+
ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
42 |
+
|
43 |
+
# Add grid
|
44 |
+
ax1.grid(True)
|
45 |
+
|
46 |
+
# Combine the legends for both plots
|
47 |
+
lines, labels = ax1.get_legend_handles_labels()
|
48 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
49 |
+
ax2.legend(lines + lines2, labels + labels2, loc='best')
|
50 |
+
|
51 |
+
# Set the title
|
52 |
+
plt.title(f'{lora_name} LR and Loss vs Epoch')
|
53 |
+
|
54 |
+
# Save the chart as an image
|
55 |
+
plt.savefig(image_model_path)
|
56 |
+
|
57 |
+
print(f"Graph saved in {image_model_path}")
|
58 |
+
else:
|
59 |
+
print(f"File 'training_graph.json' does not exist in the {lora_path}")
|
60 |
+
|
61 |
+
except ImportError:
|
62 |
+
print("matplotlib is not installed. Please install matplotlib to create PNG graphs")
|
$extensions/Training_PRO/script.py
ADDED
@@ -0,0 +1,1376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["WANDB_MODE"] = "offline"
|
4 |
+
# os.environ["WANDB_DISABLED"] = "true"
|
5 |
+
|
6 |
+
import json
|
7 |
+
import math
|
8 |
+
import random
|
9 |
+
import shutil
|
10 |
+
import sys
|
11 |
+
import threading
|
12 |
+
import time
|
13 |
+
import traceback
|
14 |
+
from datetime import datetime
|
15 |
+
from pathlib import Path
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
import pandas as pd
|
19 |
+
import torch
|
20 |
+
import transformers
|
21 |
+
|
22 |
+
from functools import partial
|
23 |
+
|
24 |
+
from .custom_scheduler import FPSchedulerTrainer, FPNEFtuneTrainer
|
25 |
+
|
26 |
+
from .matplotgraph import create_graph
|
27 |
+
from .train_utils import get_available_loras_local, precise_cut, sliding_block_cut, download_file_from_url
|
28 |
+
|
29 |
+
from datasets import Dataset, load_dataset
|
30 |
+
from peft import (
|
31 |
+
LoraConfig,
|
32 |
+
get_peft_model,
|
33 |
+
prepare_model_for_kbit_training,
|
34 |
+
set_peft_model_state_dict
|
35 |
+
)
|
36 |
+
from peft.utils.other import \
|
37 |
+
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules
|
38 |
+
from transformers.models.auto.modeling_auto import (
|
39 |
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
40 |
+
)
|
41 |
+
|
42 |
+
from modules import shared, utils
|
43 |
+
from modules.ui import create_refresh_button
|
44 |
+
|
45 |
+
from modules.evaluate import (
|
46 |
+
calculate_perplexity,
|
47 |
+
generate_markdown_table,
|
48 |
+
save_past_evaluations
|
49 |
+
)
|
50 |
+
from modules.logging_colors import logger
|
51 |
+
from modules.models import reload_model
|
52 |
+
from modules.utils import natural_keys
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
## just temporary to avoid warning
|
57 |
+
|
58 |
+
import inspect
|
59 |
+
|
60 |
+
from typing import Callable, Optional, Tuple, ContextManager
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
if hasattr(torch.utils.checkpoint, 'noop_context_fn'):
|
65 |
+
def my_checkpoint(
|
66 |
+
function,
|
67 |
+
*args,
|
68 |
+
use_reentrant: Optional[bool] = None,
|
69 |
+
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = torch.utils.checkpoint.noop_context_fn,
|
70 |
+
determinism_check: str = torch.utils.checkpoint._DEFAULT_DETERMINISM_MODE,
|
71 |
+
debug: bool = False,
|
72 |
+
**kwargs
|
73 |
+
):
|
74 |
+
|
75 |
+
if use_reentrant is None:
|
76 |
+
#print ("reentran = NONE")
|
77 |
+
use_reentrant = True
|
78 |
+
# Hack to mix *args with **kwargs in a python 2.7-compliant way
|
79 |
+
preserve = kwargs.pop("preserve_rng_state", True)
|
80 |
+
if kwargs and use_reentrant:
|
81 |
+
raise ValueError(
|
82 |
+
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
83 |
+
)
|
84 |
+
|
85 |
+
if use_reentrant:
|
86 |
+
if context_fn is not torch.utils.checkpoint.noop_context_fn or debug is not False:
|
87 |
+
raise ValueError(
|
88 |
+
"Passing `context_fn` or `debug` is only supported when "
|
89 |
+
"use_reentrant=False."
|
90 |
+
)
|
91 |
+
return torch.utils.checkpoint.CheckpointFunction.apply(function, preserve, *args)
|
92 |
+
else:
|
93 |
+
|
94 |
+
print ("reentran = FALSE")
|
95 |
+
gen = torch.utils.checkpoint._checkpoint_without_reentrant_generator(
|
96 |
+
function, preserve, context_fn, determinism_check, debug, *args, **kwargs
|
97 |
+
)
|
98 |
+
# Runs pre-forward logic
|
99 |
+
next(gen)
|
100 |
+
ret = function(*args, **kwargs)
|
101 |
+
# Runs post-forward logic
|
102 |
+
try:
|
103 |
+
next(gen)
|
104 |
+
except StopIteration:
|
105 |
+
return ret
|
106 |
+
|
107 |
+
|
108 |
+
params = {
|
109 |
+
"display_name": "Training PRO",
|
110 |
+
"is_tab": True
|
111 |
+
}
|
112 |
+
|
113 |
+
non_serialized_params = {
|
114 |
+
"debug_slicer": False,
|
115 |
+
"Lora_sortedByTime": False,
|
116 |
+
"stop_at_loss": 0,
|
117 |
+
"save_steps_under_loss": 0.0,
|
118 |
+
"save_checkpoint_now": False,
|
119 |
+
"training_loop": False,
|
120 |
+
"current_stability": 0,
|
121 |
+
"save_epochs": 0,
|
122 |
+
"checkpoint_offset": 0,
|
123 |
+
"epoch_offset":0,
|
124 |
+
}
|
125 |
+
|
126 |
+
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
|
127 |
+
|
128 |
+
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to", "precize_slicing_overlap", "add_eos_token_type", "save_steps_under_loss", "add_bos_token", "training_projection","sliding_window","warmup_ratio","grad_accumulation","neft_noise_alpha"]
|
129 |
+
WANT_INTERRUPT = False
|
130 |
+
|
131 |
+
train_log = {}
|
132 |
+
train_template = {}
|
133 |
+
train_log_graph = []
|
134 |
+
train_choices = ["all","q-k-v-o","q-k-v","k-v-down","q-v"]
|
135 |
+
|
136 |
+
statistics = {
|
137 |
+
'loss': [],
|
138 |
+
'lr': [],
|
139 |
+
}
|
140 |
+
|
141 |
+
RED = "\033[91m"
|
142 |
+
YELLOW = "\033[93m"
|
143 |
+
GREEN = "\033[92m"
|
144 |
+
RESET = "\033[0m"
|
145 |
+
|
146 |
+
def ui():
|
147 |
+
|
148 |
+
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
149 |
+
tmp = gr.State('')
|
150 |
+
with gr.Row():
|
151 |
+
with gr.Column():
|
152 |
+
# YY.MM.DD
|
153 |
+
gr.Markdown("`Ver: 23.10.20` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)")
|
154 |
+
|
155 |
+
with gr.Row():
|
156 |
+
with gr.Column(scale=5):
|
157 |
+
with gr.Row():
|
158 |
+
copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=get_available_loras_local(non_serialized_params['Lora_sortedByTime']), elem_classes=['slim-dropdown'])
|
159 |
+
create_refresh_button(copy_from, lambda: None, lambda: {'choices': get_available_loras_local(non_serialized_params['Lora_sortedByTime'])}, 'refresh-button')
|
160 |
+
with gr.Column():
|
161 |
+
sort_byTime = gr.Checkbox(label='Sort list by Date', value=False, info='Sorts Loras by date created.', elem_classes=['no-background'])
|
162 |
+
|
163 |
+
with gr.Row():
|
164 |
+
with gr.Column(scale=5):
|
165 |
+
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
|
166 |
+
|
167 |
+
with gr.Column():
|
168 |
+
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
|
169 |
+
|
170 |
+
with gr.Row():
|
171 |
+
with gr.Column():
|
172 |
+
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.')
|
173 |
+
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
|
174 |
+
batch_size = gr.Slider(visible= False, label='Batch Size', value=0, minimum=0, maximum=1024, step=4, info='Now Replaced with Gradient accumulation. Keeping it for sake of old saved data')
|
175 |
+
micro_batch_size = gr.Slider(label='True Batch Size', value=4, minimum=1, maximum=128, step=1, info='Specifies how many text blocks per step will be trained. The higher value, the better the concept of training will be, but it requires more GPU memory and it reduces speed.')
|
176 |
+
grad_accumulation = gr.Slider(label='Gradient Accumulation Steps', value=1, minimum=1, maximum=256, step=1, info="Virtually multiplies the Batch Size by averaging the learning over more than one step. VRAM friendly. Evens out loss fluctuations but can also degrade training fidelity.")
|
177 |
+
|
178 |
+
with gr.Column():
|
179 |
+
stop_at_loss = gr.Slider(label='Stop at loss (Can be changed during training)', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached.')
|
180 |
+
gr.Markdown(" ")
|
181 |
+
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
|
182 |
+
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
|
183 |
+
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt', 'FP_low_epoch_annealing', 'FP_half_time_annealing','FP_raise_fall_creative'], info='Learning rate scheduler - defines how the learning rate changes over time. Custom schedulers: FP_low_epoch_annealing, FP_half_time_annealing, FP_raise_fall_creative (see README)', elem_classes=['slim-dropdown'])
|
184 |
+
|
185 |
+
with gr.Accordion(label='Checkpoints', open=True):
|
186 |
+
with gr.Row():
|
187 |
+
with gr.Column():
|
188 |
+
save_steps = gr.Number(label='Save every n steps', value=0, info='A checkpoint will be saved every n steps and at each Epoch boundary. (0 = OFF)')
|
189 |
+
with gr.Column():
|
190 |
+
save_steps_under_loss = gr.Slider(label='Save at 10% Loss change', value=1.8, minimum=0.0, maximum=3.0, step=0.1, info="Saves checkpoints at (or bellow) this loss and then each time loss falls by at least 10% This works independently from 'Save every n steps'")
|
191 |
+
with gr.Row():
|
192 |
+
save_chackpoint_now = gr.Button('Queue Checkpoint Now')
|
193 |
+
|
194 |
+
with gr.Accordion(label='Advanced Options', open=True):
|
195 |
+
with gr.Row():
|
196 |
+
with gr.Column():
|
197 |
+
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='Number of max steps used for a linear warmup. Reduces early over-fitting by the first training blocks. Value has precedent over Warmup Ratio. Aligns to the closest multiple of graddient accumulation')
|
198 |
+
warmup_ratio = gr.Slider(label='Warmup Ratio', minimum=0.0, maximum=0.2, step=0.025, value=0.0, info='Ratio of total training steps that will be used for a linear warmup. It applies only if Warmup Step is 0.')
|
199 |
+
neft_noise_alpha = gr.Slider(label='NEFtune noise scale', minimum=0.0, maximum=15, step=1, value=0.0, info='Add noise to the training to improve generalization. [0 - OFF, Starting value to experiment: 5]')
|
200 |
+
training_projection = gr.Radio(value = train_choices[4], label='LLaMA Target Projections', info='Change the targets (LORA is typically q-v)', choices=train_choices)
|
201 |
+
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
|
202 |
+
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.', elem_classes=['slim-dropdown'])
|
203 |
+
|
204 |
+
with gr.Column():
|
205 |
+
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
|
206 |
+
add_bos_token = gr.Checkbox(label='Add BOS token', value=True, info="Adds BOS token for each dataset item")
|
207 |
+
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item")
|
208 |
+
add_eos_token_type = gr.Dropdown(label='EOS placement (Text file)', choices=['Every Block', 'Hard Cut Blocks Only'], value='Every Block', info='', allow_custom_value = False)
|
209 |
+
|
210 |
+
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
211 |
+
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
|
212 |
+
# for future
|
213 |
+
#with gr.Accordion(label='Dynamic Scheduler', open = False):
|
214 |
+
# ds_min_epochs = gr.Number(label='Minimum Epochs', value='1', info='Minimum epochs that will be always performed before ramp down can be triggered')
|
215 |
+
# ds_max_epochs = gr.Number(label='Maximum Epochs (fallback)', value='50', info='Maximum Epochs before the training will bail out completely (should be a large number)')
|
216 |
+
# ds_loss_trigger = gr.Slider(label='Trigger Loss', minimum=0.0, maximum=2.8, step=0.1, value=1.6, info='Loss at which the ramp down schedule will be triggered')
|
217 |
+
# ds_loss_rolling_window = gr.Number(label='Loss rolling average', value='4', info='Calculate loss by averaging last x numbers to avoid jumps and noise')
|
218 |
+
# ds_epochs_to_ramp = gr.Slider(label='Ramp down ratio', minimum=0.0, maximum=2.0, step=0.1, value=1.00, info='How long the ramp down will last relative to ellapsed steps (before trigger)')
|
219 |
+
# gr.Markdown('These are settings for FP_dynamic_loss_trigger scheduler. The scheduler will do warm up, then hold constant untill a loss falls under Trigger Loss, then it will commence linear ramp down schedule and stop. The length of ramp down is set by Ramp down ratio where (ramp down steps) = ratio * (elapsed steps). (The time to completition shown will be very high untill ramp down is triggered.)')
|
220 |
+
|
221 |
+
|
222 |
+
with gr.Column():
|
223 |
+
with gr.Tab(label='Formatted Dataset'):
|
224 |
+
with gr.Row():
|
225 |
+
with gr.Column():
|
226 |
+
with gr.Row():
|
227 |
+
dataset = gr.Dropdown(choices=get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'])
|
228 |
+
create_refresh_button(dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
|
229 |
+
with gr.Row():
|
230 |
+
eval_dataset = gr.Dropdown(choices=get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'])
|
231 |
+
create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
|
232 |
+
|
233 |
+
with gr.Column():
|
234 |
+
with gr.Row():
|
235 |
+
format = gr.Dropdown(choices=get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown'])
|
236 |
+
create_refresh_button(format, lambda: None, lambda: {'choices': get_datasets('training/formats', 'json')}, 'refresh-button')
|
237 |
+
with gr.Row():
|
238 |
+
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
239 |
+
|
240 |
+
with gr.Tab(label="Text file"):
|
241 |
+
with gr.Row():
|
242 |
+
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The text file to use for training.', elem_classes=['slim-dropdown'])
|
243 |
+
create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
244 |
+
|
245 |
+
with gr.Row():
|
246 |
+
with gr.Column():
|
247 |
+
precize_slicing_overlap = gr.Checkbox(label='Add Overlapping blocks', value = True)
|
248 |
+
sliding_window = gr.Checkbox(label='DEMENTOR Long-form Learning by FP (Highly Experimental, use low epochs)', value = False, info='Deep Memorization Enforcement Through Overlapping and Repetition. (I named it, so shush). Special process for learning long-form text using low amount of epochs.')
|
249 |
+
#debug_slicer = gr.Checkbox(label='Dump sentencelist.json to logs', value = non_serialized_params['debug_slicer'], info='Debug Slicer')
|
250 |
+
|
251 |
+
with gr.Column():
|
252 |
+
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a cut between logical blocks of text (ex. Ideas or Chapters). Helps prevent unwanted overlap between unrelated ideas.')
|
253 |
+
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Text blocks that have less or equal characters than this number.')
|
254 |
+
with gr.Tab(label="URL"):
|
255 |
+
with gr.Row():
|
256 |
+
with gr.Column():
|
257 |
+
download_file_url = gr.Textbox(label='Download JSON or txt file to datasets (or formats) folder', value='',info='The URL of a file to download. If on github, make sure you get url of the raw file (https://raw.githubusercontent.com/...). If huggin face, make sure the url has /resolve/ in it not /blob/')
|
258 |
+
with gr.Row():
|
259 |
+
download_check_overwrite = gr.Checkbox(label='Overwrite', value=False, info='Overwrite if file exist')
|
260 |
+
download_folder = gr.Radio(label="Destination", value='training/datasets', choices=['training/datasets', 'training/formats'], interactive=True)
|
261 |
+
download_button = gr.Button('Download')
|
262 |
+
download_status = gr.Textbox(label='Download Status', value='', interactive=False)
|
263 |
+
with gr.Row():
|
264 |
+
with gr.Column():
|
265 |
+
with gr.Row():
|
266 |
+
cutoff_len = gr.Slider(label='Chunk Length (Cutoff Length)', minimum=32, maximum=2048, value=256, step=32, info='The maximum length of a chunk (in tokens). Applies to both JSON dataset and text files. Higher values require much more VRAM.')
|
267 |
+
with gr.Row():
|
268 |
+
with gr.Column():
|
269 |
+
check_dataset_btn = gr.Button('Verify Dataset/Text File and suggest data entries')
|
270 |
+
check_dataset_txt = gr.Textbox(label='Dataset info', value='')
|
271 |
+
|
272 |
+
with gr.Row():
|
273 |
+
start_button = gr.Button("Start LoRA Training", variant='primary')
|
274 |
+
stop_button = gr.Button("Interrupt")
|
275 |
+
|
276 |
+
with gr.Accordion(label="Graph", open=True):
|
277 |
+
with gr.Row():
|
278 |
+
# show_actions_button = False - we use old gradio
|
279 |
+
plot_graph = gr.LinePlot(x="epoch", y="value", title="Loss Metrics", overlay_point=True, tooltip=["epoch", "value"], x_lim=[0, 1], y_lim=[0, 3.5], width=500, height=250)
|
280 |
+
|
281 |
+
output = gr.Markdown(value="Ready")
|
282 |
+
|
283 |
+
with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'):
|
284 |
+
with gr.Row():
|
285 |
+
with gr.Column():
|
286 |
+
models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True)
|
287 |
+
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
|
288 |
+
with gr.Row():
|
289 |
+
with gr.Column():
|
290 |
+
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
291 |
+
|
292 |
+
with gr.Column():
|
293 |
+
max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
|
294 |
+
|
295 |
+
with gr.Row():
|
296 |
+
start_current_evaluation = gr.Button("Evaluate loaded model")
|
297 |
+
start_evaluation = gr.Button("Evaluate selected models")
|
298 |
+
stop_evaluation = gr.Button("Interrupt")
|
299 |
+
|
300 |
+
with gr.Column():
|
301 |
+
evaluation_log = gr.Markdown(value='')
|
302 |
+
|
303 |
+
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
|
304 |
+
with gr.Row():
|
305 |
+
save_comments = gr.Button('Save comments', elem_classes="small-button")
|
306 |
+
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
|
307 |
+
|
308 |
+
# Training events
|
309 |
+
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to, precize_slicing_overlap, add_eos_token_type, save_steps_under_loss, add_bos_token, training_projection,sliding_window,warmup_ratio,grad_accumulation, neft_noise_alpha]
|
310 |
+
|
311 |
+
def fix_old_version(batch_size_val,micro_batch_size_val, grad_accumulation_val):
|
312 |
+
if batch_size_val>0:
|
313 |
+
gradient_acc = batch_size_val // micro_batch_size_val
|
314 |
+
print(f"Using Old version of Batch Size ({batch_size_val}) to set Gradient Accumulation: {gradient_acc}")
|
315 |
+
return gradient_acc
|
316 |
+
|
317 |
+
return grad_accumulation_val
|
318 |
+
|
319 |
+
|
320 |
+
copy_from.change(partial(do_copy_params, all_params= all_params), copy_from, all_params).then(fix_old_version,[batch_size,micro_batch_size, grad_accumulation],grad_accumulation)
|
321 |
+
start_button.click(do_train, all_params, [output,plot_graph])
|
322 |
+
stop_button.click(do_interrupt, None, None, queue=False)
|
323 |
+
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
|
324 |
+
|
325 |
+
def trigger_stop_at_loss(stop_at_loss_value):
|
326 |
+
non_serialized_params.update({"stop_at_loss": stop_at_loss_value})
|
327 |
+
if non_serialized_params['training_loop']:
|
328 |
+
print(f"Queue: [Stop at loss Change] to {stop_at_loss_value}")
|
329 |
+
|
330 |
+
|
331 |
+
stop_at_loss.change(trigger_stop_at_loss, stop_at_loss, None)
|
332 |
+
|
333 |
+
def trigger_save_checkpoint():
|
334 |
+
non_serialized_params.update({"save_checkpoint_now": True})
|
335 |
+
if non_serialized_params['training_loop']:
|
336 |
+
print("Queue: [Save checkpoint] Checkpoint will be saved after the current step is finished.")
|
337 |
+
else:
|
338 |
+
print("Use during the training to save the checkpoint at any time.")
|
339 |
+
|
340 |
+
|
341 |
+
def update_button():
|
342 |
+
return gr.Button.update('[Checkpoint in Queue]', variant='stop', interactive=True)
|
343 |
+
|
344 |
+
def update_button2():
|
345 |
+
time.sleep(1.0)
|
346 |
+
return gr.Button.update('Queue Checkpoint Now', variant='secondary',interactive = True)
|
347 |
+
|
348 |
+
save_chackpoint_now.click(trigger_save_checkpoint, None, None).then(update_button, None,save_chackpoint_now).then(update_button2, None,save_chackpoint_now)
|
349 |
+
|
350 |
+
dataset_calc_params = [save_steps,micro_batch_size, epochs, cutoff_len, dataset, format, raw_text_file, warmup_steps, hard_cut_string, min_chars, precize_slicing_overlap,sliding_window,warmup_ratio,grad_accumulation]
|
351 |
+
|
352 |
+
def check_dataset(save_steps:int, micro_batch_size: int, epochs: int, cutoff_len: int, dataset:str, format:str, raw_text_file:str, warmup_steps:int, hard_cut_string:str, min_chars:int, precize_slicing_overlap:bool,sliding_window:bool,warmup_ratio:float,grad_accumulation:int):
|
353 |
+
result = "Specify JSON dastaset or Text file"
|
354 |
+
total_blocks = 0
|
355 |
+
if shared.tokenizer is None:
|
356 |
+
yield "Tokenizer is not available. Please Load some Model first."
|
357 |
+
return
|
358 |
+
|
359 |
+
|
360 |
+
if raw_text_file not in ['None', '']:
|
361 |
+
logger.info("Loading Text file...")
|
362 |
+
fullpath = clean_path('training/datasets', f'{raw_text_file}')
|
363 |
+
fullpath = Path(fullpath)
|
364 |
+
if fullpath.is_dir():
|
365 |
+
logger.info('Training path directory {}'.format(raw_text_file))
|
366 |
+
raw_text = ""
|
367 |
+
file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name))
|
368 |
+
for file_path in file_paths:
|
369 |
+
if file_path.is_file():
|
370 |
+
with file_path.open('r', encoding='utf-8') as file:
|
371 |
+
raw_text += file.read().replace('\r', '')
|
372 |
+
|
373 |
+
logger.info(f"Loaded training file: {file_path.name}")
|
374 |
+
else:
|
375 |
+
try:
|
376 |
+
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
377 |
+
raw_text = file.read().replace('\r', '')
|
378 |
+
except:
|
379 |
+
yield f"{raw_text_file}.txt doesn't seem to exsist anymore... check your training/datasets folder"
|
380 |
+
return
|
381 |
+
|
382 |
+
|
383 |
+
if min_chars<0:
|
384 |
+
min_chars = 0
|
385 |
+
|
386 |
+
# == New more precise slicing on sentence boundary ==
|
387 |
+
if sliding_window:
|
388 |
+
text_chunks = sliding_block_cut(raw_text, min_chars, False, cutoff_len, hard_cut_string,non_serialized_params['debug_slicer'])
|
389 |
+
else:
|
390 |
+
text_chunks = precise_cut(raw_text, precize_slicing_overlap, min_chars, False, cutoff_len, hard_cut_string,non_serialized_params['debug_slicer'])
|
391 |
+
|
392 |
+
total_blocks = len(text_chunks)
|
393 |
+
result = f"Text: ({raw_text_file}.txt) has {total_blocks} blocks (Block Size {cutoff_len} tokens)"
|
394 |
+
del text_chunks
|
395 |
+
|
396 |
+
else:
|
397 |
+
if dataset in ['None', '']:
|
398 |
+
yield "Select dataset or text file."
|
399 |
+
return
|
400 |
+
|
401 |
+
if format in ['None', '']:
|
402 |
+
yield "Select format choice for dataset."
|
403 |
+
return
|
404 |
+
|
405 |
+
with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
|
406 |
+
format_data: dict[str, str] = json.load(formatFile)
|
407 |
+
|
408 |
+
def generate_prompt(data_point: dict[str, str]):
|
409 |
+
for options, data in format_data.items():
|
410 |
+
if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
|
411 |
+
for key, val in data_point.items():
|
412 |
+
if type(val) is str:
|
413 |
+
data = data.replace(f'%{key}%', val)
|
414 |
+
return data
|
415 |
+
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
|
416 |
+
|
417 |
+
def tokenize_dummy(prompt):
|
418 |
+
|
419 |
+
input_ids = shared.tokenizer.encode(prompt, truncation=True, max_length=cutoff_len)
|
420 |
+
labels = [1] * len(input_ids)
|
421 |
+
input_ids = torch.tensor(input_ids)
|
422 |
+
return {
|
423 |
+
"input_ids": input_ids,
|
424 |
+
"labels": labels,
|
425 |
+
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
|
426 |
+
}
|
427 |
+
|
428 |
+
def generate_and_tokenize_prompt(data_point):
|
429 |
+
prompt = generate_prompt(data_point)
|
430 |
+
return tokenize_dummy(prompt)
|
431 |
+
|
432 |
+
logger.info("Loading JSON datasets...")
|
433 |
+
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
434 |
+
|
435 |
+
data_keys = []
|
436 |
+
|
437 |
+
if data:
|
438 |
+
if 'train' in data: # Check if the 'train' split exists in the dataset
|
439 |
+
data_keys = list(data['train'][0].keys())
|
440 |
+
print("Data Keys:", data_keys)
|
441 |
+
else:
|
442 |
+
print("The dataset is empty.")
|
443 |
+
|
444 |
+
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
445 |
+
total_blocks = train_data.num_rows
|
446 |
+
|
447 |
+
result = f"Dataset: ({dataset}.json) has {total_blocks} blocks @ length = {cutoff_len} tokens\n(Keys: {data_keys} - Format: {format}.json): "
|
448 |
+
|
449 |
+
#for options, data in format_data.items():
|
450 |
+
# format_keys = options.split(',')
|
451 |
+
# result += f"{format_keys}, "
|
452 |
+
#result = result.rstrip()
|
453 |
+
#result = result.rstrip(',')
|
454 |
+
|
455 |
+
if total_blocks>0:
|
456 |
+
number_ofSteps = int(math.ceil(total_blocks / micro_batch_size) * epochs)
|
457 |
+
num_stepsPer_epoch = int(math.ceil(number_ofSteps/epochs))
|
458 |
+
min_warm = math.ceil(100 / grad_accumulation)
|
459 |
+
|
460 |
+
warmup_steps_suggest = min(int(min_warm*grad_accumulation), int(math.ceil(number_ofSteps * 0.1)))
|
461 |
+
warmup_steps_suggest = min(warmup_steps_suggest,num_stepsPer_epoch)
|
462 |
+
|
463 |
+
save_each_n_min = int(math.ceil(number_ofSteps/10))
|
464 |
+
save_each_n_max = int(math.ceil(number_ofSteps/5))
|
465 |
+
gradient_accumulation_max = int(total_blocks)//micro_batch_size
|
466 |
+
|
467 |
+
|
468 |
+
result += f"\n[Batch Size: {micro_batch_size}, Epochs: {epochs}, Gradient Accumulation: {grad_accumulation}]\n"
|
469 |
+
result += f"Total number of steps: {number_ofSteps}\n"
|
470 |
+
result += f"Steps per each Epoch: {num_stepsPer_epoch}\n"
|
471 |
+
result += f"Suggestions:\n"
|
472 |
+
result += f"Checkpoints: Save every {save_each_n_min} - {save_each_n_max} steps (Current: {int(save_steps)})\n"
|
473 |
+
result += f"Warmup steps: {warmup_steps_suggest} (Current: {int(warmup_steps)})"
|
474 |
+
if gradient_accumulation_max < grad_accumulation:
|
475 |
+
result += f"\n\nWARNING: Gradient Accumulation {grad_accumulation} is too high: It should be below {gradient_accumulation_max}"
|
476 |
+
|
477 |
+
|
478 |
+
yield result
|
479 |
+
return
|
480 |
+
|
481 |
+
check_dataset_btn.click(check_dataset, dataset_calc_params ,check_dataset_txt)
|
482 |
+
|
483 |
+
# Evaluation events. For some reason, the interrupt event
|
484 |
+
# doesn't work with the .then() syntax, so I write them one
|
485 |
+
# by one in this ugly but functional way.
|
486 |
+
ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
487 |
+
start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
488 |
+
|
489 |
+
start_current_evaluation.click(lambda: ['current model'], None, tmp)
|
490 |
+
ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
491 |
+
start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
492 |
+
|
493 |
+
stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False)
|
494 |
+
refresh_table.click(generate_markdown_table, None, evaluation_table, show_progress=True)
|
495 |
+
save_comments.click(
|
496 |
+
save_past_evaluations, evaluation_table, None).then(
|
497 |
+
lambda: "Comments saved.", None, evaluation_log, show_progress=False)
|
498 |
+
|
499 |
+
def reload_lora():
|
500 |
+
return gr.Dropdown.update(choices=get_available_loras_local(non_serialized_params['Lora_sortedByTime']))
|
501 |
+
|
502 |
+
# nonserialized items
|
503 |
+
|
504 |
+
sort_byTime.change(lambda x: non_serialized_params.update({"Lora_sortedByTime": x}), sort_byTime, None).then(reload_lora,None,copy_from)
|
505 |
+
#debug_slicer.change(lambda x: non_serialized_params.update({"debug_slicer": x}), debug_slicer, None)
|
506 |
+
|
507 |
+
def update_dataset():
|
508 |
+
return gr.update(choices=get_datasets('training/datasets', 'json')), gr.update(choices=get_datasets('training/datasets', 'txt'))
|
509 |
+
|
510 |
+
download_button.click(download_file_from_url, [download_file_url,download_check_overwrite,download_folder] , download_status).then(update_dataset,None,[dataset , raw_text_file])
|
511 |
+
|
512 |
+
def get_datasets(path: str, ext: str):
|
513 |
+
# include subdirectories for raw txt files to allow training from a subdirectory of txt files
|
514 |
+
#if ext == "txt":
|
515 |
+
# return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
516 |
+
|
517 |
+
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
518 |
+
|
519 |
+
def do_interrupt():
|
520 |
+
global WANT_INTERRUPT
|
521 |
+
WANT_INTERRUPT = True
|
522 |
+
|
523 |
+
|
524 |
+
def do_copy_params(lora_name: str, all_params):
|
525 |
+
|
526 |
+
if lora_name:
|
527 |
+
f_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json"
|
528 |
+
if Path(f_name).is_file():
|
529 |
+
with open(f_name, 'r', encoding='utf-8') as format_file:
|
530 |
+
params: dict[str, str] = json.load(format_file)
|
531 |
+
else:
|
532 |
+
params = {}
|
533 |
+
else:
|
534 |
+
params = {}
|
535 |
+
|
536 |
+
result = list()
|
537 |
+
for i in range(0, len(PARAMETERS)):
|
538 |
+
key = PARAMETERS[i]
|
539 |
+
if key in params:
|
540 |
+
result.append(params[key])
|
541 |
+
else:
|
542 |
+
result.append(all_params[i])
|
543 |
+
|
544 |
+
return result
|
545 |
+
|
546 |
+
|
547 |
+
def change_rank_limit(use_higher_ranks: bool):
|
548 |
+
mult = 2 if use_higher_ranks else 1
|
549 |
+
return {"maximum": 1024 * mult, "__type__": "update"}, {"maximum": 2048 * mult, "__type__": "update"}
|
550 |
+
|
551 |
+
|
552 |
+
def clean_path(base_path: str, path: str):
|
553 |
+
"""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
554 |
+
path = path.replace('\\', '/').replace('..', '_')
|
555 |
+
if base_path is None:
|
556 |
+
return path
|
557 |
+
|
558 |
+
return f'{Path(base_path).absolute()}/{path}'
|
559 |
+
|
560 |
+
|
561 |
+
def backup_adapter(input_folder):
|
562 |
+
# Get the creation date of the file adapter_model.bin
|
563 |
+
try:
|
564 |
+
adapter_file = Path(f"{input_folder}/adapter_model.bin")
|
565 |
+
if adapter_file.is_file():
|
566 |
+
|
567 |
+
logger.info("Backing up existing LoRA adapter...")
|
568 |
+
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
|
569 |
+
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
|
570 |
+
|
571 |
+
# Create the new subfolder
|
572 |
+
subfolder_path = Path(f"{input_folder}/{creation_date_str}")
|
573 |
+
subfolder_path.mkdir(parents=True, exist_ok=True)
|
574 |
+
|
575 |
+
# Check if the file already exists in the subfolder
|
576 |
+
backup_adapter_file = Path(f"{input_folder}/{creation_date_str}/adapter_model.bin")
|
577 |
+
if backup_adapter_file.is_file():
|
578 |
+
print(" - Backup already exists. Skipping backup process.")
|
579 |
+
return
|
580 |
+
|
581 |
+
# Copy existing files to the new subfolder
|
582 |
+
existing_files = Path(input_folder).iterdir()
|
583 |
+
for file in existing_files:
|
584 |
+
if file.is_file():
|
585 |
+
shutil.copy2(file, subfolder_path)
|
586 |
+
except Exception as e:
|
587 |
+
print("An error occurred in backup_adapter:", str(e))
|
588 |
+
|
589 |
+
|
590 |
+
def calc_trainable_parameters(model):
|
591 |
+
trainable_params = 0
|
592 |
+
all_param = 0
|
593 |
+
for _, param in model.named_parameters():
|
594 |
+
num_params = param.numel()
|
595 |
+
# if using DS Zero 3 and the weights are initialized empty
|
596 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
597 |
+
num_params = param.ds_numel
|
598 |
+
|
599 |
+
all_param += num_params
|
600 |
+
if param.requires_grad:
|
601 |
+
trainable_params += num_params
|
602 |
+
|
603 |
+
return trainable_params, all_param
|
604 |
+
|
605 |
+
|
606 |
+
|
607 |
+
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str, precize_slicing_overlap: bool, add_eos_token_type: str, save_steps_under_loss: float, add_bos_token: bool, training_projection: str,sliding_window:bool,warmup_ratio:float, grad_accumulation: int,neft_noise_alpha:float):
|
608 |
+
|
609 |
+
if shared.args.monkey_patch:
|
610 |
+
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
611 |
+
replace_peft_model_with_int4_lora_model
|
612 |
+
)
|
613 |
+
replace_peft_model_with_int4_lora_model()
|
614 |
+
|
615 |
+
global train_log_graph
|
616 |
+
global WANT_INTERRUPT
|
617 |
+
WANT_INTERRUPT = False
|
618 |
+
|
619 |
+
statistics['loss'] = []
|
620 |
+
|
621 |
+
statistics['loss'].append({'epoch': 0, 'value': 0})
|
622 |
+
zero_pd = pd.DataFrame(statistics['loss'])
|
623 |
+
|
624 |
+
# == Input validation / processing ==
|
625 |
+
yield "Preparing the input...", zero_pd
|
626 |
+
lora_file_path = clean_path(None, lora_name)
|
627 |
+
if lora_file_path.strip() == '':
|
628 |
+
yield "Missing or invalid LoRA file name input.", zero_pd
|
629 |
+
return
|
630 |
+
|
631 |
+
lora_file_path = f"{Path(shared.args.lora_dir)}/{lora_file_path}"
|
632 |
+
actual_lr = float(learning_rate)
|
633 |
+
model_type = type(shared.model).__name__
|
634 |
+
|
635 |
+
if model_type in MODEL_CLASSES:
|
636 |
+
model_id = MODEL_CLASSES[model_type]
|
637 |
+
else:
|
638 |
+
model_id = "llama"
|
639 |
+
if model_type == "PeftModelForCausalLM":
|
640 |
+
if len(shared.lora_names) > 0:
|
641 |
+
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*", zero_pd
|
642 |
+
logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
|
643 |
+
else:
|
644 |
+
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*", zero_pd
|
645 |
+
logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
|
646 |
+
else:
|
647 |
+
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*", zero_pd
|
648 |
+
logger.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
|
649 |
+
|
650 |
+
time.sleep(5)
|
651 |
+
|
652 |
+
if shared.args.loader == 'GPTQ-for-LLaMa' and not shared.args.monkey_patch:
|
653 |
+
yield "LoRA training with GPTQ-for-LLaMa requires loading with `--monkey-patch`", zero_pd
|
654 |
+
return
|
655 |
+
|
656 |
+
if cutoff_len <= 0 or micro_batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
657 |
+
yield "Cannot input zeroes.", zero_pd
|
658 |
+
return
|
659 |
+
|
660 |
+
#in new version we dumped this in favor of grad_accumulation
|
661 |
+
#set it to zero fo new save
|
662 |
+
batch_size = 0
|
663 |
+
|
664 |
+
gradient_accumulation_steps = grad_accumulation #batch_size // micro_batch_size
|
665 |
+
shared.tokenizer.pad_token_id = 0
|
666 |
+
shared.tokenizer.padding_side = "left"
|
667 |
+
|
668 |
+
def encode(text, prepend_bos_token):
|
669 |
+
|
670 |
+
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
|
671 |
+
# Check if the first two tokens are BOS
|
672 |
+
if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
|
673 |
+
result = result[1:]
|
674 |
+
|
675 |
+
if not prepend_bos_token and result[0] == shared.tokenizer.bos_token_id:
|
676 |
+
result = result[1:]
|
677 |
+
return result
|
678 |
+
|
679 |
+
def tokenize(prompt, append_eos_token=False, prepend_bos_token = False):
|
680 |
+
|
681 |
+
if train_only_after == '' or train_only_after not in prompt:
|
682 |
+
input_ids = encode(prompt, prepend_bos_token)
|
683 |
+
|
684 |
+
if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len:
|
685 |
+
input_ids.append(shared.tokenizer.eos_token_id)
|
686 |
+
|
687 |
+
input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
|
688 |
+
|
689 |
+
labels = [1] * len(input_ids)
|
690 |
+
else:
|
691 |
+
ind = prompt.index(train_only_after) + len(train_only_after)
|
692 |
+
before_tokens = encode(prompt[:ind], prepend_bos_token)
|
693 |
+
after_tokens = encode(prompt[ind:], False)
|
694 |
+
|
695 |
+
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
|
696 |
+
after_tokens.append(shared.tokenizer.eos_token_id)
|
697 |
+
|
698 |
+
full_length = len(after_tokens) + len(before_tokens)
|
699 |
+
if full_length > cutoff_len:
|
700 |
+
after_tokens = after_tokens[:cutoff_len - len(before_tokens)]
|
701 |
+
else:
|
702 |
+
before_tokens = [shared.tokenizer.pad_token_id] * (cutoff_len - full_length) + before_tokens
|
703 |
+
|
704 |
+
input_ids = before_tokens + after_tokens
|
705 |
+
labels = [-100] * len(before_tokens) + [1] * len(after_tokens)
|
706 |
+
|
707 |
+
input_ids = torch.tensor(input_ids)
|
708 |
+
return {
|
709 |
+
"input_ids": input_ids,
|
710 |
+
"labels": labels,
|
711 |
+
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
|
712 |
+
}
|
713 |
+
|
714 |
+
train_template.clear()
|
715 |
+
|
716 |
+
|
717 |
+
#reset stuff
|
718 |
+
print(f"*** LoRA: {lora_name} ***")
|
719 |
+
non_serialized_params.update({"stop_at_loss": stop_at_loss})
|
720 |
+
non_serialized_params.update({"save_steps_under_loss": save_steps_under_loss+0.01})
|
721 |
+
non_serialized_params.update({"save_checkpoint_now": False})
|
722 |
+
non_serialized_params.update({"training_loop": False})
|
723 |
+
non_serialized_params.update({"current_stability": 0})
|
724 |
+
non_serialized_params.update({"save_epochs": 0})
|
725 |
+
non_serialized_params.update({"checkpoint_offset": 0})
|
726 |
+
non_serialized_params.update({"epoch_offset": 0})
|
727 |
+
train_log_graph.clear()
|
728 |
+
|
729 |
+
# === once fixed, this can be removed ==============================
|
730 |
+
if hasattr(torch.utils.checkpoint, 'noop_context_fn'):
|
731 |
+
print("Testing Pytorch...")
|
732 |
+
old_checkpoint_signature = inspect.signature(torch.utils.checkpoint.checkpoint)
|
733 |
+
|
734 |
+
# Get the signature of your new checkpoint function
|
735 |
+
my_checkpoint_signature = inspect.signature(my_checkpoint)
|
736 |
+
|
737 |
+
# Check if the signatures match
|
738 |
+
if old_checkpoint_signature.parameters == my_checkpoint_signature.parameters:
|
739 |
+
print(F"{RED}Overriding Torch checkpoint function to avoid repeated 'use_reentrant not explicitly set' warnings{RESET}")
|
740 |
+
#print(" - Note: Transformers need to pass use_reentrant in llama.modeling_llama in def forward, layer_outputs = torch.utils.checkpoint.checkpoint")
|
741 |
+
#print(" Once they do, this function can be removed")
|
742 |
+
torch.utils.checkpoint.checkpoint = my_checkpoint
|
743 |
+
|
744 |
+
|
745 |
+
# END OF FPHAM SENTENCE SPLIT functions ===================
|
746 |
+
|
747 |
+
# == Prep the dataset, format, etc ==
|
748 |
+
if raw_text_file not in ['None', '']:
|
749 |
+
train_template["template_type"] = "raw_text"
|
750 |
+
logger.info("Loading text file...")
|
751 |
+
fullpath = clean_path('training/datasets', f'{raw_text_file}')
|
752 |
+
fullpath = Path(fullpath)
|
753 |
+
if fullpath.is_dir():
|
754 |
+
logger.info('Training path directory {}'.format(raw_text_file))
|
755 |
+
raw_text = ""
|
756 |
+
file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name))
|
757 |
+
for file_path in file_paths:
|
758 |
+
if file_path.is_file():
|
759 |
+
with file_path.open('r', encoding='utf-8') as file:
|
760 |
+
raw_text += file.read().replace('\r', '')
|
761 |
+
|
762 |
+
logger.info(f"Loaded training file: {file_path.name}")
|
763 |
+
else:
|
764 |
+
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
765 |
+
raw_text = file.read().replace('\r', '')
|
766 |
+
|
767 |
+
# FPHAM PRECISE SLICING
|
768 |
+
if min_chars<0:
|
769 |
+
min_chars = 0
|
770 |
+
|
771 |
+
add_EOS_to_all = add_eos_token and add_eos_token_type == 'Every Block'
|
772 |
+
add_EOS_to_HC = add_eos_token and add_eos_token_type != 'Every Block'
|
773 |
+
|
774 |
+
#print (f"add_eos_token {add_eos_token}, add_EOS_to_all {add_EOS_to_all}, add_EOS_to_HC {add_EOS_to_HC}")
|
775 |
+
|
776 |
+
# == New more precise slicing on sentence boundary ==
|
777 |
+
if sliding_window:
|
778 |
+
text_chunks = sliding_block_cut(raw_text, min_chars, add_EOS_to_HC, cutoff_len, hard_cut_string,non_serialized_params['debug_slicer'])
|
779 |
+
else:
|
780 |
+
text_chunks = precise_cut(raw_text, precize_slicing_overlap, min_chars, add_EOS_to_HC, cutoff_len, hard_cut_string,non_serialized_params['debug_slicer'])
|
781 |
+
|
782 |
+
train_data = Dataset.from_list([tokenize(x, add_EOS_to_all, add_bos_token) for x in text_chunks])
|
783 |
+
if add_EOS_to_all:
|
784 |
+
print(f"Added EOS to {len(text_chunks)} blocks")
|
785 |
+
|
786 |
+
print(f"All Data Blocks: {len(text_chunks)}")
|
787 |
+
|
788 |
+
del text_chunks
|
789 |
+
eval_data = None
|
790 |
+
else:
|
791 |
+
if dataset in ['None', '']:
|
792 |
+
yield "Missing dataset choice input, cannot continue.", zero_pd
|
793 |
+
return
|
794 |
+
|
795 |
+
if format in ['None', '']:
|
796 |
+
yield "Missing format choice input, cannot continue.", zero_pd
|
797 |
+
return
|
798 |
+
|
799 |
+
train_template["template_type"] = "dataset"
|
800 |
+
|
801 |
+
with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
|
802 |
+
format_data: dict[str, str] = json.load(formatFile)
|
803 |
+
|
804 |
+
# == store training prompt ==
|
805 |
+
for _, value in format_data.items():
|
806 |
+
prompt_key = f"template_{len(train_template)}"
|
807 |
+
train_template[prompt_key] = value
|
808 |
+
|
809 |
+
def generate_prompt(data_point: dict[str, str]):
|
810 |
+
for options, data in format_data.items():
|
811 |
+
if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
|
812 |
+
for key, val in data_point.items():
|
813 |
+
if type(val) is str:
|
814 |
+
data = data.replace(f'%{key}%', val)
|
815 |
+
return data
|
816 |
+
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
|
817 |
+
|
818 |
+
def generate_and_tokenize_prompt(data_point):
|
819 |
+
prompt = generate_prompt(data_point)
|
820 |
+
return tokenize(prompt, add_eos_token, add_bos_token)
|
821 |
+
|
822 |
+
logger.info("Loading JSON datasets...")
|
823 |
+
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
824 |
+
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
825 |
+
|
826 |
+
print(f"BOS: {add_bos_token} EOS: {add_eos_token}")
|
827 |
+
print(f"Data Blocks: {train_data.num_rows}")
|
828 |
+
|
829 |
+
if eval_dataset == 'None':
|
830 |
+
eval_data = None
|
831 |
+
else:
|
832 |
+
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
|
833 |
+
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
834 |
+
|
835 |
+
# == We MUST reload model if it went through any previous training, even failed one ==
|
836 |
+
if shared.model_dirty_from_training:
|
837 |
+
selected_model = shared.model_name
|
838 |
+
if selected_model:
|
839 |
+
print("\033[1;31;1m(Model has been modified by previous training, it needs to be reloaded...)\033[0;37;0m")
|
840 |
+
try:
|
841 |
+
yield f"Reloading {selected_model}...", zero_pd
|
842 |
+
reload_model()
|
843 |
+
shared.tokenizer.pad_token_id = 0
|
844 |
+
shared.tokenizer.padding_side = "left"
|
845 |
+
|
846 |
+
if shared.model is not None:
|
847 |
+
print("Model reloaded OK, continue with training.")
|
848 |
+
else:
|
849 |
+
return f"Failed to load {selected_model}."
|
850 |
+
except:
|
851 |
+
exc = traceback.format_exc()
|
852 |
+
logger.error('Failed to reload the model.')
|
853 |
+
print(exc)
|
854 |
+
return exc.replace('\n', '\n\n')
|
855 |
+
|
856 |
+
# == Start prepping the model itself ==
|
857 |
+
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
858 |
+
logger.info("Getting model ready...")
|
859 |
+
# here we can disable gradient checkpoint, by default = true, use_gradient_checkpointing=True
|
860 |
+
prepare_model_for_kbit_training(shared.model)
|
861 |
+
|
862 |
+
# base model is now frozen and should not be reused for any other LoRA training than this one
|
863 |
+
shared.model_dirty_from_training = True
|
864 |
+
print(f"Transformers Model Type: {YELLOW}{model_type}{RESET}")
|
865 |
+
|
866 |
+
if training_projection==train_choices[0]:
|
867 |
+
model_to_lora_modules[model_id] = ["gate_proj","down_proj","up_proj","q_proj","k_proj","v_proj","o_proj"]
|
868 |
+
elif training_projection==train_choices[1]:
|
869 |
+
model_to_lora_modules[model_id] = ["q_proj","k_proj", "v_proj", "o_proj"]
|
870 |
+
elif training_projection==train_choices[2]:
|
871 |
+
model_to_lora_modules[model_id] = ["q_proj","k_proj", "v_proj"]
|
872 |
+
elif training_projection==train_choices[3]:
|
873 |
+
model_to_lora_modules[model_id] = ["k_proj", "v_proj", "down_proj"]
|
874 |
+
else:
|
875 |
+
model_to_lora_modules[model_id] = ["q_proj", "v_proj"]
|
876 |
+
|
877 |
+
|
878 |
+
logger.info("Preparing for training...")
|
879 |
+
config = LoraConfig(
|
880 |
+
r=lora_rank,
|
881 |
+
lora_alpha=lora_alpha,
|
882 |
+
target_modules=model_to_lora_modules[model_id],
|
883 |
+
lora_dropout=lora_dropout,
|
884 |
+
bias="none",
|
885 |
+
task_type="CAUSAL_LM"
|
886 |
+
)
|
887 |
+
|
888 |
+
# == Backup the existing adapter ==
|
889 |
+
if not always_override:
|
890 |
+
backup_adapter(lora_file_path)
|
891 |
+
|
892 |
+
# == get model trainable params
|
893 |
+
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
|
894 |
+
|
895 |
+
try:
|
896 |
+
logger.info("Creating LoRA model...")
|
897 |
+
lora_model = get_peft_model(shared.model, config)
|
898 |
+
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
899 |
+
logger.info("Loading existing LoRA data...")
|
900 |
+
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
|
901 |
+
set_peft_model_state_dict(lora_model, state_dict_peft)
|
902 |
+
|
903 |
+
print(f" + Continue Training on {RED}{lora_file_path}/adapter_model.bin{RESET}")
|
904 |
+
|
905 |
+
#load training_log.json if exist
|
906 |
+
|
907 |
+
if Path(f"{lora_file_path}/training_log.json").is_file():
|
908 |
+
with open(f"{lora_file_path}/training_log.json", 'r') as json_file:
|
909 |
+
json_ilog = json.load(json_file)
|
910 |
+
for key, value in json_ilog.items():
|
911 |
+
if key=='current_steps':
|
912 |
+
non_serialized_params.update({"checkpoint_offset": int(value+1)})
|
913 |
+
print(f" + Checkpoints will be saved with offset: {RED}{non_serialized_params['checkpoint_offset']}{RESET}")
|
914 |
+
if key=='epoch':
|
915 |
+
non_serialized_params.update({"epoch_offset": value})
|
916 |
+
print(f" + Epoch offset: {RED}{non_serialized_params['epoch_offset']}{RESET}")
|
917 |
+
|
918 |
+
|
919 |
+
if Path(f"{lora_file_path}/training_graph.json").is_file():
|
920 |
+
try:
|
921 |
+
with open(f"{lora_file_path}/training_graph.json", 'r') as json_file:
|
922 |
+
train_log_graph = json.load(json_file)
|
923 |
+
print(" + Training Graph loaded")
|
924 |
+
except:
|
925 |
+
print(f"Can't read training_graph")
|
926 |
+
|
927 |
+
|
928 |
+
except:
|
929 |
+
yield traceback.format_exc().replace('\n', '\n\n'), zero_pd
|
930 |
+
return
|
931 |
+
|
932 |
+
if shared.args.monkey_patch:
|
933 |
+
from alpaca_lora_4bit.autograd_4bit import Autograd4bitQuantLinear
|
934 |
+
from alpaca_lora_4bit.models import Linear4bitLt
|
935 |
+
for _, m in lora_model.named_modules():
|
936 |
+
if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt):
|
937 |
+
if m.is_v1_model:
|
938 |
+
m.zeros = m.zeros.half()
|
939 |
+
m.scales = m.scales.half()
|
940 |
+
|
941 |
+
class Tracked():
|
942 |
+
def __init__(self):
|
943 |
+
self.current_steps = 0
|
944 |
+
self.max_steps = 0
|
945 |
+
self.did_save = False
|
946 |
+
|
947 |
+
tracked = Tracked()
|
948 |
+
actual_save_steps = math.ceil(save_steps / gradient_accumulation_steps)
|
949 |
+
|
950 |
+
class Callbacks(transformers.TrainerCallback):
|
951 |
+
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
952 |
+
tracked.current_steps = state.global_step * gradient_accumulation_steps
|
953 |
+
tracked.max_steps = state.max_steps * gradient_accumulation_steps
|
954 |
+
ssteps10 = int(max(2,(state.max_steps/epochs)*0.1))
|
955 |
+
|
956 |
+
if WANT_INTERRUPT:
|
957 |
+
control.should_epoch_stop = True
|
958 |
+
control.should_training_stop = True
|
959 |
+
else:
|
960 |
+
current_loss = float(train_log.get('loss', 0.0))
|
961 |
+
current_epoch_int = int(float(train_log.get('epoch', 0.0)))
|
962 |
+
|
963 |
+
force_save = False
|
964 |
+
|
965 |
+
current_steps_offset = tracked.current_steps + non_serialized_params['checkpoint_offset']
|
966 |
+
|
967 |
+
folder_save = f"checkpoint-{current_steps_offset}"
|
968 |
+
|
969 |
+
# save if triggered by user
|
970 |
+
if non_serialized_params['save_checkpoint_now']:
|
971 |
+
force_save = True
|
972 |
+
non_serialized_params.update({"save_checkpoint_now": False})
|
973 |
+
print(f"\033[1;31;1mSave Checkpoint manually trigerred.\033[0;37;0m")
|
974 |
+
folder_save = f"checkpoint-{current_steps_offset}-user"
|
975 |
+
|
976 |
+
patience = 3 # Set the number of consecutive steps for tracking stability
|
977 |
+
|
978 |
+
if gradient_accumulation_steps==1:
|
979 |
+
patience = 4
|
980 |
+
|
981 |
+
min_steps = ssteps10
|
982 |
+
|
983 |
+
# Save each time the loss is below the threshold
|
984 |
+
if current_loss < non_serialized_params['save_steps_under_loss'] and current_loss > 0 and state.global_step > min_steps:
|
985 |
+
current_stability = non_serialized_params['current_stability']
|
986 |
+
current_stability += 1
|
987 |
+
non_serialized_params.update({"current_stability": current_stability})
|
988 |
+
|
989 |
+
if current_stability >= patience:
|
990 |
+
current_stability = 0
|
991 |
+
non_serialized_params.update({"current_stability": current_stability})
|
992 |
+
current_loss_dec = round(current_loss, 2)
|
993 |
+
loss_str = f"{current_loss_dec:.2f}"
|
994 |
+
loss_str = loss_str.replace('.', '_')
|
995 |
+
new_save = (current_loss_dec-0.1) + 0.01
|
996 |
+
non_serialized_params.update({"save_steps_under_loss": new_save})
|
997 |
+
|
998 |
+
folder_save = f"checkpoint-{current_steps_offset}-loss-{loss_str}"
|
999 |
+
force_save = True
|
1000 |
+
|
1001 |
+
|
1002 |
+
else:
|
1003 |
+
# Reset stability if the loss goes above the threshold
|
1004 |
+
non_serialized_params.update({"current_stability": 0})
|
1005 |
+
|
1006 |
+
# Save full epochs
|
1007 |
+
if actual_save_steps>0 and current_epoch_int > non_serialized_params['save_epochs'] and state.global_step > min_steps:
|
1008 |
+
|
1009 |
+
|
1010 |
+
current_epoch_offset = current_epoch_int
|
1011 |
+
|
1012 |
+
if non_serialized_params['epoch_offset'] > 0:
|
1013 |
+
current_epoch_offset = current_epoch_int + round(non_serialized_params['epoch_offset'], 2)
|
1014 |
+
|
1015 |
+
ep_off_str = f"{current_epoch_offset}"
|
1016 |
+
ep_off_str = ep_off_str.replace('.', '_')
|
1017 |
+
folder_save = f"checkpoint-{current_steps_offset}-epoch-{ep_off_str}"
|
1018 |
+
|
1019 |
+
non_serialized_params.update({"save_epochs": current_epoch_int})
|
1020 |
+
force_save = True
|
1021 |
+
|
1022 |
+
# save each actual_save_steps
|
1023 |
+
if state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
|
1024 |
+
folder_save = f"checkpoint-{current_steps_offset}"
|
1025 |
+
force_save = True
|
1026 |
+
|
1027 |
+
if force_save:
|
1028 |
+
lora_model.save_pretrained(f"{lora_file_path}/{folder_save}/")
|
1029 |
+
print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m Saved: [{folder_save}]")
|
1030 |
+
# Save log
|
1031 |
+
with open(f"{lora_file_path}/{folder_save}/training_log.json", 'w', encoding='utf-8') as file:
|
1032 |
+
json.dump(train_log, file, indent=2)
|
1033 |
+
# == Save training prompt ==
|
1034 |
+
with open(f"{lora_file_path}/{folder_save}/training_prompt.json", 'w', encoding='utf-8') as file:
|
1035 |
+
json.dump(train_template, file, indent=2)
|
1036 |
+
|
1037 |
+
|
1038 |
+
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
1039 |
+
tracked.current_steps += 1
|
1040 |
+
if WANT_INTERRUPT:
|
1041 |
+
control.should_epoch_stop = True
|
1042 |
+
control.should_training_stop = True
|
1043 |
+
|
1044 |
+
def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs):
|
1045 |
+
train_log.update(logs)
|
1046 |
+
|
1047 |
+
current_steps_offset = tracked.current_steps + non_serialized_params['checkpoint_offset']
|
1048 |
+
current_epoch_offset = train_log.get('epoch', 0.0) + non_serialized_params['epoch_offset']
|
1049 |
+
|
1050 |
+
train_log.update({"current_steps": tracked.current_steps})
|
1051 |
+
train_log.update({"current_steps_adjusted": current_steps_offset})
|
1052 |
+
train_log.update({"epoch_adjusted": current_epoch_offset})
|
1053 |
+
|
1054 |
+
if WANT_INTERRUPT:
|
1055 |
+
print("\033[1;31;1mInterrupted by user\033[0;37;0m")
|
1056 |
+
|
1057 |
+
if non_serialized_params['checkpoint_offset']>0:
|
1058 |
+
print(f"\033[1;30;40mStep: {tracked.current_steps:6} [+{non_serialized_params['checkpoint_offset']}] \033[0;37;0m", end='')
|
1059 |
+
else:
|
1060 |
+
print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m", end='')
|
1061 |
+
|
1062 |
+
graphentry = {
|
1063 |
+
'current_steps': int(train_log.get('current_steps_adjusted',0)),
|
1064 |
+
'loss': float(train_log.get('loss', 0.0)),
|
1065 |
+
'learning_rate': float(train_log.get('learning_rate', 0.0)),
|
1066 |
+
'epoch': float(train_log.get('epoch_adjusted', 0.0))
|
1067 |
+
}
|
1068 |
+
|
1069 |
+
cur_loss = float(train_log.get('loss', 0.0))
|
1070 |
+
cur_lr = float(train_log.get('learning_rate', 0.0))
|
1071 |
+
cur_epoch = float(train_log.get('epoch', 0.0))
|
1072 |
+
|
1073 |
+
if len(statistics['loss']) == 1:
|
1074 |
+
first_epoch = statistics['loss'][0]['epoch']
|
1075 |
+
first_value = statistics['loss'][0]['value']
|
1076 |
+
if first_value ==0:
|
1077 |
+
statistics['loss'] = []
|
1078 |
+
|
1079 |
+
|
1080 |
+
statistics['loss'].append({'epoch': cur_epoch, 'value': cur_loss})
|
1081 |
+
statistics['lr'].append({'epoch': cur_epoch, 'value': cur_lr})
|
1082 |
+
|
1083 |
+
# Add the entry to the continuous log
|
1084 |
+
train_log_graph.append(graphentry)
|
1085 |
+
|
1086 |
+
# Save the graph log for now, we can later generate full graph
|
1087 |
+
with open(f"{lora_file_path}/training_graph.json", 'w') as file:
|
1088 |
+
json.dump(train_log_graph, file, indent=4)
|
1089 |
+
|
1090 |
+
if 'loss' in logs:
|
1091 |
+
loss = float(logs['loss'])
|
1092 |
+
if loss <= stop_at_loss:
|
1093 |
+
control.should_epoch_stop = True
|
1094 |
+
control.should_training_stop = True
|
1095 |
+
print(f"{RED}Stop Loss {stop_at_loss} reached.{RESET}")
|
1096 |
+
|
1097 |
+
# FPHAM SAMPLE REQ Transformers error handling
|
1098 |
+
gradient_accumulation_max = int(train_data.num_rows)//micro_batch_size
|
1099 |
+
|
1100 |
+
if gradient_accumulation_max < gradient_accumulation_steps:
|
1101 |
+
print(f"{RED}WARNING:{RESET} Current gradient accumulation is {RED}too high{RESET} for the amount of training data.")
|
1102 |
+
print(f"Gradient accumulation: {gradient_accumulation_steps} should be less than: {gradient_accumulation_max}. {RED}This could crash Accelerate/Transformers{RESET}")
|
1103 |
+
#min_batchSize = sample_req*micro_batch_size
|
1104 |
+
print(f"Preferable fix: {RED}Increase the size of dataset{RESET}")
|
1105 |
+
print(f"... or Decrerase Gradient Accumulation {RED}{gradient_accumulation_steps}{RESET} to below {GREEN}{gradient_accumulation_max}{RESET}")
|
1106 |
+
gradient_accumulation_steps = max(1,gradient_accumulation_max-1)
|
1107 |
+
print(f"Last resort fix for this run: Lowering Gradient accumulation to {GREEN}{gradient_accumulation_steps}{RESET} [Good luck]")
|
1108 |
+
|
1109 |
+
else:
|
1110 |
+
print(f"Data Size Check: Gradient accumulation: {YELLOW}{gradient_accumulation_steps}{RESET} <= Blocks/Batch {gradient_accumulation_max} ... {GREEN}[OK]{RESET}")
|
1111 |
+
|
1112 |
+
#END OF FPHAM SAMPLE REQ
|
1113 |
+
|
1114 |
+
# FPHAM Custom Scheduler ==
|
1115 |
+
custom_scheduller = False
|
1116 |
+
lr_scheduler_type_arg = lr_scheduler_type
|
1117 |
+
|
1118 |
+
if lr_scheduler_type == 'FP_low_epoch_annealing':
|
1119 |
+
custom_scheduller = True
|
1120 |
+
lr_scheduler_type_arg = 'cosine'
|
1121 |
+
elif lr_scheduler_type == 'FP_half_time_annealing':
|
1122 |
+
custom_scheduller = True
|
1123 |
+
lr_scheduler_type_arg = 'constant'
|
1124 |
+
elif lr_scheduler_type =='FP_raise_fall_creative':
|
1125 |
+
custom_scheduller = True
|
1126 |
+
lr_scheduler_type_arg = 'constant_with_warmup'
|
1127 |
+
|
1128 |
+
#gradient_checkpointing=True
|
1129 |
+
|
1130 |
+
args=transformers.TrainingArguments(
|
1131 |
+
report_to=report_to if report_to != "None" else None,
|
1132 |
+
per_device_train_batch_size=micro_batch_size,
|
1133 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
1134 |
+
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
|
1135 |
+
warmup_ratio = warmup_ratio,
|
1136 |
+
num_train_epochs=epochs,
|
1137 |
+
learning_rate=actual_lr,
|
1138 |
+
fp16=False if shared.args.cpu else True,
|
1139 |
+
optim=optimizer,
|
1140 |
+
logging_steps=1,
|
1141 |
+
evaluation_strategy="steps" if eval_data is not None else "no",
|
1142 |
+
eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
|
1143 |
+
save_strategy="steps" if eval_data is not None else "no",
|
1144 |
+
output_dir=lora_file_path,
|
1145 |
+
lr_scheduler_type=lr_scheduler_type_arg,
|
1146 |
+
load_best_model_at_end=eval_data is not None,
|
1147 |
+
# TODO: Enable multi-device support
|
1148 |
+
ddp_find_unused_parameters=None,
|
1149 |
+
no_cuda=shared.args.cpu,
|
1150 |
+
)
|
1151 |
+
|
1152 |
+
if custom_scheduller:
|
1153 |
+
trainer = FPSchedulerTrainer(
|
1154 |
+
neftune_noise_alpha=neft_noise_alpha,
|
1155 |
+
model=lora_model,
|
1156 |
+
train_dataset=train_data,
|
1157 |
+
eval_dataset=eval_data,
|
1158 |
+
args=args,
|
1159 |
+
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
1160 |
+
callbacks=list([Callbacks()])
|
1161 |
+
)
|
1162 |
+
elif neft_noise_alpha > 0:
|
1163 |
+
trainer = FPNEFtuneTrainer(
|
1164 |
+
neftune_noise_alpha=neft_noise_alpha,
|
1165 |
+
model=lora_model,
|
1166 |
+
train_dataset=train_data,
|
1167 |
+
eval_dataset=eval_data,
|
1168 |
+
args=args,
|
1169 |
+
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
1170 |
+
callbacks=list([Callbacks()])
|
1171 |
+
)
|
1172 |
+
else:
|
1173 |
+
trainer = transformers.Trainer(
|
1174 |
+
model=lora_model,
|
1175 |
+
train_dataset=train_data,
|
1176 |
+
eval_dataset=eval_data,
|
1177 |
+
args=args,
|
1178 |
+
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
1179 |
+
callbacks=list([Callbacks()])
|
1180 |
+
)
|
1181 |
+
|
1182 |
+
# END OF FPHAM CUSTOM SCHEDULER
|
1183 |
+
|
1184 |
+
lora_model.config.use_cache = False
|
1185 |
+
|
1186 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
1187 |
+
lora_model = torch.compile(lora_model)
|
1188 |
+
|
1189 |
+
# == Save parameters for reuse ==
|
1190 |
+
with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
|
1191 |
+
vars = locals()
|
1192 |
+
json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2)
|
1193 |
+
|
1194 |
+
# == Save training prompt ==
|
1195 |
+
with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file:
|
1196 |
+
json.dump(train_template, file, indent=2)
|
1197 |
+
|
1198 |
+
# == Main run and monitor loop ==
|
1199 |
+
logger.info("Starting training...")
|
1200 |
+
yield "Starting...", zero_pd
|
1201 |
+
|
1202 |
+
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
1203 |
+
|
1204 |
+
projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]])
|
1205 |
+
|
1206 |
+
print(f"Training '{model_id}' model using {YELLOW}({projections_string}){RESET} projections")
|
1207 |
+
|
1208 |
+
if lora_all_param > 0:
|
1209 |
+
print(f"Trainable params: {lora_trainable_param:,d} ({RED}{100 * lora_trainable_param / lora_all_param:.4f} %{RESET}), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
|
1210 |
+
|
1211 |
+
train_log.update({"base_model_name": shared.model_name})
|
1212 |
+
train_log.update({"base_model_class": shared.model.__class__.__name__})
|
1213 |
+
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
|
1214 |
+
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
|
1215 |
+
train_log.update({"projections": projections_string})
|
1216 |
+
if non_serialized_params['checkpoint_offset'] > 0:
|
1217 |
+
train_log.update({"last_run_steps_offset": non_serialized_params['checkpoint_offset']})
|
1218 |
+
train_log.update({"last_run_epoch_offset": non_serialized_params['epoch_offset']})
|
1219 |
+
|
1220 |
+
|
1221 |
+
if non_serialized_params['checkpoint_offset'] > 0:
|
1222 |
+
print(f"Continue training on {RED}previous adapter{RESET} from epoch: {RED}{non_serialized_params['epoch_offset']}{RESET}")
|
1223 |
+
|
1224 |
+
if stop_at_loss > 0:
|
1225 |
+
print(f"Monitoring loss {RED}(Auto-Stop at: {stop_at_loss}){RESET}")
|
1226 |
+
|
1227 |
+
|
1228 |
+
|
1229 |
+
if WANT_INTERRUPT:
|
1230 |
+
yield "Interrupted before start.", zero_pd
|
1231 |
+
return
|
1232 |
+
|
1233 |
+
def log_train_dataset(trainer):
|
1234 |
+
decoded_entries = []
|
1235 |
+
# Try to decode the entries and write the log file
|
1236 |
+
try:
|
1237 |
+
# Iterate over the first 10 elements in the dataset (or fewer if there are less than 10)
|
1238 |
+
for i in range(min(10, len(trainer.train_dataset))):
|
1239 |
+
decoded_text = shared.tokenizer.decode(trainer.train_dataset[i]['input_ids'])
|
1240 |
+
decoded_entries.append({"value": decoded_text})
|
1241 |
+
|
1242 |
+
# Write the log file
|
1243 |
+
Path('logs').mkdir(exist_ok=True)
|
1244 |
+
with open(Path('logs/train_dataset_sample.json'), 'w') as json_file:
|
1245 |
+
json.dump(decoded_entries, json_file, indent=4)
|
1246 |
+
|
1247 |
+
logger.info("Log file 'train_dataset_sample.json' created in the 'logs' directory.")
|
1248 |
+
except Exception as e:
|
1249 |
+
logger.error(f"Failed to create log file due to error: {e}")
|
1250 |
+
|
1251 |
+
def threaded_run():
|
1252 |
+
log_train_dataset(trainer)
|
1253 |
+
trainer.train()
|
1254 |
+
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
1255 |
+
lora_model.save_pretrained(lora_file_path)
|
1256 |
+
logger.info("LoRA training run is completed and saved.")
|
1257 |
+
# Save log
|
1258 |
+
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
|
1259 |
+
json.dump(train_log, file, indent=2)
|
1260 |
+
|
1261 |
+
thread = threading.Thread(target=threaded_run)
|
1262 |
+
thread.start()
|
1263 |
+
last_step = 0
|
1264 |
+
start_time = time.perf_counter()
|
1265 |
+
|
1266 |
+
while thread.is_alive():
|
1267 |
+
time.sleep(0.5)
|
1268 |
+
|
1269 |
+
if statistics['loss']:
|
1270 |
+
max_value_dict = max(statistics['loss'], key=lambda x: x['value'])
|
1271 |
+
max_value = max_value_dict['value']+0.4
|
1272 |
+
first_epoch = statistics['loss'][0]['epoch']
|
1273 |
+
last_epoch = statistics['loss'][-1]['epoch']
|
1274 |
+
else:
|
1275 |
+
max_value = 3.5
|
1276 |
+
last_epoch = 0
|
1277 |
+
first_epoch = 0
|
1278 |
+
|
1279 |
+
if WANT_INTERRUPT:
|
1280 |
+
|
1281 |
+
losses = gr.LinePlot.update(
|
1282 |
+
value = pd.DataFrame(statistics['loss']),
|
1283 |
+
x="epoch", y="value",
|
1284 |
+
title="Loss Metrics",
|
1285 |
+
overlay_point=True, tooltip=["epoch", "value"],
|
1286 |
+
x_lim=[first_epoch,last_epoch], y_lim=[0,max_value],
|
1287 |
+
width=500, height=250 )
|
1288 |
+
|
1289 |
+
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*", losses
|
1290 |
+
|
1291 |
+
elif tracked.current_steps != last_step:
|
1292 |
+
last_step = tracked.current_steps
|
1293 |
+
time_elapsed = time.perf_counter() - start_time
|
1294 |
+
lastloss = float(train_log.get('loss', 0.0))
|
1295 |
+
|
1296 |
+
non_serialized_params.update({"training_loop": True})
|
1297 |
+
|
1298 |
+
if lastloss > 0:
|
1299 |
+
lastloss_str = f", ... Current Loss: `{lastloss:.2f}`"
|
1300 |
+
else:
|
1301 |
+
lastloss_str = ""
|
1302 |
+
|
1303 |
+
if time_elapsed <= 0:
|
1304 |
+
timer_info = ""
|
1305 |
+
total_time_estimate = 999
|
1306 |
+
else:
|
1307 |
+
its = tracked.current_steps / time_elapsed
|
1308 |
+
if its > 1:
|
1309 |
+
timer_info = f"`{its:.2f}` it/s"
|
1310 |
+
else:
|
1311 |
+
timer_info = f"`{1.0/its:.2f}` s/it"
|
1312 |
+
|
1313 |
+
total_time_estimate = (1.0 / its) * (tracked.max_steps)
|
1314 |
+
|
1315 |
+
if stop_at_loss != non_serialized_params['stop_at_loss']:
|
1316 |
+
stop_at_loss = non_serialized_params['stop_at_loss']
|
1317 |
+
print(f"Stop at loss changed {RED}(Auto-Stop at: {stop_at_loss}){RESET}")
|
1318 |
+
|
1319 |
+
losses = gr.LinePlot.update(
|
1320 |
+
value = pd.DataFrame(statistics['loss']),
|
1321 |
+
x="epoch", y="value",
|
1322 |
+
title="Loss Metrics",
|
1323 |
+
overlay_point=True, tooltip=["epoch", "value"],
|
1324 |
+
x_lim=[first_epoch,last_epoch], y_lim=[0,max_value],
|
1325 |
+
width=500, height=250 )
|
1326 |
+
|
1327 |
+
|
1328 |
+
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining {lastloss_str}", losses
|
1329 |
+
|
1330 |
+
# Saving in the train thread might fail if an error occurs, so save here if so.
|
1331 |
+
|
1332 |
+
#return_pd = pd.DataFrame(statistics['loss'])
|
1333 |
+
|
1334 |
+
if statistics['loss']:
|
1335 |
+
max_value_dict = max(statistics['loss'], key=lambda x: x['value'])
|
1336 |
+
max_value = max_value_dict['value']+0.4
|
1337 |
+
first_epoch = statistics['loss'][0]['epoch']
|
1338 |
+
last_epoch = statistics['loss'][-1]['epoch']
|
1339 |
+
else:
|
1340 |
+
max_value = 3.5
|
1341 |
+
last_epoch = 0
|
1342 |
+
first_epoch = 0
|
1343 |
+
|
1344 |
+
return_pd = gr.LinePlot.update(
|
1345 |
+
value = pd.DataFrame(statistics['loss']),
|
1346 |
+
x="epoch", y="value",
|
1347 |
+
title="Loss Metrics",
|
1348 |
+
overlay_point=True, tooltip=["epoch", "value"],
|
1349 |
+
x_lim=[first_epoch,last_epoch], y_lim=[0,max_value],
|
1350 |
+
width=500, height=250)
|
1351 |
+
|
1352 |
+
non_serialized_params.update({"training_loop": False})
|
1353 |
+
|
1354 |
+
if not tracked.did_save:
|
1355 |
+
logger.info("Training complete, saving...")
|
1356 |
+
lora_model.save_pretrained(lora_file_path)
|
1357 |
+
|
1358 |
+
if WANT_INTERRUPT:
|
1359 |
+
logger.info("Training interrupted.")
|
1360 |
+
yield f"Interrupted by user. LoRA saved to `{lora_file_path}`.", return_pd
|
1361 |
+
else:
|
1362 |
+
logger.info("Training complete!")
|
1363 |
+
yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training.", return_pd
|
1364 |
+
|
1365 |
+
create_graph(lora_file_path, lora_name)
|
1366 |
+
|
1367 |
+
def format_time(seconds: float):
|
1368 |
+
if seconds < 120:
|
1369 |
+
return f"`{seconds:.0f}` seconds"
|
1370 |
+
|
1371 |
+
minutes = seconds / 60
|
1372 |
+
if minutes < 120:
|
1373 |
+
return f"`{minutes:.0f}` minutes"
|
1374 |
+
|
1375 |
+
hours = minutes / 60
|
1376 |
+
return f"`{hours:.0f}` hours"
|
$extensions/Training_PRO/train_utils.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules import shared, utils
|
3 |
+
from pathlib import Path
|
4 |
+
import requests
|
5 |
+
import tqdm
|
6 |
+
import json
|
7 |
+
|
8 |
+
'''
|
9 |
+
def get_gpu_memory_usage(rank):
|
10 |
+
return {
|
11 |
+
'total': round(torch.cuda.get_device_properties(rank).total_memory / (1024**3), 2),
|
12 |
+
'max': round(torch.cuda.max_memory_allocated(rank) / (1024**3), 2),
|
13 |
+
'reserved': round(torch.cuda.memory_reserved(rank) / (1024**3), 2),
|
14 |
+
'allocated': round(torch.cuda.memory_allocated(rank) / (1024**3), 2)
|
15 |
+
}
|
16 |
+
'''
|
17 |
+
|
18 |
+
def list_subfoldersByTime(directory):
|
19 |
+
|
20 |
+
if not directory.endswith('/'):
|
21 |
+
directory += '/'
|
22 |
+
subfolders = []
|
23 |
+
subfolders.append('None')
|
24 |
+
path = directory
|
25 |
+
name_list = os.listdir(path)
|
26 |
+
full_list = [os.path.join(path,i) for i in name_list]
|
27 |
+
time_sorted_list = sorted(full_list, key=os.path.getmtime,reverse=True)
|
28 |
+
|
29 |
+
for entry in time_sorted_list:
|
30 |
+
if os.path.isdir(entry):
|
31 |
+
entry_str = f"{entry}" # Convert entry to a string
|
32 |
+
full_path = entry_str
|
33 |
+
entry_str = entry_str.replace('\\','/')
|
34 |
+
entry_str = entry_str.replace(f"{directory}", "") # Remove directory part
|
35 |
+
subfolders.append(entry_str)
|
36 |
+
|
37 |
+
return subfolders
|
38 |
+
|
39 |
+
def get_available_loras_local(_sortedByTime):
|
40 |
+
|
41 |
+
model_dir = shared.args.lora_dir # Update with the appropriate directory path
|
42 |
+
subfolders = []
|
43 |
+
if _sortedByTime:
|
44 |
+
subfolders = list_subfoldersByTime(model_dir)
|
45 |
+
else:
|
46 |
+
subfolders = utils.get_available_loras()
|
47 |
+
|
48 |
+
return subfolders
|
49 |
+
|
50 |
+
|
51 |
+
# FPHAM SPLIT BY SENTENCE BLOCK ===============
|
52 |
+
|
53 |
+
def split_sentences(text: str, cutoff_len: int):
|
54 |
+
sentences = []
|
55 |
+
sentence = ''
|
56 |
+
delimiters = ['. ', '? ', '! ', '... ', '.\n', '?\n', '!\n','...\n','</s>','<//>']
|
57 |
+
abbreviations = ['Mr. ', 'Mrs. ', 'Dr. ', 'Ms. ', 'St. ', 'Prof. ', 'Jr. ', 'Ltd. ', 'Capt. ', 'Col. ', 'Gen. ', 'Ave. ', 'Blvd. ', 'Co. ', 'Corp. ', 'Dept. ', 'Est. ', 'Gov. ', 'Inc. ', 'Ph.D. ', 'Univ. ']
|
58 |
+
errors = 0
|
59 |
+
max_cut = cutoff_len-1
|
60 |
+
prev_char = ''
|
61 |
+
|
62 |
+
for char in text:
|
63 |
+
sentence += char
|
64 |
+
|
65 |
+
|
66 |
+
if (any(sentence.endswith(delimiter) for delimiter in delimiters) and
|
67 |
+
not (prev_char.isupper() and len(sentence) >= 3 and sentence[-3] != ' ') and
|
68 |
+
not any(sentence.endswith(abbreviation) for abbreviation in abbreviations)):
|
69 |
+
tokens = shared.tokenizer.encode(sentence)
|
70 |
+
|
71 |
+
if len(tokens) > max_cut:
|
72 |
+
tokens = tokens[:max_cut]
|
73 |
+
sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True)
|
74 |
+
errors = errors + 1
|
75 |
+
|
76 |
+
sentences.append({'text': sentence, 'size': len(tokens)})
|
77 |
+
|
78 |
+
sentence = ''
|
79 |
+
|
80 |
+
prev_char = char
|
81 |
+
|
82 |
+
if sentence:
|
83 |
+
tokens = shared.tokenizer.encode(sentence)
|
84 |
+
if len(tokens) > max_cut:
|
85 |
+
tokens = tokens[:max_cut]
|
86 |
+
sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True)
|
87 |
+
errors = errors + 1
|
88 |
+
|
89 |
+
sentences.append({'text': sentence, 'size': len(tokens)})
|
90 |
+
|
91 |
+
if errors > 0:
|
92 |
+
print(f"Trimmed sentences beyond Cutoff Length: {errors}")
|
93 |
+
|
94 |
+
return sentences
|
95 |
+
|
96 |
+
# The goal of following code is to create blocks of text + overlapping blocks while:
|
97 |
+
# respects sentence boundaries
|
98 |
+
# always uses all the text
|
99 |
+
# hard cut defined by hard_cut_string or </s> will always end at the end of data block
|
100 |
+
# no overlapping blocks will be created across hard cut or across </s> token
|
101 |
+
|
102 |
+
def precise_cut(text: str, overlap: bool, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str, debug_slicer:bool):
|
103 |
+
|
104 |
+
EOSX_str = '<//>' #hardcut placeholder
|
105 |
+
EOS_str = '</s>'
|
106 |
+
print("Precise raw text slicer: ON")
|
107 |
+
|
108 |
+
cut_string = hard_cut_string.replace('\\n', '\n')
|
109 |
+
text = text.replace(cut_string, EOSX_str)
|
110 |
+
sentences = split_sentences(text, cutoff_len)
|
111 |
+
|
112 |
+
print(f"Sentences: {len(sentences)}")
|
113 |
+
sentencelist = []
|
114 |
+
currentSentence = ''
|
115 |
+
totalLength = 0
|
116 |
+
max_cut = cutoff_len-1
|
117 |
+
half_cut = cutoff_len//2
|
118 |
+
halfcut_length = 0
|
119 |
+
|
120 |
+
edgeindex = []
|
121 |
+
half_index = 0
|
122 |
+
|
123 |
+
for index, item in enumerate(sentences):
|
124 |
+
|
125 |
+
if halfcut_length+ item['size'] < half_cut:
|
126 |
+
halfcut_length += item['size']
|
127 |
+
half_index = index
|
128 |
+
else:
|
129 |
+
edgeindex.append(half_index)
|
130 |
+
halfcut_length = -2 * max_cut
|
131 |
+
|
132 |
+
|
133 |
+
if totalLength + item['size'] < max_cut and not currentSentence.endswith(EOSX_str):
|
134 |
+
currentSentence += item['text']
|
135 |
+
totalLength += item['size']
|
136 |
+
else:
|
137 |
+
|
138 |
+
if len(currentSentence.strip()) > min_chars_cut:
|
139 |
+
sentencelist.append(currentSentence.strip())
|
140 |
+
|
141 |
+
currentSentence = item['text']
|
142 |
+
totalLength = item['size']
|
143 |
+
halfcut_length = item['size']
|
144 |
+
|
145 |
+
if len(currentSentence.strip()) > min_chars_cut:
|
146 |
+
sentencelist.append(currentSentence.strip())
|
147 |
+
|
148 |
+
unique_blocks = len(sentencelist)
|
149 |
+
print(f"Text Blocks: {unique_blocks}")
|
150 |
+
|
151 |
+
#overlap strategies:
|
152 |
+
# don't overlap across HARD CUT (EOSX)
|
153 |
+
if overlap:
|
154 |
+
for edge_idx in edgeindex:
|
155 |
+
currentSentence = ''
|
156 |
+
totalLength = 0
|
157 |
+
|
158 |
+
for item in sentences[edge_idx:]:
|
159 |
+
if totalLength + item['size'] < max_cut:
|
160 |
+
currentSentence += item['text']
|
161 |
+
totalLength += item['size']
|
162 |
+
else:
|
163 |
+
#if by chance EOSX is at the end then it's acceptable
|
164 |
+
if currentSentence.endswith(EOSX_str) and len(currentSentence.strip()) > min_chars_cut:
|
165 |
+
sentencelist.append(currentSentence.strip())
|
166 |
+
# otherwise don't cross hard cut
|
167 |
+
elif EOSX_str not in currentSentence and len(currentSentence.strip()) > min_chars_cut:
|
168 |
+
sentencelist.append(currentSentence.strip())
|
169 |
+
|
170 |
+
currentSentence = ''
|
171 |
+
totalLength = 0
|
172 |
+
break
|
173 |
+
|
174 |
+
print(f"+ Overlapping blocks: {len(sentencelist)-unique_blocks}")
|
175 |
+
|
176 |
+
num_EOS = 0
|
177 |
+
for i in range(len(sentencelist)):
|
178 |
+
if eos_to_hc:
|
179 |
+
sentencelist[i] = sentencelist[i].replace(EOSX_str, EOS_str)
|
180 |
+
else:
|
181 |
+
sentencelist[i] = sentencelist[i].replace(EOSX_str, '')
|
182 |
+
|
183 |
+
#someone may have had stop strings in the raw text...
|
184 |
+
sentencelist[i] = sentencelist[i].replace("</s></s>", EOS_str)
|
185 |
+
num_EOS += sentencelist[i].count(EOS_str)
|
186 |
+
|
187 |
+
if num_EOS > 0:
|
188 |
+
print(f"+ EOS count: {num_EOS}")
|
189 |
+
|
190 |
+
#final check for useless lines
|
191 |
+
sentencelist = [item for item in sentencelist if item.strip() != "</s>"]
|
192 |
+
sentencelist = [item for item in sentencelist if item.strip() != ""]
|
193 |
+
|
194 |
+
|
195 |
+
if debug_slicer:
|
196 |
+
# Write the log file
|
197 |
+
Path('logs').mkdir(exist_ok=True)
|
198 |
+
sentencelist_dict = {index: sentence for index, sentence in enumerate(sentencelist)}
|
199 |
+
output_file = "logs/sentencelist.json"
|
200 |
+
with open(output_file, 'w') as f:
|
201 |
+
json.dump(sentencelist_dict, f,indent=2)
|
202 |
+
|
203 |
+
print("Saved sentencelist.json in logs folder")
|
204 |
+
|
205 |
+
return sentencelist
|
206 |
+
|
207 |
+
|
208 |
+
def sliding_block_cut(text: str, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str, debug_slicer:bool):
|
209 |
+
|
210 |
+
EOSX_str = '<//>' #hardcut placeholder
|
211 |
+
EOS_str = '</s>'
|
212 |
+
print("Mega Block Overlap: ON")
|
213 |
+
|
214 |
+
cut_string = hard_cut_string.replace('\\n', '\n')
|
215 |
+
text = text.replace(cut_string, EOSX_str)
|
216 |
+
sentences = split_sentences(text, cutoff_len)
|
217 |
+
|
218 |
+
print(f"Sentences: {len(sentences)}")
|
219 |
+
sentencelist = []
|
220 |
+
|
221 |
+
max_cut = cutoff_len-1
|
222 |
+
|
223 |
+
#print(f"max_cut: {max_cut}")
|
224 |
+
advancing_to = 0
|
225 |
+
|
226 |
+
prev_block_lastsentence = ""
|
227 |
+
|
228 |
+
|
229 |
+
for i in range(len(sentences)):
|
230 |
+
totalLength = 0
|
231 |
+
currentSentence = ''
|
232 |
+
lastsentence = ""
|
233 |
+
|
234 |
+
if i >= advancing_to:
|
235 |
+
for k in range(i, len(sentences)):
|
236 |
+
|
237 |
+
current_length = sentences[k]['size']
|
238 |
+
|
239 |
+
if totalLength + current_length <= max_cut and not currentSentence.endswith(EOSX_str):
|
240 |
+
currentSentence += sentences[k]['text']
|
241 |
+
totalLength += current_length
|
242 |
+
lastsentence = sentences[k]['text']
|
243 |
+
else:
|
244 |
+
if len(currentSentence.strip()) > min_chars_cut:
|
245 |
+
if prev_block_lastsentence!=lastsentence:
|
246 |
+
sentencelist.append(currentSentence.strip())
|
247 |
+
prev_block_lastsentence = lastsentence
|
248 |
+
|
249 |
+
advancing_to = 0
|
250 |
+
if currentSentence.endswith(EOSX_str):
|
251 |
+
advancing_to = k
|
252 |
+
|
253 |
+
currentSentence = ""
|
254 |
+
totalLength = 0
|
255 |
+
break
|
256 |
+
|
257 |
+
if currentSentence != "":
|
258 |
+
if len(currentSentence.strip()) > min_chars_cut:
|
259 |
+
sentencelist.append(currentSentence.strip())
|
260 |
+
|
261 |
+
unique_blocks = len(sentencelist)
|
262 |
+
print(f"Text Blocks: {unique_blocks}")
|
263 |
+
num_EOS = 0
|
264 |
+
for i in range(len(sentencelist)):
|
265 |
+
if eos_to_hc:
|
266 |
+
sentencelist[i] = sentencelist[i].replace(EOSX_str, EOS_str)
|
267 |
+
else:
|
268 |
+
sentencelist[i] = sentencelist[i].replace(EOSX_str, '')
|
269 |
+
|
270 |
+
#someone may have had stop strings in the raw text...
|
271 |
+
sentencelist[i] = sentencelist[i].replace("</s></s>", EOS_str)
|
272 |
+
num_EOS += sentencelist[i].count(EOS_str)
|
273 |
+
|
274 |
+
if num_EOS > 0:
|
275 |
+
print(f"+ EOS count: {num_EOS}")
|
276 |
+
|
277 |
+
#final check for useless lines
|
278 |
+
sentencelist = [item for item in sentencelist if item.strip() != "</s>"]
|
279 |
+
sentencelist = [item for item in sentencelist if item.strip() != ""]
|
280 |
+
|
281 |
+
|
282 |
+
if debug_slicer:
|
283 |
+
# Write the log file
|
284 |
+
Path('logs').mkdir(exist_ok=True)
|
285 |
+
sentencelist_dict = {index: sentence for index, sentence in enumerate(sentencelist)}
|
286 |
+
output_file = "logs/sentencelist.json"
|
287 |
+
with open(output_file, 'w') as f:
|
288 |
+
json.dump(sentencelist_dict, f,indent=2)
|
289 |
+
|
290 |
+
print("Saved sentencelist.json in logs folder")
|
291 |
+
|
292 |
+
return sentencelist
|
293 |
+
|
294 |
+
# Example usage:
|
295 |
+
# download_file_from_url('https://example.com/path/to/your/file.ext', '/output/directory')
|
296 |
+
|
297 |
+
def download_file_from_url(url, overwrite, output_dir_in, valid_extensions = {'.txt', '.json'}):
|
298 |
+
try:
|
299 |
+
# Validate and sanitize the URL
|
300 |
+
#parsed_url = urllib.parse.urlparse(url)
|
301 |
+
#if not parsed_url.netloc:
|
302 |
+
# raise ValueError("Invalid URL")
|
303 |
+
#filename = os.path.basename(parsed_url.path)
|
304 |
+
|
305 |
+
# Get the filename from the URL
|
306 |
+
|
307 |
+
session = requests.Session()
|
308 |
+
headers = {}
|
309 |
+
mode = 'wb'
|
310 |
+
filename = url.split('/')[-1]
|
311 |
+
|
312 |
+
output_dir = str(output_dir_in)
|
313 |
+
# Construct the full path to the output file
|
314 |
+
local_filename = os.path.join(output_dir, filename)
|
315 |
+
|
316 |
+
# Check if the local file already exists
|
317 |
+
overw = ''
|
318 |
+
if os.path.exists(local_filename):
|
319 |
+
if not overwrite:
|
320 |
+
yield f"File '{local_filename}' already exists. Aborting."
|
321 |
+
return
|
322 |
+
else:
|
323 |
+
overw = ' [Overwrite existing]'
|
324 |
+
|
325 |
+
filename_lower = filename.lower()
|
326 |
+
|
327 |
+
# Send an HTTP GET request to the URL with a timeout
|
328 |
+
file_extension = os.path.splitext(filename_lower)[-1]
|
329 |
+
|
330 |
+
if file_extension not in valid_extensions:
|
331 |
+
yield f"Invalid file extension: {file_extension}. Only {valid_extensions} files are supported."
|
332 |
+
return
|
333 |
+
|
334 |
+
with session.get(url, stream=True, headers=headers, timeout=10) as r:
|
335 |
+
r.raise_for_status()
|
336 |
+
# total size can be wildly inaccurate
|
337 |
+
#total_size = int(r.headers.get('content-length', 0))
|
338 |
+
|
339 |
+
block_size = 1024 * 4
|
340 |
+
with open(local_filename, mode) as f:
|
341 |
+
count = 0
|
342 |
+
for data in r.iter_content(block_size):
|
343 |
+
f.write(data)
|
344 |
+
count += len(data)
|
345 |
+
|
346 |
+
yield f"Downloaded: {count} " + overw
|
347 |
+
|
348 |
+
# Verify file size if possible
|
349 |
+
if os.path.exists(local_filename):
|
350 |
+
downloaded_size = os.path.getsize(local_filename)
|
351 |
+
if downloaded_size > 0:
|
352 |
+
yield f"File '{filename}' downloaded to '{output_dir}' ({downloaded_size} bytes)."
|
353 |
+
print("File Downloaded")
|
354 |
+
else:
|
355 |
+
print("Downloaded file is zero")
|
356 |
+
yield f"Failed. Downloaded file size is zero)."
|
357 |
+
else:
|
358 |
+
print(f"Error: {local_filename} failed to download.")
|
359 |
+
yield f"Error: {local_filename} failed to download"
|
360 |
+
|
361 |
+
except Exception as e:
|
362 |
+
print(f"An error occurred: {e}")
|
363 |
+
yield f"An error occurred: {e}"
|
364 |
+
|
365 |
+
finally:
|
366 |
+
# Close the session to release resources
|
367 |
+
session.close()
|
368 |
+
|
$extensions/character_bias/script.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
# get the current directory of the script
|
6 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
7 |
+
|
8 |
+
# check if the bias_options.txt file exists, if not, create it
|
9 |
+
bias_file = os.path.join(current_dir, "bias_options.txt")
|
10 |
+
if not os.path.isfile(bias_file):
|
11 |
+
with open(bias_file, "w") as f:
|
12 |
+
f.write("*I am so happy*\n*I am so sad*\n*I am so excited*\n*I am so bored*\n*I am so angry*")
|
13 |
+
|
14 |
+
# read bias options from the text file
|
15 |
+
with open(bias_file, "r") as f:
|
16 |
+
bias_options = [line.strip() for line in f.readlines()]
|
17 |
+
|
18 |
+
params = {
|
19 |
+
"activate": True,
|
20 |
+
"bias string": " *I am so happy*",
|
21 |
+
"use custom string": False,
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
def input_modifier(string):
|
26 |
+
"""
|
27 |
+
This function is applied to your text inputs before
|
28 |
+
they are fed into the model.
|
29 |
+
"""
|
30 |
+
return string
|
31 |
+
|
32 |
+
|
33 |
+
def output_modifier(string):
|
34 |
+
"""
|
35 |
+
This function is applied to the model outputs.
|
36 |
+
"""
|
37 |
+
return string
|
38 |
+
|
39 |
+
|
40 |
+
def bot_prefix_modifier(string):
|
41 |
+
"""
|
42 |
+
This function is only applied in chat mode. It modifies
|
43 |
+
the prefix text for the Bot and can be used to bias its
|
44 |
+
behavior.
|
45 |
+
"""
|
46 |
+
if params['activate']:
|
47 |
+
if params['use custom string']:
|
48 |
+
return f'{string} {params["custom string"].strip()} '
|
49 |
+
else:
|
50 |
+
return f'{string} {params["bias string"].strip()} '
|
51 |
+
else:
|
52 |
+
return string
|
53 |
+
|
54 |
+
|
55 |
+
def ui():
|
56 |
+
# Gradio elements
|
57 |
+
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
58 |
+
dropdown_string = gr.Dropdown(choices=bias_options, value=params["bias string"], label='Character bias', info='To edit the options in this dropdown edit the "bias_options.txt" file')
|
59 |
+
use_custom_string = gr.Checkbox(value=False, label='Use custom bias textbox instead of dropdown')
|
60 |
+
custom_string = gr.Textbox(value="", placeholder="Enter custom bias string", label="Custom Character Bias", info='To use this textbox activate the checkbox above')
|
61 |
+
|
62 |
+
# Event functions to update the parameters in the backend
|
63 |
+
def update_bias_string(x):
|
64 |
+
if x:
|
65 |
+
params.update({"bias string": x})
|
66 |
+
else:
|
67 |
+
params.update({"bias string": dropdown_string.get()})
|
68 |
+
return x
|
69 |
+
|
70 |
+
def update_custom_string(x):
|
71 |
+
params.update({"custom string": x})
|
72 |
+
|
73 |
+
dropdown_string.change(update_bias_string, dropdown_string, None)
|
74 |
+
custom_string.change(update_custom_string, custom_string, None)
|
75 |
+
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
76 |
+
use_custom_string.change(lambda x: params.update({"use custom string": x}), use_custom_string, None)
|
77 |
+
|
78 |
+
# Group elements together depending on the selected option
|
79 |
+
def bias_string_group():
|
80 |
+
if use_custom_string.value:
|
81 |
+
return gr.Group([use_custom_string, custom_string])
|
82 |
+
else:
|
83 |
+
return dropdown_string
|
$extensions/coqui_tts/harvard_sentences.txt
ADDED
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The birch canoe slid on the smooth planks.
|
2 |
+
Glue the sheet to the dark blue background.
|
3 |
+
It's easy to tell the depth of a well.
|
4 |
+
These days a chicken leg is a rare dish.
|
5 |
+
Rice is often served in round bowls.
|
6 |
+
The juice of lemons makes fine punch.
|
7 |
+
The box was thrown beside the parked truck.
|
8 |
+
The hogs were fed chopped corn and garbage.
|
9 |
+
Four hours of steady work faced us.
|
10 |
+
A large size in stockings is hard to sell.
|
11 |
+
The boy was there when the sun rose.
|
12 |
+
A rod is used to catch pink salmon.
|
13 |
+
The source of the huge river is the clear spring.
|
14 |
+
Kick the ball straight and follow through.
|
15 |
+
Help the woman get back to her feet.
|
16 |
+
A pot of tea helps to pass the evening.
|
17 |
+
Smoky fires lack flame and heat.
|
18 |
+
The soft cushion broke the man's fall.
|
19 |
+
The salt breeze came across from the sea.
|
20 |
+
The girl at the booth sold fifty bonds.
|
21 |
+
The small pup gnawed a hole in the sock.
|
22 |
+
The fish twisted and turned on the bent hook.
|
23 |
+
Press the pants and sew a button on the vest.
|
24 |
+
The swan dive was far short of perfect.
|
25 |
+
The beauty of the view stunned the young boy.
|
26 |
+
Two blue fish swam in the tank.
|
27 |
+
Her purse was full of useless trash.
|
28 |
+
The colt reared and threw the tall rider.
|
29 |
+
It snowed, rained, and hailed the same morning.
|
30 |
+
Read verse out loud for pleasure.
|
31 |
+
Hoist the load to your left shoulder.
|
32 |
+
Take the winding path to reach the lake.
|
33 |
+
Note closely the size of the gas tank.
|
34 |
+
Wipe the grease off his dirty face.
|
35 |
+
Mend the coat before you go out.
|
36 |
+
The wrist was badly strained and hung limp.
|
37 |
+
The stray cat gave birth to kittens.
|
38 |
+
The young girl gave no clear response.
|
39 |
+
The meal was cooked before the bell rang.
|
40 |
+
What joy there is in living.
|
41 |
+
A king ruled the state in the early days.
|
42 |
+
The ship was torn apart on the sharp reef.
|
43 |
+
Sickness kept him home the third week.
|
44 |
+
The wide road shimmered in the hot sun.
|
45 |
+
The lazy cow lay in the cool grass.
|
46 |
+
Lift the square stone over the fence.
|
47 |
+
The rope will bind the seven books at once.
|
48 |
+
Hop over the fence and plunge in.
|
49 |
+
The friendly gang left the drug store.
|
50 |
+
Mesh wire keeps chicks inside.
|
51 |
+
The frosty air passed through the coat.
|
52 |
+
The crooked maze failed to fool the mouse.
|
53 |
+
Adding fast leads to wrong sums.
|
54 |
+
The show was a flop from the very start.
|
55 |
+
A saw is a tool used for making boards.
|
56 |
+
The wagon moved on well oiled wheels.
|
57 |
+
March the soldiers past the next hill.
|
58 |
+
A cup of sugar makes sweet fudge.
|
59 |
+
Place a rosebush near the porch steps.
|
60 |
+
Both lost their lives in the raging storm.
|
61 |
+
We talked of the side show in the circus.
|
62 |
+
Use a pencil to write the first draft.
|
63 |
+
He ran half way to the hardware store.
|
64 |
+
The clock struck to mark the third period.
|
65 |
+
A small creek cut across the field.
|
66 |
+
Cars and busses stalled in snow drifts.
|
67 |
+
The set of china hit the floor with a crash.
|
68 |
+
This is a grand season for hikes on the road.
|
69 |
+
The dune rose from the edge of the water.
|
70 |
+
Those words were the cue for the actor to leave.
|
71 |
+
A yacht slid around the point into the bay.
|
72 |
+
The two met while playing on the sand.
|
73 |
+
The ink stain dried on the finished page.
|
74 |
+
The walled town was seized without a fight.
|
75 |
+
The lease ran out in sixteen weeks.
|
76 |
+
A tame squirrel makes a nice pet.
|
77 |
+
The horn of the car woke the sleeping cop.
|
78 |
+
The heart beat strongly and with firm strokes.
|
79 |
+
The pearl was worn in a thin silver ring.
|
80 |
+
The fruit peel was cut in thick slices.
|
81 |
+
The Navy attacked the big task force.
|
82 |
+
See the cat glaring at the scared mouse.
|
83 |
+
There are more than two factors here.
|
84 |
+
The hat brim was wide and too droopy.
|
85 |
+
The lawyer tried to lose his case.
|
86 |
+
The grass curled around the fence post.
|
87 |
+
Cut the pie into large parts.
|
88 |
+
Men strive but seldom get rich.
|
89 |
+
Always close the barn door tight.
|
90 |
+
He lay prone and hardly moved a limb.
|
91 |
+
The slush lay deep along the street.
|
92 |
+
A wisp of cloud hung in the blue air.
|
93 |
+
A pound of sugar costs more than eggs.
|
94 |
+
The fin was sharp and cut the clear water.
|
95 |
+
The play seems dull and quite stupid.
|
96 |
+
Bail the boat to stop it from sinking.
|
97 |
+
The term ended in late June that year.
|
98 |
+
A tusk is used to make costly gifts.
|
99 |
+
Ten pins were set in order.
|
100 |
+
The bill was paid every third week.
|
101 |
+
Oak is strong and also gives shade.
|
102 |
+
Cats and dogs each hate the other.
|
103 |
+
The pipe began to rust while new.
|
104 |
+
Open the crate but don't break the glass.
|
105 |
+
Add the sum to the product of these three.
|
106 |
+
Thieves who rob friends deserve jail.
|
107 |
+
The ripe taste of cheese improves with age.
|
108 |
+
Act on these orders with great speed.
|
109 |
+
The hog crawled under the high fence.
|
110 |
+
Move the vat over the hot fire.
|
111 |
+
The bark of the pine tree was shiny and dark.
|
112 |
+
Leaves turn brown and yellow in the fall.
|
113 |
+
The pennant waved when the wind blew.
|
114 |
+
Split the log with a quick, sharp blow.
|
115 |
+
Burn peat after the logs give out.
|
116 |
+
He ordered peach pie with ice cream.
|
117 |
+
Weave the carpet on the right hand side.
|
118 |
+
Hemp is a weed found in parts of the tropics.
|
119 |
+
A lame back kept his score low.
|
120 |
+
We find joy in the simplest things.
|
121 |
+
Type out three lists of orders.
|
122 |
+
The harder he tried the less he got done.
|
123 |
+
The boss ran the show with a watchful eye.
|
124 |
+
The cup cracked and spilled its contents.
|
125 |
+
Paste can cleanse the most dirty brass.
|
126 |
+
The slang word for raw whiskey is booze.
|
127 |
+
It caught its hind paw in a rusty trap.
|
128 |
+
The wharf could be seen at the farther shore.
|
129 |
+
Feel the heat of the weak dying flame.
|
130 |
+
The tiny girl took off her hat.
|
131 |
+
A cramp is no small danger on a swim.
|
132 |
+
He said the same phrase thirty times.
|
133 |
+
Pluck the bright rose without leaves.
|
134 |
+
Two plus seven is less than ten.
|
135 |
+
The glow deepened in the eyes of the sweet girl.
|
136 |
+
Bring your problems to the wise chief.
|
137 |
+
Write a fond note to the friend you cherish.
|
138 |
+
Clothes and lodging are free to new men.
|
139 |
+
We frown when events take a bad turn.
|
140 |
+
Port is a strong wine with a smoky taste.
|
141 |
+
The young kid jumped the rusty gate.
|
142 |
+
Guess the results from the first scores.
|
143 |
+
A salt pickle tastes fine with ham.
|
144 |
+
The just claim got the right verdict.
|
145 |
+
These thistles bend in a high wind.
|
146 |
+
Pure bred poodles have curls.
|
147 |
+
The tree top waved in a graceful way.
|
148 |
+
The spot on the blotter was made by green ink.
|
149 |
+
Mud was spattered on the front of his white shirt.
|
150 |
+
The cigar burned a hole in the desk top.
|
151 |
+
The empty flask stood on the tin tray.
|
152 |
+
A speedy man can beat this track mark.
|
153 |
+
He broke a new shoelace that day.
|
154 |
+
The coffee stand is too high for the couch.
|
155 |
+
The urge to write short stories is rare.
|
156 |
+
The pencils have all been used.
|
157 |
+
The pirates seized the crew of the lost ship.
|
158 |
+
We tried to replace the coin but failed.
|
159 |
+
She sewed the torn coat quite neatly.
|
160 |
+
The sofa cushion is red and of light weight.
|
161 |
+
The jacket hung on the back of the wide chair.
|
162 |
+
At that high level the air is pure.
|
163 |
+
Drop the two when you add the figures.
|
164 |
+
A filing case is now hard to buy.
|
165 |
+
An abrupt start does not win the prize.
|
166 |
+
Wood is best for making toys and blocks.
|
167 |
+
The office paint was a dull, sad tan.
|
168 |
+
He knew the skill of the great young actress.
|
169 |
+
A rag will soak up spilled water.
|
170 |
+
A shower of dirt fell from the hot pipes.
|
171 |
+
Steam hissed from the broken valve.
|
172 |
+
The child almost hurt the small dog.
|
173 |
+
There was a sound of dry leaves outside.
|
174 |
+
The sky that morning was clear and bright blue.
|
175 |
+
Torn scraps littered the stone floor.
|
176 |
+
Sunday is the best part of the week.
|
177 |
+
The doctor cured him with these pills.
|
178 |
+
The new girl was fired today at noon.
|
179 |
+
They felt gay when the ship arrived in port.
|
180 |
+
Add the store's account to the last cent.
|
181 |
+
Acid burns holes in wool cloth.
|
182 |
+
Fairy tales should be fun to write.
|
183 |
+
Eight miles of woodland burned to waste.
|
184 |
+
The third act was dull and tired the players.
|
185 |
+
A young child should not suffer fright.
|
186 |
+
Add the column and put the sum here.
|
187 |
+
We admire and love a good cook.
|
188 |
+
There the flood mark is ten inches.
|
189 |
+
He carved a head from the round block of marble.
|
190 |
+
She has a smart way of wearing clothes.
|
191 |
+
The fruit of a fig tree is apple-shaped.
|
192 |
+
Corn cobs can be used to kindle a fire.
|
193 |
+
Where were they when the noise started.
|
194 |
+
The paper box is full of thumb tacks.
|
195 |
+
Sell your gift to a buyer at a good gain.
|
196 |
+
The tongs lay beside the ice pail.
|
197 |
+
The petals fall with the next puff of wind.
|
198 |
+
Bring your best compass to the third class.
|
199 |
+
They could laugh although they were sad.
|
200 |
+
Farmers came in to thresh the oat crop.
|
201 |
+
The brown house was on fire to the attic.
|
202 |
+
The lure is used to catch trout and flounder.
|
203 |
+
Float the soap on top of the bath water.
|
204 |
+
A blue crane is a tall wading bird.
|
205 |
+
A fresh start will work such wonders.
|
206 |
+
The club rented the rink for the fifth night.
|
207 |
+
After the dance, they went straight home.
|
208 |
+
The hostess taught the new maid to serve.
|
209 |
+
He wrote his last novel there at the inn.
|
210 |
+
Even the worst will beat his low score.
|
211 |
+
The cement had dried when he moved it.
|
212 |
+
The loss of the second ship was hard to take.
|
213 |
+
The fly made its way along the wall.
|
214 |
+
Do that with a wooden stick.
|
215 |
+
Live wires should be kept covered.
|
216 |
+
The large house had hot water taps.
|
217 |
+
It is hard to erase blue or red ink.
|
218 |
+
Write at once or you may forget it.
|
219 |
+
The doorknob was made of bright clean brass.
|
220 |
+
The wreck occurred by the bank on Main Street.
|
221 |
+
A pencil with black lead writes best.
|
222 |
+
Coax a young calf to drink from a bucket.
|
223 |
+
Schools for ladies teach charm and grace.
|
224 |
+
The lamp shone with a steady green flame.
|
225 |
+
They took the axe and the saw to the forest.
|
226 |
+
The ancient coin was quite dull and worn.
|
227 |
+
The shaky barn fell with a loud crash.
|
228 |
+
Jazz and swing fans like fast music.
|
229 |
+
Rake the rubbish up and then burn it.
|
230 |
+
Slash the gold cloth into fine ribbons.
|
231 |
+
Try to have the court decide the case.
|
232 |
+
They are pushed back each time they attack.
|
233 |
+
He broke his ties with groups of former friends.
|
234 |
+
They floated on the raft to sun their white backs.
|
235 |
+
The map had an X that meant nothing.
|
236 |
+
Whitings are small fish caught in nets.
|
237 |
+
Some ads serve to cheat buyers.
|
238 |
+
Jerk the rope and the bell rings weakly.
|
239 |
+
A waxed floor makes us lose balance.
|
240 |
+
Madam, this is the best brand of corn.
|
241 |
+
On the islands the sea breeze is soft and mild.
|
242 |
+
The play began as soon as we sat down.
|
243 |
+
This will lead the world to more sound and fury.
|
244 |
+
Add salt before you fry the egg.
|
245 |
+
The rush for funds reached its peak Tuesday.
|
246 |
+
The birch looked stark white and lonesome.
|
247 |
+
The box is held by a bright red snapper.
|
248 |
+
To make pure ice, you freeze water.
|
249 |
+
The first worm gets snapped early.
|
250 |
+
Jump the fence and hurry up the bank.
|
251 |
+
Yell and clap as the curtain slides back.
|
252 |
+
They are men who walk the middle of the road.
|
253 |
+
Both brothers wear the same size.
|
254 |
+
In some form or other we need fun.
|
255 |
+
The prince ordered his head chopped off.
|
256 |
+
The houses are built of red clay bricks.
|
257 |
+
Ducks fly north but lack a compass.
|
258 |
+
Fruit flavors are used in fizz drinks.
|
259 |
+
These pills do less good than others.
|
260 |
+
Canned pears lack full flavor.
|
261 |
+
The dark pot hung in the front closet.
|
262 |
+
Carry the pail to the wall and spill it there.
|
263 |
+
The train brought our hero to the big town.
|
264 |
+
We are sure that one war is enough.
|
265 |
+
Gray paint stretched for miles around.
|
266 |
+
The rude laugh filled the empty room.
|
267 |
+
High seats are best for football fans.
|
268 |
+
Tea served from the brown jug is tasty.
|
269 |
+
A dash of pepper spoils beef stew.
|
270 |
+
A zestful food is the hot-cross bun.
|
271 |
+
The horse trotted around the field at a brisk pace.
|
272 |
+
Find the twin who stole the pearl necklace.
|
273 |
+
Cut the cord that binds the box tightly.
|
274 |
+
The red tape bound the smuggled food.
|
275 |
+
Look in the corner to find the tan shirt.
|
276 |
+
The cold drizzle will halt the bond drive.
|
277 |
+
Nine men were hired to dig the ruins.
|
278 |
+
The junk yard had a mouldy smell.
|
279 |
+
The flint sputtered and lit a pine torch.
|
280 |
+
Soak the cloth and drown the sharp odor.
|
281 |
+
The shelves were bare of both jam or crackers.
|
282 |
+
A joy to every child is the swan boat.
|
283 |
+
All sat frozen and watched the screen.
|
284 |
+
A cloud of dust stung his tender eyes.
|
285 |
+
To reach the end he needs much courage.
|
286 |
+
Shape the clay gently into block form.
|
287 |
+
A ridge on a smooth surface is a bump or flaw.
|
288 |
+
Hedge apples may stain your hands green.
|
289 |
+
Quench your thirst, then eat the crackers.
|
290 |
+
Tight curls get limp on rainy days.
|
291 |
+
The mute muffled the high tones of the horn.
|
292 |
+
The gold ring fits only a pierced ear.
|
293 |
+
The old pan was covered with hard fudge.
|
294 |
+
Watch the log float in the wide river.
|
295 |
+
The node on the stalk of wheat grew daily.
|
296 |
+
The heap of fallen leaves was set on fire.
|
297 |
+
Write fast if you want to finish early.
|
298 |
+
His shirt was clean but one button was gone.
|
299 |
+
The barrel of beer was a brew of malt and hops.
|
300 |
+
Tin cans are absent from store shelves.
|
301 |
+
Slide the box into that empty space.
|
302 |
+
The plant grew large and green in the window.
|
303 |
+
The beam dropped down on the workmen's head.
|
304 |
+
Pink clouds floated with the breeze.
|
305 |
+
She danced like a swan, tall and graceful.
|
306 |
+
The tube was blown and the tire flat and useless.
|
307 |
+
It is late morning on the old wall clock.
|
308 |
+
Let's all join as we sing the last chorus.
|
309 |
+
The last switch cannot be turned off.
|
310 |
+
The fight will end in just six minutes.
|
311 |
+
The store walls were lined with colored frocks.
|
312 |
+
The peace league met to discuss their plans.
|
313 |
+
The rise to fame of a person takes luck.
|
314 |
+
Paper is scarce, so write with much care.
|
315 |
+
The quick fox jumped on the sleeping cat.
|
316 |
+
The nozzle of the fire hose was bright brass.
|
317 |
+
Screw the round cap on as tight as needed.
|
318 |
+
Time brings us many changes.
|
319 |
+
The purple tie was ten years old.
|
320 |
+
Men think and plan and sometimes act.
|
321 |
+
Fill the ink jar with sticky glue.
|
322 |
+
He smoke a big pipe with strong contents.
|
323 |
+
We need grain to keep our mules healthy.
|
324 |
+
Pack the records in a neat thin case.
|
325 |
+
The crunch of feet in the snow was the only sound.
|
326 |
+
The copper bowl shone in the sun's rays.
|
327 |
+
Boards will warp unless kept dry.
|
328 |
+
The plush chair leaned against the wall.
|
329 |
+
Glass will clink when struck by metal.
|
330 |
+
Bathe and relax in the cool green grass.
|
331 |
+
Nine rows of soldiers stood in line.
|
332 |
+
The beach is dry and shallow at low tide.
|
333 |
+
The idea is to sew both edges straight.
|
334 |
+
The kitten chased the dog down the street.
|
335 |
+
Pages bound in cloth make a book.
|
336 |
+
Try to trace the fine lines of the painting.
|
337 |
+
Women form less than half of the group.
|
338 |
+
The zones merge in the central part of town.
|
339 |
+
A gem in the rough needs work to polish.
|
340 |
+
Code is used when secrets are sent.
|
341 |
+
Most of the news is easy for us to hear.
|
342 |
+
He used the lathe to make brass objects.
|
343 |
+
The vane on top of the pole revolved in the wind.
|
344 |
+
Mince pie is a dish served to children.
|
345 |
+
The clan gathered on each dull night.
|
346 |
+
Let it burn, it gives us warmth and comfort.
|
347 |
+
A castle built from sand fails to endure.
|
348 |
+
A child's wit saved the day for us.
|
349 |
+
Tack the strip of carpet to the worn floor.
|
350 |
+
Next Tuesday we must vote.
|
351 |
+
Pour the stew from the pot into the plate.
|
352 |
+
Each penny shone like new.
|
353 |
+
The man went to the woods to gather sticks.
|
354 |
+
The dirt piles were lines along the road.
|
355 |
+
The logs fell and tumbled into the clear stream.
|
356 |
+
Just hoist it up and take it away.
|
357 |
+
A ripe plum is fit for a king's palate.
|
358 |
+
Our plans right now are hazy.
|
359 |
+
Brass rings are sold by these natives.
|
360 |
+
It takes a good trap to capture a bear.
|
361 |
+
Feed the white mouse some flower seeds.
|
362 |
+
The thaw came early and freed the stream.
|
363 |
+
He took the lead and kept it the whole distance.
|
364 |
+
The key you designed will fit the lock.
|
365 |
+
Plead to the council to free the poor thief.
|
366 |
+
Better hash is made of rare beef.
|
367 |
+
This plank was made for walking on.
|
368 |
+
The lake sparkled in the red hot sun.
|
369 |
+
He crawled with care along the ledge.
|
370 |
+
Tend the sheep while the dog wanders.
|
371 |
+
It takes a lot of help to finish these.
|
372 |
+
Mark the spot with a sign painted red.
|
373 |
+
Take two shares as a fair profit.
|
374 |
+
The fur of cats goes by many names.
|
375 |
+
North winds bring colds and fevers.
|
376 |
+
He asks no person to vouch for him.
|
377 |
+
Go now and come here later.
|
378 |
+
A sash of gold silk will trim her dress.
|
379 |
+
Soap can wash most dirt away.
|
380 |
+
That move means the game is over.
|
381 |
+
He wrote down a long list of items.
|
382 |
+
A siege will crack the strong defense.
|
383 |
+
Grape juice and water mix well.
|
384 |
+
Roads are paved with sticky tar.
|
385 |
+
Fake stones shine but cost little.
|
386 |
+
The drip of the rain made a pleasant sound.
|
387 |
+
Smoke poured out of every crack.
|
388 |
+
Serve the hot rum to the tired heroes.
|
389 |
+
Much of the story makes good sense.
|
390 |
+
The sun came up to light the eastern sky.
|
391 |
+
Heave the line over the port side.
|
392 |
+
A lathe cuts and trims any wood.
|
393 |
+
It's a dense crowd in two distinct ways.
|
394 |
+
His hip struck the knee of the next player.
|
395 |
+
The stale smell of old beer lingers.
|
396 |
+
The desk was firm on the shaky floor.
|
397 |
+
It takes heat to bring out the odor.
|
398 |
+
Beef is scarcer than some lamb.
|
399 |
+
Raise the sail and steer the ship northward.
|
400 |
+
A cone costs five cents on Mondays.
|
401 |
+
A pod is what peas always grow in.
|
402 |
+
Jerk the dart from the cork target.
|
403 |
+
No cement will hold hard wood.
|
404 |
+
We now have a new base for shipping.
|
405 |
+
A list of names is carved around the base.
|
406 |
+
The sheep were led home by a dog.
|
407 |
+
Three for a dime, the young peddler cried.
|
408 |
+
The sense of smell is better than that of touch.
|
409 |
+
No hardship seemed to keep him sad.
|
410 |
+
Grace makes up for lack of beauty.
|
411 |
+
Nudge gently but wake her now.
|
412 |
+
The news struck doubt into restless minds.
|
413 |
+
Once we stood beside the shore.
|
414 |
+
A chink in the wall allowed a draft to blow.
|
415 |
+
Fasten two pins on each side.
|
416 |
+
A cold dip restores health and zest.
|
417 |
+
He takes the oath of office each March.
|
418 |
+
The sand drifts over the sill of the old house.
|
419 |
+
The point of the steel pen was bent and twisted.
|
420 |
+
There is a lag between thought and act.
|
421 |
+
Seed is needed to plant the spring corn.
|
422 |
+
Draw the chart with heavy black lines.
|
423 |
+
The boy owed his pal thirty cents.
|
424 |
+
The chap slipped into the crowd and was lost.
|
425 |
+
Hats are worn to tea and not to dinner.
|
426 |
+
The ramp led up to the wide highway.
|
427 |
+
Beat the dust from the rug onto the lawn.
|
428 |
+
Say it slowly but make it ring clear.
|
429 |
+
The straw nest housed five robins.
|
430 |
+
Screen the porch with woven straw mats.
|
431 |
+
This horse will nose his way to the finish.
|
432 |
+
The dry wax protects the deep scratch.
|
433 |
+
He picked up the dice for a second roll.
|
434 |
+
These coins will be needed to pay his debt.
|
435 |
+
The nag pulled the frail cart along.
|
436 |
+
Twist the valve and release hot steam.
|
437 |
+
The vamp of the shoe had a gold buckle.
|
438 |
+
The smell of burned rags itches my nose.
|
439 |
+
New pants lack cuffs and pockets.
|
440 |
+
The marsh will freeze when cold enough.
|
441 |
+
They slice the sausage thin with a knife.
|
442 |
+
The bloom of the rose lasts a few days.
|
443 |
+
A gray mare walked before the colt.
|
444 |
+
Breakfast buns are fine with a hot drink.
|
445 |
+
Bottles hold four kinds of rum.
|
446 |
+
The man wore a feather in his felt hat.
|
447 |
+
He wheeled the bike past the winding road.
|
448 |
+
Drop the ashes on the worn old rug.
|
449 |
+
The desk and both chairs were painted tan.
|
450 |
+
Throw out the used paper cup and plate.
|
451 |
+
A clean neck means a neat collar.
|
452 |
+
The couch cover and hall drapes were blue.
|
453 |
+
The stems of the tall glasses cracked and broke.
|
454 |
+
The wall phone rang loud and often.
|
455 |
+
The clothes dried on a thin wooden rack.
|
456 |
+
Turn on the lantern which gives us light.
|
457 |
+
The cleat sank deeply into the soft turf.
|
458 |
+
The bills were mailed promptly on the tenth of the month.
|
459 |
+
To have is better than to wait and hope.
|
460 |
+
The price is fair for a good antique clock.
|
461 |
+
The music played on while they talked.
|
462 |
+
Dispense with a vest on a day like this.
|
463 |
+
The bunch of grapes was pressed into wine.
|
464 |
+
He sent the figs, but kept the ripe cherries.
|
465 |
+
The hinge on the door creaked with old age.
|
466 |
+
The screen before the fire kept in the sparks.
|
467 |
+
Fly by night, and you waste little time.
|
468 |
+
Thick glasses helped him read the print.
|
469 |
+
Birth and death mark the limits of life.
|
470 |
+
The chair looked strong but had no bottom.
|
471 |
+
The kite flew wildly in the high wind.
|
472 |
+
A fur muff is stylish once more.
|
473 |
+
The tin box held priceless stones.
|
474 |
+
We need an end of all such matter.
|
475 |
+
The case was puzzling to the old and wise.
|
476 |
+
The bright lanterns were gay on the dark lawn.
|
477 |
+
We don't get much money but we have fun.
|
478 |
+
The youth drove with zest, but little skill.
|
479 |
+
Five years he lived with a shaggy dog.
|
480 |
+
A fence cuts through the corner lot.
|
481 |
+
The way to save money is not to spend much.
|
482 |
+
Shut the hatch before the waves push it in.
|
483 |
+
The odor of spring makes young hearts jump.
|
484 |
+
Crack the walnut with your sharp side teeth.
|
485 |
+
He offered proof in the form of a large chart.
|
486 |
+
Send the stuff in a thick paper bag.
|
487 |
+
A quart of milk is water for the most part.
|
488 |
+
They told wild tales to frighten him.
|
489 |
+
The three story house was built of stone.
|
490 |
+
In the rear of the ground floor was a large passage.
|
491 |
+
A man in a blue sweater sat at the desk.
|
492 |
+
Oats are a food eaten by horse and man.
|
493 |
+
Their eyelids droop for want of sleep.
|
494 |
+
A sip of tea revives his tired friend.
|
495 |
+
There are many ways to do these things.
|
496 |
+
Tuck the sheet under the edge of the mat.
|
497 |
+
A force equal to that would move the earth.
|
498 |
+
We like to see clear weather.
|
499 |
+
The work of the tailor is seen on each side.
|
500 |
+
Take a chance and win a china doll.
|
501 |
+
Shake the dust from your shoes, stranger.
|
502 |
+
She was kind to sick old people.
|
503 |
+
The square wooden crate was packed to be shipped.
|
504 |
+
The dusty bench stood by the stone wall.
|
505 |
+
We dress to suit the weather of most days.
|
506 |
+
Smile when you say nasty words.
|
507 |
+
A bowl of rice is free with chicken stew.
|
508 |
+
The water in this well is a source of good health.
|
509 |
+
Take shelter in this tent, but keep still.
|
510 |
+
That guy is the writer of a few banned books.
|
511 |
+
The little tales they tell are false.
|
512 |
+
The door was barred, locked, and bolted as well.
|
513 |
+
Ripe pears are fit for a queen's table.
|
514 |
+
A big wet stain was on the round carpet.
|
515 |
+
The kite dipped and swayed, but stayed aloft.
|
516 |
+
The pleasant hours fly by much too soon.
|
517 |
+
The room was crowded with a wild mob.
|
518 |
+
This strong arm shall shield your honor.
|
519 |
+
She blushed when he gave her a white orchid.
|
520 |
+
The beetle droned in the hot June sun.
|
521 |
+
Press the pedal with your left foot.
|
522 |
+
Neat plans fail without luck.
|
523 |
+
The black trunk fell from the landing.
|
524 |
+
The bank pressed for payment of the debt.
|
525 |
+
The theft of the pearl pin was kept secret.
|
526 |
+
Shake hands with this friendly child.
|
527 |
+
The vast space stretched into the far distance.
|
528 |
+
A rich farm is rare in this sandy waste.
|
529 |
+
His wide grin earned many friends.
|
530 |
+
Flax makes a fine brand of paper.
|
531 |
+
Hurdle the pit with the aid of a long pole.
|
532 |
+
A strong bid may scare your partner stiff.
|
533 |
+
Even a just cause needs power to win.
|
534 |
+
Peep under the tent and see the clowns.
|
535 |
+
The leaf drifts along with a slow spin.
|
536 |
+
Cheap clothes are flashy but don't last.
|
537 |
+
A thing of small note can cause despair.
|
538 |
+
Flood the mails with requests for this book.
|
539 |
+
A thick coat of black paint covered all.
|
540 |
+
The pencil was cut to be sharp at both ends.
|
541 |
+
Those last words were a strong statement.
|
542 |
+
He wrote his name boldly at the top of the sheet.
|
543 |
+
Dill pickles are sour but taste fine.
|
544 |
+
Down that road is the way to the grain farmer.
|
545 |
+
Either mud or dust are found at all times.
|
546 |
+
The best method is to fix it in place with clips.
|
547 |
+
If you mumble your speech will be lost.
|
548 |
+
At night the alarm roused him from a deep sleep.
|
549 |
+
Read just what the meter says.
|
550 |
+
Fill your pack with bright trinkets for the poor.
|
551 |
+
The small red neon lamp went out.
|
552 |
+
Clams are small, round, soft, and tasty.
|
553 |
+
The fan whirled its round blades softly.
|
554 |
+
The line where the edges join was clean.
|
555 |
+
Breathe deep and smell the piny air.
|
556 |
+
It matters not if he reads these words or those.
|
557 |
+
A brown leather bag hung from its strap.
|
558 |
+
A toad and a frog are hard to tell apart.
|
559 |
+
A white silk jacket goes with any shoes.
|
560 |
+
A break in the dam almost caused a flood.
|
561 |
+
Paint the sockets in the wall dull green.
|
562 |
+
The child crawled into the dense grass.
|
563 |
+
Bribes fail where honest men work.
|
564 |
+
Trample the spark, else the flames will spread.
|
565 |
+
The hilt of the sword was carved with fine designs.
|
566 |
+
A round hole was drilled through the thin board.
|
567 |
+
Footprints showed the path he took up the beach.
|
568 |
+
She was waiting at my front lawn.
|
569 |
+
A vent near the edge brought in fresh air.
|
570 |
+
Prod the old mule with a crooked stick.
|
571 |
+
It is a band of steel three inches wide.
|
572 |
+
The pipe ran almost the length of the ditch.
|
573 |
+
It was hidden from sight by a mass of leaves and shrubs.
|
574 |
+
The weight of the package was seen on the high scale.
|
575 |
+
Wake and rise, and step into the green outdoors.
|
576 |
+
The green light in the brown box flickered.
|
577 |
+
The brass tube circled the high wall.
|
578 |
+
The lobes of her ears were pierced to hold rings.
|
579 |
+
Hold the hammer near the end to drive the nail.
|
580 |
+
Next Sunday is the twelfth of the month.
|
581 |
+
Every word and phrase he speaks is true.
|
582 |
+
He put his last cartridge into the gun and fired.
|
583 |
+
They took their kids from the public school.
|
584 |
+
Drive the screw straight into the wood.
|
585 |
+
Keep the hatch tight and the watch constant.
|
586 |
+
Sever the twine with a quick snip of the knife.
|
587 |
+
Paper will dry out when wet.
|
588 |
+
Slide the catch back and open the desk.
|
589 |
+
Help the weak to preserve their strength.
|
590 |
+
A sullen smile gets few friends.
|
591 |
+
Stop whistling and watch the boys march.
|
592 |
+
Jerk the cord, and out tumbles the gold.
|
593 |
+
Slide the tray across the glass top.
|
594 |
+
The cloud moved in a stately way and was gone.
|
595 |
+
Light maple makes for a swell room.
|
596 |
+
Set the piece here and say nothing.
|
597 |
+
Dull stories make her laugh.
|
598 |
+
A stiff cord will do to fasten your shoe.
|
599 |
+
Get the trust fund to the bank early.
|
600 |
+
Choose between the high road and the low.
|
601 |
+
A plea for funds seems to come again.
|
602 |
+
He lent his coat to the tall gaunt stranger.
|
603 |
+
There is a strong chance it will happen once more.
|
604 |
+
The duke left the park in a silver coach.
|
605 |
+
Greet the new guests and leave quickly.
|
606 |
+
When the frost has come it is time for turkey.
|
607 |
+
Sweet words work better than fierce.
|
608 |
+
A thin stripe runs down the middle.
|
609 |
+
A six comes up more often than a ten.
|
610 |
+
Lush fern grow on the lofty rocks.
|
611 |
+
The ram scared the school children off.
|
612 |
+
The team with the best timing looks good.
|
613 |
+
The farmer swapped his horse for a brown ox.
|
614 |
+
Sit on the perch and tell the others what to do.
|
615 |
+
A steep trail is painful for our feet.
|
616 |
+
The early phase of life moves fast.
|
617 |
+
Green moss grows on the northern side.
|
618 |
+
Tea in thin china has a sweet taste.
|
619 |
+
Pitch the straw through the door of the stable.
|
620 |
+
The latch on the back gate needed a nail.
|
621 |
+
The goose was brought straight from the old market.
|
622 |
+
The sink is the thing in which we pile dishes.
|
623 |
+
A whiff of it will cure the most stubborn cold.
|
624 |
+
The facts don't always show who is right.
|
625 |
+
She flaps her cape as she parades the street.
|
626 |
+
The loss of the cruiser was a blow to the fleet.
|
627 |
+
Loop the braid to the left and then over.
|
628 |
+
Plead with the lawyer to drop the lost cause.
|
629 |
+
Calves thrive on tender spring grass.
|
630 |
+
Post no bills on this office wall.
|
631 |
+
Tear a thin sheet from the yellow pad.
|
632 |
+
A cruise in warm waters in a sleek yacht is fun.
|
633 |
+
A streak of color ran down the left edge.
|
634 |
+
It was done before the boy could see it.
|
635 |
+
Crouch before you jump or miss the mark.
|
636 |
+
Pack the kits and don't forget the salt.
|
637 |
+
The square peg will settle in the round hole.
|
638 |
+
Fine soap saves tender skin.
|
639 |
+
Poached eggs and tea must suffice.
|
640 |
+
Bad nerves are jangled by a door slam.
|
641 |
+
Ship maps are different from those for planes.
|
642 |
+
Dimes showered down from all sides.
|
643 |
+
They sang the same tunes at each party.
|
644 |
+
The sky in the west is tinged with orange red.
|
645 |
+
The pods of peas ferment in bare fields.
|
646 |
+
The horse balked and threw the tall rider.
|
647 |
+
The hitch between the horse and cart broke.
|
648 |
+
Pile the coal high in the shed corner.
|
649 |
+
A gold vase is both rare and costly.
|
650 |
+
The knife was hung inside its bright sheath.
|
651 |
+
The rarest spice comes from the far East.
|
652 |
+
The roof should be tilted at a sharp slant.
|
653 |
+
A smatter of French is worse than none.
|
654 |
+
The mule trod the treadmill day and night.
|
655 |
+
The aim of the contest is to raise a great fund.
|
656 |
+
To send it now in large amounts is bad.
|
657 |
+
There is a fine hard tang in salty air.
|
658 |
+
Cod is the main business of the north shore.
|
659 |
+
The slab was hewn from heavy blocks of slate.
|
660 |
+
Dunk the stale biscuits into strong drink.
|
661 |
+
Hang tinsel from both branches.
|
662 |
+
Cap the jar with a tight brass cover.
|
663 |
+
The poor boy missed the boat again.
|
664 |
+
Be sure to set the lamp firmly in the hole.
|
665 |
+
Pick a card and slip it under the pack.
|
666 |
+
A round mat will cover the dull spot.
|
667 |
+
The first part of the plan needs changing.
|
668 |
+
A good book informs of what we ought to know.
|
669 |
+
The mail comes in three batches per day.
|
670 |
+
You cannot brew tea in a cold pot.
|
671 |
+
Dots of light betrayed the black cat.
|
672 |
+
Put the chart on the mantel and tack it down.
|
673 |
+
The night shift men rate extra pay.
|
674 |
+
The red paper brightened the dim stage.
|
675 |
+
See the player scoot to third base.
|
676 |
+
Slide the bill between the two leaves.
|
677 |
+
Many hands help get the job done.
|
678 |
+
We don't like to admit our small faults.
|
679 |
+
No doubt about the way the wind blows.
|
680 |
+
Dig deep in the earth for pirate's gold.
|
681 |
+
The steady drip is worse than a drenching rain.
|
682 |
+
A flat pack takes less luggage space.
|
683 |
+
Green ice frosted the punch bowl.
|
684 |
+
A stuffed chair slipped from the moving van.
|
685 |
+
The stitch will serve but needs to be shortened.
|
686 |
+
A thin book fits in the side pocket.
|
687 |
+
The gloss on top made it unfit to read.
|
688 |
+
The hail pattered on the burnt brown grass.
|
689 |
+
Seven seals were stamped on great sheets.
|
690 |
+
Our troops are set to strike heavy blows.
|
691 |
+
The store was jammed before the sale could start.
|
692 |
+
It was a bad error on the part of the new judge.
|
693 |
+
One step more and the board will collapse.
|
694 |
+
Take the match and strike it against your shoe.
|
695 |
+
The pot boiled, but the contents failed to jell.
|
696 |
+
The baby puts his right foot in his mouth.
|
697 |
+
The bombs left most of the town in ruins.
|
698 |
+
Stop and stare at the hard working man.
|
699 |
+
The streets are narrow and full of sharp turns.
|
700 |
+
The pup jerked the leash as he saw a feline shape.
|
701 |
+
Open your book to the first page.
|
702 |
+
Fish evade the net and swim off.
|
703 |
+
Dip the pail once and let it settle.
|
704 |
+
Will you please answer that phone.
|
705 |
+
The big red apple fell to the ground.
|
706 |
+
The curtain rose and the show was on.
|
707 |
+
The young prince became heir to the throne.
|
708 |
+
He sent the boy on a short errand.
|
709 |
+
Leave now and you will arrive on time.
|
710 |
+
The corner store was robbed last night.
|
711 |
+
A gold ring will please most any girl.
|
712 |
+
The long journey home took a year.
|
713 |
+
She saw a cat in the neighbor's house.
|
714 |
+
A pink shell was found on the sandy beach.
|
715 |
+
Small children came to see him.
|
716 |
+
The grass and bushes were wet with dew.
|
717 |
+
The blind man counted his old coins.
|
718 |
+
A severe storm tore down the barn.
|
719 |
+
She called his name many times.
|
720 |
+
When you hear the bell, come quickly.
|
$extensions/coqui_tts/languages.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Arabic": "ar",
|
3 |
+
"Chinese": "zh-cn",
|
4 |
+
"Czech": "cs",
|
5 |
+
"Dutch": "nl",
|
6 |
+
"English": "en",
|
7 |
+
"French": "fr",
|
8 |
+
"German": "de",
|
9 |
+
"Hungarian": "hu",
|
10 |
+
"Italian": "it",
|
11 |
+
"Japanese": "ja",
|
12 |
+
"Korean": "ko",
|
13 |
+
"Polish": "pl",
|
14 |
+
"Portuguese": "pt",
|
15 |
+
"Russian": "ru",
|
16 |
+
"Spanish": "es",
|
17 |
+
"Turkish": "tr"
|
18 |
+
}
|
$extensions/coqui_tts/requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
TTS==0.21.*
|
$extensions/coqui_tts/script.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from modules import chat, shared, ui_chat
|
10 |
+
from modules.logging_colors import logger
|
11 |
+
from modules.ui import create_refresh_button
|
12 |
+
from modules.utils import gradio
|
13 |
+
|
14 |
+
try:
|
15 |
+
from TTS.api import TTS
|
16 |
+
from TTS.utils.synthesizer import Synthesizer
|
17 |
+
except ModuleNotFoundError:
|
18 |
+
logger.error(
|
19 |
+
"Could not find the TTS module. Make sure to install the requirements for the coqui_tts extension."
|
20 |
+
"\n"
|
21 |
+
"\nLinux / Mac:\npip install -r extensions/coqui_tts/requirements.txt\n"
|
22 |
+
"\nWindows:\npip install -r extensions\\coqui_tts\\requirements.txt\n"
|
23 |
+
"\n"
|
24 |
+
"If you used the one-click installer, paste the command above in the terminal window launched after running the \"cmd_\" script. On Windows, that's \"cmd_windows.bat\"."
|
25 |
+
)
|
26 |
+
|
27 |
+
raise
|
28 |
+
|
29 |
+
|
30 |
+
params = {
|
31 |
+
"activate": True,
|
32 |
+
"autoplay": True,
|
33 |
+
"show_text": False,
|
34 |
+
"remove_trailing_dots": False,
|
35 |
+
"voice": "female_01.wav",
|
36 |
+
"language": "English",
|
37 |
+
"model_name": "tts_models/multilingual/multi-dataset/xtts_v2",
|
38 |
+
"device": "cuda"
|
39 |
+
}
|
40 |
+
|
41 |
+
this_dir = str(Path(__file__).parent.resolve())
|
42 |
+
model = None
|
43 |
+
with open(Path(f"{this_dir}/languages.json"), encoding='utf8') as f:
|
44 |
+
languages = json.load(f)
|
45 |
+
|
46 |
+
|
47 |
+
def get_available_voices():
|
48 |
+
return sorted([voice.name for voice in Path(f"{this_dir}/voices").glob("*.wav")])
|
49 |
+
|
50 |
+
|
51 |
+
def preprocess(raw_input):
|
52 |
+
raw_input = html.unescape(raw_input)
|
53 |
+
# raw_input = raw_input.strip("\"")
|
54 |
+
return raw_input
|
55 |
+
|
56 |
+
|
57 |
+
def new_split_into_sentences(self, text):
|
58 |
+
sentences = self.seg.segment(text)
|
59 |
+
if params['remove_trailing_dots']:
|
60 |
+
sentences_without_dots = []
|
61 |
+
for sentence in sentences:
|
62 |
+
if sentence.endswith('.') and not sentence.endswith('...'):
|
63 |
+
sentence = sentence[:-1]
|
64 |
+
|
65 |
+
sentences_without_dots.append(sentence)
|
66 |
+
|
67 |
+
return sentences_without_dots
|
68 |
+
else:
|
69 |
+
return sentences
|
70 |
+
|
71 |
+
|
72 |
+
Synthesizer.split_into_sentences = new_split_into_sentences
|
73 |
+
|
74 |
+
|
75 |
+
def load_model():
|
76 |
+
model = TTS(params["model_name"]).to(params["device"])
|
77 |
+
return model
|
78 |
+
|
79 |
+
|
80 |
+
def remove_tts_from_history(history):
|
81 |
+
for i, entry in enumerate(history['internal']):
|
82 |
+
history['visible'][i] = [history['visible'][i][0], entry[1]]
|
83 |
+
|
84 |
+
return history
|
85 |
+
|
86 |
+
|
87 |
+
def toggle_text_in_history(history):
|
88 |
+
for i, entry in enumerate(history['visible']):
|
89 |
+
visible_reply = entry[1]
|
90 |
+
if visible_reply.startswith('<audio'):
|
91 |
+
if params['show_text']:
|
92 |
+
reply = history['internal'][i][1]
|
93 |
+
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
94 |
+
else:
|
95 |
+
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
96 |
+
|
97 |
+
return history
|
98 |
+
|
99 |
+
|
100 |
+
def random_sentence():
|
101 |
+
with open(Path("extensions/coqui_tts/harvard_sentences.txt")) as f:
|
102 |
+
return random.choice(list(f))
|
103 |
+
|
104 |
+
|
105 |
+
def voice_preview(string):
|
106 |
+
string = html.unescape(string) or random_sentence()
|
107 |
+
|
108 |
+
output_file = Path('extensions/coqui_tts/outputs/voice_preview.wav')
|
109 |
+
model.tts_to_file(
|
110 |
+
text=string,
|
111 |
+
file_path=output_file,
|
112 |
+
speaker_wav=[f"{this_dir}/voices/{params['voice']}"],
|
113 |
+
language=languages[params["language"]]
|
114 |
+
)
|
115 |
+
|
116 |
+
return f'<audio src="file/{output_file.as_posix()}?{int(time.time())}" controls autoplay></audio>'
|
117 |
+
|
118 |
+
|
119 |
+
def history_modifier(history):
|
120 |
+
# Remove autoplay from the last reply
|
121 |
+
if len(history['internal']) > 0:
|
122 |
+
history['visible'][-1] = [
|
123 |
+
history['visible'][-1][0],
|
124 |
+
history['visible'][-1][1].replace('controls autoplay>', 'controls>')
|
125 |
+
]
|
126 |
+
|
127 |
+
return history
|
128 |
+
|
129 |
+
|
130 |
+
def state_modifier(state):
|
131 |
+
if not params['activate']:
|
132 |
+
return state
|
133 |
+
|
134 |
+
state['stream'] = False
|
135 |
+
return state
|
136 |
+
|
137 |
+
|
138 |
+
def input_modifier(string, state):
|
139 |
+
if not params['activate']:
|
140 |
+
return string
|
141 |
+
|
142 |
+
shared.processing_message = "*Is recording a voice message...*"
|
143 |
+
return string
|
144 |
+
|
145 |
+
|
146 |
+
def output_modifier(string, state):
|
147 |
+
if not params['activate']:
|
148 |
+
return string
|
149 |
+
|
150 |
+
original_string = string
|
151 |
+
string = preprocess(html.unescape(string))
|
152 |
+
if string == '':
|
153 |
+
string = '*Empty reply, try regenerating*'
|
154 |
+
else:
|
155 |
+
output_file = Path(f'extensions/coqui_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav')
|
156 |
+
model.tts_to_file(
|
157 |
+
text=string,
|
158 |
+
file_path=output_file,
|
159 |
+
speaker_wav=[f"{this_dir}/voices/{params['voice']}"],
|
160 |
+
language=languages[params["language"]]
|
161 |
+
)
|
162 |
+
|
163 |
+
autoplay = 'autoplay' if params['autoplay'] else ''
|
164 |
+
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
|
165 |
+
if params['show_text']:
|
166 |
+
string += f'\n\n{original_string}'
|
167 |
+
|
168 |
+
shared.processing_message = "*Is typing...*"
|
169 |
+
return string
|
170 |
+
|
171 |
+
|
172 |
+
def custom_css():
|
173 |
+
path_to_css = Path(f"{this_dir}/style.css")
|
174 |
+
return open(path_to_css, 'r').read()
|
175 |
+
|
176 |
+
|
177 |
+
def setup():
|
178 |
+
global model
|
179 |
+
print("[XTTS] Loading XTTS...")
|
180 |
+
model = load_model()
|
181 |
+
print("[XTTS] Done!")
|
182 |
+
Path(f"{this_dir}/outputs").mkdir(parents=True, exist_ok=True)
|
183 |
+
|
184 |
+
|
185 |
+
def ui():
|
186 |
+
with gr.Accordion("Coqui TTS (XTTSv2)"):
|
187 |
+
with gr.Row():
|
188 |
+
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
189 |
+
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
|
190 |
+
|
191 |
+
with gr.Row():
|
192 |
+
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
|
193 |
+
remove_trailing_dots = gr.Checkbox(value=params['remove_trailing_dots'], label='Remove trailing "." from text segments before converting to audio')
|
194 |
+
|
195 |
+
with gr.Row():
|
196 |
+
with gr.Row():
|
197 |
+
voice = gr.Dropdown(get_available_voices(), label="Voice wav", value=params["voice"])
|
198 |
+
create_refresh_button(voice, lambda: None, lambda: {'choices': get_available_voices(), 'value': params["voice"]}, 'refresh-button')
|
199 |
+
|
200 |
+
language = gr.Dropdown(languages.keys(), label="Language", value=params["language"])
|
201 |
+
|
202 |
+
with gr.Row():
|
203 |
+
preview_text = gr.Text(show_label=False, placeholder="Preview text", elem_id="silero_preview_text")
|
204 |
+
preview_play = gr.Button("Preview")
|
205 |
+
preview_audio = gr.HTML(visible=False)
|
206 |
+
|
207 |
+
with gr.Row():
|
208 |
+
convert = gr.Button('Permanently replace audios with the message texts')
|
209 |
+
convert_cancel = gr.Button('Cancel', visible=False)
|
210 |
+
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
|
211 |
+
|
212 |
+
# Convert history with confirmation
|
213 |
+
convert_arr = [convert_confirm, convert, convert_cancel]
|
214 |
+
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
215 |
+
convert_confirm.click(
|
216 |
+
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
|
217 |
+
remove_tts_from_history, gradio('history'), gradio('history')).then(
|
218 |
+
chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None).then(
|
219 |
+
chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display'))
|
220 |
+
|
221 |
+
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
222 |
+
|
223 |
+
# Toggle message text in history
|
224 |
+
show_text.change(
|
225 |
+
lambda x: params.update({"show_text": x}), show_text, None).then(
|
226 |
+
toggle_text_in_history, gradio('history'), gradio('history')).then(
|
227 |
+
chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None).then(
|
228 |
+
chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display'))
|
229 |
+
|
230 |
+
# Event functions to update the parameters in the backend
|
231 |
+
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
232 |
+
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
|
233 |
+
remove_trailing_dots.change(lambda x: params.update({"remove_trailing_dots": x}), remove_trailing_dots, None)
|
234 |
+
voice.change(lambda x: params.update({"voice": x}), voice, None)
|
235 |
+
language.change(lambda x: params.update({"language": x}), language, None)
|
236 |
+
|
237 |
+
# Play preview
|
238 |
+
preview_text.submit(voice_preview, preview_text, preview_audio)
|
239 |
+
preview_play.click(voice_preview, preview_text, preview_audio)
|
$extensions/coqui_tts/style.css
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.SDAP .hires_opts input[type="number"] {
|
2 |
+
width: 6em !important;
|
3 |
+
}
|
4 |
+
|
5 |
+
/* silero_tts preview */
|
6 |
+
.form:has(> #silero_preview_text) {
|
7 |
+
min-width: 75%
|
8 |
+
}
|
$extensions/coqui_tts/voices/arnold.wav
ADDED
Binary file (897 kB). View file
|
|
$extensions/coqui_tts/voices/female_01.wav
ADDED
Binary file (501 kB). View file
|
|
$extensions/coqui_tts/voices/female_02.wav
ADDED
Binary file (334 kB). View file
|
|
$extensions/elevenlabs_tts/outputs/outputs-will-be-saved-here.txt
ADDED
File without changes
|
$extensions/elevenlabs_tts/requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
elevenlabs==0.2.24
|
$extensions/elevenlabs_tts/script.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import re
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import elevenlabs
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from modules import chat, shared, ui_chat
|
9 |
+
from modules.logging_colors import logger
|
10 |
+
from modules.utils import gradio
|
11 |
+
|
12 |
+
params = {
|
13 |
+
'activate': True,
|
14 |
+
'api_key': None,
|
15 |
+
'selected_voice': 'None',
|
16 |
+
'autoplay': False,
|
17 |
+
'show_text': True,
|
18 |
+
'model': 'eleven_monolingual_v1',
|
19 |
+
}
|
20 |
+
|
21 |
+
voices = None
|
22 |
+
wav_idx = 0
|
23 |
+
LANG_MODELS = ['eleven_monolingual_v1', 'eleven_multilingual_v1']
|
24 |
+
|
25 |
+
|
26 |
+
def update_api_key(key):
|
27 |
+
params['api_key'] = key
|
28 |
+
if key is not None:
|
29 |
+
elevenlabs.set_api_key(key)
|
30 |
+
|
31 |
+
|
32 |
+
def refresh_voices():
|
33 |
+
global params
|
34 |
+
your_voices = elevenlabs.voices()
|
35 |
+
voice_names = [voice.name for voice in your_voices]
|
36 |
+
return voice_names
|
37 |
+
|
38 |
+
|
39 |
+
def refresh_voices_dd():
|
40 |
+
all_voices = refresh_voices()
|
41 |
+
return gr.Dropdown.update(value=all_voices[0], choices=all_voices)
|
42 |
+
|
43 |
+
|
44 |
+
def remove_tts_from_history(history):
|
45 |
+
for i, entry in enumerate(history['internal']):
|
46 |
+
history['visible'][i] = [history['visible'][i][0], entry[1]]
|
47 |
+
|
48 |
+
return history
|
49 |
+
|
50 |
+
|
51 |
+
def toggle_text_in_history(history):
|
52 |
+
for i, entry in enumerate(history['visible']):
|
53 |
+
visible_reply = entry[1]
|
54 |
+
if visible_reply.startswith('<audio'):
|
55 |
+
if params['show_text']:
|
56 |
+
reply = history['internal'][i][1]
|
57 |
+
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
58 |
+
else:
|
59 |
+
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
60 |
+
|
61 |
+
return history
|
62 |
+
|
63 |
+
|
64 |
+
def remove_surrounded_chars(string):
|
65 |
+
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
66 |
+
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
67 |
+
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
68 |
+
|
69 |
+
|
70 |
+
def state_modifier(state):
|
71 |
+
if not params['activate']:
|
72 |
+
return state
|
73 |
+
|
74 |
+
state['stream'] = False
|
75 |
+
return state
|
76 |
+
|
77 |
+
|
78 |
+
def input_modifier(string):
|
79 |
+
if not params['activate']:
|
80 |
+
return string
|
81 |
+
|
82 |
+
shared.processing_message = "*Is recording a voice message...*"
|
83 |
+
return string
|
84 |
+
|
85 |
+
|
86 |
+
def history_modifier(history):
|
87 |
+
# Remove autoplay from the last reply
|
88 |
+
if len(history['internal']) > 0:
|
89 |
+
history['visible'][-1] = [
|
90 |
+
history['visible'][-1][0],
|
91 |
+
history['visible'][-1][1].replace('controls autoplay>', 'controls>')
|
92 |
+
]
|
93 |
+
|
94 |
+
return history
|
95 |
+
|
96 |
+
|
97 |
+
def output_modifier(string):
|
98 |
+
global params, wav_idx
|
99 |
+
|
100 |
+
if not params['activate']:
|
101 |
+
return string
|
102 |
+
|
103 |
+
original_string = string
|
104 |
+
string = remove_surrounded_chars(string)
|
105 |
+
string = string.replace('"', '')
|
106 |
+
string = string.replace('“', '')
|
107 |
+
string = string.replace('\n', ' ')
|
108 |
+
string = string.strip()
|
109 |
+
if string == '':
|
110 |
+
string = 'empty reply, try regenerating'
|
111 |
+
|
112 |
+
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.mp3'.format(wav_idx))
|
113 |
+
print(f'Outputting audio to {str(output_file)}')
|
114 |
+
try:
|
115 |
+
audio = elevenlabs.generate(text=html.unescape(string), voice=params['selected_voice'], model=params['model'])
|
116 |
+
elevenlabs.save(audio, str(output_file))
|
117 |
+
|
118 |
+
autoplay = 'autoplay' if params['autoplay'] else ''
|
119 |
+
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
|
120 |
+
wav_idx += 1
|
121 |
+
except elevenlabs.api.error.UnauthenticatedRateLimitError:
|
122 |
+
string = "🤖 ElevenLabs Unauthenticated Rate Limit Reached - Please create an API key to continue\n\n"
|
123 |
+
except elevenlabs.api.error.RateLimitError:
|
124 |
+
string = "🤖 ElevenLabs API Tier Limit Reached\n\n"
|
125 |
+
except elevenlabs.api.error.APIError as err:
|
126 |
+
string = f"🤖 ElevenLabs Error: {err}\n\n"
|
127 |
+
|
128 |
+
if params['show_text']:
|
129 |
+
string += f'\n\n{original_string}'
|
130 |
+
|
131 |
+
shared.processing_message = "*Is typing...*"
|
132 |
+
return string
|
133 |
+
|
134 |
+
|
135 |
+
def ui():
|
136 |
+
global voices
|
137 |
+
if not voices:
|
138 |
+
voices = refresh_voices()
|
139 |
+
selected = params['selected_voice']
|
140 |
+
if selected == 'None':
|
141 |
+
params['selected_voice'] = voices[0]
|
142 |
+
elif selected not in voices:
|
143 |
+
logger.error(f'Selected voice {selected} not available, switching to {voices[0]}')
|
144 |
+
params['selected_voice'] = voices[0]
|
145 |
+
|
146 |
+
# Gradio elements
|
147 |
+
with gr.Row():
|
148 |
+
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
149 |
+
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
|
150 |
+
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
|
151 |
+
|
152 |
+
with gr.Row():
|
153 |
+
voice = gr.Dropdown(value=params['selected_voice'], choices=voices, label='TTS Voice')
|
154 |
+
refresh = gr.Button(value='Refresh')
|
155 |
+
|
156 |
+
with gr.Row():
|
157 |
+
if params['api_key']:
|
158 |
+
api_key = gr.Textbox(value=params['api_key'], label='API Key')
|
159 |
+
update_api_key(params['api_key'])
|
160 |
+
else:
|
161 |
+
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
|
162 |
+
|
163 |
+
with gr.Row():
|
164 |
+
model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
|
165 |
+
|
166 |
+
with gr.Row():
|
167 |
+
convert = gr.Button('Permanently replace audios with the message texts')
|
168 |
+
convert_cancel = gr.Button('Cancel', visible=False)
|
169 |
+
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
|
170 |
+
|
171 |
+
# Convert history with confirmation
|
172 |
+
convert_arr = [convert_confirm, convert, convert_cancel]
|
173 |
+
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
174 |
+
convert_confirm.click(
|
175 |
+
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
|
176 |
+
remove_tts_from_history, gradio('history'), gradio('history')).then(
|
177 |
+
chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None).then(
|
178 |
+
chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display'))
|
179 |
+
|
180 |
+
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
181 |
+
|
182 |
+
# Toggle message text in history
|
183 |
+
show_text.change(
|
184 |
+
lambda x: params.update({"show_text": x}), show_text, None).then(
|
185 |
+
toggle_text_in_history, gradio('history'), gradio('history')).then(
|
186 |
+
chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None).then(
|
187 |
+
chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display'))
|
188 |
+
|
189 |
+
# Event functions to update the parameters in the backend
|
190 |
+
activate.change(lambda x: params.update({'activate': x}), activate, None)
|
191 |
+
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
|
192 |
+
api_key.change(update_api_key, api_key, None)
|
193 |
+
model.change(lambda x: params.update({'model': x}), model, None)
|
194 |
+
# connect.click(check_valid_api, [], connection_status)
|
195 |
+
refresh.click(refresh_voices_dd, [], voice)
|
196 |
+
# Event functions to update the parameters in the backend
|
197 |
+
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
|
$extensions/example/script.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
An example of extension. It does nothing, but you can add transformations
|
3 |
+
before the return statements to customize the webui behavior.
|
4 |
+
|
5 |
+
Starting from history_modifier and ending in output_modifier, the
|
6 |
+
functions are declared in the same order that they are called at
|
7 |
+
generation time.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import torch
|
12 |
+
from transformers import LogitsProcessor
|
13 |
+
|
14 |
+
from modules import chat, shared
|
15 |
+
from modules.text_generation import (
|
16 |
+
decode,
|
17 |
+
encode,
|
18 |
+
generate_reply,
|
19 |
+
)
|
20 |
+
|
21 |
+
params = {
|
22 |
+
"display_name": "Example Extension",
|
23 |
+
"is_tab": False,
|
24 |
+
}
|
25 |
+
|
26 |
+
class MyLogits(LogitsProcessor):
|
27 |
+
"""
|
28 |
+
Manipulates the probabilities for the next token before it gets sampled.
|
29 |
+
Used in the logits_processor_modifier function below.
|
30 |
+
"""
|
31 |
+
def __init__(self):
|
32 |
+
pass
|
33 |
+
|
34 |
+
def __call__(self, input_ids, scores):
|
35 |
+
# probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
36 |
+
# probs[0] /= probs[0].sum()
|
37 |
+
# scores = torch.log(probs / (1 - probs))
|
38 |
+
return scores
|
39 |
+
|
40 |
+
def history_modifier(history):
|
41 |
+
"""
|
42 |
+
Modifies the chat history.
|
43 |
+
Only used in chat mode.
|
44 |
+
"""
|
45 |
+
return history
|
46 |
+
|
47 |
+
def state_modifier(state):
|
48 |
+
"""
|
49 |
+
Modifies the state variable, which is a dictionary containing the input
|
50 |
+
values in the UI like sliders and checkboxes.
|
51 |
+
"""
|
52 |
+
return state
|
53 |
+
|
54 |
+
def chat_input_modifier(text, visible_text, state):
|
55 |
+
"""
|
56 |
+
Modifies the user input string in chat mode (visible_text).
|
57 |
+
You can also modify the internal representation of the user
|
58 |
+
input (text) to change how it will appear in the prompt.
|
59 |
+
"""
|
60 |
+
return text, visible_text
|
61 |
+
|
62 |
+
def input_modifier(string, state, is_chat=False):
|
63 |
+
"""
|
64 |
+
In default/notebook modes, modifies the whole prompt.
|
65 |
+
|
66 |
+
In chat mode, it is the same as chat_input_modifier but only applied
|
67 |
+
to "text", here called "string", and not to "visible_text".
|
68 |
+
"""
|
69 |
+
return string
|
70 |
+
|
71 |
+
def bot_prefix_modifier(string, state):
|
72 |
+
"""
|
73 |
+
Modifies the prefix for the next bot reply in chat mode.
|
74 |
+
By default, the prefix will be something like "Bot Name:".
|
75 |
+
"""
|
76 |
+
return string
|
77 |
+
|
78 |
+
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
79 |
+
"""
|
80 |
+
Modifies the input ids and embeds.
|
81 |
+
Used by the multimodal extension to put image embeddings in the prompt.
|
82 |
+
Only used by loaders that use the transformers library for sampling.
|
83 |
+
"""
|
84 |
+
return prompt, input_ids, input_embeds
|
85 |
+
|
86 |
+
def logits_processor_modifier(processor_list, input_ids):
|
87 |
+
"""
|
88 |
+
Adds logits processors to the list, allowing you to access and modify
|
89 |
+
the next token probabilities.
|
90 |
+
Only used by loaders that use the transformers library for sampling.
|
91 |
+
"""
|
92 |
+
processor_list.append(MyLogits())
|
93 |
+
return processor_list
|
94 |
+
|
95 |
+
def output_modifier(string, state, is_chat=False):
|
96 |
+
"""
|
97 |
+
Modifies the LLM output before it gets presented.
|
98 |
+
|
99 |
+
In chat mode, the modified version goes into history['visible'],
|
100 |
+
and the original version goes into history['internal'].
|
101 |
+
"""
|
102 |
+
return string
|
103 |
+
|
104 |
+
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
105 |
+
"""
|
106 |
+
Replaces the function that generates the prompt from the chat history.
|
107 |
+
Only used in chat mode.
|
108 |
+
"""
|
109 |
+
result = chat.generate_chat_prompt(user_input, state, **kwargs)
|
110 |
+
return result
|
111 |
+
|
112 |
+
def custom_css():
|
113 |
+
"""
|
114 |
+
Returns a CSS string that gets appended to the CSS for the webui.
|
115 |
+
"""
|
116 |
+
return ''
|
117 |
+
|
118 |
+
def custom_js():
|
119 |
+
"""
|
120 |
+
Returns a javascript string that gets appended to the javascript
|
121 |
+
for the webui.
|
122 |
+
"""
|
123 |
+
return ''
|
124 |
+
|
125 |
+
def setup():
|
126 |
+
"""
|
127 |
+
Gets executed only once, when the extension is imported.
|
128 |
+
"""
|
129 |
+
pass
|
130 |
+
|
131 |
+
def ui():
|
132 |
+
"""
|
133 |
+
Gets executed when the UI is drawn. Custom gradio elements and
|
134 |
+
their corresponding event handlers should be defined here.
|
135 |
+
|
136 |
+
To learn about gradio components, check out the docs:
|
137 |
+
https://gradio.app/docs/
|
138 |
+
"""
|
139 |
+
pass
|
$extensions/gallery/__pycache__/script.cpython-311.pyc
ADDED
Binary file (6.9 kB). View file
|
|
$extensions/gallery/script.js
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
let gallery_element = document.getElementById('gallery-extension');
|
2 |
+
let chat_mode_element = document.getElementById('chat-mode');
|
3 |
+
|
4 |
+
let extensions_block = document.getElementById('extensions');
|
5 |
+
let extensions_block_size = extensions_block.childNodes.length;
|
6 |
+
let gallery_only = (extensions_block_size == 5);
|
7 |
+
|
8 |
+
function gotoFirstPage() {
|
9 |
+
const firstPageButton = gallery_element.querySelector('.paginate > button');
|
10 |
+
if (firstPageButton) {
|
11 |
+
firstPageButton.click();
|
12 |
+
}
|
13 |
+
}
|
14 |
+
|
15 |
+
document.querySelector('.header_bar').addEventListener('click', function(event) {
|
16 |
+
if (event.target.tagName === 'BUTTON') {
|
17 |
+
const buttonText = event.target.textContent.trim();
|
18 |
+
|
19 |
+
let chat_visible = (buttonText == 'Chat');
|
20 |
+
let default_visible = (buttonText == 'Default');
|
21 |
+
let notebook_visible = (buttonText == 'Notebook');
|
22 |
+
let chat_mode_visible = (chat_mode_element.offsetHeight > 0 && chat_mode_element.offsetWidth > 0);
|
23 |
+
|
24 |
+
// Only show this extension in the Chat tab
|
25 |
+
if (chat_visible) {
|
26 |
+
if (chat_mode_visible) {
|
27 |
+
gallery_element.style.display = 'block';
|
28 |
+
extensions_block.style.display = '';
|
29 |
+
} else {
|
30 |
+
gallery_element.style.display = 'none';
|
31 |
+
extensions_block.style.display = 'none';
|
32 |
+
}
|
33 |
+
} else {
|
34 |
+
gallery_element.style.display = 'none';
|
35 |
+
if (gallery_only) {
|
36 |
+
extensions_block.style.display = 'none';
|
37 |
+
}
|
38 |
+
}
|
39 |
+
}
|
40 |
+
});
|
$extensions/gallery/script.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from modules.html_generator import get_image_cache
|
6 |
+
from modules.shared import gradio, settings
|
7 |
+
|
8 |
+
|
9 |
+
cards = []
|
10 |
+
|
11 |
+
|
12 |
+
def generate_css():
|
13 |
+
css = """
|
14 |
+
.highlighted-border {
|
15 |
+
border-color: rgb(249, 115, 22) !important;
|
16 |
+
}
|
17 |
+
|
18 |
+
.character-gallery > .gallery {
|
19 |
+
margin: 1rem 0;
|
20 |
+
display: grid !important;
|
21 |
+
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
|
22 |
+
grid-column-gap: 0.4rem;
|
23 |
+
grid-row-gap: 1.2rem;
|
24 |
+
}
|
25 |
+
|
26 |
+
.character-gallery > .label {
|
27 |
+
display: none !important;
|
28 |
+
}
|
29 |
+
|
30 |
+
.character-gallery button.gallery-item {
|
31 |
+
display: contents;
|
32 |
+
}
|
33 |
+
|
34 |
+
.character-container {
|
35 |
+
cursor: pointer;
|
36 |
+
text-align: center;
|
37 |
+
position: relative;
|
38 |
+
opacity: 0.85;
|
39 |
+
}
|
40 |
+
|
41 |
+
.character-container:hover {
|
42 |
+
opacity: 1;
|
43 |
+
}
|
44 |
+
|
45 |
+
.character-container .placeholder, .character-container img {
|
46 |
+
width: 150px;
|
47 |
+
height: 200px;
|
48 |
+
background-color: gray;
|
49 |
+
object-fit: cover;
|
50 |
+
margin: 0 auto;
|
51 |
+
border-radius: 1rem;
|
52 |
+
border: 3px solid white;
|
53 |
+
box-shadow: 3px 3px 6px 0px rgb(0 0 0 / 50%);
|
54 |
+
}
|
55 |
+
|
56 |
+
.character-name {
|
57 |
+
margin-top: 0.3rem;
|
58 |
+
display: block;
|
59 |
+
font-size: 1.2rem;
|
60 |
+
font-weight: 600;
|
61 |
+
overflow-wrap: anywhere;
|
62 |
+
}
|
63 |
+
"""
|
64 |
+
return css
|
65 |
+
|
66 |
+
|
67 |
+
def generate_html():
|
68 |
+
global cards
|
69 |
+
cards = []
|
70 |
+
# Iterate through files in image folder
|
71 |
+
for file in sorted(Path("characters").glob("*")):
|
72 |
+
if file.suffix in [".json", ".yml", ".yaml"]:
|
73 |
+
character = file.stem
|
74 |
+
container_html = '<div class="character-container">'
|
75 |
+
image_html = "<div class='placeholder'></div>"
|
76 |
+
|
77 |
+
for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
|
78 |
+
if path.exists():
|
79 |
+
image_html = f'<img src="file/{get_image_cache(path)}">'
|
80 |
+
break
|
81 |
+
|
82 |
+
container_html += f'{image_html} <span class="character-name">{character}</span>'
|
83 |
+
container_html += "</div>"
|
84 |
+
cards.append([container_html, character])
|
85 |
+
|
86 |
+
return cards
|
87 |
+
|
88 |
+
|
89 |
+
def filter_cards(filter_str=''):
|
90 |
+
if filter_str == '':
|
91 |
+
return cards
|
92 |
+
|
93 |
+
filter_upper = filter_str.upper()
|
94 |
+
return [k for k in cards if filter_upper in k[1].upper()]
|
95 |
+
|
96 |
+
|
97 |
+
def select_character(evt: gr.SelectData):
|
98 |
+
return (evt.value[1])
|
99 |
+
|
100 |
+
|
101 |
+
def custom_js():
|
102 |
+
path_to_js = Path(__file__).parent.resolve() / 'script.js'
|
103 |
+
return open(path_to_js, 'r').read()
|
104 |
+
|
105 |
+
|
106 |
+
def ui():
|
107 |
+
with gr.Accordion("Character gallery", open=settings["gallery-open"], elem_id='gallery-extension'):
|
108 |
+
gr.HTML(value="<style>" + generate_css() + "</style>")
|
109 |
+
with gr.Row():
|
110 |
+
filter_box = gr.Textbox(label='', placeholder='Filter', lines=1, max_lines=1, container=False, elem_id='gallery-filter-box')
|
111 |
+
gr.ClearButton(filter_box, value='🗑️', elem_classes='refresh-button')
|
112 |
+
update = gr.Button("Refresh", elem_classes='refresh-button')
|
113 |
+
|
114 |
+
gallery = gr.Dataset(
|
115 |
+
components=[gr.HTML(visible=False)],
|
116 |
+
label="",
|
117 |
+
samples=generate_html(),
|
118 |
+
elem_classes=["character-gallery"],
|
119 |
+
samples_per_page=settings["gallery-items_per_page"]
|
120 |
+
)
|
121 |
+
|
122 |
+
filter_box.change(lambda: None, None, None, _js=f'() => {{{custom_js()}; gotoFirstPage()}}').success(
|
123 |
+
filter_cards, filter_box, gallery).then(
|
124 |
+
lambda x: gr.update(elem_classes='highlighted-border' if x != '' else ''), filter_box, filter_box, show_progress=False)
|
125 |
+
|
126 |
+
update.click(generate_html, [], None).success(
|
127 |
+
filter_cards, filter_box, gallery)
|
128 |
+
|
129 |
+
gallery.select(select_character, None, gradio['character_menu'])
|
$extensions/google_translate/requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
deep-translator==1.9.2
|
$extensions/google_translate/script.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from deep_translator import GoogleTranslator
|
5 |
+
|
6 |
+
params = {
|
7 |
+
"activate": True,
|
8 |
+
"language string": "ja",
|
9 |
+
}
|
10 |
+
|
11 |
+
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
|
12 |
+
|
13 |
+
|
14 |
+
def input_modifier(string):
|
15 |
+
"""
|
16 |
+
This function is applied to your text inputs before
|
17 |
+
they are fed into the model.
|
18 |
+
"""
|
19 |
+
if not params['activate']:
|
20 |
+
return string
|
21 |
+
|
22 |
+
return GoogleTranslator(source=params['language string'], target='en').translate(string)
|
23 |
+
|
24 |
+
|
25 |
+
def output_modifier(string):
|
26 |
+
"""
|
27 |
+
This function is applied to the model outputs.
|
28 |
+
"""
|
29 |
+
if not params['activate']:
|
30 |
+
return string
|
31 |
+
|
32 |
+
translated_str = GoogleTranslator(source='en', target=params['language string']).translate(html.unescape(string))
|
33 |
+
return html.escape(translated_str)
|
34 |
+
|
35 |
+
|
36 |
+
def bot_prefix_modifier(string):
|
37 |
+
"""
|
38 |
+
This function is only applied in chat mode. It modifies
|
39 |
+
the prefix text for the Bot and can be used to bias its
|
40 |
+
behavior.
|
41 |
+
"""
|
42 |
+
|
43 |
+
return string
|
44 |
+
|
45 |
+
|
46 |
+
def ui():
|
47 |
+
# Finding the language name from the language code to use as the default value
|
48 |
+
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
|
49 |
+
|
50 |
+
# Gradio elements
|
51 |
+
with gr.Row():
|
52 |
+
activate = gr.Checkbox(value=params['activate'], label='Activate translation')
|
53 |
+
|
54 |
+
with gr.Row():
|
55 |
+
language = gr.Dropdown(value=language_name, choices=[k for k in language_codes], label='Language')
|
56 |
+
|
57 |
+
# Event functions to update the parameters in the backend
|
58 |
+
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
59 |
+
language.change(lambda x: params.update({"language string": language_codes[x]}), language, None)
|
$extensions/long_replies/script.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from modules import chat, shared
|
3 |
+
from modules.text_generation import (
|
4 |
+
decode,
|
5 |
+
encode,
|
6 |
+
generate_reply,
|
7 |
+
)
|
8 |
+
from transformers import LogitsProcessor
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
params = {
|
12 |
+
"display_name": "Long replies",
|
13 |
+
"is_tab": False,
|
14 |
+
"min_length": 120,
|
15 |
+
}
|
16 |
+
|
17 |
+
initial_size = 0
|
18 |
+
|
19 |
+
class MyLogits(LogitsProcessor):
|
20 |
+
"""
|
21 |
+
Manipulates the probabilities for the next token before it gets sampled.
|
22 |
+
Used in the logits_processor_modifier function below.
|
23 |
+
"""
|
24 |
+
def __init__(self):
|
25 |
+
self.newline_id = shared.tokenizer.encode('\n')[-1]
|
26 |
+
pass
|
27 |
+
|
28 |
+
def __call__(self, input_ids, scores):
|
29 |
+
if input_ids.shape[-1] - initial_size < params["min_length"]:
|
30 |
+
scores[...,self.newline_id] = -1000
|
31 |
+
# scores[...,shared.tokenizer.eos_token_id] = -1000
|
32 |
+
|
33 |
+
# probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
34 |
+
# probs[0] /= probs[0].sum()
|
35 |
+
# scores = torch.log(probs / (1 - probs))
|
36 |
+
return scores
|
37 |
+
|
38 |
+
def history_modifier(history):
|
39 |
+
"""
|
40 |
+
Modifies the chat history.
|
41 |
+
Only used in chat mode.
|
42 |
+
"""
|
43 |
+
return history
|
44 |
+
|
45 |
+
def state_modifier(state):
|
46 |
+
"""
|
47 |
+
Modifies the state variable, which is a dictionary containing the input
|
48 |
+
values in the UI like sliders and checkboxes.
|
49 |
+
"""
|
50 |
+
return state
|
51 |
+
|
52 |
+
def chat_input_modifier(text, visible_text, state):
|
53 |
+
"""
|
54 |
+
Modifies the user input string in chat mode (visible_text).
|
55 |
+
You can also modify the internal representation of the user
|
56 |
+
input (text) to change how it will appear in the prompt.
|
57 |
+
"""
|
58 |
+
return text, visible_text
|
59 |
+
|
60 |
+
def input_modifier(string, state):
|
61 |
+
"""
|
62 |
+
In default/notebook modes, modifies the whole prompt.
|
63 |
+
|
64 |
+
In chat mode, it is the same as chat_input_modifier but only applied
|
65 |
+
to "text", here called "string", and not to "visible_text".
|
66 |
+
"""
|
67 |
+
return string
|
68 |
+
|
69 |
+
def bot_prefix_modifier(string, state):
|
70 |
+
"""
|
71 |
+
Modifies the prefix for the next bot reply in chat mode.
|
72 |
+
By default, the prefix will be something like "Bot Name:".
|
73 |
+
"""
|
74 |
+
return string
|
75 |
+
|
76 |
+
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
77 |
+
"""
|
78 |
+
Modifies the input ids and embeds.
|
79 |
+
Used by the multimodal extension to put image embeddings in the prompt.
|
80 |
+
Only used by loaders that use the transformers library for sampling.
|
81 |
+
"""
|
82 |
+
|
83 |
+
global initial_size
|
84 |
+
initial_size = input_ids.shape[-1]
|
85 |
+
|
86 |
+
return prompt, input_ids, input_embeds
|
87 |
+
|
88 |
+
def logits_processor_modifier(processor_list, input_ids):
|
89 |
+
"""
|
90 |
+
Adds logits processors to the list, allowing you to access and modify
|
91 |
+
the next token probabilities.
|
92 |
+
Only used by loaders that use the transformers library for sampling.
|
93 |
+
"""
|
94 |
+
processor_list.append(MyLogits())
|
95 |
+
return processor_list
|
96 |
+
|
97 |
+
def output_modifier(string, state):
|
98 |
+
"""
|
99 |
+
Modifies the LLM output before it gets presented.
|
100 |
+
|
101 |
+
In chat mode, the modified version goes into history['visible'],
|
102 |
+
and the original version goes into history['internal'].
|
103 |
+
"""
|
104 |
+
return string
|
105 |
+
|
106 |
+
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
107 |
+
"""
|
108 |
+
Replaces the function that generates the prompt from the chat history.
|
109 |
+
Only used in chat mode.
|
110 |
+
"""
|
111 |
+
result = chat.generate_chat_prompt(user_input, state, **kwargs)
|
112 |
+
return result
|
113 |
+
|
114 |
+
def custom_css():
|
115 |
+
"""
|
116 |
+
Returns a CSS string that gets appended to the CSS for the webui.
|
117 |
+
"""
|
118 |
+
return ''
|
119 |
+
|
120 |
+
def custom_js():
|
121 |
+
"""
|
122 |
+
Returns a javascript string that gets appended to the javascript
|
123 |
+
for the webui.
|
124 |
+
"""
|
125 |
+
return ''
|
126 |
+
|
127 |
+
def setup():
|
128 |
+
"""
|
129 |
+
Gets executed only once, when the extension is imported.
|
130 |
+
"""
|
131 |
+
pass
|
132 |
+
|
133 |
+
def ui():
|
134 |
+
"""
|
135 |
+
Gets executed when the UI is drawn. Custom gradio elements and
|
136 |
+
their corresponding event handlers should be defined here.
|
137 |
+
|
138 |
+
To learn about gradio components, check out the docs:
|
139 |
+
https://gradio.app/docs/
|
140 |
+
"""
|
141 |
+
|
142 |
+
min_length = gr.Slider(0, 800, step=10, value=params['min_length'], label='Minimum reply length')
|
143 |
+
min_length.change(lambda x: params.update({'min_length': x}), min_length, None)
|
$extensions/multimodal/DOCS.md
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Technical description of multimodal extension
|
2 |
+
|
3 |
+
## Working principle
|
4 |
+
Multimodality extension does most of the stuff which is required for any image input:
|
5 |
+
|
6 |
+
- adds the UI
|
7 |
+
- saves the images as base64 JPEGs to history
|
8 |
+
- provides the hooks to the UI
|
9 |
+
- if there are images in the prompt, it:
|
10 |
+
- splits the prompt to text and image parts
|
11 |
+
- adds image start/end markers to text parts, then encodes and embeds the text parts
|
12 |
+
- calls the vision pipeline to embed the images
|
13 |
+
- stitches the embeddings together, and returns them to text generation
|
14 |
+
- loads the appropriate vision pipeline, selected either from model name, or by specifying --multimodal-pipeline parameter
|
15 |
+
|
16 |
+
Now, for the pipelines, they:
|
17 |
+
|
18 |
+
- load the required vision models
|
19 |
+
- return some consts, for example the number of tokens taken up by image
|
20 |
+
- and most importantly: return the embeddings for LLM, given a list of images
|
21 |
+
|
22 |
+
## Prompts/history
|
23 |
+
|
24 |
+
To save images in prompt/history, this extension is using a base64 JPEG, wrapped in a HTML tag, like so:
|
25 |
+
```
|
26 |
+
<img src="data:image/jpeg;base64,{img_str}">
|
27 |
+
```
|
28 |
+
where `{img_str}` is the actual image data. This format makes displaying them in the UI for free. Do note, that this format is required to be exactly the same, the regex used to find the images is: `<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">`.
|
29 |
+
|
30 |
+
## LLM input
|
31 |
+
To describe the input, let's see it on an example prompt:
|
32 |
+
```
|
33 |
+
text1<image1>text2<image2>text3
|
34 |
+
```
|
35 |
+
where `textN` is N-th text, `<imageN>` is N-th image, in HTML format specified above.
|
36 |
+
|
37 |
+
**The first step is to split the prompt into image/text parts**, so we get:
|
38 |
+
```
|
39 |
+
['text1', '<image1>', 'text2', '<image2>', 'text3']
|
40 |
+
```
|
41 |
+
this is done in `MultimodalEmbedder._split_prompt(...)` function, which returns a list of `PromptPart`s - dataclasses wrapping the separate parts.
|
42 |
+
|
43 |
+
This function also appends the image start/end markers to text, which are provided by `AbstractMultimodalPipeline.image_start()` / `AbstractMultimodalPipeline.image_end()` functions. If image start is `<Img>`, and end is `</Img>`, this function will return:
|
44 |
+
```
|
45 |
+
['text1<Img>', '<image1>', '</Img>text2<Img>', '<image2>', '</Img>text3']
|
46 |
+
```
|
47 |
+
|
48 |
+
**The returned prompt parts are then turned into token embeddings.**
|
49 |
+
|
50 |
+
First, they are modified to token IDs, for the text it is done using standard `modules.text_generation.encode()` function, and for the images the returned token IDs are changed to placeholders. The placeholder is a list of `N` times `placeholder token id`, where `N` is specified using `AbstractMultimodalPipeline.num_image_embeds()`, and placeholder token IDs using `AbstractMultimodalPipeline.placeholder_token_id()`.
|
51 |
+
|
52 |
+
Now, based on the token IDs, the prompt might get truncated, especially if `max_new_tokens` are unreasonably high. Unfortunately, it can't be done simply, just by trimming the prompt to be short enough. This way will lead to sometimes splitting the prompt in the middle of an image embedding, which usually breaks the generation. Therefore, in this case, the entire image needs to be removed from input. This is done inside `MultimodalEmbedder._encode_text(...)` function.
|
53 |
+
|
54 |
+
**After the tokenization, the tokens need to get embedded**, the text and images are once again treated separately.
|
55 |
+
|
56 |
+
The text parts are turned to embeddings, using `AbstractMultimodalPipeline.embed_tokens(...)` function. It uses standard embedding function from the model, but to support many LLMs, the actual function is returned by the pipeline (as it might be different for different LLMs), for LLaMA it is `shared.model.model.embed_tokens(...)`.
|
57 |
+
|
58 |
+
The image parts are turned to embeddings, using `AbstractMultimodalPipeline.embed_images(...)` function. This function is specific for a given pipeline, it takes the images as input, forwards them through vision model/projector, and returns the embeddings.
|
59 |
+
|
60 |
+
**Now, the returned embeddings are stitched together**, using `torch.cat()`, this is creating the final input to the LLM.
|
61 |
+
|
62 |
+
## Pipelines
|
63 |
+
|
64 |
+
All of the pipelines should subclass `AbstractMultimodalPipeline` class. The idea is to allow for new pipelines to be added in the same way as user extensions - git clone into `extensions/multimodal/pipelines`.
|
65 |
+
|
66 |
+
The pipelines are the description of the vision part, containing vision model/multimodal projector. All of the pipelines should have an unique `name()`, which is then selected by user, in `--multimodal-pipeline` CLI argument. For an example, see `pipelines/llava/llava.py`.
|
67 |
+
|
68 |
+
## Pipeline modules
|
69 |
+
|
70 |
+
Pipelines are organized into "pipeline modules" - subdirectories in `pipelines` directory. The pipeline modules should contain a file called `pipelines.py`, that should contain the following fields:
|
71 |
+
- `available_pipelines: List[str]` - list of pipelines provided by this module, shown as the list of available pipelines to the user
|
72 |
+
- `def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]`: - a function to get a concrete pipeline by `name`, if `name` doesn't match any, should return `None`. `params` is the user settings for multimodal extension
|
73 |
+
- `def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]`: - a function to get a pipeline from `model_name`, should be eager to return `None`, unless the determination can be done clearly (for example: minigpt-4 bases on vicuna - it should never return the pipeline, but llava can, as it has its own specific LLM finetune)
|
74 |
+
|
75 |
+
**NOTE**: A pipeline module should lazy-import the pipelines only when necessary, and it should keep its imports to minimum
|
76 |
+
|
77 |
+
## Pipeline params
|
78 |
+
|
79 |
+
The pipelines will get the extension `params` in the constructor. They should honor the following fields:
|
80 |
+
- `vision_device` - string, specifying `torch.device` to run the vision model (CLIP/ViT) on
|
81 |
+
- `vision_bits` - int, number of fp bits to load the vision model(s) in
|
82 |
+
- `projector_device` - string, specifying `torch.device` to run the projector models (Linear layers, QFormer, etc.) on
|
83 |
+
- `projector_bits` - int, number of fp bits to load the projector models in
|
84 |
+
|
85 |
+
As a helper, `AbstractMultimodalPipeline` has `_get_device(self, setting_name: str, params: dict)` and `_get_dtype(self, setting_name: str, params: dict)` helper functions, which parse string/int and return `torch.device` / `torch.dtype`.
|
$extensions/multimodal/README.md
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multimodal
|
2 |
+
|
3 |
+
## Description
|
4 |
+
|
5 |
+
Adds support for multimodality (text+images) to text-generation-webui.
|
6 |
+
|
7 |
+
Note: multimodal currently only works for transformers, AutoGPTQ, and GPTQ-for-LLaMa loaders. ExLlama (v1 and v2) and llama.cpp support are planned.
|
8 |
+
|
9 |
+
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
|
10 |
+
|
11 |
+
## Usage
|
12 |
+
|
13 |
+
To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples:
|
14 |
+
|
15 |
+
```
|
16 |
+
# LLaVA 1.5 13B has the best performance
|
17 |
+
python server.py --model liuhaotian_llava-v1.5-13b --multimodal-pipeline llava-v1.5-13b --load-in-4bit
|
18 |
+
# LLaVA 1.5 7B is relatively weaker, but requires less memory
|
19 |
+
python server.py --model liuhaotian_llava-v1.5-7b --multimodal-pipeline llava-v1.5-7b --load-in-4bit
|
20 |
+
python server.py --model TheBloke_llava-v1.5-13B-GPTQ_gptq-4bit-32g-actorder_True --multimodal-pipeline llava-v1.5-13b --disable_exllama --loader autogptq
|
21 |
+
python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b
|
22 |
+
python server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b
|
23 |
+
python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b
|
24 |
+
python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b
|
25 |
+
```
|
26 |
+
|
27 |
+
There is built-in support for LLaVA-v0-13B, LLaVA-v0-7b, and LLaVA-v1.5-13B. To install `minigpt4`:
|
28 |
+
|
29 |
+
- clone https://github.com/Wojtab/minigpt-4-pipeline into `extensions/multimodal/pipelines`
|
30 |
+
- install the requirements.txt
|
31 |
+
|
32 |
+
The same procedure should be used to install other pipelines, which can then be used with `--multimodal-pipeline [pipeline name]`. For additional multimodal pipelines refer to the compatibility section below.
|
33 |
+
|
34 |
+
Do note, that each image takes up a considerable amount of tokens, so adjust `max_new_tokens` to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated.
|
35 |
+
|
36 |
+
To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `<image>` in your prompt.
|
37 |
+
|
38 |
+
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. It seems as if some multimodal networks consider the features in all images at the same time as if they were a single image. Due to this behavior, by default, the extension skips previous images. However, it can lead to sub-par generation on other pipelines. If you want to include all images, just tick this checkbox.
|
39 |
+
|
40 |
+
## Compatibility
|
41 |
+
|
42 |
+
As of now, the following multimodal pipelines are supported:
|
43 |
+
|Pipeline|`--multimodal-pipeline`|Default LLM|LLM info(for the linked model)|Pipeline repository|
|
44 |
+
|-|-|-|-|-|
|
45 |
+
|[LLaVA 13B](https://github.com/haotian-liu/LLaVA)|`llava-13b`|[LLaVA 13B](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|
46 |
+
|[LLaVA 7B](https://github.com/haotian-liu/LLaVA)|`llava-7b`|[LLaVA 7B](https://huggingface.co/wojtab/llava-7b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|
47 |
+
|[MiniGPT-4 7B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-7b`|[Vicuna v0 7B](https://huggingface.co/TheBloke/vicuna-7B-GPTQ-4bit-128g)|GPTQ 4-bit quant, new format|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|
48 |
+
|[MiniGPT-4 13B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-13b`|[Vicuna v0 13B](https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g)|GPTQ 4-bit quant, old CUDA|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|
49 |
+
|[InstructBLIP 7B](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip)|`instructblip-7b`|[Vicuna v1.1 7B](https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g)|GPTQ 4-bit quant|[kjerk/instructblip-pipeline](https://github.com/kjerk/instructblip-pipeline)|
|
50 |
+
|[InstructBLIP 13B](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip)|`instructblip-13b`|[Vicuna v1.1 13B](https://huggingface.co/TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g)|GPTQ 4-bit quant|[kjerk/instructblip-pipeline](https://github.com/kjerk/instructblip-pipeline)|
|
51 |
+
|
52 |
+
Some pipelines could support different LLMs but do note that while it might work, it isn't a supported configuration.
|
53 |
+
|
54 |
+
DO NOT report bugs if you are using a different LLM.
|
55 |
+
|
56 |
+
DO NOT report bugs with pipelines in this repository (unless they are built-in)
|
57 |
+
|
58 |
+
## Extension config
|
59 |
+
This extension uses the following parameters (from `settings.json`):
|
60 |
+
|Parameter|Description|
|
61 |
+
|---------|-----------|
|
62 |
+
|`multimodal-vision_bits`|Number of bits to load vision models (CLIP/ViT) feature extractor in (most pipelines should support either 32 or 16, default=32)|
|
63 |
+
|`multimodal-vision_device`|Torch device to run the feature extractor on, for example, `cpu` or `cuda:0`, by default `cuda:0` if available|
|
64 |
+
|`multimodal-projector_bits`|Number of bits to load feature projector model(s) in (most pipelines should support either 32 or 16, default=32)|
|
65 |
+
|`multimodal-projector_device`|Torch device to run the feature projector model(s) on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|
66 |
+
|`multimodal-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox|
|
67 |
+
|
68 |
+
## Usage through API
|
69 |
+
|
70 |
+
You can run the multimodal inference through API, by inputting the images to prompt. Images are embedded like so: `f'<img src="data:image/jpeg;base64,{img_str}">'`, where `img_str` is base-64 jpeg data. Note that you will need to launch `server.py` with the arguments `--api --extensions multimodal`.
|
71 |
+
|
72 |
+
Python example:
|
73 |
+
|
74 |
+
```Python
|
75 |
+
import base64
|
76 |
+
import requests
|
77 |
+
|
78 |
+
CONTEXT = "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.### Human: Hi!### Assistant: Hi there! How can I help you today?\n"
|
79 |
+
|
80 |
+
with open('extreme_ironing.jpg', 'rb') as f:
|
81 |
+
img_str = base64.b64encode(f.read()).decode('utf-8')
|
82 |
+
prompt = CONTEXT + f'### Human: What is unusual about this image: \n<img src="data:image/jpeg;base64,{img_str}">### Assistant: '
|
83 |
+
print(requests.post('http://127.0.0.1:5000/v1/completions', json={'prompt': prompt, 'max_tokens': 200, 'stop': ['\n###']}).json())
|
84 |
+
```
|
85 |
+
script output:
|
86 |
+
```Python
|
87 |
+
{'results': [{'text': "The unusual aspect of this image is that a man is standing on top of a yellow minivan while doing his laundry. He has set up a makeshift clothes line using the car's rooftop as an outdoor drying area. This scene is uncommon because people typically do their laundry indoors, in a dedicated space like a laundromat or a room in their home, rather than on top of a moving vehicle. Additionally, hanging clothes on the car could be potentially hazardous or illegal in some jurisdictions due to the risk of damaging the vehicle or causing accidents on the road.\n##"}]}
|
88 |
+
```
|
89 |
+
|
90 |
+
## For pipeline developers/technical description
|
91 |
+
see [DOCS.md](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/multimodal/DOCS.md)
|
$extensions/multimodal/abstract_pipeline.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import is_torch_xpu_available
|
7 |
+
|
8 |
+
|
9 |
+
class AbstractMultimodalPipeline(ABC):
|
10 |
+
@staticmethod
|
11 |
+
@abstractmethod
|
12 |
+
def name() -> str:
|
13 |
+
'name of the pipeline, should be same as in --multimodal-pipeline'
|
14 |
+
pass
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
@abstractmethod
|
18 |
+
def image_start() -> Optional[str]:
|
19 |
+
'return image start string, string representation of image start token, or None if not applicable'
|
20 |
+
pass
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
@abstractmethod
|
24 |
+
def image_end() -> Optional[str]:
|
25 |
+
'return image end string, string representation of image end token, or None if not applicable'
|
26 |
+
pass
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
@abstractmethod
|
30 |
+
def placeholder_token_id() -> int:
|
31 |
+
'return placeholder token id'
|
32 |
+
pass
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
@abstractmethod
|
36 |
+
def num_image_embeds() -> int:
|
37 |
+
'return the number of embeds used by a single image (for example: 256 for LLaVA)'
|
38 |
+
pass
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
|
42 |
+
'forward the images through vision pipeline, and return their embeddings'
|
43 |
+
pass
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
@abstractmethod
|
47 |
+
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
48 |
+
'embed tokens, the exact function varies by LLM, for LLaMA it is `shared.model.model.embed_tokens`'
|
49 |
+
pass
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
@abstractmethod
|
53 |
+
def placeholder_embeddings() -> torch.Tensor:
|
54 |
+
'get placeholder embeddings if there are multiple images, and `add_all_images_to_prompt` is False'
|
55 |
+
pass
|
56 |
+
|
57 |
+
def _get_device(self, setting_name: str, params: dict):
|
58 |
+
if params[setting_name] is None:
|
59 |
+
return torch.device("cuda:0" if torch.cuda.is_available() else "xpu:0" if is_torch_xpu_available() else "cpu")
|
60 |
+
return torch.device(params[setting_name])
|
61 |
+
|
62 |
+
def _get_dtype(self, setting_name: str, params: dict):
|
63 |
+
return torch.float32 if int(params[setting_name]) == 32 else torch.float16
|
$extensions/multimodal/multimodal_embedder.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import re
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from io import BytesIO
|
5 |
+
from typing import Any, List, Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from extensions.multimodal.pipeline_loader import load_pipeline
|
11 |
+
from modules import shared
|
12 |
+
from modules.logging_colors import logger
|
13 |
+
from modules.text_generation import encode, get_max_prompt_length
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class PromptPart:
|
18 |
+
text: str
|
19 |
+
image: Optional[Image.Image] = None
|
20 |
+
is_image: bool = False
|
21 |
+
input_ids: Optional[torch.Tensor] = None
|
22 |
+
embedding: Optional[torch.Tensor] = None
|
23 |
+
|
24 |
+
|
25 |
+
class MultimodalEmbedder:
|
26 |
+
def __init__(self, params: dict):
|
27 |
+
pipeline, source = load_pipeline(params)
|
28 |
+
self.pipeline = pipeline
|
29 |
+
logger.info(f'Multimodal: loaded pipeline {self.pipeline.name()} from pipelines/{source} ({self.pipeline.__class__.__name__})')
|
30 |
+
|
31 |
+
def _split_prompt(self, prompt: str, load_images: bool = False) -> List[PromptPart]:
|
32 |
+
"""Splits a prompt into a list of `PromptParts` to separate image data from text.
|
33 |
+
It will also append `image_start` and `image_end` before and after the image, and optionally parse and load the images,
|
34 |
+
if `load_images` is `True`.
|
35 |
+
"""
|
36 |
+
parts: List[PromptPart] = []
|
37 |
+
curr = 0
|
38 |
+
while True:
|
39 |
+
match = re.search(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt[curr:])
|
40 |
+
if match is None:
|
41 |
+
# no more image tokens, append the rest of the prompt
|
42 |
+
if curr > 0:
|
43 |
+
# add image end token after last image
|
44 |
+
parts.append(PromptPart(text=self.pipeline.image_end() + prompt[curr:]))
|
45 |
+
else:
|
46 |
+
parts.append(PromptPart(text=prompt))
|
47 |
+
break
|
48 |
+
# found an image, append image start token to the text
|
49 |
+
if match.start() > 0:
|
50 |
+
parts.append(PromptPart(text=prompt[curr:curr + match.start()] + self.pipeline.image_start()))
|
51 |
+
else:
|
52 |
+
parts.append(PromptPart(text=self.pipeline.image_start()))
|
53 |
+
# append the image
|
54 |
+
parts.append(PromptPart(
|
55 |
+
text=match.group(0),
|
56 |
+
image=Image.open(BytesIO(base64.b64decode(match.group(1)))) if load_images else None,
|
57 |
+
is_image=True
|
58 |
+
))
|
59 |
+
curr += match.end()
|
60 |
+
return parts
|
61 |
+
|
62 |
+
def _len_in_tokens_prompt_parts(self, parts: List[PromptPart]) -> int:
|
63 |
+
"""Total length in tokens of all `parts`"""
|
64 |
+
tokens = 0
|
65 |
+
for part in parts:
|
66 |
+
if part.is_image:
|
67 |
+
tokens += self.pipeline.num_image_embeds()
|
68 |
+
elif part.input_ids is not None:
|
69 |
+
tokens += len(part.input_ids)
|
70 |
+
else:
|
71 |
+
tokens += len(encode(part.text)[0])
|
72 |
+
return tokens
|
73 |
+
|
74 |
+
def len_in_tokens(self, prompt: str) -> int:
|
75 |
+
"""Total length in tokens for a given text `prompt`"""
|
76 |
+
parts = self._split_prompt(prompt, False)
|
77 |
+
return self._len_in_tokens_prompt_parts(parts)
|
78 |
+
|
79 |
+
def _encode_single_text(self, part: PromptPart, add_bos_token: bool) -> PromptPart:
|
80 |
+
"""Encode a single prompt `part` to `input_ids`. Returns a `PromptPart`"""
|
81 |
+
if part.is_image:
|
82 |
+
placeholders = torch.ones((self.pipeline.num_image_embeds())) * self.pipeline.placeholder_token_id()
|
83 |
+
part.input_ids = placeholders.to(shared.model.device, dtype=torch.int64)
|
84 |
+
else:
|
85 |
+
part.input_ids = encode(part.text, add_bos_token=add_bos_token)[0].to(shared.model.device, dtype=torch.int64)
|
86 |
+
return part
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def _num_images(parts: List[PromptPart]) -> int:
|
90 |
+
count = 0
|
91 |
+
for part in parts:
|
92 |
+
if part.is_image:
|
93 |
+
count += 1
|
94 |
+
return count
|
95 |
+
|
96 |
+
def _encode_text(self, state, parts: List[PromptPart]) -> List[PromptPart]:
|
97 |
+
"""Encode text to token_ids, also truncate the prompt, if necessary.
|
98 |
+
|
99 |
+
The chat/instruct mode should make prompts that fit in get_max_prompt_length, but if max_new_tokens are set
|
100 |
+
such that the context + min_rows don't fit, we can get a prompt which is too long.
|
101 |
+
We can't truncate image embeddings, as it leads to broken generation, so remove the images instead and warn the user
|
102 |
+
"""
|
103 |
+
encoded: List[PromptPart] = []
|
104 |
+
for i, part in enumerate(parts):
|
105 |
+
encoded.append(self._encode_single_text(part, i == 0 and state['add_bos_token']))
|
106 |
+
|
107 |
+
# truncation:
|
108 |
+
max_len = get_max_prompt_length(state)
|
109 |
+
removed_images = 0
|
110 |
+
|
111 |
+
# 1. remove entire text/image blocks
|
112 |
+
while self._len_in_tokens_prompt_parts(encoded[1:]) > max_len:
|
113 |
+
if encoded[0].is_image:
|
114 |
+
removed_images += 1
|
115 |
+
encoded = encoded[1:]
|
116 |
+
|
117 |
+
# 2. check if the last prompt part doesn't need to get truncated
|
118 |
+
if self._len_in_tokens_prompt_parts(encoded) > max_len:
|
119 |
+
if encoded[0].is_image:
|
120 |
+
# don't truncate image embeddings, just remove the image, otherwise generation will be broken
|
121 |
+
removed_images += 1
|
122 |
+
encoded = encoded[1:]
|
123 |
+
elif len(encoded) > 1 and encoded[0].text.endswith(self.pipeline.image_start()):
|
124 |
+
# see if we can keep image_start token
|
125 |
+
len_image_start = len(encode(self.pipeline.image_start(), add_bos_token=state['add_bos_token'])[0])
|
126 |
+
if self._len_in_tokens_prompt_parts(encoded[1:]) + len_image_start > max_len:
|
127 |
+
# we can't -> remove this text, and the image
|
128 |
+
encoded = encoded[2:]
|
129 |
+
removed_images += 1
|
130 |
+
else:
|
131 |
+
# we can -> just truncate the text
|
132 |
+
trunc_len = self._len_in_tokens_prompt_parts(encoded) - max_len
|
133 |
+
encoded[0].input_ids = encoded[0].input_ids[trunc_len:]
|
134 |
+
elif len(encoded) > 0:
|
135 |
+
# only one text left, truncate it normally
|
136 |
+
trunc_len = self._len_in_tokens_prompt_parts(encoded) - max_len
|
137 |
+
encoded[0].input_ids = encoded[0].input_ids[trunc_len:]
|
138 |
+
|
139 |
+
# notify user if we truncated an image
|
140 |
+
if removed_images > 0:
|
141 |
+
logger.warning(f"Multimodal: removed {removed_images} image(s) from prompt. Try decreasing max_new_tokens if generation is broken")
|
142 |
+
|
143 |
+
return encoded
|
144 |
+
|
145 |
+
def _embed(self, parts: List[PromptPart]) -> List[PromptPart]:
|
146 |
+
# batch images
|
147 |
+
image_indicies = [i for i, part in enumerate(parts) if part.is_image]
|
148 |
+
embedded = self.pipeline.embed_images([parts[i].image for i in image_indicies])
|
149 |
+
for i, embeds in zip(image_indicies, embedded):
|
150 |
+
parts[i].embedding = embeds
|
151 |
+
# embed text
|
152 |
+
for (i, part) in enumerate(parts):
|
153 |
+
if not part.is_image:
|
154 |
+
parts[i].embedding = self.pipeline.embed_tokens(part.input_ids)
|
155 |
+
return parts
|
156 |
+
|
157 |
+
def _remove_old_images(self, parts: List[PromptPart], params: dict) -> List[PromptPart]:
|
158 |
+
if params['add_all_images_to_prompt']:
|
159 |
+
return parts
|
160 |
+
already_added = False
|
161 |
+
for i, part in reversed(list(enumerate(parts))):
|
162 |
+
if part.is_image:
|
163 |
+
if already_added:
|
164 |
+
parts[i].embedding = self.pipeline.placeholder_embeddings()
|
165 |
+
else:
|
166 |
+
already_added = True
|
167 |
+
return parts
|
168 |
+
|
169 |
+
def forward(self, prompt: str, state: Any, params: dict):
|
170 |
+
prompt_parts = self._split_prompt(prompt, True)
|
171 |
+
prompt_parts = self._encode_text(state, prompt_parts)
|
172 |
+
prompt_parts = self._embed(prompt_parts)
|
173 |
+
prompt_parts = self._remove_old_images(prompt_parts, params)
|
174 |
+
embeds = tuple(part.embedding for part in prompt_parts)
|
175 |
+
ids = tuple(part.input_ids for part in prompt_parts)
|
176 |
+
input_embeds = torch.cat(embeds, dim=0)
|
177 |
+
input_ids = torch.cat(ids, dim=0)
|
178 |
+
return prompt, input_ids, input_embeds, self._num_images(prompt_parts)
|
$extensions/multimodal/pipeline_loader.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
from importlib import import_module
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
7 |
+
from modules import shared
|
8 |
+
from modules.logging_colors import logger
|
9 |
+
|
10 |
+
|
11 |
+
def _get_available_pipeline_modules():
|
12 |
+
pipeline_path = Path(__file__).parent / 'pipelines'
|
13 |
+
modules = [p for p in pipeline_path.iterdir() if p.is_dir()]
|
14 |
+
return [m.name for m in modules if (m / 'pipelines.py').exists()]
|
15 |
+
|
16 |
+
|
17 |
+
def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
|
18 |
+
pipeline_modules = {}
|
19 |
+
available_pipeline_modules = _get_available_pipeline_modules()
|
20 |
+
for name in available_pipeline_modules:
|
21 |
+
try:
|
22 |
+
pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines')
|
23 |
+
except:
|
24 |
+
logger.warning(f'Failed to get multimodal pipelines from {name}')
|
25 |
+
logger.warning(traceback.format_exc())
|
26 |
+
|
27 |
+
if shared.args.multimodal_pipeline is not None:
|
28 |
+
for k in pipeline_modules:
|
29 |
+
if hasattr(pipeline_modules[k], 'get_pipeline'):
|
30 |
+
pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params)
|
31 |
+
if pipeline is not None:
|
32 |
+
return (pipeline, k)
|
33 |
+
else:
|
34 |
+
model_name = shared.args.model.lower()
|
35 |
+
for k in pipeline_modules:
|
36 |
+
if hasattr(pipeline_modules[k], 'get_pipeline_from_model_name'):
|
37 |
+
pipeline = getattr(pipeline_modules[k], 'get_pipeline_from_model_name')(model_name, params)
|
38 |
+
if pipeline is not None:
|
39 |
+
return (pipeline, k)
|
40 |
+
|
41 |
+
available = []
|
42 |
+
for k in pipeline_modules:
|
43 |
+
if hasattr(pipeline_modules[k], 'available_pipelines'):
|
44 |
+
pipelines = getattr(pipeline_modules[k], 'available_pipelines')
|
45 |
+
available += pipelines
|
46 |
+
|
47 |
+
if shared.args.multimodal_pipeline is not None:
|
48 |
+
log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.'
|
49 |
+
else:
|
50 |
+
log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.'
|
51 |
+
logger.critical(f'{log} Please specify a correct pipeline, or disable the extension')
|
52 |
+
raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension')
|
$extensions/multimodal/pipelines/llava/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## LLaVA pipeline
|
2 |
+
|
3 |
+
This module provides 2 pipelines:
|
4 |
+
- `llava-7b` - for use with LLaVA v0 7B model (finetuned LLaMa 7B)
|
5 |
+
- `llava-13b` - for use with LLaVA v0 13B model (finetuned LLaMa 13B)
|
6 |
+
|
7 |
+
[LLaVA](https://github.com/haotian-liu/LLaVA) uses CLIP `openai/clip-vit-large-patch14` as the vision model, and then a single linear layer. For 13B the projector weights are in `liuhaotian/LLaVA-13b-delta-v0`, and for 7B they are in `liuhaotian/LLaVA-7b-delta-v0`.
|
8 |
+
|
9 |
+
The supported parameter combinations for both the vision model, and the projector are: CUDA/32bit, CUDA/16bit, CPU/32bit
|
$extensions/multimodal/pipelines/llava/llava.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from abc import abstractmethod
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from PIL import Image
|
8 |
+
from transformers import CLIPImageProcessor, CLIPVisionModel
|
9 |
+
|
10 |
+
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
11 |
+
from modules import shared
|
12 |
+
from modules.logging_colors import logger
|
13 |
+
from modules.text_generation import encode
|
14 |
+
|
15 |
+
|
16 |
+
def expand2square(pil_img: Image.Image, background_color: Tuple[int]) -> Image.Image:
|
17 |
+
width, height = pil_img.size
|
18 |
+
if width == height:
|
19 |
+
return pil_img
|
20 |
+
elif width > height:
|
21 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
22 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
23 |
+
return result
|
24 |
+
else:
|
25 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
26 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
27 |
+
return result
|
28 |
+
|
29 |
+
|
30 |
+
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
31 |
+
CLIP_REPO = "openai/clip-vit-large-patch14"
|
32 |
+
|
33 |
+
def __init__(self, params: dict) -> None:
|
34 |
+
super().__init__()
|
35 |
+
self.clip_device = self._get_device("vision_device", params)
|
36 |
+
self.clip_dtype = self._get_dtype("vision_bits", params)
|
37 |
+
self.projector_device = self._get_device("projector_device", params)
|
38 |
+
self.projector_dtype = self._get_dtype("projector_bits", params)
|
39 |
+
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
|
40 |
+
|
41 |
+
def _load_models(self):
|
42 |
+
start_ts = time.time()
|
43 |
+
|
44 |
+
logger.info(f"LLaVA - Loading CLIP from {self.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
|
45 |
+
image_processor = CLIPImageProcessor.from_pretrained(self.CLIP_REPO, torch_dtype=self.clip_dtype)
|
46 |
+
vision_tower = CLIPVisionModel.from_pretrained(self.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
|
47 |
+
|
48 |
+
logger.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
|
49 |
+
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
|
50 |
+
mm_projector = self.build_mm_projector()
|
51 |
+
projector_data = torch.load(projector_path)
|
52 |
+
projector_data = {k[19:]: v for k, v in projector_data.items() if k.startswith('model.mm_projector.')}
|
53 |
+
mm_projector.load_state_dict(projector_data)
|
54 |
+
mm_projector = mm_projector.to(self.projector_device)
|
55 |
+
|
56 |
+
logger.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
57 |
+
return image_processor, vision_tower, mm_projector
|
58 |
+
|
59 |
+
def build_mm_projector(self) -> torch.nn.Module:
|
60 |
+
projector_shape = self.llava_projector_shape()
|
61 |
+
if len(projector_shape) == 2:
|
62 |
+
return torch.nn.Linear(*projector_shape)
|
63 |
+
else:
|
64 |
+
modules = []
|
65 |
+
modules.append(torch.nn.Linear(projector_shape[0], projector_shape[1]))
|
66 |
+
for i in range(2, len(projector_shape)):
|
67 |
+
modules.append(torch.nn.GELU())
|
68 |
+
modules.append(torch.nn.Linear(projector_shape[i-1], projector_shape[i]))
|
69 |
+
return torch.nn.Sequential(*modules)
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def image_start() -> str:
|
73 |
+
return "<im_start>"
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def image_end() -> str:
|
77 |
+
return "<im_end>"
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def num_image_embeds() -> int:
|
81 |
+
return 256
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
85 |
+
for attr in ['', 'model', 'model.model', 'model.model.model']:
|
86 |
+
tmp = getattr(shared.model, attr, None) if attr != '' else shared.model
|
87 |
+
if tmp is not None and hasattr(tmp, 'embed_tokens'):
|
88 |
+
func = tmp.embed_tokens
|
89 |
+
break
|
90 |
+
else:
|
91 |
+
raise ValueError('The embed_tokens method has not been found for this loader.')
|
92 |
+
|
93 |
+
return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def placeholder_embeddings() -> torch.Tensor:
|
97 |
+
return LLaVA_v0_Pipeline.embed_tokens(encode("<im_patch>"*256, add_bos_token=False)[0])
|
98 |
+
|
99 |
+
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
|
100 |
+
images = self.image_processor(images, return_tensors='pt')['pixel_values']
|
101 |
+
images = images.to(self.clip_device, dtype=self.clip_dtype)
|
102 |
+
|
103 |
+
with torch.no_grad():
|
104 |
+
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
|
105 |
+
select_hidden_state_layer = -2
|
106 |
+
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
107 |
+
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
|
108 |
+
image_features = self.mm_projector(image_features)
|
109 |
+
return image_features.to(shared.model.device, dtype=shared.model.dtype)
|
110 |
+
|
111 |
+
@staticmethod
|
112 |
+
@abstractmethod
|
113 |
+
def llava_projector_repo() -> str:
|
114 |
+
pass
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
@abstractmethod
|
118 |
+
def llava_projector_filename() -> str:
|
119 |
+
pass
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
@abstractmethod
|
123 |
+
def llava_projector_shape() -> Tuple[int, int]:
|
124 |
+
pass
|
125 |
+
|
126 |
+
|
127 |
+
class LLaVA_v0_13B_Pipeline(LLaVA_v0_Pipeline):
|
128 |
+
def __init__(self, params: dict) -> None:
|
129 |
+
super().__init__(params)
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def name() -> str:
|
133 |
+
return "llava-13b"
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def placeholder_token_id() -> int:
|
137 |
+
return 32000
|
138 |
+
|
139 |
+
@staticmethod
|
140 |
+
def llava_projector_shape() -> Tuple[int, int]:
|
141 |
+
return (1024, 5120)
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def llava_projector_filename() -> str:
|
145 |
+
return "mm_projector.bin"
|
146 |
+
|
147 |
+
@staticmethod
|
148 |
+
def llava_projector_repo() -> str:
|
149 |
+
return "liuhaotian/LLaVA-13b-delta-v0"
|
150 |
+
|
151 |
+
|
152 |
+
class LLaVA_v0_7B_Pipeline(LLaVA_v0_Pipeline):
|
153 |
+
def __init__(self, params: dict) -> None:
|
154 |
+
super().__init__(params)
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def name() -> str:
|
158 |
+
return "llava-7b"
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def placeholder_token_id() -> int:
|
162 |
+
return 32001
|
163 |
+
|
164 |
+
@staticmethod
|
165 |
+
def llava_projector_shape() -> Tuple[int, int]:
|
166 |
+
return (1024, 4096)
|
167 |
+
|
168 |
+
@staticmethod
|
169 |
+
def llava_projector_filename() -> str:
|
170 |
+
return "mm_projector.bin"
|
171 |
+
|
172 |
+
@staticmethod
|
173 |
+
def llava_projector_repo() -> str:
|
174 |
+
return "liuhaotian/LLaVA-7b-delta-v0"
|
175 |
+
|
176 |
+
|
177 |
+
class LLaVA_LLaMA_2_13B_Pipeline(LLaVA_v0_13B_Pipeline):
|
178 |
+
def __init__(self, params: dict) -> None:
|
179 |
+
super().__init__(params)
|
180 |
+
|
181 |
+
@staticmethod
|
182 |
+
def name() -> str:
|
183 |
+
return "llava-llama-2-13b"
|
184 |
+
|
185 |
+
@staticmethod
|
186 |
+
def placeholder_token_id() -> int:
|
187 |
+
return 0
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def llava_projector_repo() -> str:
|
191 |
+
return "liuhaotian/llava-llama-2-13b-chat-lightning-preview"
|
192 |
+
|
193 |
+
@staticmethod
|
194 |
+
def image_start() -> str:
|
195 |
+
return ""
|
196 |
+
|
197 |
+
@staticmethod
|
198 |
+
def image_end() -> str:
|
199 |
+
return ""
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
def placeholder_embeddings() -> torch.Tensor:
|
203 |
+
return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*256, add_bos_token=False)[0])
|
204 |
+
|
205 |
+
|
206 |
+
class LLaVA_v1_5_13B_Pipeline(LLaVA_v0_13B_Pipeline):
|
207 |
+
CLIP_REPO = "openai/clip-vit-large-patch14-336"
|
208 |
+
|
209 |
+
def __init__(self, params: dict) -> None:
|
210 |
+
super().__init__(params)
|
211 |
+
|
212 |
+
@staticmethod
|
213 |
+
def name() -> str:
|
214 |
+
return "llava-v1.5-13b"
|
215 |
+
|
216 |
+
@staticmethod
|
217 |
+
def llava_projector_shape() -> Tuple[int, int]:
|
218 |
+
return (1024, 5120, 5120)
|
219 |
+
|
220 |
+
@staticmethod
|
221 |
+
def placeholder_token_id() -> int:
|
222 |
+
return 0
|
223 |
+
|
224 |
+
@staticmethod
|
225 |
+
def llava_projector_repo() -> str:
|
226 |
+
return "liuhaotian/llava-v1.5-13b"
|
227 |
+
|
228 |
+
@staticmethod
|
229 |
+
def image_start() -> str:
|
230 |
+
return ""
|
231 |
+
|
232 |
+
@staticmethod
|
233 |
+
def image_end() -> str:
|
234 |
+
return ""
|
235 |
+
|
236 |
+
@staticmethod
|
237 |
+
def num_image_embeds() -> int:
|
238 |
+
return 576
|
239 |
+
|
240 |
+
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
|
241 |
+
# pad it to square first
|
242 |
+
images = [
|
243 |
+
expand2square(image, tuple(int(x*255) for x in self.image_processor.image_mean))
|
244 |
+
for image in images
|
245 |
+
]
|
246 |
+
return super().embed_images(images)
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
def placeholder_embeddings() -> torch.Tensor:
|
250 |
+
return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*576, add_bos_token=False)[0])
|
251 |
+
|
252 |
+
class LLaVA_v1_5_7B_Pipeline(LLaVA_v1_5_13B_Pipeline):
|
253 |
+
@staticmethod
|
254 |
+
def name() -> str:
|
255 |
+
return "llava-v1.5-7b"
|
256 |
+
|
257 |
+
@staticmethod
|
258 |
+
def llava_projector_shape() -> Tuple[int, int]:
|
259 |
+
return (1024, 4096, 4096)
|
260 |
+
@staticmethod
|
261 |
+
def llava_projector_repo() -> str:
|
262 |
+
return "liuhaotian/llava-v1.5-7b"
|
$extensions/multimodal/pipelines/llava/pipelines.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
4 |
+
|
5 |
+
available_pipelines = ['llava-7b', 'llava-13b', 'llava-llama-2-13b', 'llava-v1.5-13b', 'llava-v1.5-7b']
|
6 |
+
|
7 |
+
|
8 |
+
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
9 |
+
if name == 'llava-7b':
|
10 |
+
from .llava import LLaVA_v0_7B_Pipeline
|
11 |
+
return LLaVA_v0_7B_Pipeline(params)
|
12 |
+
if name == 'llava-13b':
|
13 |
+
from .llava import LLaVA_v0_13B_Pipeline
|
14 |
+
return LLaVA_v0_13B_Pipeline(params)
|
15 |
+
if name == 'llava-llama-2-13b':
|
16 |
+
from .llava import LLaVA_LLaMA_2_13B_Pipeline
|
17 |
+
return LLaVA_LLaMA_2_13B_Pipeline(params)
|
18 |
+
if name == 'llava-v1.5-7b':
|
19 |
+
from .llava import LLaVA_v1_5_7B_Pipeline
|
20 |
+
return LLaVA_v1_5_7B_Pipeline(params)
|
21 |
+
if name == 'llava-v1.5-13b':
|
22 |
+
from .llava import LLaVA_v1_5_13B_Pipeline
|
23 |
+
return LLaVA_v1_5_13B_Pipeline(params)
|
24 |
+
return None
|
25 |
+
|
26 |
+
|
27 |
+
def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
28 |
+
if 'llava' not in model_name.lower():
|
29 |
+
return None
|
30 |
+
if 'llama-2' in model_name.lower():
|
31 |
+
if '13b' in model_name.lower():
|
32 |
+
from .llava import LLaVA_LLaMA_2_13B_Pipeline
|
33 |
+
return LLaVA_LLaMA_2_13B_Pipeline(params)
|
34 |
+
elif 'llava-v1.5' in model_name.lower():
|
35 |
+
if '13b' in model_name.lower():
|
36 |
+
from .llava import LLaVA_v1_5_13B_Pipeline
|
37 |
+
return LLaVA_v1_5_13B_Pipeline(params)
|
38 |
+
if '7b' in model_name.lower():
|
39 |
+
from .llava import LLaVA_v1_5_7B_Pipeline
|
40 |
+
return LLaVA_v1_5_7B_Pipeline(params)
|
41 |
+
else:
|
42 |
+
if '7b' in model_name.lower():
|
43 |
+
from .llava import LLaVA_v0_7B_Pipeline
|
44 |
+
return LLaVA_v0_7B_Pipeline(params)
|
45 |
+
if '13b' in model_name.lower():
|
46 |
+
from .llava import LLaVA_v0_13B_Pipeline
|
47 |
+
return LLaVA_v0_13B_Pipeline(params)
|
48 |
+
return None
|
$extensions/multimodal/pipelines/place-additional-pipelines-here.txt
ADDED
File without changes
|
$extensions/multimodal/script.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
from functools import partial
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
11 |
+
from modules import shared
|
12 |
+
from modules.logging_colors import logger
|
13 |
+
|
14 |
+
params = {
|
15 |
+
"add_all_images_to_prompt": False,
|
16 |
+
# device to run vision encoder on
|
17 |
+
"vision_device": None,
|
18 |
+
# bits to load vision encoder in, either 16 or 32
|
19 |
+
"vision_bits": 32,
|
20 |
+
# device to run multimodal projector on
|
21 |
+
"projector_device": None,
|
22 |
+
# multimodal projector bits, either 32 or 16
|
23 |
+
"projector_bits": 32
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
# If 'state' is True, will hijack the next chat generation
|
28 |
+
input_hijack = {
|
29 |
+
'state': False,
|
30 |
+
'value': ["", ""]
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
# initialized in ui, so that params are loaded from settings
|
35 |
+
multimodal_embedder: MultimodalEmbedder = None
|
36 |
+
|
37 |
+
|
38 |
+
def chat_input_modifier(text, visible_text, state):
|
39 |
+
global input_hijack
|
40 |
+
if input_hijack['state']:
|
41 |
+
input_hijack['state'] = False
|
42 |
+
return input_hijack['value'](text, visible_text)
|
43 |
+
else:
|
44 |
+
return text, visible_text
|
45 |
+
|
46 |
+
|
47 |
+
def add_chat_picture(picture, text, visible_text):
|
48 |
+
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
|
49 |
+
# Adjusted to 336 for the values here, due to the increased resolution in llava-v1.5
|
50 |
+
max_hw, min_hw = max(picture.size), min(picture.size)
|
51 |
+
aspect_ratio = max_hw / min_hw
|
52 |
+
shortest_edge = int(max(336 / aspect_ratio, 336))
|
53 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
54 |
+
w = shortest_edge if picture.width < picture.height else longest_edge
|
55 |
+
h = shortest_edge if picture.width >= picture.height else longest_edge
|
56 |
+
picture = picture.resize((w, h))
|
57 |
+
|
58 |
+
buffer = BytesIO()
|
59 |
+
picture.save(buffer, format="PNG")
|
60 |
+
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
61 |
+
image = f'<img src="data:image/jpeg;base64,{img_str}">'
|
62 |
+
|
63 |
+
if '<image>' in text:
|
64 |
+
text = text.replace('<image>', image)
|
65 |
+
else:
|
66 |
+
text = image + '\n' + text
|
67 |
+
|
68 |
+
if visible_text == '' or visible_text is None:
|
69 |
+
visible_text = text
|
70 |
+
elif '<image>' in visible_text:
|
71 |
+
visible_text = visible_text.replace('<image>', image)
|
72 |
+
else:
|
73 |
+
visible_text = visible_text + '\n' + image
|
74 |
+
|
75 |
+
return text, visible_text
|
76 |
+
|
77 |
+
|
78 |
+
def custom_tokenized_length(prompt):
|
79 |
+
return multimodal_embedder.len_in_tokens(prompt)
|
80 |
+
|
81 |
+
|
82 |
+
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
83 |
+
global params
|
84 |
+
start_ts = time.time()
|
85 |
+
image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt)
|
86 |
+
|
87 |
+
if image_match is None:
|
88 |
+
return prompt, input_ids, input_embeds
|
89 |
+
|
90 |
+
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
|
91 |
+
logger.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
92 |
+
return (prompt,
|
93 |
+
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
|
94 |
+
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
|
95 |
+
|
96 |
+
|
97 |
+
def ui():
|
98 |
+
global multimodal_embedder
|
99 |
+
multimodal_embedder = MultimodalEmbedder(params)
|
100 |
+
with gr.Column():
|
101 |
+
picture_select = gr.Image(label='Send a picture', type='pil')
|
102 |
+
# The models don't seem to deal well with multiple images
|
103 |
+
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
|
104 |
+
# Prepare the input hijack
|
105 |
+
picture_select.upload(
|
106 |
+
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
|
107 |
+
[picture_select],
|
108 |
+
None
|
109 |
+
)
|
110 |
+
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None)
|
111 |
+
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
112 |
+
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
113 |
+
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
$extensions/ngrok/README.md
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adding an ingress URL through the ngrok Agent SDK for Python
|
2 |
+
|
3 |
+
[ngrok](https://ngrok.com) is a globally distributed reverse proxy commonly used for quickly getting a public URL to a
|
4 |
+
service running inside a private network, such as on your local laptop. The ngrok agent is usually
|
5 |
+
deployed inside a private network and is used to communicate with the ngrok cloud service.
|
6 |
+
|
7 |
+
By default the authtoken in the NGROK_AUTHTOKEN environment variable will be used. Alternatively one may be specified in
|
8 |
+
the `settings.json` file, see the Examples below. Retrieve your authtoken on the [Auth Token page of your ngrok dashboard](https://dashboard.ngrok.com/get-started/your-authtoken), signing up is free.
|
9 |
+
|
10 |
+
# Documentation
|
11 |
+
|
12 |
+
For a list of all available options, see [the configuration documentation](https://ngrok.com/docs/ngrok-agent/config/) or [the connect example](https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py).
|
13 |
+
|
14 |
+
The ngrok Python SDK is [on github here](https://github.com/ngrok/ngrok-py). A quickstart guide and a full API reference are included in the [ngrok-py Python API documentation](https://ngrok.github.io/ngrok-py/).
|
15 |
+
|
16 |
+
# Running
|
17 |
+
|
18 |
+
To enable ngrok install the requirements and then add `--extension ngrok` to the command line options, for instance:
|
19 |
+
|
20 |
+
```bash
|
21 |
+
pip install -r extensions/ngrok/requirements.txt
|
22 |
+
python server.py --extension ngrok
|
23 |
+
```
|
24 |
+
|
25 |
+
In the output you should then see something like this:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
INFO:Loading the extension "ngrok"...
|
29 |
+
INFO:Session created
|
30 |
+
INFO:Created tunnel "9d9d0944dc75ff9d3aae653e5eb29fe9" with url "https://d83706cf7be7.ngrok.app"
|
31 |
+
INFO:Tunnel "9d9d0944dc75ff9d3aae653e5eb29fe9" TCP forwarding to "localhost:7860"
|
32 |
+
INFO:Ingress established at https://d83706cf7be7.ngrok.app
|
33 |
+
```
|
34 |
+
|
35 |
+
You can now access the webui via the url shown, in this case `https://d83706cf7be7.ngrok.app`. It is recommended to add some authentication to the ingress, see below.
|
36 |
+
|
37 |
+
# Example Settings
|
38 |
+
|
39 |
+
In `settings.json` add a `ngrok` key with a dictionary of options, for instance:
|
40 |
+
|
41 |
+
To enable basic authentication:
|
42 |
+
```json
|
43 |
+
{
|
44 |
+
"ngrok": {
|
45 |
+
"basic_auth": "user:password"
|
46 |
+
}
|
47 |
+
}
|
48 |
+
```
|
49 |
+
|
50 |
+
To enable OAUTH authentication:
|
51 |
+
```json
|
52 |
+
{
|
53 |
+
"ngrok": {
|
54 |
+
"oauth_provider": "google",
|
55 |
+
"oauth_allow_domains": "asdf.com",
|
56 |
+
"oauth_allow_emails": "asdf@asdf.com"
|
57 |
+
}
|
58 |
+
}
|
59 |
+
```
|
60 |
+
|
61 |
+
To add an authtoken instead of using the NGROK_AUTHTOKEN environment variable:
|
62 |
+
```json
|
63 |
+
{
|
64 |
+
"ngrok": {
|
65 |
+
"authtoken": "<token>",
|
66 |
+
"authtoken_from_env":false
|
67 |
+
}
|
68 |
+
}
|
69 |
+
```
|
$extensions/ngrok/requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ngrok==0.*
|
$extensions/ngrok/script.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adds ngrok ingress, to use add `--extension ngrok` to the command line options
|
2 |
+
#
|
3 |
+
# Parameters can be customized in settings.json of webui, e.g.:
|
4 |
+
# {"ngrok": {"basic_auth":"user:password"} }
|
5 |
+
# or
|
6 |
+
# {"ngrok": {"oauth_provider":"google", "oauth_allow_emails":["asdf@asdf.com"]} }
|
7 |
+
#
|
8 |
+
# See this example for full list of options: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
|
9 |
+
# or the README.md in this directory.
|
10 |
+
|
11 |
+
import logging
|
12 |
+
from modules import shared
|
13 |
+
|
14 |
+
# Pick up host/port command line arguments
|
15 |
+
host = shared.args.listen_host if shared.args.listen_host and shared.args.listen else '127.0.0.1'
|
16 |
+
port = shared.args.listen_port if shared.args.listen_port else '7860'
|
17 |
+
|
18 |
+
# Default options
|
19 |
+
options = {
|
20 |
+
'addr': f"{host}:{port}",
|
21 |
+
'authtoken_from_env': True,
|
22 |
+
'session_metadata': 'text-generation-webui',
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def ui():
|
27 |
+
settings = shared.settings.get("ngrok")
|
28 |
+
if settings:
|
29 |
+
options.update(settings)
|
30 |
+
|
31 |
+
try:
|
32 |
+
import ngrok
|
33 |
+
tunnel = ngrok.connect(**options)
|
34 |
+
logging.info(f"Ingress established at: {tunnel.url()}")
|
35 |
+
except ModuleNotFoundError:
|
36 |
+
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")
|
$extensions/openai/cache_embedding_model.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# preload the embedding model, useful for Docker images to prevent re-download on config change
|
3 |
+
# Dockerfile:
|
4 |
+
# ENV OPENEDAI_EMBEDDING_MODEL="sentence-transformers/all-mpnet-base-v2" # Optional
|
5 |
+
# RUN python3 cache_embedded_model.py
|
6 |
+
import os
|
7 |
+
|
8 |
+
import sentence_transformers
|
9 |
+
|
10 |
+
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
|
11 |
+
model = sentence_transformers.SentenceTransformer(st_model)
|
$extensions/openai/completions.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import time
|
3 |
+
from collections import deque
|
4 |
+
|
5 |
+
import tiktoken
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from transformers import LogitsProcessor, LogitsProcessorList
|
9 |
+
|
10 |
+
from extensions.openai.errors import InvalidRequestError
|
11 |
+
from extensions.openai.utils import debug_msg
|
12 |
+
from modules import shared
|
13 |
+
from modules.chat import (
|
14 |
+
generate_chat_prompt,
|
15 |
+
generate_chat_reply,
|
16 |
+
load_character_memoized
|
17 |
+
)
|
18 |
+
from modules.presets import load_preset_memoized
|
19 |
+
from modules.text_generation import decode, encode, generate_reply
|
20 |
+
|
21 |
+
|
22 |
+
class LogitsBiasProcessor(LogitsProcessor):
|
23 |
+
def __init__(self, logit_bias={}):
|
24 |
+
self.logit_bias = logit_bias
|
25 |
+
if self.logit_bias:
|
26 |
+
self.keys = list([int(key) for key in self.logit_bias.keys()])
|
27 |
+
values = [self.logit_bias[str(key)] for key in self.keys]
|
28 |
+
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
|
29 |
+
debug_msg(f"{self})")
|
30 |
+
|
31 |
+
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
32 |
+
if self.logit_bias:
|
33 |
+
debug_msg(logits[0, self.keys], " + ", self.values)
|
34 |
+
logits[0, self.keys] += self.values
|
35 |
+
debug_msg(" --> ", logits[0, self.keys])
|
36 |
+
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
|
37 |
+
|
38 |
+
return logits
|
39 |
+
|
40 |
+
def __repr__(self):
|
41 |
+
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
|
42 |
+
|
43 |
+
|
44 |
+
class LogprobProcessor(LogitsProcessor):
|
45 |
+
def __init__(self, logprobs=None):
|
46 |
+
self.logprobs = logprobs
|
47 |
+
self.token_alternatives = {}
|
48 |
+
|
49 |
+
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
50 |
+
if self.logprobs is not None: # 0-5
|
51 |
+
log_e_probabilities = F.log_softmax(logits, dim=1)
|
52 |
+
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
53 |
+
top_tokens = [decode(tok) for tok in top_indices[0]]
|
54 |
+
top_probs = [float(x) for x in top_values[0]]
|
55 |
+
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
56 |
+
debug_msg(repr(self))
|
57 |
+
|
58 |
+
return logits
|
59 |
+
|
60 |
+
def __repr__(self):
|
61 |
+
return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>"
|
62 |
+
|
63 |
+
|
64 |
+
def convert_logprobs_to_tiktoken(model, logprobs):
|
65 |
+
# more problems than it's worth.
|
66 |
+
# try:
|
67 |
+
# encoder = tiktoken.encoding_for_model(model)
|
68 |
+
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
69 |
+
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
70 |
+
# except KeyError:
|
71 |
+
# # assume native tokens if we can't find the tokenizer
|
72 |
+
# return logprobs
|
73 |
+
|
74 |
+
return logprobs
|
75 |
+
|
76 |
+
|
77 |
+
def process_parameters(body, is_legacy=False):
|
78 |
+
generate_params = body
|
79 |
+
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
80 |
+
generate_params['max_new_tokens'] = body.pop(max_tokens_str)
|
81 |
+
if generate_params['truncation_length'] == 0:
|
82 |
+
generate_params['truncation_length'] = shared.settings['truncation_length']
|
83 |
+
|
84 |
+
if body['preset'] is not None:
|
85 |
+
preset = load_preset_memoized(body['preset'])
|
86 |
+
generate_params.update(preset)
|
87 |
+
|
88 |
+
generate_params['custom_stopping_strings'] = []
|
89 |
+
if 'stop' in body: # str or array, max len 4 (ignored)
|
90 |
+
if isinstance(body['stop'], str):
|
91 |
+
generate_params['custom_stopping_strings'] = [body['stop']]
|
92 |
+
elif isinstance(body['stop'], list):
|
93 |
+
generate_params['custom_stopping_strings'] = body['stop']
|
94 |
+
|
95 |
+
logits_processor = []
|
96 |
+
logit_bias = body.get('logit_bias', None)
|
97 |
+
if logit_bias: # {str: float, ...}
|
98 |
+
# XXX convert tokens from tiktoken based on requested model
|
99 |
+
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
|
100 |
+
try:
|
101 |
+
encoder = tiktoken.encoding_for_model(generate_params['model'])
|
102 |
+
new_logit_bias = {}
|
103 |
+
for logit, bias in logit_bias.items():
|
104 |
+
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
|
105 |
+
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
|
106 |
+
continue
|
107 |
+
|
108 |
+
new_logit_bias[str(int(x))] = bias
|
109 |
+
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
|
110 |
+
logit_bias = new_logit_bias
|
111 |
+
except KeyError:
|
112 |
+
pass # assume native tokens if we can't find the tokenizer
|
113 |
+
|
114 |
+
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
115 |
+
|
116 |
+
logprobs = None # coming to chat eventually
|
117 |
+
if 'logprobs' in body:
|
118 |
+
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
119 |
+
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
120 |
+
logits_processor.extend([generate_params['logprob_proc']])
|
121 |
+
else:
|
122 |
+
logprobs = None
|
123 |
+
|
124 |
+
if logits_processor: # requires logits_processor support
|
125 |
+
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
126 |
+
|
127 |
+
return generate_params
|
128 |
+
|
129 |
+
|
130 |
+
def convert_history(history):
|
131 |
+
'''
|
132 |
+
Chat histories in this program are in the format [message, reply].
|
133 |
+
This function converts OpenAI histories to that format.
|
134 |
+
'''
|
135 |
+
chat_dialogue = []
|
136 |
+
current_message = ""
|
137 |
+
current_reply = ""
|
138 |
+
user_input = ""
|
139 |
+
system_message = ""
|
140 |
+
|
141 |
+
for entry in history:
|
142 |
+
content = entry["content"]
|
143 |
+
role = entry["role"]
|
144 |
+
|
145 |
+
if role == "user":
|
146 |
+
user_input = content
|
147 |
+
if current_message:
|
148 |
+
chat_dialogue.append([current_message, ''])
|
149 |
+
current_message = ""
|
150 |
+
current_message = content
|
151 |
+
elif role == "assistant":
|
152 |
+
current_reply = content
|
153 |
+
if current_message:
|
154 |
+
chat_dialogue.append([current_message, current_reply])
|
155 |
+
current_message = ""
|
156 |
+
current_reply = ""
|
157 |
+
else:
|
158 |
+
chat_dialogue.append(['', current_reply])
|
159 |
+
elif role == "system":
|
160 |
+
system_message = content
|
161 |
+
|
162 |
+
# if current_message:
|
163 |
+
# chat_dialogue.append([current_message, ''])
|
164 |
+
|
165 |
+
return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
|
166 |
+
|
167 |
+
|
168 |
+
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -> dict:
|
169 |
+
if body.get('functions', []):
|
170 |
+
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
171 |
+
|
172 |
+
if body.get('function_call', ''):
|
173 |
+
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
174 |
+
|
175 |
+
if 'messages' not in body:
|
176 |
+
raise InvalidRequestError(message="messages is required", param='messages')
|
177 |
+
|
178 |
+
messages = body['messages']
|
179 |
+
for m in messages:
|
180 |
+
if 'role' not in m:
|
181 |
+
raise InvalidRequestError(message="messages: missing role", param='messages')
|
182 |
+
elif m['role'] == 'function':
|
183 |
+
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
184 |
+
if 'content' not in m:
|
185 |
+
raise InvalidRequestError(message="messages: missing content", param='messages')
|
186 |
+
|
187 |
+
# Chat Completions
|
188 |
+
object_type = 'chat.completions' if not stream else 'chat.completions.chunk'
|
189 |
+
created_time = int(time.time())
|
190 |
+
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
191 |
+
resp_list = 'data' if is_legacy else 'choices'
|
192 |
+
|
193 |
+
# generation parameters
|
194 |
+
generate_params = process_parameters(body, is_legacy=is_legacy)
|
195 |
+
continue_ = body['continue_']
|
196 |
+
|
197 |
+
# Instruction template
|
198 |
+
instruction_template = body['instruction_template'] or shared.settings['instruction_template']
|
199 |
+
instruction_template = "Alpaca" if instruction_template == "None" else instruction_template
|
200 |
+
name1_instruct, name2_instruct, _, _, context_instruct, turn_template, system_message = load_character_memoized(instruction_template, '', '', instruct=True)
|
201 |
+
name1_instruct = body['name1_instruct'] or name1_instruct
|
202 |
+
name2_instruct = body['name2_instruct'] or name2_instruct
|
203 |
+
turn_template = body['turn_template'] or turn_template
|
204 |
+
context_instruct = body['context_instruct'] or context_instruct
|
205 |
+
system_message = body['system_message'] or system_message
|
206 |
+
chat_instruct_command = body['chat_instruct_command'] or shared.settings['chat-instruct_command']
|
207 |
+
|
208 |
+
# Chat character
|
209 |
+
character = body['character'] or shared.settings['character']
|
210 |
+
character = "Assistant" if character == "None" else character
|
211 |
+
name1 = body['name1'] or shared.settings['name1']
|
212 |
+
name1, name2, _, greeting, context, _, _ = load_character_memoized(character, name1, '', instruct=False)
|
213 |
+
name2 = body['name2'] or name2
|
214 |
+
context = body['context'] or context
|
215 |
+
greeting = body['greeting'] or greeting
|
216 |
+
|
217 |
+
# History
|
218 |
+
user_input, custom_system_message, history = convert_history(messages)
|
219 |
+
|
220 |
+
generate_params.update({
|
221 |
+
'mode': body['mode'],
|
222 |
+
'name1': name1,
|
223 |
+
'name2': name2,
|
224 |
+
'context': context,
|
225 |
+
'greeting': greeting,
|
226 |
+
'name1_instruct': name1_instruct,
|
227 |
+
'name2_instruct': name2_instruct,
|
228 |
+
'context_instruct': context_instruct,
|
229 |
+
'system_message': system_message,
|
230 |
+
'custom_system_message': custom_system_message,
|
231 |
+
'turn_template': turn_template,
|
232 |
+
'chat-instruct_command': chat_instruct_command,
|
233 |
+
'history': history,
|
234 |
+
'stream': stream
|
235 |
+
})
|
236 |
+
|
237 |
+
max_tokens = generate_params['max_new_tokens']
|
238 |
+
if max_tokens in [None, 0]:
|
239 |
+
generate_params['max_new_tokens'] = 200
|
240 |
+
generate_params['auto_max_new_tokens'] = True
|
241 |
+
|
242 |
+
requested_model = generate_params.pop('model')
|
243 |
+
logprob_proc = generate_params.pop('logprob_proc', None)
|
244 |
+
|
245 |
+
def chat_streaming_chunk(content):
|
246 |
+
# begin streaming
|
247 |
+
chunk = {
|
248 |
+
"id": cmpl_id,
|
249 |
+
"object": object_type,
|
250 |
+
"created": created_time,
|
251 |
+
"model": shared.model_name,
|
252 |
+
resp_list: [{
|
253 |
+
"index": 0,
|
254 |
+
"finish_reason": None,
|
255 |
+
# So yeah... do both methods? delta and messages.
|
256 |
+
"message": {'role': 'assistant', 'content': content},
|
257 |
+
"delta": {'role': 'assistant', 'content': content},
|
258 |
+
}],
|
259 |
+
}
|
260 |
+
|
261 |
+
if logprob_proc: # not official for chat yet
|
262 |
+
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
263 |
+
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
264 |
+
# else:
|
265 |
+
# chunk[resp_list][0]["logprobs"] = None
|
266 |
+
return chunk
|
267 |
+
|
268 |
+
if stream:
|
269 |
+
yield chat_streaming_chunk('')
|
270 |
+
|
271 |
+
# generate reply #######################################
|
272 |
+
prompt = generate_chat_prompt(user_input, generate_params)
|
273 |
+
token_count = len(encode(prompt)[0])
|
274 |
+
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
275 |
+
|
276 |
+
generator = generate_chat_reply(
|
277 |
+
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
278 |
+
|
279 |
+
answer = ''
|
280 |
+
seen_content = ''
|
281 |
+
completion_token_count = 0
|
282 |
+
|
283 |
+
for a in generator:
|
284 |
+
answer = a['internal'][-1][1]
|
285 |
+
if stream:
|
286 |
+
len_seen = len(seen_content)
|
287 |
+
new_content = answer[len_seen:]
|
288 |
+
|
289 |
+
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
290 |
+
continue
|
291 |
+
|
292 |
+
seen_content = answer
|
293 |
+
chunk = chat_streaming_chunk(new_content)
|
294 |
+
yield chunk
|
295 |
+
|
296 |
+
completion_token_count = len(encode(answer)[0])
|
297 |
+
stop_reason = "stop"
|
298 |
+
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
299 |
+
stop_reason = "length"
|
300 |
+
|
301 |
+
if stream:
|
302 |
+
chunk = chat_streaming_chunk('')
|
303 |
+
chunk[resp_list][0]['finish_reason'] = stop_reason
|
304 |
+
chunk['usage'] = {
|
305 |
+
"prompt_tokens": token_count,
|
306 |
+
"completion_tokens": completion_token_count,
|
307 |
+
"total_tokens": token_count + completion_token_count
|
308 |
+
}
|
309 |
+
|
310 |
+
yield chunk
|
311 |
+
else:
|
312 |
+
resp = {
|
313 |
+
"id": cmpl_id,
|
314 |
+
"object": object_type,
|
315 |
+
"created": created_time,
|
316 |
+
"model": shared.model_name,
|
317 |
+
resp_list: [{
|
318 |
+
"index": 0,
|
319 |
+
"finish_reason": stop_reason,
|
320 |
+
"message": {"role": "assistant", "content": answer}
|
321 |
+
}],
|
322 |
+
"usage": {
|
323 |
+
"prompt_tokens": token_count,
|
324 |
+
"completion_tokens": completion_token_count,
|
325 |
+
"total_tokens": token_count + completion_token_count
|
326 |
+
}
|
327 |
+
}
|
328 |
+
if logprob_proc: # not official for chat yet
|
329 |
+
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
330 |
+
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
331 |
+
# else:
|
332 |
+
# resp[resp_list][0]["logprobs"] = None
|
333 |
+
|
334 |
+
yield resp
|
335 |
+
|
336 |
+
|
337 |
+
def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
338 |
+
object_type = 'text_completion.chunk' if stream else 'text_completion'
|
339 |
+
created_time = int(time.time())
|
340 |
+
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
341 |
+
resp_list = 'data' if is_legacy else 'choices'
|
342 |
+
|
343 |
+
prompt_str = 'context' if is_legacy else 'prompt'
|
344 |
+
|
345 |
+
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
346 |
+
if prompt_str not in body:
|
347 |
+
raise InvalidRequestError("Missing required input", param=prompt_str)
|
348 |
+
|
349 |
+
# common params
|
350 |
+
generate_params = process_parameters(body, is_legacy=is_legacy)
|
351 |
+
max_tokens = generate_params['max_new_tokens']
|
352 |
+
generate_params['stream'] = stream
|
353 |
+
requested_model = generate_params.pop('model')
|
354 |
+
logprob_proc = generate_params.pop('logprob_proc', None)
|
355 |
+
suffix = body['suffix'] if body['suffix'] else ''
|
356 |
+
echo = body['echo']
|
357 |
+
|
358 |
+
if not stream:
|
359 |
+
prompt_arg = body[prompt_str]
|
360 |
+
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
|
361 |
+
prompt_arg = [prompt_arg]
|
362 |
+
|
363 |
+
resp_list_data = []
|
364 |
+
total_completion_token_count = 0
|
365 |
+
total_prompt_token_count = 0
|
366 |
+
|
367 |
+
for idx, prompt in enumerate(prompt_arg, start=0):
|
368 |
+
if isinstance(prompt[0], int):
|
369 |
+
# token lists
|
370 |
+
if requested_model == shared.model_name:
|
371 |
+
prompt = decode(prompt)[0]
|
372 |
+
else:
|
373 |
+
try:
|
374 |
+
encoder = tiktoken.encoding_for_model(requested_model)
|
375 |
+
prompt = encoder.decode(prompt)
|
376 |
+
except KeyError:
|
377 |
+
prompt = decode(prompt)[0]
|
378 |
+
|
379 |
+
prefix = prompt if echo else ''
|
380 |
+
token_count = len(encode(prompt)[0])
|
381 |
+
total_prompt_token_count += token_count
|
382 |
+
|
383 |
+
# generate reply #######################################
|
384 |
+
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
385 |
+
generator = generate_reply(prompt, generate_params, is_chat=False)
|
386 |
+
answer = ''
|
387 |
+
|
388 |
+
for a in generator:
|
389 |
+
answer = a
|
390 |
+
|
391 |
+
completion_token_count = len(encode(answer)[0])
|
392 |
+
total_completion_token_count += completion_token_count
|
393 |
+
stop_reason = "stop"
|
394 |
+
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
395 |
+
stop_reason = "length"
|
396 |
+
|
397 |
+
respi = {
|
398 |
+
"index": idx,
|
399 |
+
"finish_reason": stop_reason,
|
400 |
+
"text": prefix + answer + suffix,
|
401 |
+
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
402 |
+
}
|
403 |
+
|
404 |
+
resp_list_data.extend([respi])
|
405 |
+
|
406 |
+
resp = {
|
407 |
+
"id": cmpl_id,
|
408 |
+
"object": object_type,
|
409 |
+
"created": created_time,
|
410 |
+
"model": shared.model_name,
|
411 |
+
resp_list: resp_list_data,
|
412 |
+
"usage": {
|
413 |
+
"prompt_tokens": total_prompt_token_count,
|
414 |
+
"completion_tokens": total_completion_token_count,
|
415 |
+
"total_tokens": total_prompt_token_count + total_completion_token_count
|
416 |
+
}
|
417 |
+
}
|
418 |
+
|
419 |
+
yield resp
|
420 |
+
else:
|
421 |
+
prompt = body[prompt_str]
|
422 |
+
if isinstance(prompt, list):
|
423 |
+
if prompt and isinstance(prompt[0], int):
|
424 |
+
try:
|
425 |
+
encoder = tiktoken.encoding_for_model(requested_model)
|
426 |
+
prompt = encoder.decode(prompt)
|
427 |
+
except KeyError:
|
428 |
+
prompt = decode(prompt)[0]
|
429 |
+
else:
|
430 |
+
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
431 |
+
|
432 |
+
prefix = prompt if echo else ''
|
433 |
+
token_count = len(encode(prompt)[0])
|
434 |
+
|
435 |
+
def text_streaming_chunk(content):
|
436 |
+
# begin streaming
|
437 |
+
chunk = {
|
438 |
+
"id": cmpl_id,
|
439 |
+
"object": object_type,
|
440 |
+
"created": created_time,
|
441 |
+
"model": shared.model_name,
|
442 |
+
resp_list: [{
|
443 |
+
"index": 0,
|
444 |
+
"finish_reason": None,
|
445 |
+
"text": content,
|
446 |
+
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
447 |
+
}],
|
448 |
+
}
|
449 |
+
|
450 |
+
return chunk
|
451 |
+
|
452 |
+
yield text_streaming_chunk(prefix)
|
453 |
+
|
454 |
+
# generate reply #######################################
|
455 |
+
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
456 |
+
generator = generate_reply(prompt, generate_params, is_chat=False)
|
457 |
+
|
458 |
+
answer = ''
|
459 |
+
seen_content = ''
|
460 |
+
completion_token_count = 0
|
461 |
+
|
462 |
+
for a in generator:
|
463 |
+
answer = a
|
464 |
+
|
465 |
+
len_seen = len(seen_content)
|
466 |
+
new_content = answer[len_seen:]
|
467 |
+
|
468 |
+
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
469 |
+
continue
|
470 |
+
|
471 |
+
seen_content = answer
|
472 |
+
chunk = text_streaming_chunk(new_content)
|
473 |
+
yield chunk
|
474 |
+
|
475 |
+
completion_token_count = len(encode(answer)[0])
|
476 |
+
stop_reason = "stop"
|
477 |
+
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
478 |
+
stop_reason = "length"
|
479 |
+
|
480 |
+
chunk = text_streaming_chunk(suffix)
|
481 |
+
chunk[resp_list][0]["finish_reason"] = stop_reason
|
482 |
+
chunk["usage"] = {
|
483 |
+
"prompt_tokens": token_count,
|
484 |
+
"completion_tokens": completion_token_count,
|
485 |
+
"total_tokens": token_count + completion_token_count
|
486 |
+
}
|
487 |
+
|
488 |
+
yield chunk
|
489 |
+
|
490 |
+
|
491 |
+
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
492 |
+
generator = chat_completions_common(body, is_legacy, stream=False)
|
493 |
+
return deque(generator, maxlen=1).pop()
|
494 |
+
|
495 |
+
|
496 |
+
def stream_chat_completions(body: dict, is_legacy: bool = False):
|
497 |
+
for resp in chat_completions_common(body, is_legacy, stream=True):
|
498 |
+
yield resp
|
499 |
+
|
500 |
+
|
501 |
+
def completions(body: dict, is_legacy: bool = False) -> dict:
|
502 |
+
generator = completions_common(body, is_legacy, stream=False)
|
503 |
+
return deque(generator, maxlen=1).pop()
|
504 |
+
|
505 |
+
|
506 |
+
def stream_completions(body: dict, is_legacy: bool = False):
|
507 |
+
for resp in completions_common(body, is_legacy, stream=True):
|
508 |
+
yield resp
|
$extensions/openai/embeddings.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from transformers import AutoModel
|
5 |
+
|
6 |
+
from extensions.openai.errors import ServiceUnavailableError
|
7 |
+
from extensions.openai.utils import debug_msg, float_list_to_base64
|
8 |
+
from modules.logging_colors import logger
|
9 |
+
|
10 |
+
embeddings_params_initialized = False
|
11 |
+
|
12 |
+
|
13 |
+
def initialize_embedding_params():
|
14 |
+
'''
|
15 |
+
using 'lazy loading' to avoid circular import
|
16 |
+
so this function will be executed only once
|
17 |
+
'''
|
18 |
+
global embeddings_params_initialized
|
19 |
+
if not embeddings_params_initialized:
|
20 |
+
from extensions.openai.script import params
|
21 |
+
|
22 |
+
global st_model, embeddings_model, embeddings_device
|
23 |
+
|
24 |
+
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
|
25 |
+
embeddings_model = None
|
26 |
+
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
|
27 |
+
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", params.get('embedding_device', 'cpu'))
|
28 |
+
if embeddings_device.lower() == 'auto':
|
29 |
+
embeddings_device = None
|
30 |
+
|
31 |
+
embeddings_params_initialized = True
|
32 |
+
|
33 |
+
|
34 |
+
def load_embedding_model(model: str):
|
35 |
+
try:
|
36 |
+
from sentence_transformers import SentenceTransformer
|
37 |
+
except ModuleNotFoundError:
|
38 |
+
logger.error("The sentence_transformers module has not been found. Please install it manually with pip install -U sentence-transformers.")
|
39 |
+
raise ModuleNotFoundError
|
40 |
+
|
41 |
+
initialize_embedding_params()
|
42 |
+
global embeddings_device, embeddings_model
|
43 |
+
try:
|
44 |
+
print(f"Try embedding model: {model} on {embeddings_device}")
|
45 |
+
if 'jina-embeddings' in model:
|
46 |
+
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True) # trust_remote_code is needed to use the encode method
|
47 |
+
embeddings_model = embeddings_model.to(embeddings_device)
|
48 |
+
else:
|
49 |
+
embeddings_model = SentenceTransformer(model, device=embeddings_device)
|
50 |
+
|
51 |
+
print(f"Loaded embedding model: {model}")
|
52 |
+
except Exception as e:
|
53 |
+
embeddings_model = None
|
54 |
+
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
55 |
+
|
56 |
+
|
57 |
+
def get_embeddings_model():
|
58 |
+
initialize_embedding_params()
|
59 |
+
global embeddings_model, st_model
|
60 |
+
if st_model and not embeddings_model:
|
61 |
+
load_embedding_model(st_model) # lazy load the model
|
62 |
+
|
63 |
+
return embeddings_model
|
64 |
+
|
65 |
+
|
66 |
+
def get_embeddings_model_name() -> str:
|
67 |
+
initialize_embedding_params()
|
68 |
+
global st_model
|
69 |
+
return st_model
|
70 |
+
|
71 |
+
|
72 |
+
def get_embeddings(input: list) -> np.ndarray:
|
73 |
+
model = get_embeddings_model()
|
74 |
+
debug_msg(f"embedding model : {model}")
|
75 |
+
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
|
76 |
+
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
|
77 |
+
return embedding
|
78 |
+
|
79 |
+
|
80 |
+
def embeddings(input: list, encoding_format: str) -> dict:
|
81 |
+
embeddings = get_embeddings(input)
|
82 |
+
if encoding_format == "base64":
|
83 |
+
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
|
84 |
+
else:
|
85 |
+
data = [{"object": "embedding", "embedding": emb.tolist(), "index": n} for n, emb in enumerate(embeddings)]
|
86 |
+
|
87 |
+
response = {
|
88 |
+
"object": "list",
|
89 |
+
"data": data,
|
90 |
+
"model": st_model, # return the real model
|
91 |
+
"usage": {
|
92 |
+
"prompt_tokens": 0,
|
93 |
+
"total_tokens": 0,
|
94 |
+
}
|
95 |
+
}
|
96 |
+
|
97 |
+
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
98 |
+
return response
|
$extensions/openai/errors.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class OpenAIError(Exception):
|
2 |
+
def __init__(self, message=None, code=500, internal_message=''):
|
3 |
+
self.message = message
|
4 |
+
self.code = code
|
5 |
+
self.internal_message = internal_message
|
6 |
+
|
7 |
+
def __repr__(self):
|
8 |
+
return "%s(message=%r, code=%d)" % (
|
9 |
+
self.__class__.__name__,
|
10 |
+
self.message,
|
11 |
+
self.code,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
class InvalidRequestError(OpenAIError):
|
16 |
+
def __init__(self, message, param, code=400, internal_message=''):
|
17 |
+
super().__init__(message, code, internal_message)
|
18 |
+
self.param = param
|
19 |
+
|
20 |
+
def __repr__(self):
|
21 |
+
return "%s(message=%r, code=%d, param=%s)" % (
|
22 |
+
self.__class__.__name__,
|
23 |
+
self.message,
|
24 |
+
self.code,
|
25 |
+
self.param,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class ServiceUnavailableError(OpenAIError):
|
30 |
+
def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''):
|
31 |
+
super().__init__(message, code, internal_message)
|
$extensions/openai/images.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import requests
|
5 |
+
|
6 |
+
from extensions.openai.errors import ServiceUnavailableError
|
7 |
+
|
8 |
+
|
9 |
+
def generations(prompt: str, size: str, response_format: str, n: int):
|
10 |
+
# Stable Diffusion callout wrapper for txt2img
|
11 |
+
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
12 |
+
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
|
13 |
+
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
|
14 |
+
# it's too general an API to try and shape the result with specific tags like negative prompts
|
15 |
+
# or "masterpiece", etc. SD configuration is beyond the scope of this API.
|
16 |
+
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
17 |
+
# require changing the form data handling to accept multipart form data, also to properly support
|
18 |
+
# url return types will require file management and a web serving files... Perhaps later!
|
19 |
+
base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
|
20 |
+
sd_defaults = {
|
21 |
+
'sampler_name': 'DPM++ 2M Karras', # vast improvement
|
22 |
+
'steps': 30,
|
23 |
+
}
|
24 |
+
|
25 |
+
width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size
|
26 |
+
|
27 |
+
# to hack on better generation, edit default payload.
|
28 |
+
payload = {
|
29 |
+
'prompt': prompt, # ignore prompt limit of 1000 characters
|
30 |
+
'width': width,
|
31 |
+
'height': height,
|
32 |
+
'batch_size': n,
|
33 |
+
}
|
34 |
+
payload.update(sd_defaults)
|
35 |
+
|
36 |
+
scale = min(width, height) / base_model_size
|
37 |
+
if scale >= 1.2:
|
38 |
+
# for better performance with the default size (1024), and larger res.
|
39 |
+
scaler = {
|
40 |
+
'width': width // scale,
|
41 |
+
'height': height // scale,
|
42 |
+
'hr_scale': scale,
|
43 |
+
'enable_hr': True,
|
44 |
+
'hr_upscaler': 'Latent',
|
45 |
+
'denoising_strength': 0.68,
|
46 |
+
}
|
47 |
+
payload.update(scaler)
|
48 |
+
|
49 |
+
resp = {
|
50 |
+
'created': int(time.time()),
|
51 |
+
'data': []
|
52 |
+
}
|
53 |
+
from extensions.openai.script import params
|
54 |
+
|
55 |
+
# TODO: support SD_WEBUI_AUTH username:password pair.
|
56 |
+
sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img"
|
57 |
+
|
58 |
+
response = requests.post(url=sd_url, json=payload)
|
59 |
+
r = response.json()
|
60 |
+
if response.status_code != 200 or 'images' not in r:
|
61 |
+
print(r)
|
62 |
+
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None))
|
63 |
+
# r['parameters']...
|
64 |
+
for b64_json in r['images']:
|
65 |
+
if response_format == 'b64_json':
|
66 |
+
resp['data'].extend([{'b64_json': b64_json}])
|
67 |
+
else:
|
68 |
+
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
|
69 |
+
|
70 |
+
return resp
|
$extensions/openai/logits.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from extensions.openai.completions import process_parameters
|
2 |
+
from modules.logits import get_next_logits
|
3 |
+
|
4 |
+
|
5 |
+
def _get_next_logits(body):
|
6 |
+
# Pre-process the input payload to simulate a real generation
|
7 |
+
use_samplers = body['use_samplers']
|
8 |
+
state = process_parameters(body) if use_samplers else {}
|
9 |
+
state['stream'] = True
|
10 |
+
|
11 |
+
return get_next_logits(body['prompt'], state, use_samplers, "", return_dict=True)
|
$extensions/openai/models.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import shared
|
2 |
+
from modules.logging_colors import logger
|
3 |
+
from modules.LoRA import add_lora_to_model
|
4 |
+
from modules.models import load_model, unload_model
|
5 |
+
from modules.models_settings import get_model_metadata, update_model_parameters
|
6 |
+
from modules.utils import get_available_loras, get_available_models
|
7 |
+
|
8 |
+
|
9 |
+
def get_current_model_info():
|
10 |
+
return {
|
11 |
+
'model_name': shared.model_name,
|
12 |
+
'lora_names': shared.lora_names
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
def list_models():
|
17 |
+
return {'model_names': get_available_models()[1:]}
|
18 |
+
|
19 |
+
|
20 |
+
def list_dummy_models():
|
21 |
+
result = {
|
22 |
+
"object": "list",
|
23 |
+
"data": []
|
24 |
+
}
|
25 |
+
|
26 |
+
# these are expected by so much, so include some here as a dummy
|
27 |
+
for model in ['gpt-3.5-turbo', 'text-embedding-ada-002']:
|
28 |
+
result["data"].append(model_info_dict(model))
|
29 |
+
|
30 |
+
return result
|
31 |
+
|
32 |
+
|
33 |
+
def model_info_dict(model_name: str) -> dict:
|
34 |
+
return {
|
35 |
+
"id": model_name,
|
36 |
+
"object": "model",
|
37 |
+
"created": 0,
|
38 |
+
"owned_by": "user"
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def _load_model(data):
|
43 |
+
model_name = data["model_name"]
|
44 |
+
args = data["args"]
|
45 |
+
settings = data["settings"]
|
46 |
+
|
47 |
+
unload_model()
|
48 |
+
model_settings = get_model_metadata(model_name)
|
49 |
+
update_model_parameters(model_settings)
|
50 |
+
|
51 |
+
# Update shared.args with custom model loading settings
|
52 |
+
if args:
|
53 |
+
for k in args:
|
54 |
+
if hasattr(shared.args, k):
|
55 |
+
setattr(shared.args, k, args[k])
|
56 |
+
|
57 |
+
shared.model, shared.tokenizer = load_model(model_name)
|
58 |
+
shared.model_name = model_name
|
59 |
+
|
60 |
+
# Update shared.settings with custom generation defaults
|
61 |
+
if settings:
|
62 |
+
for k in settings:
|
63 |
+
if k in shared.settings:
|
64 |
+
shared.settings[k] = settings[k]
|
65 |
+
if k == 'truncation_length':
|
66 |
+
logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}")
|
67 |
+
elif k == 'instruction_template':
|
68 |
+
logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}")
|
69 |
+
|
70 |
+
|
71 |
+
def list_loras():
|
72 |
+
return {'lora_names': get_available_loras()[1:]}
|
73 |
+
|
74 |
+
|
75 |
+
def load_loras(lora_names):
|
76 |
+
add_lora_to_model(lora_names)
|
77 |
+
|
78 |
+
|
79 |
+
def unload_all_loras():
|
80 |
+
add_lora_to_model([])
|
$extensions/openai/moderations.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from numpy.linalg import norm
|
5 |
+
|
6 |
+
from extensions.openai.embeddings import get_embeddings
|
7 |
+
|
8 |
+
moderations_disabled = False # return 0/false
|
9 |
+
category_embeddings = None
|
10 |
+
antonym_embeddings = None
|
11 |
+
categories = ["sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence"]
|
12 |
+
flag_threshold = 0.5
|
13 |
+
|
14 |
+
|
15 |
+
def get_category_embeddings() -> dict:
|
16 |
+
global category_embeddings, categories
|
17 |
+
if category_embeddings is None:
|
18 |
+
embeddings = get_embeddings(categories).tolist()
|
19 |
+
category_embeddings = dict(zip(categories, embeddings))
|
20 |
+
|
21 |
+
return category_embeddings
|
22 |
+
|
23 |
+
|
24 |
+
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
25 |
+
return np.dot(a, b) / (norm(a) * norm(b))
|
26 |
+
|
27 |
+
|
28 |
+
# seems most openai like with all-mpnet-base-v2
|
29 |
+
def mod_score(a: np.ndarray, b: np.ndarray) -> float:
|
30 |
+
return 2.0 * np.dot(a, b)
|
31 |
+
|
32 |
+
|
33 |
+
def moderations(input):
|
34 |
+
global category_embeddings, categories, flag_threshold, moderations_disabled
|
35 |
+
results = {
|
36 |
+
"id": f"modr-{int(time.time()*1e9)}",
|
37 |
+
"model": "text-moderation-001",
|
38 |
+
"results": [],
|
39 |
+
}
|
40 |
+
|
41 |
+
if moderations_disabled:
|
42 |
+
results['results'] = [{
|
43 |
+
'categories': dict([(C, False) for C in categories]),
|
44 |
+
'category_scores': dict([(C, 0.0) for C in categories]),
|
45 |
+
'flagged': False,
|
46 |
+
}]
|
47 |
+
return results
|
48 |
+
|
49 |
+
category_embeddings = get_category_embeddings()
|
50 |
+
|
51 |
+
# input, string or array
|
52 |
+
if isinstance(input, str):
|
53 |
+
input = [input]
|
54 |
+
|
55 |
+
for in_str in input:
|
56 |
+
for ine in get_embeddings([in_str]):
|
57 |
+
category_scores = dict([(C, mod_score(category_embeddings[C], ine)) for C in categories])
|
58 |
+
category_flags = dict([(C, bool(category_scores[C] > flag_threshold)) for C in categories])
|
59 |
+
flagged = any(category_flags.values())
|
60 |
+
|
61 |
+
results['results'].extend([{
|
62 |
+
'flagged': flagged,
|
63 |
+
'categories': category_flags,
|
64 |
+
'category_scores': category_scores,
|
65 |
+
}])
|
66 |
+
|
67 |
+
print(results)
|
68 |
+
|
69 |
+
return results
|
$extensions/openai/requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SpeechRecognition==3.10.0
|
2 |
+
flask_cloudflared==0.0.14
|
3 |
+
sse-starlette==1.6.5
|
4 |
+
tiktoken
|
$extensions/openai/script.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
+
from threading import Thread
|
6 |
+
|
7 |
+
import speech_recognition as sr
|
8 |
+
import uvicorn
|
9 |
+
from fastapi import Depends, FastAPI, Header, HTTPException
|
10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
11 |
+
from fastapi.requests import Request
|
12 |
+
from fastapi.responses import JSONResponse
|
13 |
+
from pydub import AudioSegment
|
14 |
+
from sse_starlette import EventSourceResponse
|
15 |
+
|
16 |
+
import extensions.openai.completions as OAIcompletions
|
17 |
+
import extensions.openai.embeddings as OAIembeddings
|
18 |
+
import extensions.openai.images as OAIimages
|
19 |
+
import extensions.openai.logits as OAIlogits
|
20 |
+
import extensions.openai.models as OAImodels
|
21 |
+
import extensions.openai.moderations as OAImoderations
|
22 |
+
from extensions.openai.errors import ServiceUnavailableError
|
23 |
+
from extensions.openai.tokens import token_count, token_decode, token_encode
|
24 |
+
from extensions.openai.utils import _start_cloudflared
|
25 |
+
from modules import shared
|
26 |
+
from modules.logging_colors import logger
|
27 |
+
from modules.models import unload_model
|
28 |
+
from modules.text_generation import stop_everything_event
|
29 |
+
|
30 |
+
from .typing import (
|
31 |
+
ChatCompletionRequest,
|
32 |
+
ChatCompletionResponse,
|
33 |
+
CompletionRequest,
|
34 |
+
CompletionResponse,
|
35 |
+
DecodeRequest,
|
36 |
+
DecodeResponse,
|
37 |
+
EmbeddingsRequest,
|
38 |
+
EmbeddingsResponse,
|
39 |
+
EncodeRequest,
|
40 |
+
EncodeResponse,
|
41 |
+
LoadLorasRequest,
|
42 |
+
LoadModelRequest,
|
43 |
+
LogitsRequest,
|
44 |
+
LogitsResponse,
|
45 |
+
LoraListResponse,
|
46 |
+
ModelInfoResponse,
|
47 |
+
ModelListResponse,
|
48 |
+
TokenCountResponse,
|
49 |
+
to_dict
|
50 |
+
)
|
51 |
+
|
52 |
+
params = {
|
53 |
+
'embedding_device': 'cpu',
|
54 |
+
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
|
55 |
+
'sd_webui_url': '',
|
56 |
+
'debug': 0
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
streaming_semaphore = asyncio.Semaphore(1)
|
61 |
+
|
62 |
+
|
63 |
+
def verify_api_key(authorization: str = Header(None)) -> None:
|
64 |
+
expected_api_key = shared.args.api_key
|
65 |
+
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
|
66 |
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
67 |
+
|
68 |
+
|
69 |
+
def verify_admin_key(authorization: str = Header(None)) -> None:
|
70 |
+
expected_api_key = shared.args.admin_key
|
71 |
+
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
|
72 |
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
73 |
+
|
74 |
+
|
75 |
+
app = FastAPI()
|
76 |
+
check_key = [Depends(verify_api_key)]
|
77 |
+
check_admin_key = [Depends(verify_admin_key)]
|
78 |
+
|
79 |
+
# Configure CORS settings to allow all origins, methods, and headers
|
80 |
+
app.add_middleware(
|
81 |
+
CORSMiddleware,
|
82 |
+
allow_origins=["*"],
|
83 |
+
allow_credentials=True,
|
84 |
+
allow_methods=["*"],
|
85 |
+
allow_headers=["*"]
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
@app.options("/", dependencies=check_key)
|
90 |
+
async def options_route():
|
91 |
+
return JSONResponse(content="OK")
|
92 |
+
|
93 |
+
|
94 |
+
@app.post('/v1/completions', response_model=CompletionResponse, dependencies=check_key)
|
95 |
+
async def openai_completions(request: Request, request_data: CompletionRequest):
|
96 |
+
path = request.url.path
|
97 |
+
is_legacy = "/generate" in path
|
98 |
+
|
99 |
+
if request_data.stream:
|
100 |
+
async def generator():
|
101 |
+
async with streaming_semaphore:
|
102 |
+
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
|
103 |
+
for resp in response:
|
104 |
+
disconnected = await request.is_disconnected()
|
105 |
+
if disconnected:
|
106 |
+
break
|
107 |
+
|
108 |
+
yield {"data": json.dumps(resp)}
|
109 |
+
|
110 |
+
return EventSourceResponse(generator()) # SSE streaming
|
111 |
+
|
112 |
+
else:
|
113 |
+
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
|
114 |
+
return JSONResponse(response)
|
115 |
+
|
116 |
+
|
117 |
+
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse, dependencies=check_key)
|
118 |
+
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
|
119 |
+
path = request.url.path
|
120 |
+
is_legacy = "/generate" in path
|
121 |
+
|
122 |
+
if request_data.stream:
|
123 |
+
async def generator():
|
124 |
+
async with streaming_semaphore:
|
125 |
+
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
126 |
+
for resp in response:
|
127 |
+
disconnected = await request.is_disconnected()
|
128 |
+
if disconnected:
|
129 |
+
break
|
130 |
+
|
131 |
+
yield {"data": json.dumps(resp)}
|
132 |
+
|
133 |
+
return EventSourceResponse(generator()) # SSE streaming
|
134 |
+
|
135 |
+
else:
|
136 |
+
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
137 |
+
return JSONResponse(response)
|
138 |
+
|
139 |
+
|
140 |
+
@app.get("/v1/models", dependencies=check_key)
|
141 |
+
@app.get("/v1/models/{model}", dependencies=check_key)
|
142 |
+
async def handle_models(request: Request):
|
143 |
+
path = request.url.path
|
144 |
+
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
|
145 |
+
|
146 |
+
if is_list:
|
147 |
+
response = OAImodels.list_dummy_models()
|
148 |
+
else:
|
149 |
+
model_name = path[len('/v1/models/'):]
|
150 |
+
response = OAImodels.model_info_dict(model_name)
|
151 |
+
|
152 |
+
return JSONResponse(response)
|
153 |
+
|
154 |
+
|
155 |
+
@app.get('/v1/billing/usage', dependencies=check_key)
|
156 |
+
def handle_billing_usage():
|
157 |
+
'''
|
158 |
+
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
159 |
+
'''
|
160 |
+
return JSONResponse(content={"total_usage": 0})
|
161 |
+
|
162 |
+
|
163 |
+
@app.post('/v1/audio/transcriptions', dependencies=check_key)
|
164 |
+
async def handle_audio_transcription(request: Request):
|
165 |
+
r = sr.Recognizer()
|
166 |
+
|
167 |
+
form = await request.form()
|
168 |
+
audio_file = await form["file"].read()
|
169 |
+
audio_data = AudioSegment.from_file(audio_file)
|
170 |
+
|
171 |
+
# Convert AudioSegment to raw data
|
172 |
+
raw_data = audio_data.raw_data
|
173 |
+
|
174 |
+
# Create AudioData object
|
175 |
+
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
|
176 |
+
whipser_language = form.getvalue('language', None)
|
177 |
+
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
|
178 |
+
|
179 |
+
transcription = {"text": ""}
|
180 |
+
|
181 |
+
try:
|
182 |
+
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
|
183 |
+
except sr.UnknownValueError:
|
184 |
+
print("Whisper could not understand audio")
|
185 |
+
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
186 |
+
except sr.RequestError as e:
|
187 |
+
print("Could not request results from Whisper", e)
|
188 |
+
transcription["text"] = "Whisper could not understand audio RequestError"
|
189 |
+
|
190 |
+
return JSONResponse(content=transcription)
|
191 |
+
|
192 |
+
|
193 |
+
@app.post('/v1/images/generations', dependencies=check_key)
|
194 |
+
async def handle_image_generation(request: Request):
|
195 |
+
|
196 |
+
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
197 |
+
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
198 |
+
|
199 |
+
body = await request.json()
|
200 |
+
prompt = body['prompt']
|
201 |
+
size = body.get('size', '1024x1024')
|
202 |
+
response_format = body.get('response_format', 'url') # or b64_json
|
203 |
+
n = body.get('n', 1) # ignore the batch limits of max 10
|
204 |
+
|
205 |
+
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
206 |
+
return JSONResponse(response)
|
207 |
+
|
208 |
+
|
209 |
+
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)
|
210 |
+
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
211 |
+
input = request_data.input
|
212 |
+
if not input:
|
213 |
+
raise HTTPException(status_code=400, detail="Missing required argument input")
|
214 |
+
|
215 |
+
if type(input) is str:
|
216 |
+
input = [input]
|
217 |
+
|
218 |
+
response = OAIembeddings.embeddings(input, request_data.encoding_format)
|
219 |
+
return JSONResponse(response)
|
220 |
+
|
221 |
+
|
222 |
+
@app.post("/v1/moderations", dependencies=check_key)
|
223 |
+
async def handle_moderations(request: Request):
|
224 |
+
body = await request.json()
|
225 |
+
input = body["input"]
|
226 |
+
if not input:
|
227 |
+
raise HTTPException(status_code=400, detail="Missing required argument input")
|
228 |
+
|
229 |
+
response = OAImoderations.moderations(input)
|
230 |
+
return JSONResponse(response)
|
231 |
+
|
232 |
+
|
233 |
+
@app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key)
|
234 |
+
async def handle_token_encode(request_data: EncodeRequest):
|
235 |
+
response = token_encode(request_data.text)
|
236 |
+
return JSONResponse(response)
|
237 |
+
|
238 |
+
|
239 |
+
@app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key)
|
240 |
+
async def handle_token_decode(request_data: DecodeRequest):
|
241 |
+
response = token_decode(request_data.tokens)
|
242 |
+
return JSONResponse(response)
|
243 |
+
|
244 |
+
|
245 |
+
@app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key)
|
246 |
+
async def handle_token_count(request_data: EncodeRequest):
|
247 |
+
response = token_count(request_data.text)
|
248 |
+
return JSONResponse(response)
|
249 |
+
|
250 |
+
|
251 |
+
@app.post("/v1/internal/logits", response_model=LogitsResponse, dependencies=check_key)
|
252 |
+
async def handle_logits(request_data: LogitsRequest):
|
253 |
+
'''
|
254 |
+
Given a prompt, returns the top 50 most likely logits as a dict.
|
255 |
+
The keys are the tokens, and the values are the probabilities.
|
256 |
+
'''
|
257 |
+
response = OAIlogits._get_next_logits(to_dict(request_data))
|
258 |
+
return JSONResponse(response)
|
259 |
+
|
260 |
+
|
261 |
+
@app.post("/v1/internal/stop-generation", dependencies=check_key)
|
262 |
+
async def handle_stop_generation(request: Request):
|
263 |
+
stop_everything_event()
|
264 |
+
return JSONResponse(content="OK")
|
265 |
+
|
266 |
+
|
267 |
+
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key)
|
268 |
+
async def handle_model_info():
|
269 |
+
payload = OAImodels.get_current_model_info()
|
270 |
+
return JSONResponse(content=payload)
|
271 |
+
|
272 |
+
|
273 |
+
@app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_admin_key)
|
274 |
+
async def handle_list_models():
|
275 |
+
payload = OAImodels.list_models()
|
276 |
+
return JSONResponse(content=payload)
|
277 |
+
|
278 |
+
|
279 |
+
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
|
280 |
+
async def handle_load_model(request_data: LoadModelRequest):
|
281 |
+
'''
|
282 |
+
This endpoint is experimental and may change in the future.
|
283 |
+
|
284 |
+
The "args" parameter can be used to modify flags like "--load-in-4bit"
|
285 |
+
or "--n-gpu-layers" before loading a model. Example:
|
286 |
+
|
287 |
+
```
|
288 |
+
"args": {
|
289 |
+
"load_in_4bit": true,
|
290 |
+
"n_gpu_layers": 12
|
291 |
+
}
|
292 |
+
```
|
293 |
+
|
294 |
+
Note that those settings will remain after loading the model. So you
|
295 |
+
may need to change them back to load a second model.
|
296 |
+
|
297 |
+
The "settings" parameter is also a dict but with keys for the
|
298 |
+
shared.settings object. It can be used to modify the default instruction
|
299 |
+
template like this:
|
300 |
+
|
301 |
+
```
|
302 |
+
"settings": {
|
303 |
+
"instruction_template": "Alpaca"
|
304 |
+
}
|
305 |
+
```
|
306 |
+
'''
|
307 |
+
|
308 |
+
try:
|
309 |
+
OAImodels._load_model(to_dict(request_data))
|
310 |
+
return JSONResponse(content="OK")
|
311 |
+
except:
|
312 |
+
traceback.print_exc()
|
313 |
+
return HTTPException(status_code=400, detail="Failed to load the model.")
|
314 |
+
|
315 |
+
|
316 |
+
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
317 |
+
async def handle_unload_model():
|
318 |
+
unload_model()
|
319 |
+
|
320 |
+
|
321 |
+
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
|
322 |
+
async def handle_list_loras():
|
323 |
+
response = OAImodels.list_loras()
|
324 |
+
return JSONResponse(content=response)
|
325 |
+
|
326 |
+
|
327 |
+
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
|
328 |
+
async def handle_load_loras(request_data: LoadLorasRequest):
|
329 |
+
try:
|
330 |
+
OAImodels.load_loras(request_data.lora_names)
|
331 |
+
return JSONResponse(content="OK")
|
332 |
+
except:
|
333 |
+
traceback.print_exc()
|
334 |
+
return HTTPException(status_code=400, detail="Failed to apply the LoRA(s).")
|
335 |
+
|
336 |
+
|
337 |
+
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
|
338 |
+
async def handle_unload_loras():
|
339 |
+
OAImodels.unload_all_loras()
|
340 |
+
return JSONResponse(content="OK")
|
341 |
+
|
342 |
+
|
343 |
+
def run_server():
|
344 |
+
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
345 |
+
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
|
346 |
+
|
347 |
+
ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
|
348 |
+
ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
|
349 |
+
|
350 |
+
if shared.args.public_api:
|
351 |
+
def on_start(public_url: str):
|
352 |
+
logger.info(f'OpenAI-compatible API URL:\n\n{public_url}\n')
|
353 |
+
|
354 |
+
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
|
355 |
+
else:
|
356 |
+
if ssl_keyfile and ssl_certfile:
|
357 |
+
logger.info(f'OpenAI-compatible API URL:\n\nhttps://{server_addr}:{port}\n')
|
358 |
+
else:
|
359 |
+
logger.info(f'OpenAI-compatible API URL:\n\nhttp://{server_addr}:{port}\n')
|
360 |
+
|
361 |
+
if shared.args.api_key:
|
362 |
+
if not shared.args.admin_key:
|
363 |
+
shared.args.admin_key = shared.args.api_key
|
364 |
+
|
365 |
+
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
|
366 |
+
|
367 |
+
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
|
368 |
+
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
|
369 |
+
|
370 |
+
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
|
371 |
+
|
372 |
+
|
373 |
+
def setup():
|
374 |
+
if shared.args.nowebui:
|
375 |
+
run_server()
|
376 |
+
else:
|
377 |
+
Thread(target=run_server, daemon=True).start()
|