import base64 |
import copy |
import re |
import time |
from collections import deque |
from io import BytesIO |
import requests |
import tiktoken |
import torch |
import torch.nn.functional as F |
from PIL import Image |
from transformers import LogitsProcessor, LogitsProcessorList |
from extensions.openai.errors import InvalidRequestError |
from extensions.openai.utils import debug_msg |
from modules import shared |
from modules.chat import ( |
generate_chat_prompt, |
generate_chat_reply, |
load_character_memoized, |
load_instruction_template_memoized |
) |
from modules.presets import load_preset_memoized |
from modules.text_generation import ( |
decode, |
encode, |
generate_reply, |
get_reply_from_output_ids |
) |
class LogitsBiasProcessor(LogitsProcessor): |
def __init__(self, logit_bias={}): |
self.logit_bias = logit_bias |
if self.logit_bias: |
self.keys = list([int(key) for key in self.logit_bias.keys()]) |
values = [self.logit_bias[str(key)] for key in self.keys] |
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device) |
debug_msg(f"{self})") |
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: |
if self.logit_bias: |
debug_msg(logits[0, self.keys], " + ", self.values) |
logits[0, self.keys] += self.values |
debug_msg(" --> ", logits[0, self.keys]) |
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0]))) |
return logits |
def __repr__(self): |
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>" |
class LogprobProcessor(LogitsProcessor): |
def __init__(self, logprobs=None): |
self.logprobs = logprobs |
self.token_alternatives = {} |
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: |
if self.logprobs is not None: |
log_e_probabilities = F.log_softmax(logits, dim=1) |
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1) |
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]] |
top_probs = [float(x) for x in top_values[0]] |
self.token_alternatives = dict(zip(top_tokens, top_probs)) |
debug_msg(repr(self)) |
return logits |
def __repr__(self): |
return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>" |
def convert_logprobs_to_tiktoken(model, logprobs): |
return logprobs |
def process_parameters(body, is_legacy=False): |
generate_params = body |
max_tokens_str = 'length' if is_legacy else 'max_tokens' |
generate_params['max_new_tokens'] = body.pop(max_tokens_str) |
if generate_params['truncation_length'] == 0: |
generate_params['truncation_length'] = shared.settings['truncation_length'] |
if generate_params['temperature'] == 0: |
generate_params['do_sample'] = False |
generate_params['top_k'] = 1 |
if body['preset'] is not None: |
preset = load_preset_memoized(body['preset']) |
generate_params.update(preset) |
generate_params['custom_stopping_strings'] = [] |
if 'stop' in body: |
if isinstance(body['stop'], str): |
generate_params['custom_stopping_strings'] = [body['stop']] |
elif isinstance(body['stop'], list): |
generate_params['custom_stopping_strings'] = body['stop'] |
logits_processor = [] |
logit_bias = body.get('logit_bias', None) |
if logit_bias: |
logits_processor = [LogitsBiasProcessor(logit_bias)] |
logprobs = None |
if 'logprobs' in body: |
logprobs = body.get('logprobs', 0) |
generate_params['logprob_proc'] = LogprobProcessor(logprobs) |
logits_processor.extend([generate_params['logprob_proc']]) |
else: |
logprobs = None |
if logits_processor: |
generate_params['logits_processor'] = LogitsProcessorList(logits_processor) |
return generate_params |
def convert_history(history): |
''' |
Chat histories in this program are in the format [message, reply]. |
This function converts OpenAI histories to that format. |
''' |
chat_dialogue = [] |
current_message = "" |
current_reply = "" |
user_input = "" |
system_message = "" |
if any('content' in entry and isinstance(entry['content'], list) for entry in history): |
new_history = [] |
for entry in history: |
if isinstance(entry['content'], list): |
image_url = None |
content = None |
for item in entry['content']: |
if not isinstance(item, dict): |
continue |
if item['type'] == 'image_url' and isinstance(item['image_url'], dict): |
image_url = item['image_url']['url'] |
elif item['type'] == 'text' and isinstance(item['text'], str): |
content = item['text'] |
if image_url and content: |
new_history.append({"image_url": image_url, "role": "user"}) |
new_history.append({"content": content, "role": "user"}) |
else: |
new_history.append(entry) |
history = new_history |
for entry in history: |
if "image_url" in entry: |
image_url = entry['image_url'] |
if "base64" in image_url: |
image_url = re.sub('^data:image/.+;base64,', '', image_url) |
img = Image.open(BytesIO(base64.b64decode(image_url))) |
else: |
try: |
my_res = requests.get(image_url) |
img = Image.open(BytesIO(my_res.content)) |
except Exception: |
raise 'Image cannot be loaded from the URL!' |
buffered = BytesIO() |
if img.mode in ("RGBA", "P"): |
img = img.convert("RGB") |
img.save(buffered, format="JPEG") |
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') |
content = f'<img src="data:image/jpeg;base64,{img_str}">' |
else: |
content = entry["content"] |
role = entry["role"] |
if role == "user": |
user_input = content |
if current_message: |
chat_dialogue.append([current_message, '']) |
current_message = "" |
current_message = content |
elif role == "assistant": |
current_reply = content |
if current_message: |
chat_dialogue.append([current_message, current_reply]) |
current_message = "" |
current_reply = "" |
else: |
chat_dialogue.append(['', current_reply]) |
elif role == "system": |
system_message = content |
return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)} |
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -> dict: |
if body.get('functions', []): |
raise InvalidRequestError(message="functions is not supported.", param='functions') |
if body.get('function_call', ''): |
raise InvalidRequestError(message="function_call is not supported.", param='function_call') |
if 'messages' not in body: |
raise InvalidRequestError(message="messages is required", param='messages') |
messages = body['messages'] |
for m in messages: |
if 'role' not in m: |
raise InvalidRequestError(message="messages: missing role", param='messages') |
elif m['role'] == 'function': |
raise InvalidRequestError(message="role: function is not supported.", param='messages') |
if 'content' not in m and "image_url" not in m: |
raise InvalidRequestError(message="messages: missing content", param='messages') |
object_type = 'chat.completions' if not stream else 'chat.completions.chunk' |
created_time = int(time.time()) |
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000)) |
resp_list = 'data' if is_legacy else 'choices' |
generate_params = process_parameters(body, is_legacy=is_legacy) |
continue_ = body['continue_'] |
if body['instruction_template_str']: |
instruction_template_str = body['instruction_template_str'] |
elif body['instruction_template']: |
instruction_template = body['instruction_template'] |
instruction_template = "Alpaca" if instruction_template == "None" else instruction_template |
instruction_template_str = load_instruction_template_memoized(instruction_template) |
else: |
instruction_template_str = shared.settings['instruction_template_str'] |
chat_template_str = body['chat_template_str'] or shared.settings['chat_template_str'] |
chat_instruct_command = body['chat_instruct_command'] or shared.settings['chat-instruct_command'] |
character = body['character'] or shared.settings['character'] |
character = "Assistant" if character == "None" else character |
name1 = body['user_name'] or shared.settings['name1'] |
name1, name2, _, greeting, context = load_character_memoized(character, name1, '') |
name2 = body['bot_name'] or name2 |
context = body['context'] or context |
greeting = body['greeting'] or greeting |
user_input, custom_system_message, history = convert_history(messages) |
generate_params.update({ |
'mode': body['mode'], |
'name1': name1, |
'name2': name2, |
'context': context, |
'greeting': greeting, |
'instruction_template_str': instruction_template_str, |
'custom_system_message': custom_system_message, |
'chat_template_str': chat_template_str, |
'chat-instruct_command': chat_instruct_command, |
'history': history, |
'stream': stream |
}) |
max_tokens = generate_params['max_new_tokens'] |
if max_tokens in [None, 0]: |
generate_params['max_new_tokens'] = 512 |
generate_params['auto_max_new_tokens'] = True |
requested_model = generate_params.pop('model') |
logprob_proc = generate_params.pop('logprob_proc', None) |
def chat_streaming_chunk(content): |
chunk = { |
"id": cmpl_id, |
"object": object_type, |
"created": created_time, |
"model": shared.model_name, |
resp_list: [{ |
"index": 0, |
"finish_reason": None, |
"message": {'role': 'assistant', 'content': content}, |
"delta": {'role': 'assistant', 'content': content}, |
}], |
} |
if logprob_proc: |
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) |
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} |
return chunk |
if stream: |
yield chat_streaming_chunk('') |
prompt = generate_chat_prompt(user_input, generate_params) |
token_count = len(encode(prompt)[0]) |
debug_msg({'prompt': prompt, 'generate_params': generate_params}) |
generator = generate_chat_reply( |
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False) |
answer = '' |
seen_content = '' |
completion_token_count = 0 |
for a in generator: |
answer = a['internal'][-1][1] |
if stream: |
len_seen = len(seen_content) |
new_content = answer[len_seen:] |
if not new_content or chr(0xfffd) in new_content: |
continue |
seen_content = answer |
chunk = chat_streaming_chunk(new_content) |
yield chunk |
completion_token_count = len(encode(answer)[0]) |
stop_reason = "stop" |
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: |
stop_reason = "length" |
if stream: |
chunk = chat_streaming_chunk('') |
chunk[resp_list][0]['finish_reason'] = stop_reason |
chunk['usage'] = { |
"prompt_tokens": token_count, |
"completion_tokens": completion_token_count, |
"total_tokens": token_count + completion_token_count |
} |
yield chunk |
else: |
resp = { |
"id": cmpl_id, |
"object": object_type, |
"created": created_time, |
"model": shared.model_name, |
resp_list: [{ |
"index": 0, |
"finish_reason": stop_reason, |
"message": {"role": "assistant", "content": answer} |
}], |
"usage": { |
"prompt_tokens": token_count, |
"completion_tokens": completion_token_count, |
"total_tokens": token_count + completion_token_count |
} |
} |
if logprob_proc: |
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) |
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} |
yield resp |
def completions_common(body: dict, is_legacy: bool = False, stream=False): |
object_type = 'text_completion.chunk' if stream else 'text_completion' |
created_time = int(time.time()) |
cmpl_id = "conv-%d" % (int(time.time() * 1000000000)) |
resp_list = 'data' if is_legacy else 'choices' |
prompt_str = 'context' if is_legacy else 'prompt' |
if prompt_str not in body: |
raise InvalidRequestError("Missing required input", param=prompt_str) |
generate_params = process_parameters(body, is_legacy=is_legacy) |
max_tokens = generate_params['max_new_tokens'] |
generate_params['stream'] = stream |
requested_model = generate_params.pop('model') |
logprob_proc = generate_params.pop('logprob_proc', None) |
suffix = body['suffix'] if body['suffix'] else '' |
echo = body['echo'] |
if not stream: |
prompt_arg = body[prompt_str] |
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)): |
prompt_arg = [prompt_arg] |
resp_list_data = [] |
total_completion_token_count = 0 |
total_prompt_token_count = 0 |
for idx, prompt in enumerate(prompt_arg, start=0): |
if isinstance(prompt[0], int): |
if requested_model == shared.model_name: |
prompt = decode(prompt)[0] |
else: |
try: |
encoder = tiktoken.encoding_for_model(requested_model) |
prompt = encoder.decode(prompt) |
except KeyError: |
prompt = decode(prompt)[0] |
prefix = prompt if echo else '' |
token_count = len(encode(prompt)[0]) |
total_prompt_token_count += token_count |
debug_msg({'prompt': prompt, 'generate_params': generate_params}) |
generator = generate_reply(prompt, generate_params, is_chat=False) |
answer = '' |
for a in generator: |
answer = a |
completion_token_count = len(encode(answer)[0]) |
total_completion_token_count += completion_token_count |
stop_reason = "stop" |
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: |
stop_reason = "length" |
respi = { |
"index": idx, |
"finish_reason": stop_reason, |
"text": prefix + answer + suffix, |
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, |
} |
resp_list_data.extend([respi]) |
resp = { |
"id": cmpl_id, |
"object": object_type, |
"created": created_time, |
"model": shared.model_name, |
resp_list: resp_list_data, |
"usage": { |
"prompt_tokens": total_prompt_token_count, |
"completion_tokens": total_completion_token_count, |
"total_tokens": total_prompt_token_count + total_completion_token_count |
} |
} |
yield resp |
else: |
prompt = body[prompt_str] |
if isinstance(prompt, list): |
if prompt and isinstance(prompt[0], int): |
try: |
encoder = tiktoken.encoding_for_model(requested_model) |
prompt = encoder.decode(prompt) |
except KeyError: |
prompt = decode(prompt)[0] |
else: |
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) |
prefix = prompt if echo else '' |
token_count = len(encode(prompt)[0]) |
def text_streaming_chunk(content): |
chunk = { |
"id": cmpl_id, |
"object": object_type, |
"created": created_time, |
"model": shared.model_name, |
resp_list: [{ |
"index": 0, |
"finish_reason": None, |
"text": content, |
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, |
}], |
} |
return chunk |
yield text_streaming_chunk(prefix) |
debug_msg({'prompt': prompt, 'generate_params': generate_params}) |
generator = generate_reply(prompt, generate_params, is_chat=False) |
answer = '' |
seen_content = '' |
completion_token_count = 0 |
for a in generator: |
answer = a |
len_seen = len(seen_content) |
new_content = answer[len_seen:] |
if not new_content or chr(0xfffd) in new_content: |
continue |
seen_content = answer |
chunk = text_streaming_chunk(new_content) |
yield chunk |
completion_token_count = len(encode(answer)[0]) |
stop_reason = "stop" |
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: |
stop_reason = "length" |
chunk = text_streaming_chunk(suffix) |
chunk[resp_list][0]["finish_reason"] = stop_reason |
chunk["usage"] = { |
"prompt_tokens": token_count, |
"completion_tokens": completion_token_count, |
"total_tokens": token_count + completion_token_count |
} |
yield chunk |
def chat_completions(body: dict, is_legacy: bool = False) -> dict: |
generator = chat_completions_common(body, is_legacy, stream=False) |
return deque(generator, maxlen=1).pop() |
def stream_chat_completions(body: dict, is_legacy: bool = False): |
for resp in chat_completions_common(body, is_legacy, stream=True): |
yield resp |
def completions(body: dict, is_legacy: bool = False) -> dict: |
generator = completions_common(body, is_legacy, stream=False) |
return deque(generator, maxlen=1).pop() |
def stream_completions(body: dict, is_legacy: bool = False): |
for resp in completions_common(body, is_legacy, stream=True): |
yield resp |