Pclanglais commited on
Commit
371b665
1 Parent(s): be38993

Create prompt_demo_inference.py

Browse files
Files changed (1) hide show
  1. prompt_demo_inference.py +84 -0
prompt_demo_inference.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Full demo of the Guillaume-Tell reference model with two references.
2
+ #Guillaume-Tell is currently trained by default on five references but future version will enhance the flexibility of the model.L
3
+
4
+ #Example of generated text:
5
+ #Il est difficile de déterminer le meilleur moyen de cuire une blanquette avec les informations disponibles.
6
+ #Cependant, voici un résumé des références fournies:
7
+ #La blanquette peut être préparée avec du beurre et une sauce épaisse pour rendre le plat plus savoureux<ref text="Moi j'aime la blanquette avec du beurre dedans Et une sauce bien épaisse.">hash49080805</ref>.
8
+ #Une autre méthode possible consiste à faire chauffer la blanquette à feu doux pendant 46 heures<ref text="(Recette de blanquette : faîtes chauffer la blanquette à feu doux pendant 46 heures.)">hash49080806</ref>.
9
+ #Ces deux références ne permettent pas de donner une réponse définitive sur le meilleur moyen de cuire une blanquette.
10
+
11
+
12
+ import sys, os
13
+ from pprint import pprint
14
+ from jinja2 import Environment, FileSystemLoader, meta
15
+ import yaml
16
+
17
+ import pandas as pd
18
+ from vllm import LLM, SamplingParams
19
+
20
+
21
+ sys.path.append(".")
22
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
23
+
24
+ def get_llm_response(prompt_template):
25
+ sampling_params = SamplingParams(temperature=0.4, top_p=.95, max_tokens=2000, presence_penalty = 2)
26
+ prompts = [prompt_template]
27
+ outputs = llm.generate(prompts, sampling_params, use_tqdm = False)
28
+ generated_text = outputs[0].outputs[0].text
29
+ prompt = prompt_template + generated_text
30
+ return prompt, generated_text
31
+
32
+
33
+ if __name__ == "__main__":
34
+
35
+ with open('prompt_config.yaml') as f:
36
+ config = yaml.safe_load(f)
37
+
38
+ print("prompt format:", config.get("prompt_format"))
39
+ print(config)
40
+ print()
41
+ for prompt in config["prompts"]:
42
+ print(f'--- prompt mode: {prompt["mode"]} ---')
43
+ env = Environment(loader=FileSystemLoader("."))
44
+ template = env.get_template(prompt["template"])
45
+
46
+ source = template.environment.loader.get_source(template.environment, template.name)
47
+ variables = meta.find_undeclared_variables(env.parse(source[0]))
48
+
49
+ print("variables:", variables)
50
+ print("---")
51
+
52
+ data = {
53
+ "query": "Quel est le meilleur moyen de cuire une blanquette?",
54
+ "chunks" : [
55
+ {
56
+ "url": "http://data.gouv.fr",
57
+ "h": "hash49080805",
58
+ "title": "A chunk title",
59
+ "text": "Moi j'aime la blanquette avec du beurre dedans\nEt une sauce bien épaisse.",
60
+ },
61
+ {
62
+ "url": "http://...",
63
+ "h": "hash49080806",
64
+ "title": "A chunk title",
65
+ "text": "text texs\ntext again ",
66
+ "context": "Recette de blanquette : faîtes chauffer la blanquette à feu doux pendant 46 heures."
67
+ },
68
+
69
+ ]
70
+ }
71
+
72
+ if "system_prompt" in variables:
73
+ data["system_prompt"] = prompt["system_prompt"]
74
+
75
+ rendered_template = template.render(**data)
76
+ print(rendered_template)
77
+ print("---")
78
+
79
+ llm = LLM("mistral-mfs-reference/mistral-mfs-reference")
80
+
81
+ sampling_params = SamplingParams(temperature=0.4, top_p=0.95, max_tokens=1500)
82
+
83
+ prompt, generated_text = get_llm_response(rendered_template)
84
+ print("Albert : ", generated_text)