File size: 2,898 Bytes
f12a60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ccc55b
f12a60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
import torch
from dual_regression_model import DualRegressionModel
import transformers
from transformers import pipeline
from functools import partial

# load the models
# CLF: A-pt-bs16-dbmdz-bert-base-italian-cased
clf_model_tag = "clf_model/"
clf_tokenizer = transformers.AutoTokenizer.from_pretrained(clf_model_tag)
clf_model = transformers.AutoModelForSequenceClassification.from_pretrained(clf_model_tag)
clf_pipeline = pipeline("text-classification", model=clf_model, tokenizer=clf_tokenizer)

# REG
reg_model_tag = "distilbert-base-multilingual-cased"
reg_model_folder = "reg_model/regression_model.pt"
reg_model = DualRegressionModel(model_name_or_path=reg_model_tag)
reg_model.load_model(reg_model_folder)


# define the function to be used for prediction
def predict(text):
    # predict the class
    clf_prediction = clf_pipeline(text)[0]
    # predict the coordinates
    reg_input = reg_model.tokenizer(text, return_tensors="pt")
    reg_prediction = reg_model(reg_input)
    latitude, longitude = reg_prediction["latitude"].item(), reg_prediction["longitude"].item()
    lat_min = 38
    lat_max = 46
    long_min = 8
    long_max = 18
    # return the results
    html_output = f"<h3>The identified region is: {clf_prediction['label']}</h3>"
    # plot points on the map of Italy
    html_output += f'<h3>Predicted point on map:</h3><p>Latitude: {latitude}</p><p>Longitude: {longitude}</p>'
    html_output += f'<iframe width="425" height="350" frameborder="0" scrolling="no" marginheight="0" marginwidth="0" src="https://www.openstreetmap.org/export/embed.html?bbox={long_min}%2C{lat_min}%2C{long_max}%2C{lat_max}&amp;layer=mapnik&marker={latitude}%2C{longitude}" style="border: 1px solid black"></iframe><br/><small><a href="https://www.openstreetmap.org/#map=13/{latitude}/{longitude}">Visualizza mappa ingrandita</a></small>'

    return html_output

# --------------------------------------------------------------------------------------------
# Gradio interface
# --------------------------------------------------------------------------------------------

# define the interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=2, placeholder="Insert the text here..."),
    outputs=gr.HTML(),
    title="DANTE: Dialect ANalysis TEam",
    description="This is a demo of a classification and regression model for locating the italian dialect of a given text.",
    examples=[
        ["Bisognerebbe saperli materializzare .... !!  Ma  ovviamente  .. belin .... NO SE PEU SCIUSCIA' E SCIORBI'"],
        ["Guaglio' Buongiorno! Azz! Vir te si scurdat puparuol e mulignane pero '!! E che se fa😑"],
        ["Il massimo...ghe ne minga par nisun"],
        ["Che poi a me la tuta piace na cifra da vede. Subisco un po' lo stigma sociale che noi con la fregna dovemo stà sempre apposto.",]
    ]
)

# launch the interface
iface.launch()