File size: 2,627 Bytes
14a55ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import os
# -----------------------------
# 1. Login using HF token (from environment variable)
# -----------------------------
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
print("β οΈ Warning: HF_TOKEN not set. Add it in the Space settings.")
# -----------------------------
# 2. Load Model
# -----------------------------
model_id = "ibm-granite/granite-3.3-2b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
use_auth_token=True
)
# -----------------------------
# 3. Functions
# -----------------------------
def classify_sentiment(text):
prompt = f"""
Classify the sentiment of the following review as one word: Positive, Negative, or Neutral.
Review: "{text}"
Only respond with the sentiment label.
"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=5,
temperature=0.1,
repetition_penalty=1.5
)
result = tokenizer.decode(output[0], skip_special_tokens=True)
final = result.split()[-1].strip()
return final
def summarize_text(text):
prompt = f"""
Summarize the following text in 3β5 bullet points.
Make the summary short, simple, and avoid repeating sentences.
Text: {text}
Summary:
"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
repetition_penalty=2.0,
)
result = tokenizer.decode(output[0], skip_special_tokens=True)
summary = result.split("Summary:")[-1].strip()
return summary
# -----------------------------
# 4. Gradio Interface
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("## π¦ Granite 3.3-2B β Sentiment + Summarization Demo")
with gr.Tab("Sentiment Classification"):
inp = gr.Textbox(label="Enter Review", lines=4)
out = gr.Textbox(label="Predicted Sentiment", lines=1)
btn = gr.Button("Classify")
btn.click(classify_sentiment, inp, out)
with gr.Tab("Summarization"):
inp2 = gr.Textbox(label="Enter Text to Summarize", lines=8)
out2 = gr.Textbox(label="Summary", lines=8)
btn2 = gr.Button("Summarize")
btn2.click(summarize_text, inp2, out2)
demo.launch() |