File size: 2,000 Bytes
81d9a9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from text_generation import InferenceAPIClient
from huggingface_hub import login
class chat_model():
def __init__(self, hf_token, mname: str = "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
login(token=hf_token, add_to_git_credential=True, write_permission=True)
self.chat_model = InferenceAPIClient(mname, timeout=90)
self.context = ""
self.number = 0
def query_bot(self, text: str, BOT_SYSTEM_INSTRUCTION: str = '', SYSTEM_REFRESH: str = ''):
MAX_CONTEXT_LENGTH = 1000
MAX_TOTAL_LENGTH = 1512
print("Usuario: " + text)
# Agregar el promptText al contexto y truncar si es necesario
prompt_text = '''<|prompter|>{}<|endoftext|>
<|assistant|>'''.format(f"{SYSTEM_REFRESH} {text}")
system = '''<|prompter|>{}<|endoftext|>
<|assistant|>Yes, I got.<|endoftext|>
'''.format(BOT_SYSTEM_INSTRUCTION)
context = (
f"{self.context}{system}{prompt_text}")[-MAX_CONTEXT_LENGTH:]
# Asegurarse de que la longitud del texto generado no exceda MAX_TOTAL_LENGTH
max_new_tokens = min(MAX_TOTAL_LENGTH - len(context), 1200)
print(max_new_tokens)
try:
inputs = self.chat_model.generate(
context, max_new_tokens=max_new_tokens,
temperature=0.8, truncate=1000,
do_sample=True,
repetition_penalty=1.2,
top_p=0.9).generated_text
except Exception as err:
print(err)
context = prompt_text # Resetear el contexto
inputs = self.chat_model.generate(
context, max_new_tokens=max_new_tokens,
temperature=0.8, truncate=1000,
do_sample=True,
repetition_penalty=1.2,
top_p=0.9).generated_text
finally:
self.context = (f"{context} {inputs}")[-MAX_CONTEXT_LENGTH:]
self.number += 1
return inputs
|