Pclanglais commited on
Commit
6b608bb
1 Parent(s): 5e8dd09

Create prompt_demo_analysis.py

Browse files
Files changed (1) hide show
  1. prompt_demo_analysis.py +73 -0
prompt_demo_analysis.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #A demo of a side functionality of Guillaume-Tell: guessing whether the question should open up a source retrieval pipeline.
3
+ #The function should return a structured answer in json with two component:
4
+ ##A short analysis with reasoning.
5
+ ##A boolean answer in French ("oui" or "non")
6
+
7
+ #Notice that json generation with LLM is still challenging due to unpredictable behavior.
8
+ #Some library like marginalia ensures the output will always be json compliant: https://github.com/Pleias/marginalia
9
+
10
+ #A typical exemple:
11
+
12
+ #{
13
+ # "analysis":"La question concerne un formulaire spécifique, le formulaire A36. Il est donc probable que des références encyclopédiques soient nécessaires pour fournir des informations précises sur ce formulaire.",
14
+ # "result":"oui"
15
+ #}
16
+
17
+
18
+ import sys, os
19
+ from pprint import pprint
20
+ from jinja2 import Environment, FileSystemLoader, meta
21
+ import yaml
22
+
23
+ import pandas as pd
24
+ from vllm import LLM, SamplingParams
25
+
26
+
27
+ sys.path.append(".")
28
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
29
+
30
+ #Specific function to deal with json format.
31
+ def get_llm_response(prompt_template, sampling_params):
32
+ prompts = [prompt_template]
33
+ outputs = llm.generate(prompts, sampling_params, use_tqdm = False)
34
+ generated_text = outputs[0].outputs[0].text
35
+ if generated_text[-1] != "}":
36
+ generated_text = generated_text + "}"
37
+ prompt = prompt_template + generated_text
38
+ return prompt, generated_text
39
+
40
+ if __name__ == "__main__":
41
+
42
+ with open('prompt_config.yaml') as f:
43
+ config = yaml.safe_load(f)
44
+
45
+ print("prompt format:", config.get("prompt_format"))
46
+ print(config)
47
+ print()
48
+ for prompt in config["prompts"]:
49
+ if prompt["mode"] == "analysis":
50
+ print(f'--- prompt mode: {prompt["mode"]} ---')
51
+ env = Environment(loader=FileSystemLoader("."))
52
+ template = env.get_template(prompt["template"])
53
+
54
+ source = template.environment.loader.get_source(template.environment, template.name)
55
+ variables = meta.find_undeclared_variables(env.parse(source[0]))
56
+
57
+ print("variables:", variables)
58
+ print("---")
59
+
60
+ data = {"query": "Comment obtenir le formulaire A36 ?"}
61
+ if "system_prompt" in variables:
62
+ data["system_prompt"] = prompt["system_prompt"]
63
+
64
+ rendered_template = template.render(**data)
65
+ print(rendered_template)
66
+ print("---")
67
+
68
+ llm = LLM("mistral-mfs-reference-2/mistral-mfs-reference-2")
69
+
70
+ sampling_params = SamplingParams(temperature=0.2, top_p=0.95, max_tokens=300, stop="}")
71
+
72
+ prompt, generated_text = get_llm_response(rendered_template, sampling_params)
73
+ print("Albert : ", generated_text)