hugo1234 commited on
Commit
0d694c7
1 Parent(s): 79b351c

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +200 -0
utils.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
3
+ import logging
4
+ import json
5
+ import os
6
+ import datetime
7
+ import hashlib
8
+ import csv
9
+ import requests
10
+ import re
11
+ import html
12
+ import torch
13
+ import sys
14
+ import gc
15
+ from pygments.lexers import guess_lexer, ClassNotFound
16
+ import gradio as gr
17
+ from pygments import highlight
18
+ from pygments.lexers import guess_lexer,get_lexer_by_name
19
+ from pygments.formatters import HtmlFormatter
20
+ import transformers
21
+ from transformers import AutoTokenizer, AutoModelForCausalLM
22
+
23
+
24
+ def reset_state():
25
+ return [], [], "Reset Done"
26
+
27
+ def reset_textbox():
28
+ return gr.update(value=""),""
29
+
30
+ def cancel_outputing():
31
+ return "Stop Done"
32
+
33
+ def transfer_input(inputs):
34
+ textbox = reset_textbox()
35
+ return (
36
+ inputs,
37
+ gr.update(value=""),
38
+ gr.Button.update(visible=True),
39
+ )
40
+
41
+ def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
42
+ for stop_word in stop_words:
43
+ if s.endswith(stop_word):
44
+ return True
45
+ for i in range(1, len(stop_word)):
46
+ if s.endswith(stop_word[:i]):
47
+ return True
48
+ return False
49
+
50
+ def generate_prompt_with_history(text, history, tokenizer, max_length=2048):
51
+ prompt = "The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n[|Human|]Hello!\n[|AI|]Hi!"
52
+ history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0],x[1]) for x in history]
53
+ history.append("\n[|Human|]{}\n[|AI|]".format(text))
54
+ history_text = ""
55
+ flag = False
56
+ for x in history[::-1]:
57
+ if tokenizer(prompt+history_text+x, return_tensors="pt")['input_ids'].size(-1) <= max_length:
58
+ history_text = x + history_text
59
+ flag = True
60
+ else:
61
+ break
62
+ if flag:
63
+ return prompt+history_text,tokenizer(prompt+history_text, return_tensors="pt")
64
+ else:
65
+ return None
66
+
67
+
68
+
69
+ #tokenizer = AutoTokenizer.from_pretrained("project-baize/baize-v2-7b")
70
+ #model = AutoModelForCausalLM.from_pretrained("project-baize/baize-v2-7b")
71
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
72
+ model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
73
+
74
+
75
+ def load_tokenizer_and_model(base_model,load_8bit=False):
76
+ if torch.cuda.is_available():
77
+ device = "cuda"
78
+ else:
79
+ device = "cpu"
80
+
81
+ tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast = False)
82
+ if device == "cuda":
83
+ model = AutoModelForCausalLM.from_pretrained(
84
+ base_model,
85
+ #load_in_8bit=load_8bit,
86
+ #torch_dtype=torch.float16,
87
+ device_map="auto",
88
+ )
89
+ else:
90
+ model = AutoModelForCausalLM.from_pretrained(
91
+ base_model, device_map={"": device}, low_cpu_mem_usage=True
92
+ )
93
+
94
+ #if not load_8bit:
95
+ #model.half() # seems to fix bugs for some users.
96
+
97
+ model.eval()
98
+ return tokenizer,model,device
99
+
100
+ # Greedy Search
101
+ def greedy_search(input_ids: torch.Tensor,
102
+ model: torch.nn.Module,
103
+ tokenizer: transformers.PreTrainedTokenizer,
104
+ stop_words: list,
105
+ max_length: int,
106
+ temperature: float = 1.0,
107
+ top_p: float = 1.0,
108
+ top_k: int = 25) -> Iterator[str]:
109
+ generated_tokens = []
110
+ past_key_values = None
111
+ current_length = 1
112
+ for i in range(max_length):
113
+ with torch.no_grad():
114
+ if past_key_values is None:
115
+ outputs = model(input_ids)
116
+ else:
117
+ outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
118
+ logits = outputs.logits[:, -1, :]
119
+ past_key_values = outputs.past_key_values
120
+
121
+ # apply temperature
122
+ logits /= temperature
123
+
124
+ probs = torch.softmax(logits, dim=-1)
125
+ # apply top_p
126
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
127
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
128
+ mask = probs_sum - probs_sort > top_p
129
+ probs_sort[mask] = 0.0
130
+
131
+ # apply top_k
132
+ #if top_k is not None:
133
+ # probs_sort1, _ = torch.topk(probs_sort, top_k)
134
+ # min_top_probs_sort = torch.min(probs_sort1, dim=-1, keepdim=True).values
135
+ # probs_sort = torch.where(probs_sort < min_top_probs_sort, torch.full_like(probs_sort, float(0.0)), probs_sort)
136
+
137
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
138
+ next_token = torch.multinomial(probs_sort, num_samples=1)
139
+ next_token = torch.gather(probs_idx, -1, next_token)
140
+
141
+ input_ids = torch.cat((input_ids, next_token), dim=-1)
142
+
143
+ generated_tokens.append(next_token[0].item())
144
+ text = tokenizer.decode(generated_tokens)
145
+
146
+ yield text
147
+ if any([x in text for x in stop_words]):
148
+ del past_key_values
149
+ del logits
150
+ del probs
151
+ del probs_sort
152
+ del probs_idx
153
+ del probs_sum
154
+ gc.collect()
155
+ return
156
+
157
+ def convert_to_markdown(text):
158
+ text = text.replace("$","&#36;")
159
+ def replace_leading_tabs_and_spaces(line):
160
+ new_line = []
161
+
162
+ for char in line:
163
+ if char == "\t":
164
+ new_line.append("&#9;")
165
+ elif char == " ":
166
+ new_line.append("&nbsp;")
167
+ else:
168
+ break
169
+ return "".join(new_line) + line[len(new_line):]
170
+
171
+ markdown_text = ""
172
+ lines = text.split("\n")
173
+ in_code_block = False
174
+
175
+ for line in lines:
176
+ if in_code_block is False and line.startswith("```"):
177
+ in_code_block = True
178
+ markdown_text += f"{line}\n"
179
+ elif in_code_block is True and line.startswith("```"):
180
+ in_code_block = False
181
+ markdown_text += f"{line}\n"
182
+ elif in_code_block:
183
+ markdown_text += f"{line}\n"
184
+ else:
185
+ line = replace_leading_tabs_and_spaces(line)
186
+ line = re.sub(r"^(#)", r"\\\1", line)
187
+ markdown_text += f"{line} \n"
188
+
189
+ return markdown_text
190
+
191
+
192
+ class State:
193
+ interrupted = False
194
+
195
+ def interrupt(self):
196
+ self.interrupted = True
197
+
198
+ def recover(self):
199
+ self.interrupted = False
200
+ shared_state = State()