pvduy commited on
Commit
829da7c
1 Parent(s): ab36e1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -33
app.py CHANGED
@@ -1,56 +1,83 @@
 
 
1
  import spaces
2
 
3
- import os
 
4
  import json
5
- from vllm import LLM, SamplingParams
6
- from transformers import AutoTokenizer
 
 
 
 
 
7
 
 
 
 
 
 
8
 
9
  @spaces.GPU()
10
  def predict(message, history, system_prompt, temperature, max_tokens):
11
- messages = [{"role": "system", "content": system_prompt}]
 
12
  for human, assistant in history:
13
- messages.append({"role": "user", "content": human})
14
- messages.append({"role": "assistant", "content": assistant})
15
- messages.append({"role": "user", "content": message})
16
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
17
- stop_tokens = ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]
18
- sampling_params = SamplingParams(temperature=temperature, top_p=1, max_tokens=max_tokens, stop=stop_tokens)
19
- completions = llm.generate(prompt, sampling_params)
20
- for output in completions:
21
- prompt = output.prompt
22
- print('==========================question=============================')
23
- print(prompt)
24
- generated_text = output.outputs[0].text
25
- print('===========================answer=============================')
26
- print(generated_text)
27
- for idx in range(len(generated_text)):
28
- yield generated_text[:idx+1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  if __name__ == "__main__":
32
- path = "stabilityai/stablelm-2-12b-chat"
33
- tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
34
- llm = LLM(model=path, tensor_parallel_size=1, trust_remote_code=True)
 
 
35
  gr.ChatInterface(
36
  predict,
37
- title="LLM playground",
38
- description="This is a LLM playground for StableLM",
39
  theme="soft",
40
- chatbot=gr.Chatbot(height=1400, label="Chat History",),
41
  textbox=gr.Textbox(placeholder="input", container=False, scale=7),
42
  retry_btn=None,
43
  undo_btn="Delete Previous",
44
  clear_btn="Clear",
45
  additional_inputs=[
46
- gr.Textbox("You are a hepful assistant.", label="System Prompt"),
47
- gr.Slider(0, 1, 0.7, label="Temperature"),
48
  gr.Slider(100, 2048, 1024, label="Max Tokens"),
49
  ],
50
  additional_inputs_accordion_name="Parameters",
51
- examples=[
52
- ["implement snake game using pygame"],
53
- ["Can you explain briefly to me what is the Python programming language?"],
54
- ["write a program to find the factorial of a number"],
55
- ],
56
  ).queue().launch()
 
1
+ import argparse
2
+ import os
3
  import spaces
4
 
5
+ import gradio as gr
6
+
7
  import json
8
+ from threading import Thread
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+
12
+ MAX_LENGTH = 4096
13
+ DEFAULT_MAX_NEW_TOKENS = 1024
14
+
15
 
16
+ def parse_args():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--base_model", type=str) # model path
19
+ parser.add_argument("--n_gpus", type=int, default=1) # n_gpu
20
+ return parser.parse_args()
21
 
22
  @spaces.GPU()
23
  def predict(message, history, system_prompt, temperature, max_tokens):
24
+ global model, tokenizer, device
25
+ messages = [{'role': 'system', 'content': system_prompt}]
26
  for human, assistant in history:
27
+ messages.append({'role': 'user', 'content': human})
28
+ messages.append({'role': 'assistant', 'content': assistant})
29
+ messages.append({'role': 'user', 'content': message})
30
+ problem = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
31
+ stop_tokens = ["<|endoftext|>", "<|im_end|>"]
32
+ streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
33
+ enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
34
+ input_ids = enc.input_ids
35
+ attention_mask = enc.attention_mask
36
+
37
+ if input_ids.shape[1] > MAX_LENGTH:
38
+ input_ids = input_ids[:, -MAX_LENGTH:]
39
+
40
+ input_ids = input_ids.to(device)
41
+ attention_mask = attention_mask.to(device)
42
+ generate_kwargs = dict(
43
+ {"input_ids": input_ids, "attention_mask": attention_mask},
44
+ streamer=streamer,
45
+ do_sample=True,
46
+ top_p=0.95,
47
+ temperature=temperature,
48
+ max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
49
+ use_cache=True,
50
+ eos_token_id=100278 # <|im_end|>
51
+ )
52
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
53
+ t.start()
54
+ outputs = []
55
+ for text in streamer:
56
+ outputs.append(text)
57
+ yield "".join(outputs)
58
+
59
 
60
 
61
  if __name__ == "__main__":
62
+ args = parse_args()
63
+ tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-chat", trust_remote_code=True)
64
+ model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-2-chat", trust_remote_code=True, torch_dtype=torch.bfloat16)
65
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
66
+ model = model.to(device)
67
  gr.ChatInterface(
68
  predict,
69
+ title="StableLM 2 Chat - Demo",
70
+ description="StableLM 2 Chat - StabilityAI",
71
  theme="soft",
72
+ chatbot=gr.Chatbot(label="Chat History",),
73
  textbox=gr.Textbox(placeholder="input", container=False, scale=7),
74
  retry_btn=None,
75
  undo_btn="Delete Previous",
76
  clear_btn="Clear",
77
  additional_inputs=[
78
+ gr.Textbox("You are a helpful assistant.", label="System Prompt"),
79
+ gr.Slider(0, 1, 0.5, label="Temperature"),
80
  gr.Slider(100, 2048, 1024, label="Max Tokens"),
81
  ],
82
  additional_inputs_accordion_name="Parameters",
 
 
 
 
 
83
  ).queue().launch()