made1570 commited on
Commit
22dbba3
·
verified ·
1 Parent(s): f118086

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -31
app.py CHANGED
@@ -1,51 +1,78 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
- import gradio as gr
 
 
 
 
 
 
 
5
 
6
- # Model loading
7
- base_model_name = "unsloth/gemma-3-12b-it-unsloth-bnb-4bit"
8
- adapter_name = "adarsh3601/my_gemma3_pt"
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- # Load base model in 4-bit with float16
 
12
  base_model = AutoModelForCausalLM.from_pretrained(
13
  base_model_name,
14
- device_map="auto",
15
- torch_dtype=torch.float16,
16
- load_in_4bit=True
17
  )
18
 
19
- # Load tokenizer
20
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
21
-
22
- # Load fine-tuned adapter
23
  model = PeftModel.from_pretrained(base_model, adapter_name)
24
  model.to(device)
25
 
26
- # Chat function
 
 
 
27
  def chat(message):
 
 
 
 
 
 
 
 
 
 
28
  try:
29
- # Tokenize input (do NOT convert to .half())
30
- inputs = tokenizer(message, return_tensors="pt").to(device)
31
 
32
- # Generate output
33
- outputs = model.generate(
34
- **inputs,
35
- max_new_tokens=150,
36
- do_sample=True,
37
- temperature=0.7,
38
- top_p=0.95
39
- )
40
-
41
- # Decode output
42
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
- return response
44
 
45
  except Exception as e:
46
- print("Unexpected error:", e)
47
- return "An error occurred during generation."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Launch Gradio interface
50
- iface = gr.Interface(fn=chat, inputs="text", outputs="text", title="Gemma Chatbot")
51
- iface.launch()
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
+ import os
5
+
6
+ # Set the environment variable for debugging (you can remove this in production)
7
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
8
+
9
+ # Load model and tokenizer
10
+ base_model_name = "adarsh3601/my_gemma_pt3"
11
+ adapter_name = "your_adapter_name_here" # Replace with actual adapter name if needed
12
 
 
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ # Load the tokenizer and model
16
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
17
  base_model = AutoModelForCausalLM.from_pretrained(
18
  base_model_name,
19
+ device_map="auto", # Using device_map="auto" for automatic GPU assignment
20
+ torch_dtype=torch.float32, # Switch to float32 to avoid precision issues
21
+ load_in_4bit=True # This should still be set if your model supports it
22
  )
23
 
24
+ # Load the adapter model
 
 
 
25
  model = PeftModel.from_pretrained(base_model, adapter_name)
26
  model.to(device)
27
 
28
+ # Ensure the model is in evaluation mode
29
+ model.eval()
30
+
31
+ # Chat function with added input/output validation
32
  def chat(message):
33
+ # Tokenize input message
34
+ inputs = tokenizer(message, return_tensors="pt")
35
+
36
+ # Check if any input token contains NaN or Inf
37
+ if torch.any(torch.isnan(inputs['input_ids'])) or torch.any(torch.isinf(inputs['input_ids'])):
38
+ return "Input contains invalid values (NaN or Inf). Please check the input."
39
+
40
+ # Move tensors to the correct device
41
+ inputs = {k: v.to(device).half() for k, v in inputs.items()} # Using half precision for performance
42
+
43
  try:
44
+ # Generate response
45
+ outputs = model.generate(**inputs, max_new_tokens=150, do_sample=True)
46
 
47
+ # Check for NaNs or Infs in the output
48
+ if torch.any(torch.isnan(outputs)) or torch.any(torch.isinf(outputs)):
49
+ return "Model output contains invalid values (NaN or Inf). Please try again."
50
+
51
+ # Decode the response
 
 
 
 
 
52
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
53
 
54
  except Exception as e:
55
+ # Catch any errors that occur during generation and return them
56
+ response = f"Unexpected error: {str(e)}"
57
+
58
+ return response
59
+
60
+ # Gradio interface for the chat
61
+ import gradio as gr
62
+
63
+ def gradio_interface():
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("## Chat with Gemma Model")
66
+
67
+ with gr.Row():
68
+ message_input = gr.Textbox(label="Input Message")
69
+ output = gr.Textbox(label="Model Response")
70
+
71
+ # Button to trigger the chat
72
+ button = gr.Button("Generate Response")
73
+ button.click(fn=chat, inputs=message_input, outputs=output)
74
+
75
+ demo.launch()
76
 
77
+ if __name__ == "__main__":
78
+ gradio_interface()