need-ai-conversational-model / example_usage.py
yogami9's picture
Add usage example script
a21df6e verified
"""
Example script for using NEED AI model from Hugging Face Hub
"""
import torch
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
# Configuration
MODEL_REPO = "yogami9/need-ai-conversational-model"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def setup_model():
"""Download and setup the model"""
print("📥 Downloading model files...")
# Download model file
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename="pytorch_model.bin"
)
# Download custom modeling code
modeling_path = hf_hub_download(
repo_id=MODEL_REPO,
filename="modeling_need.py"
)
print("✅ Files downloaded")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
# Import custom model class
import sys
import os
sys.path.insert(0, os.path.dirname(modeling_path))
from modeling_need import NEEDConversationalModel
# Load model
model = NEEDConversationalModel.from_pretrained(model_path)
model = model.to(DEVICE)
model.eval()
print(f"✅ Model loaded on {DEVICE}")
return model, tokenizer
def generate_response(model, tokenizer, text: str, max_length: int = 100):
"""Generate response for user input"""
# Prepare input
input_text = f"Human: {text}"
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(DEVICE)
speaker_ids = torch.zeros_like(input_ids)
# Generate
with torch.no_grad():
output_ids = model.generate(
input_ids=input_ids,
speaker_ids=speaker_ids,
max_length=max_length,
temperature=0.8,
top_k=50
)
# Decode
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return response
if __name__ == "__main__":
# Setup
model, tokenizer = setup_model()
# Example queries
queries = [
"I need a house cleaner in Lagos",
"How much does tutoring cost?",
"I need help with plumbing",
]
print("\n" + "="*60)
print("Testing NEED AI Model")
print("="*60)
for query in queries:
print(f"\n👤 User: {query}")
response = generate_response(model, tokenizer, query)
print(f"🤖 Assistant: {response}")