azamat's picture
Final fix
31ae1e9
import re
import requests
import gradio as gr
from transformers import pipeline
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
def process_tweet(tweet):
# remove links
tweet = re.sub('((www\.[\s]+)|(https?://[^\s]+))', '', tweet)
# remove usernames
tweet = re.sub('@[^\s]+', '', tweet)
# remove additional white spaces
tweet = re.sub('[\s]+', ' ', tweet)
# replace hashtags with words
tweet = re.sub(r'#([^\s]+)', r'\1', tweet)
# trim
tweet = tweet.strip('\'"')
return tweet #if len(tweet) > 0 else ""
tokenizer = AutoTokenizer.from_pretrained(
"azamat/geocoder_model_xlm_roberta_50"
)
relevancy_pipeline = pipeline("sentiment-analysis", model="azamat/geocoder_model")
coordinates_model = AutoModelForSequenceClassification.from_pretrained(
"azamat/geocoder_model_xlm_roberta_50",
)
def predict_relevancy(text):
outputs = relevancy_pipeline(text)
return outputs[0]['label'], outputs[0]['score']
def predict_coordinates(text):
encoding = tokenizer(text, padding="max_length", truncation=True, \
max_length=128, return_tensors='pt')
outputs = coordinates_model(**encoding)
return round(outputs[0][0][0].item(), 3), round(outputs[0][0][1].item(), 3)
def reverse_geocode(lat, lon):
payload = {
'lat' : lat,
'lon' : lon,
'zoom' : 12,
'format' : 'jsonv2',
'accept-language' : 'en'
}
try:
r = requests.get('https://geocode.maps.co/reverse', params=payload)
return f"Reverse geocoded coordinats: {r.json()['display_name']}"
except:
return "Service couldn't reverse geocode provided coordinates."
def predict(text):
text = process_tweet(text)
relevancy_label, relevancy_score = predict_relevancy(text)
if relevancy_label == 'relevant':
lat, lon = predict_coordinates(text)
reverse_geocoded = reverse_geocode(lat, lon)
return f"Confident for {round(relevancy_score * 100, 2)}% that tweet has the geolocation relevant information.\n" + \
f"Predicted coordinates are: lat: {lat} lon: {lon}.\n" + \
f"{reverse_geocoded}"
return f"Confident for {relevancy_score * 100}% that tweet does not have the geolocation relevant information."
iface = gr.Interface(fn=predict, inputs="text", outputs="text")
iface.launch()