File size: 2,356 Bytes
8e53f74
 
 
 
 
 
 
f47e717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e53f74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f47e717
 
 
 
 
8e53f74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f47e717
 
8e53f74
 
f47e717
8e53f74
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from typing import Dict
import os

from homepage2vec.model import WebsiteClassifier as Homepage2Vec

EXAMPLES = [
    # Personal site
    ["original", "tanjasenghaasdesigns.de"],
    ["finetuned-gpt4", "tanjasenghaasdesigns.de"],

    # EPFL
    ["finetuned-gpt3.5", "epfl.ch"],
    ["finetuned-gpt4", "epfl.ch"],

    # Czech Crunch - czech tech news
    ["original", "cc.cz"],
    ["finetuned-gpt4", "cc.cz"],

    # Promaminky - czech site for moms
    ["original", "promaminky.cz"],
    ["finetuned-gpt3.5", "promaminky.cz"],
]


def predict(model_choice : str, url : str) -> Dict[str, float]:
    """
    Predict the categories of a website using the Homepage2Vec model.

    Args:
        model_choice (str): The model to use for prediction.
        url (str): The url of the website to predict.
    
    Returns:
        Dict[str, float]: The categories and their corresponding scores.
    """

    if model_choice == "original":
        model_dir = os.path.join("models", "homepage2vec")
    else:
        which_gpt = model_choice.split("-")[1]
        model_dir = os.path.join("models", "finetuned", which_gpt)

    # Initialise model
    model = Homepage2Vec(model_dir=model_dir)

    # Website to predict
    website = model.fetch_website(url)

    # Obtain scores and embeddings
    scores, _ = model.predict(website)

    # Filter only scores that have a value greater than 0.5
    scores = {k: v for k, v in scores.items() if v > 0.5}

    return scores
    

iface = gr.Interface(
    fn=predict,
    inputs=[gr.Dropdown(choices=["original", "finetuned-gpt3.5", "finetuned-gpt4"], label="Select Model", show_label=True, value="finetuned-gpt4"),
            gr.Textbox(label="Enter Website's URL or domain", placeholder="e.g. ikea.com")],
    outputs=gr.Label(num_top_classes=14, label="Predicted Labels", show_label=True),
    title="Homepage2Vec",
    description="Select a version of the Homepage2Vec model and enter a website's URL or domain to predict its categories. The original model was trained on 886K websites from Curlie directory. The finetuned models, in addition, were trained on GPT annotated websites. On average, the fintuned models should predict more labels than the original model while maintaining high accuracy.",
    examples=EXAMPLES,
    live=False,
    allow_flagging="never",
)

iface.launch()