|
|
""" |
|
|
Example script for using NEED AI model from Hugging Face Hub |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
MODEL_REPO = "yogami9/need-ai-conversational-model" |
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
def setup_model(): |
|
|
"""Download and setup the model""" |
|
|
print("📥 Downloading model files...") |
|
|
|
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id=MODEL_REPO, |
|
|
filename="pytorch_model.bin" |
|
|
) |
|
|
|
|
|
|
|
|
modeling_path = hf_hub_download( |
|
|
repo_id=MODEL_REPO, |
|
|
filename="modeling_need.py" |
|
|
) |
|
|
|
|
|
print("✅ Files downloaded") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) |
|
|
|
|
|
|
|
|
import sys |
|
|
import os |
|
|
sys.path.insert(0, os.path.dirname(modeling_path)) |
|
|
from modeling_need import NEEDConversationalModel |
|
|
|
|
|
|
|
|
model = NEEDConversationalModel.from_pretrained(model_path) |
|
|
model = model.to(DEVICE) |
|
|
model.eval() |
|
|
|
|
|
print(f"✅ Model loaded on {DEVICE}") |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
def generate_response(model, tokenizer, text: str, max_length: int = 100): |
|
|
"""Generate response for user input""" |
|
|
|
|
|
input_text = f"Human: {text}" |
|
|
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(DEVICE) |
|
|
speaker_ids = torch.zeros_like(input_ids) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate( |
|
|
input_ids=input_ids, |
|
|
speaker_ids=speaker_ids, |
|
|
max_length=max_length, |
|
|
temperature=0.8, |
|
|
top_k=50 |
|
|
) |
|
|
|
|
|
|
|
|
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
return response |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model, tokenizer = setup_model() |
|
|
|
|
|
|
|
|
queries = [ |
|
|
"I need a house cleaner in Lagos", |
|
|
"How much does tutoring cost?", |
|
|
"I need help with plumbing", |
|
|
] |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Testing NEED AI Model") |
|
|
print("="*60) |
|
|
|
|
|
for query in queries: |
|
|
print(f"\n👤 User: {query}") |
|
|
response = generate_response(model, tokenizer, query) |
|
|
print(f"🤖 Assistant: {response}") |
|
|
|