|
|
|
""" |
|
Demo script for using the Shopping Assistant model with the Hugging Face Inference API |
|
""" |
|
import requests |
|
import json |
|
import argparse |
|
|
|
def query_model(text, api_token=None, model_id="selvaonline/shopping-assistant"): |
|
""" |
|
Query the model using the Hugging Face Inference API |
|
""" |
|
api_url = f"https://api-inference.huggingface.co/models/{model_id}" |
|
|
|
headers = {} |
|
if api_token: |
|
headers["Authorization"] = f"Bearer {api_token}" |
|
|
|
payload = { |
|
"inputs": text, |
|
"options": { |
|
"wait_for_model": True |
|
} |
|
} |
|
|
|
response = requests.post(api_url, headers=headers, json=payload) |
|
|
|
if response.status_code == 200: |
|
return response.json() |
|
else: |
|
print(f"Error: {response.status_code}") |
|
print(response.text) |
|
return None |
|
|
|
def process_results(results, text): |
|
""" |
|
Process the results from the Inference API |
|
""" |
|
if not results or not isinstance(results, list) or len(results) == 0: |
|
return f"No results found for '{text}'" |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
probabilities = 1 / (1 + np.exp(-np.array(results[0]))) |
|
|
|
|
|
categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"] |
|
|
|
|
|
top_categories = [] |
|
for i, score in enumerate(probabilities): |
|
if score > 0.5: |
|
top_categories.append((categories[i], float(score))) |
|
|
|
|
|
top_categories.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
if top_categories: |
|
result = f"Top categories for '{text}':\n" |
|
for category, score in top_categories: |
|
result += f" {category}: {score:.4f}\n" |
|
|
|
result += f"\nBased on your query, I would recommend looking for deals in the {top_categories[0][0]} category." |
|
else: |
|
result = f"No categories found for '{text}'. Please try a different query." |
|
|
|
return result |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Demo for using the Shopping Assistant model with the Hugging Face Inference API") |
|
parser.add_argument("--text", type=str, required=True, help="Text to classify") |
|
parser.add_argument("--token", type=str, help="Hugging Face API token") |
|
parser.add_argument("--model-id", type=str, default="selvaonline/shopping-assistant", help="Hugging Face model ID") |
|
args = parser.parse_args() |
|
|
|
|
|
results = query_model(args.text, args.token, args.model_id) |
|
|
|
|
|
print(process_results(results, args.text)) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|