nisten commited on
Commit
9ca55ad
1 Parent(s): be3574c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -4,16 +4,23 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import subprocess
6
  import sys
 
7
 
8
- # Force install the latest transformers version and flash attention
9
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "transformers", "flash-attn"])
10
 
11
  model_name = "allenai/OLMoE-1B-7B-0924"
12
 
13
  # Wrap model loading in a try-except block to handle potential errors
14
  try:
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").to(DEVICE)
 
 
 
 
 
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  except Exception as e:
19
  print(f"Error loading model: {e}")
@@ -32,20 +39,18 @@ def generate_response(message, history, temperature, max_new_tokens):
32
 
33
  full_prompt = f"{system_prompt}\n\nHuman: {message}\n\nAssistant:"
34
 
35
- inputs = tokenizer(full_prompt, return_tensors="pt")
36
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
37
 
38
  with torch.no_grad():
39
  generate_ids = model.generate(
40
  **inputs,
41
- max_length=inputs['input_ids'].shape[1] + max_new_tokens,
42
  do_sample=True,
43
  temperature=temperature,
 
44
  )
45
- response = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
46
- # Extract only the assistant's response
47
- assistant_response = response.split("Assistant:")[-1].strip()
48
- return assistant_response
49
 
50
  css = """
51
  #output {
@@ -56,9 +61,9 @@ css = """
56
  """
57
 
58
  with gr.Blocks(css=css) as demo:
59
- gr.Markdown("# Nisten's Karpathy Chatbot with OSS olMoE")
60
  chatbot = gr.Chatbot(elem_id="output")
61
- msg = gr.Textbox(label="Your prompt")
62
  with gr.Row():
63
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
64
  max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=1000, step=50, label="Max New Tokens")
 
4
  import torch
5
  import subprocess
6
  import sys
7
+ import os
8
 
9
+ # Force upgrade transformers to the latest version
10
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"])
11
 
12
  model_name = "allenai/OLMoE-1B-7B-0924"
13
 
14
  # Wrap model loading in a try-except block to handle potential errors
15
  try:
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
21
+ low_cpu_mem_usage=True,
22
+ device_map="auto"
23
+ )
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  except Exception as e:
26
  print(f"Error loading model: {e}")
 
39
 
40
  full_prompt = f"{system_prompt}\n\nHuman: {message}\n\nAssistant:"
41
 
42
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
 
43
 
44
  with torch.no_grad():
45
  generate_ids = model.generate(
46
  **inputs,
47
+ max_new_tokens=max_new_tokens,
48
  do_sample=True,
49
  temperature=temperature,
50
+ eos_token_id=tokenizer.eos_token_id,
51
  )
52
+ response = tokenizer.decode(generate_ids[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
53
+ return response.strip()
 
 
54
 
55
  css = """
56
  #output {
 
61
  """
62
 
63
  with gr.Blocks(css=css) as demo:
64
+ gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE")
65
  chatbot = gr.Chatbot(elem_id="output")
66
+ msg = gr.Textbox(label="Your message")
67
  with gr.Row():
68
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
69
  max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=1000, step=50, label="Max New Tokens")