Spaces:
Paused
Paused
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
import os | |
# Set the environment variable for debugging (you can remove this in production) | |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
# Load model and tokenizer | |
base_model_name = "adarsh3601/my_gemma_pt3" | |
adapter_name = "your_adapter_name_here" # Replace with actual adapter name if needed | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
device_map="auto", # Using device_map="auto" for automatic GPU assignment | |
torch_dtype=torch.float32, # Switch to float32 to avoid precision issues | |
load_in_4bit=True # This should still be set if your model supports it | |
) | |
# Load the adapter model | |
model = PeftModel.from_pretrained(base_model, adapter_name) | |
model.to(device) | |
# Ensure the model is in evaluation mode | |
model.eval() | |
# Chat function with added input/output validation | |
def chat(message): | |
# Tokenize input message | |
inputs = tokenizer(message, return_tensors="pt") | |
# Check if any input token contains NaN or Inf | |
if torch.any(torch.isnan(inputs['input_ids'])) or torch.any(torch.isinf(inputs['input_ids'])): | |
return "Input contains invalid values (NaN or Inf). Please check the input." | |
# Move tensors to the correct device | |
inputs = {k: v.to(device).half() for k, v in inputs.items()} # Using half precision for performance | |
try: | |
# Generate response | |
outputs = model.generate(**inputs, max_new_tokens=150, do_sample=True) | |
# Check for NaNs or Infs in the output | |
if torch.any(torch.isnan(outputs)) or torch.any(torch.isinf(outputs)): | |
return "Model output contains invalid values (NaN or Inf). Please try again." | |
# Decode the response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
# Catch any errors that occur during generation and return them | |
response = f"Unexpected error: {str(e)}" | |
return response | |
# Gradio interface for the chat | |
import gradio as gr | |
def gradio_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("## Chat with Gemma Model") | |
with gr.Row(): | |
message_input = gr.Textbox(label="Input Message") | |
output = gr.Textbox(label="Model Response") | |
# Button to trigger the chat | |
button = gr.Button("Generate Response") | |
button.click(fn=chat, inputs=message_input, outputs=output) | |
demo.launch() | |
if __name__ == "__main__": | |
gradio_interface() | |