vilarin commited on
Commit
6d1d1e9
·
verified ·
1 Parent(s): 4ed884e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -1,23 +1,22 @@
1
  import os
2
  import time
3
- import spaces
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import gradio as gr
7
 
8
- MODEL_LIST = ["internlm/internlm2_5-7b-chat", "internlm/internlm2_5-7b-chat-1m"]
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
  MODEL_ID = os.environ.get("MODEL_ID", None)
11
  MODEL_NAME = MODEL_ID.split("/")[-1]
12
 
13
- TITLE = "<h1><center>internlm2.5-7b-chat</center></h1>"
14
 
15
  DESCRIPTION = f"""
16
  <h3>MODEL NOW: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></h3>
17
  """
18
  PLACEHOLDER = """
19
  <center>
20
- <p>InternLM2.5 has open-sourced a 7 billion parameter base model<br> and a chat model tailored for practical scenarios.</p>
21
  </center>
22
  """
23
 
@@ -36,13 +35,12 @@ h3 {
36
 
37
  model = AutoModelForCausalLM.from_pretrained(
38
  MODEL_ID,
39
- torch_dtype=torch.float16,
40
- trust_remote_code=True).cuda()
41
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
42
 
43
- model = model.eval()
44
 
45
- @spaces.GPU()
46
  def stream_chat(
47
  message: str,
48
  history: list,
@@ -54,11 +52,11 @@ def stream_chat(
54
  ):
55
  print(f'message: {message}')
56
  print(f'history: {history}')
57
- for resp, history in model.stream_chat(
58
  tokenizer,
59
  query = message,
60
  history = history,
61
- max_new_tokens = max_new_tokens,
62
  do_sample = False if temperature == 0 else True,
63
  top_p = top_p,
64
  top_k = top_k,
@@ -92,7 +90,7 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
92
  maximum=8192,
93
  step=1,
94
  value=1024,
95
- label="Max New Tokens",
96
  render=False,
97
  ),
98
  gr.Slider(
 
1
  import os
2
  import time
 
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import gradio as gr
6
 
7
+ MODEL_LIST = ["openbmb/MiniCPM-1B-sft-bf16", "openbmb/MiniCPM-S-1B-sft"]
8
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
9
  MODEL_ID = os.environ.get("MODEL_ID", None)
10
  MODEL_NAME = MODEL_ID.split("/")[-1]
11
 
12
+ TITLE = "<h1><center>MiniCPM-1B-chat</center></h1>"
13
 
14
  DESCRIPTION = f"""
15
  <h3>MODEL NOW: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></h3>
16
  """
17
  PLACEHOLDER = """
18
  <center>
19
+ <p>MiniCPM is an End-Size LLM developed by ModelBest Inc. and TsinghuaNLP, with only 1.2B parameters excluding embeddings.</p>
20
  </center>
21
  """
22
 
 
35
 
36
  model = AutoModelForCausalLM.from_pretrained(
37
  MODEL_ID,
38
+ torch_dtype=torch.bfloat16,
39
+ device_map='auto',
40
+ trust_remote_code=True)
41
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
42
 
 
43
 
 
44
  def stream_chat(
45
  message: str,
46
  history: list,
 
52
  ):
53
  print(f'message: {message}')
54
  print(f'history: {history}')
55
+ for resp, history in model.chat(
56
  tokenizer,
57
  query = message,
58
  history = history,
59
+ max_length = max_new_tokens,
60
  do_sample = False if temperature == 0 else True,
61
  top_p = top_p,
62
  top_k = top_k,
 
90
  maximum=8192,
91
  step=1,
92
  value=1024,
93
+ label="Max Length",
94
  render=False,
95
  ),
96
  gr.Slider(