nisten commited on
Commit
3d92619
1 Parent(s): 9ca55ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -1,20 +1,19 @@
1
  import gradio as gr
2
  import spaces
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import subprocess
6
  import sys
7
- import os
8
 
9
  # Force upgrade transformers to the latest version
10
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"])
11
 
12
- model_name = "allenai/OLMoE-1B-7B-0924"
13
 
14
  # Wrap model loading in a try-except block to handle potential errors
15
  try:
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
- model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
  trust_remote_code=True,
20
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
@@ -37,24 +36,25 @@ def generate_response(message, history, temperature, max_new_tokens):
37
  if model is None or tokenizer is None:
38
  return "Model or tokenizer not loaded properly. Please check the logs."
39
 
40
- full_prompt = f"{system_prompt}\n\nHuman: {message}\n\nAssistant:"
 
41
 
42
- inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
43
 
44
  with torch.no_grad():
45
  generate_ids = model.generate(
46
- **inputs,
47
  max_new_tokens=max_new_tokens,
48
  do_sample=True,
49
  temperature=temperature,
50
  eos_token_id=tokenizer.eos_token_id,
51
  )
52
- response = tokenizer.decode(generate_ids[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
53
  return response.strip()
54
 
55
  css = """
56
  #output {
57
- height: 500px;
58
  overflow: auto;
59
  border: 1px solid #ccc;
60
  }
@@ -85,4 +85,4 @@ with gr.Blocks(css=css) as demo:
85
 
86
  if __name__ == "__main__":
87
  demo.queue(api_open=False)
88
- demo.launch(debug=True, show_api=False)
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import OlmoeForCausalLM, AutoTokenizer
4
  import torch
5
  import subprocess
6
  import sys
 
7
 
8
  # Force upgrade transformers to the latest version
9
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"])
10
 
11
+ model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
12
 
13
  # Wrap model loading in a try-except block to handle potential errors
14
  try:
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model = OlmoeForCausalLM.from_pretrained(
17
  model_name,
18
  trust_remote_code=True,
19
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
 
36
  if model is None or tokenizer is None:
37
  return "Model or tokenizer not loaded properly. Please check the logs."
38
 
39
+ messages = [{"role": "system", "content": system_prompt},
40
+ {"role": "user", "content": message}]
41
 
42
+ inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
43
 
44
  with torch.no_grad():
45
  generate_ids = model.generate(
46
+ inputs,
47
  max_new_tokens=max_new_tokens,
48
  do_sample=True,
49
  temperature=temperature,
50
  eos_token_id=tokenizer.eos_token_id,
51
  )
52
+ response = tokenizer.decode(generate_ids[0, inputs.shape[1]:], skip_special_tokens=True)
53
  return response.strip()
54
 
55
  css = """
56
  #output {
57
+ height: 900px;
58
  overflow: auto;
59
  border: 1px solid #ccc;
60
  }
 
85
 
86
  if __name__ == "__main__":
87
  demo.queue(api_open=False)
88
+ demo.launch(debug=True, show_api=True, share=True)