georgesung commited on
Commit
8cf6e52
1 Parent(s): 7a57d84

Not using vllm

Browse files
Files changed (1) hide show
  1. app.py +71 -92
app.py CHANGED
@@ -1,99 +1,78 @@
1
- import re
2
-
3
- import gradio as gr
4
  import torch
5
- from transformers import (AutoConfig, AutoModel, AutoModelForSeq2SeqLM,
6
- AutoTokenizer, LlamaForCausalLM, LlamaTokenizer)
7
- from vllm import LLM, SamplingParams
8
-
9
- model_id = "georgesung/llama2_7b_chat_uncensored"
10
-
11
- prompt_config = {
12
- "system_header": None,
13
- "system_footer": None,
14
- "user_header": "### HUMAN:",
15
- "user_footer": None,
16
- "input_header": None,
17
- "response_header": "### RESPONSE:",
18
- }
19
-
20
- def get_llm_response_chat(prompt):
21
- outputs = llm.generate(prompt, sampling_params)
22
- output = outputs[0].outputs[0].text
23
-
24
- # Remove trailing eos token
25
- eos_token = llm.get_tokenizer().eos_token
26
- if output.endswith(eos_token):
27
- output = output[:-len(eos_token)]
28
- return output
29
-
30
- def hist_to_prompt(history):
31
- prompt = ""
32
- if prompt_config["system_header"]:
33
- system_footer = ""
34
- if prompt_config["system_footer"]:
35
- system_footer = prompt_config["system_footer"]
36
- prompt += f"{prompt_config['system_header']}\n{SYSTEM_MESSAGE}{system_footer}\n\n"
37
-
38
- for i, (human_text, bot_text) in enumerate(history):
39
- user_footer = ""
40
- if prompt_config["user_footer"]:
41
- user_footer = prompt_config["user_footer"]
42
-
43
- prompt += f"{prompt_config['user_header']}\n{human_text}{user_footer}\n\n"
44
-
45
- prompt += f"{prompt_config['response_header']}\n"
46
-
47
- if bot_text:
48
- prompt += f"{bot_text}\n\n"
49
- return prompt
50
-
51
- def get_bot_response(text):
52
- bot_text_index = text.rfind(prompt_config['response_header'])
53
- if bot_text_index != -1:
54
- text = text[bot_text_index + len(prompt_config['response_header']):].strip()
55
- return text
56
-
57
- def main():
58
- # RE llama tokenizer:
59
- # RuntimeError: Failed to load the tokenizer.
60
- # If you are using a LLaMA-based model, use 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
61
- llm = LLM(model=model_id, tokenizer='hf-internal-testing/llama-tokenizer')
62
-
63
- sampling_params = SamplingParams(temperature=0.01, top_p=0.1, top_k=40, max_tokens=2048)
64
-
65
- tokenizer = llm.get_tokenizer()
66
 
67
- with gr.Blocks() as demo:
68
- gr.Markdown(
69
- """
70
- # Let's chat
71
- """)
72
-
73
- chatbot = gr.Chatbot()
74
- msg = gr.Textbox()
75
- clear = gr.Button("Clear")
76
-
77
- def user(user_message, history):
78
- return "", history + [[user_message, None]]
79
-
80
- def bot(history):
81
- hist_text = hist_to_prompt(history)
82
-
83
- bot_message = get_llm_response_chat(hist_text) #+ tokenizer.eos_token
84
- history[-1][1] = bot_message # add bot message to overall history
85
-
86
- return history
87
 
88
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
89
- bot, chatbot, chatbot
90
- )
91
- clear.click(lambda: None, None, chatbot, queue=False)
92
 
93
- demo.queue()
94
- demo.launch()
 
95
 
 
96
 
 
 
 
 
97
 
98
- if __name__ == "__main__":
99
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline
 
 
2
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # LLM helper functions
7
+ def get_response_text(data):
8
+ text = data[0]["generated_text"]
 
9
 
10
+ assistant_text_index = text.rfind('### RESPONSE:')
11
+ if assistant_text_index != -1:
12
+ text = text[assistant_text_index+len('### RESPONSE:'):].strip()
13
 
14
+ return text
15
 
16
+ def get_llm_response(prompt, pipe):
17
+ raw_output = pipe(prompt)
18
+ text = get_response_text(raw_output)
19
+ return text
20
 
21
+ # Load LLM
22
+ model_id = "georgesung/llama2_7b_chat_uncensored"
23
+ tokenizer = LlamaTokenizer.from_pretrained(model_id)
24
+ model = LlamaForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
25
+
26
+ # Llama tokenizer missing pad token
27
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
28
+
29
+ pipe = pipeline(
30
+ "text-generation",
31
+ model=model,
32
+ tokenizer=tokenizer,
33
+ max_length=4096, # Llama-2 default context window
34
+ temperature=0.7,
35
+ top_p=0.95,
36
+ repetition_penalty=1.15
37
+ )
38
+
39
+ with gr.Blocks() as demo:
40
+ chatbot = gr.Chatbot()
41
+ msg = gr.Textbox()
42
+ clear = gr.Button("Clear")
43
+
44
+ def hist_to_prompt(history):
45
+ prompt = ""
46
+ for human_text, bot_text in history:
47
+ prompt += f"### HUMAN:\n{human_text}\n\n### RESPONSE:\n"
48
+ if bot_text:
49
+ prompt += f"{bot_text}\n\n"
50
+ return prompt
51
+
52
+ def get_bot_response(text):
53
+ bot_text_index = text.rfind('### RESPONSE:')
54
+ if bot_text_index != -1:
55
+ text = text[bot_text_index + len('### RESPONSE:'):].strip()
56
+ return text
57
+
58
+ def user(user_message, history):
59
+ return "", history + [[user_message, None]]
60
+
61
+ def bot(history):
62
+ #bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
63
+ #history[-1][1] = bot_message + '</s>'
64
+
65
+ hist_text = hist_to_prompt(history)
66
+ print(hist_text)
67
+ bot_message = get_llm_response(hist_text, pipe) + tokenizer.eos_token
68
+ history[-1][1] = bot_message # add bot message to overall history
69
+
70
+ return history
71
+
72
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
73
+ bot, chatbot, chatbot
74
+ )
75
+ clear.click(lambda: None, None, chatbot, queue=False)
76
+
77
+ demo.queue()
78
+ demo.launch()