Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModel # Use PeftModel for loading adapter | |
import os | |
import gc | |
# --- Configuration --- | |
# Base model ID (the one you fine-tuned FROM) | |
base_model_id = "Qwen/Qwen2-0.5B" | |
# Path WITHIN THE SPACE where you will upload your adapter files | |
# Create a folder named 'adapter' in your Space and upload files there | |
adapter_path = "./adapter" | |
# Determine device (use GPU if available in the Space) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# --- Load Model and Tokenizer --- | |
print(f"Loading base model: {base_model_id}") | |
# Load base model in 4-bit | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_id, | |
quantization_config=None, # Load base normally first | |
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, # Use appropriate dtype | |
# device_map="auto", # <--- REMOVE THIS LINE | |
device_map=device, # <--- CHANGE TO THIS (load directly to device) | |
trust_remote_code=True | |
) | |
base_model.config.use_cache = True # Enable cache for inference speed | |
print(f"Base model loaded to device: {device}") | |
# --- Load PEFT Adapter --- | |
print(f"Loading PEFT adapter from: {adapter_path}") | |
# Load the PEFT model (adapter) on top of the base model | |
# Ensure the base_model is on the correct device before loading PEFT | |
model = PeftModel.from_pretrained(base_model, adapter_path) | |
print("Adapter loaded.") | |
# --- Merge Adapter --- | |
print("Merging adapter weights...") | |
model = model.merge_and_unload() | |
print("Adapter merged.") # Model should now be on the device specified earlier | |
# --- Load Tokenizer --- | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True) | |
# Set padding token if necessary (using the logic from your training script) | |
if tokenizer.pad_token is None: | |
if tokenizer.eos_token: | |
tokenizer.pad_token = tokenizer.eos_token | |
print(f"Set tokenizer pad_token to eos_token: {tokenizer.pad_token}") | |
else: | |
print("Warning: EOS token not found, cannot set pad_token automatically.") | |
tokenizer.padding_side = "left" # Important for generation | |
print("Model and tokenizer loaded successfully.") | |
# --- Inference Function --- | |
def summarize_text(article_text): | |
if not article_text: | |
return "Please enter some text to summarize." | |
# Format prompt for Qwen Base model (from your training script) | |
prompt = f"Summarize the following text:\n\n{article_text}\n\nSummary:" | |
try: | |
print("Tokenizing input...") | |
inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=True).to(device) | |
print("Generating summary...") | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=100, # Max length of the summary | |
temperature=0.6, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
# Decode only the generated part (after the prompt) | |
response_ids = outputs[0][inputs["input_ids"].shape[1]:] | |
summary = tokenizer.decode(response_ids, skip_special_tokens=True).strip() | |
print("Summary generated.") | |
# Clean up memory after generation | |
del inputs, outputs | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return summary | |
except Exception as e: | |
print(f"Error during inference: {e}") | |
return f"An error occurred: {e}" | |
# --- Create Gradio Interface --- | |
print("Creating Gradio interface...") | |
iface = gr.Interface( | |
fn=summarize_text, | |
inputs=gr.Textbox(lines=10, placeholder="Paste the text you want to summarize here...", label="Article Text"), | |
outputs=gr.Textbox(label="Generated Summary"), | |
title="Qwen2-0.5B Base - Fine-tuned Summarizer (GRPO/QLoRA)", | |
description="Enter text to get a summary generated by the fine-tuned Qwen2-0.5B base model.", | |
examples=[ | |
["SUBREDDIT: r/relationships TITLE: I (f/22) have to figure out if I want to still know these girls or not and would hate to sound insulting POST: Not sure if this belongs here but it's worth a try... (rest of example text from your logs)"] | |
# Add more examples if you like | |
] | |
) | |
# --- Launch the App --- | |
print("Launching Gradio app...") | |
iface.launch() |