Spaces:
Running
Running
"""Generate a response.""" | |
# pylint:disable=line-too-long, too-many-argument | |
import torch | |
from logzero import logger | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from .force_async import force_async | |
# model_name = "microsoft/DialoGPT-large" | |
# model_name = "microsoft/DialoGPT-small" | |
# pylint: disable=invalid-name | |
model_name = "microsoft/DialoGPT-medium" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
def _convbot( | |
text: str, | |
max_length: int = 1000, | |
do_sample: bool = True, | |
top_p: float = 0.95, | |
top_k: int = 0, | |
temperature: float = 0.75, | |
) -> str: | |
"""Generate a reponse. | |
Args | |
n_retires: retry if response is "" or the same as previouse resp. | |
Returns | |
reply | |
""" | |
try: | |
chat_history_ids = _convbot.chat_history_ids | |
except AttributeError: | |
chat_history_ids = "" | |
try: | |
chat_history_ids = _convbot.chat_history_ids | |
except AttributeError: | |
chat_history_ids = "" | |
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt") | |
if isinstance(chat_history_ids, torch.Tensor): | |
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) | |
else: | |
bot_input_ids = input_ids | |
# generate a bot response | |
chat_history_ids = model.generate( | |
bot_input_ids, | |
max_length=max_length, | |
do_sample=do_sample, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
output = tokenizer.decode( | |
chat_history_ids[:, bot_input_ids.shape[-1] :][0], skip_special_tokens=True | |
) | |
_convbot.chat_history_ids = chat_history_ids | |
return output | |
def convbot( | |
text: str, | |
n_retries: int = 3, | |
max_length: int = 1000, | |
do_sample: bool = True, | |
top_p: float = 0.95, | |
top_k: int = 0, | |
temperature: float = 0.75, | |
) -> str: | |
"""Generate a response.""" | |
try: | |
n_retries = int(n_retries) | |
except Exception as e: | |
logger.error(e) | |
raise | |
try: | |
prev_resp = convbot.prev_resp | |
except AttributeError: | |
prev_resp = "" | |
resp = _convbot(text, max_length, do_sample, top_p, top_k, temperature) | |
# retry n_retries if resp is empty | |
if not resp.strip(): | |
idx = 0 | |
while idx < n_retries: | |
idx += 1 | |
_convbot.chat_history_ids = "" | |
resp = _convbot(text, max_length, do_sample, top_p, top_k, temperature) | |
if resp.strip(): | |
break | |
else: | |
logger.warning("bot acting up (empty response), something has gone awry") | |
# check repeated responses | |
if resp.strip() == prev_resp: | |
idx = 0 | |
while idx < n_retries: | |
idx += 1 | |
resp = _convbot(text, max_length, do_sample, top_p, top_k, temperature) | |
if resp.strip() != prev_resp: | |
break | |
else: | |
logger.warning("bot acting up (repeating), something has gone awry") | |
convbot.prev_resp = resp | |
return resp | |
def aconvbot( | |
text: str, | |
n_retries: int = 3, | |
max_length: int = 1000, | |
do_sample: bool = True, | |
top_p: float = 0.95, | |
top_k: int = 0, | |
temperature: float = 0.75, | |
) -> str: | |
try: | |
_ = convbot(text, n_retries, max_length, do_sample, top_p, top_k, temperature) | |
except Exception as e: | |
logger.error(e) | |
raise | |
return _ | |
def main(): | |
print("Bot: Talk to me") | |
while 1: | |
text = input("You: ") | |
resp = _convbot(text) | |
print("Bot: ", resp) | |
if __name__ == "__main__": | |
main() | |