taras5500 commited on
Commit
3a0bf59
1 Parent(s): 1d14120

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +44 -0
models.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+
3
+ class LlmBot():
4
+ def __init__(self, model):
5
+ self.client = InferenceClient(model)
6
+
7
+ def character_prompt(self, dict, max_new_tokens):
8
+ system_prompt = '<SYSTEM> <'
9
+ if dict["name"]:
10
+ system_prompt += f'the person whose name :{dict["name"]}.'
11
+ if dict["description"]:
12
+ system_prompt += 'Your description :{dict["description"]}.'
13
+ if dict["user_name"]:
14
+ system_prompt += f'users name :{dict["user_name"]}.'
15
+ system_prompt += 'Do not add the greeting, only at the first request.'
16
+ system_prompt += 'Be emotional in your responses.'
17
+ system_prompt += 'Do not include your own name in any responses.'
18
+ system_prompt += f'ensure responses are shorter than {max_new_tokens} tokens.>'
19
+ return system_prompt
20
+
21
+ def format_prompt(self, prompt, history, system_setting):
22
+ formatted_prompt = "<history>"
23
+ for user_prompt, bot_response in history:
24
+ formatted_prompt += f"[INST] {user_prompt} [/INST] {bot_response} </history> "
25
+ formatted_prompt += f"[INST] {system_setting}, <user>{prompt}</user> [/INST]"
26
+ return formatted_prompt
27
+
28
+ def call(self, prompt, history, name, description, user_name, max_new_tokens):
29
+ generate_kwargs = dict(
30
+ temperature=0.9,
31
+ max_new_tokens=max_new_tokens,
32
+ top_p=0.95,
33
+ repetition_penalty=1.0,
34
+ do_sample=True,
35
+ )
36
+ system_setting = self.character_prompt({"name": name, "description": description, "user_name": user_name}, max_new_tokens)
37
+ formatted_prompt = self.format_prompt(prompt, history, system_setting)
38
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
39
+ output = ""
40
+ for response in stream:
41
+ output += response.token.text
42
+ yield output
43
+ return output
44
+