DEFAULT_SYSTEM_PROMPT = """
Below is a sentence. Identify the topic of the sentence in one word.
""".strip()
def generate_prompt(
conversation: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT
) -> str:
return f"""### Instruction: {system_prompt}
### Input:
{conversation.strip()}
### Response:
""".strip
def find_topic(model, text: str):
inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
inputs_length = len(inputs["input_ids"][0])
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.0001)
return tokenizer.decode(outputs[0][inputs_length:], skip_special_tokens=True)
topic = find_topic(model, generate_prompt("I am attending some calsses to learn math and physics"))
pprint((topic.split('### Response:')[0].strip()))
'Education & Reference'