convbot / convbot /convbot.py
freemt
Update app.py
ec9d18e
"""Generate a response."""
# pylint:disable=line-too-long, too-many-argument
import torch
from logzero import logger
from transformers import AutoModelForCausalLM, AutoTokenizer
from .force_async import force_async
# model_name = "microsoft/DialoGPT-large"
# model_name = "microsoft/DialoGPT-small"
# pylint: disable=invalid-name
model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def _convbot(
text: str,
max_length: int = 1000,
do_sample: bool = True,
top_p: float = 0.95,
top_k: int = 0,
temperature: float = 0.75,
) -> str:
"""Generate a reponse.
Args
n_retires: retry if response is "" or the same as previouse resp.
Returns
reply
"""
try:
chat_history_ids = _convbot.chat_history_ids
except AttributeError:
chat_history_ids = ""
try:
chat_history_ids = _convbot.chat_history_ids
except AttributeError:
chat_history_ids = ""
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
if isinstance(chat_history_ids, torch.Tensor):
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1)
else:
bot_input_ids = input_ids
# generate a bot response
chat_history_ids = model.generate(
bot_input_ids,
max_length=max_length,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
pad_token_id=tokenizer.eos_token_id,
)
output = tokenizer.decode(
chat_history_ids[:, bot_input_ids.shape[-1] :][0], skip_special_tokens=True
)
_convbot.chat_history_ids = chat_history_ids
return output
def convbot(
text: str,
n_retries: int = 3,
max_length: int = 1000,
do_sample: bool = True,
top_p: float = 0.95,
top_k: int = 0,
temperature: float = 0.75,
) -> str:
"""Generate a response."""
try:
n_retries = int(n_retries)
except Exception as e:
logger.error(e)
raise
try:
prev_resp = convbot.prev_resp
except AttributeError:
prev_resp = ""
resp = _convbot(text, max_length, do_sample, top_p, top_k, temperature)
# retry n_retries if resp is empty
if not resp.strip():
idx = 0
while idx < n_retries:
idx += 1
_convbot.chat_history_ids = ""
resp = _convbot(text, max_length, do_sample, top_p, top_k, temperature)
if resp.strip():
break
else:
logger.warning("bot acting up (empty response), something has gone awry")
# check repeated responses
if resp.strip() == prev_resp:
idx = 0
while idx < n_retries:
idx += 1
resp = _convbot(text, max_length, do_sample, top_p, top_k, temperature)
if resp.strip() != prev_resp:
break
else:
logger.warning("bot acting up (repeating), something has gone awry")
convbot.prev_resp = resp
return resp
@force_async
def aconvbot(
text: str,
n_retries: int = 3,
max_length: int = 1000,
do_sample: bool = True,
top_p: float = 0.95,
top_k: int = 0,
temperature: float = 0.75,
) -> str:
try:
_ = convbot(text, n_retries, max_length, do_sample, top_p, top_k, temperature)
except Exception as e:
logger.error(e)
raise
return _
def main():
print("Bot: Talk to me")
while 1:
text = input("You: ")
resp = _convbot(text)
print("Bot: ", resp)
if __name__ == "__main__":
main()