vilarin commited on
Commit
652620b
1 Parent(s): b49792d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -31
app.py CHANGED
@@ -1,29 +1,20 @@
1
- import subprocess
2
- subprocess.run(
3
- 'pip install flash-attn --no-build-isolation',
4
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
5
- shell=True
6
- )
7
  import os
8
  import time
9
  import spaces
10
  import torch
11
- from transformers import AutoModelForCausalLM, AutoTokenizer
12
  import gradio as gr
 
13
 
14
- MODEL_LIST = ["internlm/internlm2_5-7b-chat", "internlm/internlm2_5-7b-chat-1m"]
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
- MODEL_ID = os.environ.get("MODEL_ID", None)
17
- MODEL_NAME = MODEL_ID.split("/")[-1]
18
 
19
- TITLE = "<h1><center>internlm2.5-7b-chat</center></h1>"
20
 
21
- DESCRIPTION = f"""
22
- <h3>MODEL NOW: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></h3>
23
- """
24
  PLACEHOLDER = """
25
  <center>
26
- <p>InternLM2.5 has open-sourced a 7 billion parameter base model<br> and a chat model tailored for practical scenarios.</p>
27
  </center>
28
  """
29
 
@@ -40,14 +31,19 @@ h3 {
40
  }
41
  """
42
 
 
 
 
43
  model = AutoModelForCausalLM.from_pretrained(
44
- MODEL_ID,
45
- torch_dtype=torch.float16,
46
- attn_implementation="flash_attention_2",
47
- trust_remote_code=True).cuda()
48
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
49
 
50
- model = model.eval()
51
 
52
  @spaces.GPU()
53
  def stream_chat(
@@ -57,28 +53,49 @@ def stream_chat(
57
  max_new_tokens: int = 1024,
58
  top_p: float = 1.0,
59
  top_k: int = 20,
60
- penalty: float = 1.2
61
  ):
62
  print(f'message: {message}')
63
  print(f'history: {history}')
64
- for resp, history in model.stream_chat(
65
- tokenizer,
66
- query = message,
67
- history = history,
 
 
 
 
 
 
 
 
 
 
 
 
68
  max_new_tokens = max_new_tokens,
69
  do_sample = False if temperature == 0 else True,
70
  top_p = top_p,
71
  top_k = top_k,
72
  temperature = temperature,
73
- ):
74
- yield resp
 
75
 
 
 
 
 
 
 
 
 
76
 
 
77
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
78
 
79
  with gr.Blocks(css=CSS, theme="soft") as demo:
80
  gr.HTML(TITLE)
81
- gr.HTML(DESCRIPTION)
82
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
83
  gr.ChatInterface(
84
  fn=stream_chat,
@@ -99,7 +116,7 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
99
  maximum=8192,
100
  step=1,
101
  value=1024,
102
- label="Max New Tokens",
103
  render=False,
104
  ),
105
  gr.Slider(
@@ -138,4 +155,4 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
138
 
139
 
140
  if __name__ == "__main__":
141
- demo.launch()
 
 
 
 
 
 
 
1
  import os
2
  import time
3
  import spaces
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import gradio as gr
7
+ from threading import Thread
8
 
9
+ MODEL_LIST = ["meta-llama/Meta-Llama-3.1-8B-Instruct"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL = os.environ.get("MODEL_ID")
 
12
 
13
+ TITLE = "<h1><center>Mistral-Nemo</center></h1>"
14
 
 
 
 
15
  PLACEHOLDER = """
16
  <center>
17
+ <p>Hi! How can I help you today?</p>
18
  </center>
19
  """
20
 
 
31
  }
32
  """
33
 
34
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
37
  model = AutoModelForCausalLM.from_pretrained(
38
+ MODEL,
39
+ torch_dtype=torch.bfloat16,
40
+ device_map="auto",
41
+ ignore_mismatched_sizes=True)
42
+ terminators = [
43
+ tokenizer.eos_token_id,
44
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
45
+ ]
46
 
 
47
 
48
  @spaces.GPU()
49
  def stream_chat(
 
53
  max_new_tokens: int = 1024,
54
  top_p: float = 1.0,
55
  top_k: int = 20,
56
+ penalty: float = 1.2,
57
  ):
58
  print(f'message: {message}')
59
  print(f'history: {history}')
60
+
61
+ conversation = []
62
+ for prompt, answer in history:
63
+ conversation.extend([
64
+ {"role": "user", "content": prompt},
65
+ {"role": "assistant", "content": answer},
66
+ ])
67
+
68
+ conversation.append({"role": "user", "content": message})
69
+
70
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
71
+
72
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
73
+
74
+ generate_kwargs = dict(
75
+ input_ids=input_ids,
76
  max_new_tokens = max_new_tokens,
77
  do_sample = False if temperature == 0 else True,
78
  top_p = top_p,
79
  top_k = top_k,
80
  temperature = temperature,
81
+ eos_token_id=terminators,
82
+ streamer=streamer,
83
+ )
84
 
85
+ with torch.no_grad():
86
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
87
+ thread.start()
88
+
89
+ buffer = ""
90
+ for new_text in streamer:
91
+ buffer += new_text
92
+ yield buffer
93
 
94
+
95
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
96
 
97
  with gr.Blocks(css=CSS, theme="soft") as demo:
98
  gr.HTML(TITLE)
 
99
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
100
  gr.ChatInterface(
101
  fn=stream_chat,
 
116
  maximum=8192,
117
  step=1,
118
  value=1024,
119
+ label="Max new tokens",
120
  render=False,
121
  ),
122
  gr.Slider(
 
155
 
156
 
157
  if __name__ == "__main__":
158
+ demo.launch()