guillaumetell-7b / prompt_demo_inference.py
Pclanglais's picture
Create prompt_demo_inference.py
371b665 verified
raw history blame
No virus
3.44 kB
#Full demo of the Guillaume-Tell reference model with two references.
#Guillaume-Tell is currently trained by default on five references but future version will enhance the flexibility of the model.L
#Example of generated text:
#Il est difficile de déterminer le meilleur moyen de cuire une blanquette avec les informations disponibles.
#Cependant, voici un résumé des références fournies:
#La blanquette peut être préparée avec du beurre et une sauce épaisse pour rendre le plat plus savoureux<ref text="Moi j'aime la blanquette avec du beurre dedans Et une sauce bien épaisse.">hash49080805</ref>.
#Une autre méthode possible consiste à faire chauffer la blanquette à feu doux pendant 46 heures<ref text="(Recette de blanquette : faîtes chauffer la blanquette à feu doux pendant 46 heures.)">hash49080806</ref>.
#Ces deux références ne permettent pas de donner une réponse définitive sur le meilleur moyen de cuire une blanquette.
import sys, os
from pprint import pprint
from jinja2 import Environment, FileSystemLoader, meta
import yaml
import pandas as pd
from vllm import LLM, SamplingParams
sys.path.append(".")
os.chdir(os.path.dirname(os.path.abspath(__file__)))
def get_llm_response(prompt_template):
sampling_params = SamplingParams(temperature=0.4, top_p=.95, max_tokens=2000, presence_penalty = 2)
prompts = [prompt_template]
outputs = llm.generate(prompts, sampling_params, use_tqdm = False)
generated_text = outputs[0].outputs[0].text
prompt = prompt_template + generated_text
return prompt, generated_text
if __name__ == "__main__":
with open('prompt_config.yaml') as f:
config = yaml.safe_load(f)
print("prompt format:", config.get("prompt_format"))
print(config)
print()
for prompt in config["prompts"]:
print(f'--- prompt mode: {prompt["mode"]} ---')
env = Environment(loader=FileSystemLoader("."))
template = env.get_template(prompt["template"])
source = template.environment.loader.get_source(template.environment, template.name)
variables = meta.find_undeclared_variables(env.parse(source[0]))
print("variables:", variables)
print("---")
data = {
"query": "Quel est le meilleur moyen de cuire une blanquette?",
"chunks" : [
{
"url": "http://data.gouv.fr",
"h": "hash49080805",
"title": "A chunk title",
"text": "Moi j'aime la blanquette avec du beurre dedans\nEt une sauce bien épaisse.",
},
{
"url": "http://...",
"h": "hash49080806",
"title": "A chunk title",
"text": "text texs\ntext again ",
"context": "Recette de blanquette : faîtes chauffer la blanquette à feu doux pendant 46 heures."
},
]
}
if "system_prompt" in variables:
data["system_prompt"] = prompt["system_prompt"]
rendered_template = template.render(**data)
print(rendered_template)
print("---")
llm = LLM("mistral-mfs-reference/mistral-mfs-reference")
sampling_params = SamplingParams(temperature=0.4, top_p=0.95, max_tokens=1500)
prompt, generated_text = get_llm_response(rendered_template)
print("Albert : ", generated_text)