mrfakename commited on
Commit
37cccbf
1 Parent(s): ba60fdb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ #######################
4
+ '''
5
+ Name: Phine Inference
6
+ License: MIT
7
+ '''
8
+ #######################
9
+
10
+
11
+ ##### Dependencies
12
+
13
+ """ IMPORTANT: Uncomment the following line if you are in a Colab/Notebook environment """
14
+
15
+ #!pip install gradio einops accelerate bitsandbytes transformers
16
+
17
+ #####
18
+
19
+ import gradio as gr
20
+ import transformers
21
+ from transformers import AutoTokenizer, AutoModelForCausalLM
22
+ import torch
23
+ import random
24
+ import re
25
+
26
+ def cut_text_after_last_token(text, token):
27
+
28
+ last_occurrence = text.rfind(token)
29
+
30
+ if last_occurrence != -1:
31
+ result = text[last_occurrence + len(token):].strip()
32
+ return result
33
+ else:
34
+ return None
35
+
36
+
37
+ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
38
+
39
+ def __init__(self, sentinel_token_ids: torch.LongTensor,
40
+ starting_idx: int):
41
+ transformers.StoppingCriteria.__init__(self)
42
+ self.sentinel_token_ids = sentinel_token_ids
43
+ self.starting_idx = starting_idx
44
+
45
+ def __call__(self, input_ids: torch.LongTensor,
46
+ _scores: torch.FloatTensor) -> bool:
47
+ for sample in input_ids:
48
+ trimmed_sample = sample[self.starting_idx:]
49
+
50
+ if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
51
+ continue
52
+
53
+ for window in trimmed_sample.unfold(
54
+ 0, self.sentinel_token_ids.shape[-1], 1):
55
+ if torch.all(torch.eq(self.sentinel_token_ids, window)):
56
+ return True
57
+ return False
58
+
59
+
60
+
61
+
62
+
63
+ model_path = 'freecs/phine-2-v0'
64
+
65
+ device = "cuda"
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
68
+
69
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=False, torch_dtype=torch.float16).to(device) #remove .to() if load_in_4/8bit = True
70
+
71
+ sys_message = "You are an AI assistant named Phine developed by FreeCS.org. You are polite and smart." #System Message
72
+ @spaces.GPU
73
+ def phine(message, history, temperature, top_p, top_k, repetition_penalty):
74
+
75
+
76
+
77
+ n = 0
78
+ context = ""
79
+ if history and len(history) > 0:
80
+
81
+ for x in history:
82
+ for h in x:
83
+ if n%2 == 0:
84
+ context+=f"""\n<|prompt|>{h}\n"""
85
+ else:
86
+ context+=f"""<|response|>{h}"""
87
+ n+=1
88
+ else:
89
+
90
+ context = ""
91
+
92
+
93
+
94
+
95
+ prompt = f"""\n<|system|>{sys_message}"""+context+"\n<|prompt|>"+message+"<|endoftext|>\n<|response|>"
96
+ tokenized = tokenizer(prompt, return_tensors="pt").to(device)
97
+
98
+
99
+ stopping_criteria_list = transformers.StoppingCriteriaList([
100
+ _SentinelTokenStoppingCriteria(
101
+ sentinel_token_ids=tokenizer(
102
+ "<|endoftext|>",
103
+ add_special_tokens=False,
104
+ return_tensors="pt",
105
+ ).input_ids.to(device),
106
+ starting_idx=tokenized.input_ids.shape[-1])
107
+ ])
108
+
109
+
110
+ token = model.generate(**tokenized,
111
+ stopping_criteria=stopping_criteria_list,
112
+ do_sample=True,
113
+ max_length=2048, temperature=temperature, top_p=top_p, top_k = top_k, repetition_penalty = repetition_penalty
114
+ )
115
+
116
+ completion = tokenizer.decode(token[0], skip_special_tokens=False)
117
+ token = "<|response|>"
118
+ res = cut_text_after_last_token(completion, token)
119
+ return res.replace('<|endoftext|>', '')
120
+ demo = gr.ChatInterface(phine,
121
+ additional_inputs=[
122
+ gr.Slider(0.1, 2.0, label="temperature", value=0.5),
123
+ gr.Slider(0.1, 2.0, label="Top P", value=0.9),
124
+ gr.Slider(1, 500, label="Top K", value=50),
125
+ gr.Slider(0.1, 2.0, label="Repetition Penalty", value=1.15)
126
+ ]
127
+ )
128
+
129
+ if __name__ == "__main__":
130
+ demo.queue().launch(share=True, debug=True) #If debug=True causes problems you can set it to False