File size: 2,721 Bytes
6b608bb
 
8ea22c2
6b608bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

#A demo of a side functionality of Guillaume-Tell: guessing whether the question should open up a source retrieval pipeline.
#The function should return a structured answer in json with two components:
##A short analysis with reasoning.
##A boolean answer in French ("oui" or "non")

#Notice that json generation with LLM is still challenging due to unpredictable behavior.
#Some library like marginalia ensures the output will always be json compliant: https://github.com/Pleias/marginalia

#A typical exemple:

#{
#   "analysis":"La question concerne un formulaire spécifique, le formulaire A36. Il est donc probable que des références encyclopédiques soient nécessaires pour fournir des informations précises sur ce formulaire.",
#   "result":"oui"
#}


import sys, os
from pprint import pprint
from jinja2 import Environment, FileSystemLoader, meta
import yaml

import pandas as pd
from vllm import LLM, SamplingParams


sys.path.append(".")
os.chdir(os.path.dirname(os.path.abspath(__file__)))

#Specific function to deal with json format.
def get_llm_response(prompt_template, sampling_params):
  prompts = [prompt_template]
  outputs = llm.generate(prompts, sampling_params, use_tqdm = False)
  generated_text = outputs[0].outputs[0].text
  if generated_text[-1] != "}":
    generated_text = generated_text + "}"
  prompt = prompt_template + generated_text
  return prompt, generated_text

if __name__ == "__main__":

    with open('prompt_config.yaml') as f:
        config = yaml.safe_load(f)

    print("prompt format:", config.get("prompt_format"))
    print(config)
    print()
    for prompt in config["prompts"]:
        if prompt["mode"] == "analysis":
            print(f'--- prompt mode: {prompt["mode"]} ---')
            env = Environment(loader=FileSystemLoader("."))
            template = env.get_template(prompt["template"])

            source = template.environment.loader.get_source(template.environment, template.name)
            variables = meta.find_undeclared_variables(env.parse(source[0]))

            print("variables:", variables)
            print("---")

            data = {"query": "Comment obtenir le formulaire A36 ?"}
            if "system_prompt" in variables:
                data["system_prompt"] = prompt["system_prompt"]

            rendered_template = template.render(**data)
            print(rendered_template)
            print("---")

            llm = LLM("mistral-mfs-reference-2/mistral-mfs-reference-2")

            sampling_params = SamplingParams(temperature=0.2, top_p=0.95, max_tokens=300, stop="}")
            
            prompt, generated_text = get_llm_response(rendered_template, sampling_params)
            print("Albert : ", generated_text)