made1570 commited on
Commit
258e5e7
·
verified ·
1 Parent(s): 22dbba3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -60
app.py CHANGED
@@ -1,78 +1,35 @@
1
- import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
 
4
  import os
5
 
6
- # Set the environment variable for debugging (you can remove this in production)
7
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
8
-
9
- # Load model and tokenizer
10
  base_model_name = "adarsh3601/my_gemma_pt3"
11
- adapter_name = "your_adapter_name_here" # Replace with actual adapter name if needed
12
-
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
14
 
15
- # Load the tokenizer and model
16
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
17
  base_model = AutoModelForCausalLM.from_pretrained(
18
  base_model_name,
19
- device_map="auto", # Using device_map="auto" for automatic GPU assignment
20
- torch_dtype=torch.float32, # Switch to float32 to avoid precision issues
21
- load_in_4bit=True # This should still be set if your model supports it
 
22
  )
23
 
24
- # Load the adapter model
25
  model = PeftModel.from_pretrained(base_model, adapter_name)
26
  model.to(device)
27
 
28
- # Ensure the model is in evaluation mode
29
- model.eval()
30
-
31
- # Chat function with added input/output validation
32
  def chat(message):
33
- # Tokenize input message
34
  inputs = tokenizer(message, return_tensors="pt")
35
-
36
- # Check if any input token contains NaN or Inf
37
- if torch.any(torch.isnan(inputs['input_ids'])) or torch.any(torch.isinf(inputs['input_ids'])):
38
- return "Input contains invalid values (NaN or Inf). Please check the input."
39
-
40
- # Move tensors to the correct device
41
- inputs = {k: v.to(device).half() for k, v in inputs.items()} # Using half precision for performance
42
-
43
- try:
44
- # Generate response
45
- outputs = model.generate(**inputs, max_new_tokens=150, do_sample=True)
46
-
47
- # Check for NaNs or Infs in the output
48
- if torch.any(torch.isnan(outputs)) or torch.any(torch.isinf(outputs)):
49
- return "Model output contains invalid values (NaN or Inf). Please try again."
50
-
51
- # Decode the response
52
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
-
54
- except Exception as e:
55
- # Catch any errors that occur during generation and return them
56
- response = f"Unexpected error: {str(e)}"
57
-
58
  return response
59
 
60
- # Gradio interface for the chat
61
- import gradio as gr
62
-
63
- def gradio_interface():
64
- with gr.Blocks() as demo:
65
- gr.Markdown("## Chat with Gemma Model")
66
-
67
- with gr.Row():
68
- message_input = gr.Textbox(label="Input Message")
69
- output = gr.Textbox(label="Model Response")
70
-
71
- # Button to trigger the chat
72
- button = gr.Button("Generate Response")
73
- button.click(fn=chat, inputs=message_input, outputs=output)
74
-
75
- demo.launch()
76
-
77
- if __name__ == "__main__":
78
- gradio_interface()
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  from peft import PeftModel
3
+ import gradio as gr
4
  import os
5
 
6
+ # Model loading
 
 
 
7
  base_model_name = "adarsh3601/my_gemma_pt3"
8
+ adapter_name = "adarsh3601/my_gemma3_pt"
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ auth_token = os.getenv("HF_AUTH_TOKEN") # Make sure to set the Hugging Face token as an environment variable
11
 
12
+ # Load model and tokenizer
13
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_auth_token=auth_token)
14
  base_model = AutoModelForCausalLM.from_pretrained(
15
  base_model_name,
16
+ device_map={"": device},
17
+ torch_dtype=torch.float16,
18
+ load_in_4bit=True,
19
+ use_auth_token=auth_token
20
  )
21
 
 
22
  model = PeftModel.from_pretrained(base_model, adapter_name)
23
  model.to(device)
24
 
25
+ # Chat function
 
 
 
26
  def chat(message):
 
27
  inputs = tokenizer(message, return_tensors="pt")
28
+ inputs = {k: v.to(device).half() for k, v in inputs.items()}
29
+ outputs = model.generate(**inputs, max_new_tokens=150, do_sample=True)
30
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return response
32
 
33
+ # Launch Gradio app
34
+ iface = gr.Interface(fn=chat, inputs="text", outputs="text", title="Gemma Chatbot")
35
+ iface.launch()