guillaumetell-7b / prompt_demo_analysis.py
Camille1905's picture
Update prompt_demo_analysis.py
7f61ffc verified
#A demo of a side functionality of Guillaume-Tell: guessing whether the question should open up a source retrieval pipeline.
#The function should return a structured answer in json with two components:
##A short analysis with reasoning.
##A boolean answer in French ("oui" or "non")
import sys, os
from pprint import pprint
from jinja2 import Environment, FileSystemLoader, meta
import yaml
import pandas as pd
from vllm import LLM, SamplingParams
sys.path.append(".")
os.chdir(os.path.dirname(os.path.abspath(__file__)))
#Specific function to deal with json format.
def get_llm_response(prompt_template, sampling_params):
prompts = [prompt_template]
outputs = llm.generate(prompts, sampling_params, use_tqdm = False)
generated_text = outputs[0].outputs[0].text
if generated_text[-1] != "}":
generated_text = generated_text + "}"
prompt = prompt_template + generated_text
return prompt, generated_text
if __name__ == "__main__":
with open('prompt_config.yaml') as f:
config = yaml.safe_load(f)
print("prompt format:", config.get("prompt_format"))
print(config)
print()
for prompt in config["prompts"]:
if prompt["mode"] == "analysis":
print(f'--- prompt mode: {prompt["mode"]} ---')
env = Environment(loader=FileSystemLoader("."))
template = env.get_template(prompt["template"])
source = template.environment.loader.get_source(template.environment, template.name)
variables = meta.find_undeclared_variables(env.parse(source[0]))
print("variables:", variables)
print("---")
data = {"query": "Comment obtenir le formulaire A36 ?"}
if "system_prompt" in variables:
data["system_prompt"] = prompt["system_prompt"]
rendered_template = template.render(**data)
print(rendered_template)
print("---")
llm = LLM("mistral-mfs-reference-2/mistral-mfs-reference-2")
sampling_params = SamplingParams(temperature=0.2, top_p=0.95, max_tokens=300, stop="}")
prompt, generated_text = get_llm_response(rendered_template, sampling_params)
print("Albert : ", generated_text)