|
from text_generation import InferenceAPIClient |
|
from huggingface_hub import login |
|
|
|
|
|
class chat_model(): |
|
def __init__(self, hf_token, mname: str = "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"): |
|
login(token=hf_token, add_to_git_credential=True, write_permission=True) |
|
self.chat_model = InferenceAPIClient(mname, timeout=90) |
|
self.context = "" |
|
self.number = 0 |
|
|
|
def query_bot(self, text: str, BOT_SYSTEM_INSTRUCTION: str = '', SYSTEM_REFRESH: str = ''): |
|
MAX_CONTEXT_LENGTH = 1000 |
|
MAX_TOTAL_LENGTH = 1512 |
|
print("Usuario: " + text) |
|
|
|
prompt_text = '''<|prompter|>{}<|endoftext|> |
|
<|assistant|>'''.format(f"{SYSTEM_REFRESH} {text}") |
|
system = '''<|prompter|>{}<|endoftext|> |
|
<|assistant|>Yes, I got.<|endoftext|> |
|
'''.format(BOT_SYSTEM_INSTRUCTION) |
|
context = ( |
|
f"{self.context}{system}{prompt_text}")[-MAX_CONTEXT_LENGTH:] |
|
|
|
max_new_tokens = min(MAX_TOTAL_LENGTH - len(context), 1200) |
|
print(max_new_tokens) |
|
try: |
|
inputs = self.chat_model.generate( |
|
context, max_new_tokens=max_new_tokens, |
|
temperature=0.8, truncate=1000, |
|
do_sample=True, |
|
repetition_penalty=1.2, |
|
top_p=0.9).generated_text |
|
except Exception as err: |
|
print(err) |
|
context = prompt_text |
|
inputs = self.chat_model.generate( |
|
context, max_new_tokens=max_new_tokens, |
|
temperature=0.8, truncate=1000, |
|
do_sample=True, |
|
repetition_penalty=1.2, |
|
top_p=0.9).generated_text |
|
finally: |
|
self.context = (f"{context} {inputs}")[-MAX_CONTEXT_LENGTH:] |
|
self.number += 1 |
|
return inputs |
|
|