fllay commited on
Commit
5a1c3b8
·
verified ·
1 Parent(s): 33553e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -16
app.py CHANGED
@@ -2,70 +2,78 @@ import torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
- # Hugging Face repo ID (from the model page)
6
  MODEL_NAME = "NextGLab/ORANSight_Gemma_2_2B_Instruct"
7
 
8
- # Load tokenizer & model
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  MODEL_NAME,
13
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
14
- device_map="auto"
15
  )
16
 
17
- # --- Helper function ---
18
  def chat(message, history, max_new_tokens=128, temperature=0.7):
19
  try:
20
- # Turn history into messages for the chat template
21
  messages = []
22
  for user_msg, bot_msg in history:
23
  messages.append({"role": "user", "content": user_msg})
24
  messages.append({"role": "assistant", "content": bot_msg})
25
  messages.append({"role": "user", "content": message})
26
 
27
- # Tokenize input
28
  inputs = tokenizer.apply_chat_template(
29
  messages,
30
  add_generation_prompt=True,
31
  tokenize=True,
32
  return_tensors="pt",
33
- ).to(model.device)
 
 
 
 
 
34
 
35
- # Generate response
36
  outputs = model.generate(
37
  **inputs,
38
  max_new_tokens=max_new_tokens,
39
  temperature=temperature,
40
  do_sample=True,
41
- pad_token_id=tokenizer.eos_token_id,
42
  )
43
 
44
- # Decode only new tokens
45
  response = tokenizer.decode(
46
  outputs[0][inputs["input_ids"].shape[-1]:],
47
  skip_special_tokens=True
48
  ).strip()
49
 
 
50
  history.append((message, response))
51
  return history, history, ""
52
 
53
  except Exception as e:
54
  import traceback
55
- traceback.print_exc() # this will show the full error in Logs
56
  return history + [(message, f"⚠️ Error: {str(e)}")], history, ""
57
 
58
- # --- Gradio App ---
 
59
  with gr.Blocks() as demo:
60
- gr.Markdown("# 🤖 ORANSight Gemma 2 2B Instruct")
61
 
62
  chatbot = gr.Chatbot()
63
  msg = gr.Textbox(show_label=False, placeholder="Type a message...")
64
  send = gr.Button("Send")
65
  clear = gr.Button("Clear Chat")
66
 
67
- max_tokens = gr.Slider(50, 512, value=128, step=10, label="Max new tokens")
68
- temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
 
69
 
70
  state = gr.State([])
71
 
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # Choose your model repo (from NextGLab)
6
  MODEL_NAME = "NextGLab/ORANSight_Gemma_2_2B_Instruct"
7
 
8
+ # Load tokenizer and model
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  MODEL_NAME,
13
+ torch_dtype="auto", # lets HF decide (fp16/bf16/fp32 depending on GPU)
14
+ device_map="auto" # put on GPU if available
15
  )
16
 
17
+ # --- Chat function ---
18
  def chat(message, history, max_new_tokens=128, temperature=0.7):
19
  try:
20
+ # Convert history into Hugging Face messages format
21
  messages = []
22
  for user_msg, bot_msg in history:
23
  messages.append({"role": "user", "content": user_msg})
24
  messages.append({"role": "assistant", "content": bot_msg})
25
  messages.append({"role": "user", "content": message})
26
 
27
+ # Prepare inputs with chat template → return dictionary
28
  inputs = tokenizer.apply_chat_template(
29
  messages,
30
  add_generation_prompt=True,
31
  tokenize=True,
32
  return_tensors="pt",
33
+ return_dict=True
34
+ )
35
+
36
+ # Move all tensors in the input dict to the model device
37
+ for k in inputs:
38
+ inputs[k] = inputs[k].to(model.device)
39
 
40
+ # Generate model output
41
  outputs = model.generate(
42
  **inputs,
43
  max_new_tokens=max_new_tokens,
44
  temperature=temperature,
45
  do_sample=True,
46
+ pad_token_id=tokenizer.eos_token_id
47
  )
48
 
49
+ # Decode ONLY the newly generated tokens (past the input length)
50
  response = tokenizer.decode(
51
  outputs[0][inputs["input_ids"].shape[-1]:],
52
  skip_special_tokens=True
53
  ).strip()
54
 
55
+ # Append to history
56
  history.append((message, response))
57
  return history, history, ""
58
 
59
  except Exception as e:
60
  import traceback
61
+ traceback.print_exc() # will show in HF Space Logs
62
  return history + [(message, f"⚠️ Error: {str(e)}")], history, ""
63
 
64
+
65
+ # --- Gradio UI ---
66
  with gr.Blocks() as demo:
67
+ gr.Markdown("# 🤖 ORANSight Gemma Chat (2B Instruct)")
68
 
69
  chatbot = gr.Chatbot()
70
  msg = gr.Textbox(show_label=False, placeholder="Type a message...")
71
  send = gr.Button("Send")
72
  clear = gr.Button("Clear Chat")
73
 
74
+ with gr.Row():
75
+ max_tokens = gr.Slider(50, 512, step=10, value=128, label="Max tokens")
76
+ temperature = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
77
 
78
  state = gr.State([])
79