File size: 1,684 Bytes
2a36adb
 
 
 
 
 
 
 
 
 
 
 
 
6e8a510
 
 
 
1268880
 
 
 
6e8a510
 
1268880
 
6e8a510
2a36adb
 
 
03d1cd4
8700bd0
 
 
4645f37
8700bd0
4645f37
 
 
03d1cd4
2a36adb
 
b6e5fba
2a36adb
 
b6e5fba
cc84955
2a36adb
 
 
1268880
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
import numpy as np
import pandas as pd
import tensorflow as tf
from transformers import XLNetTokenizer, TFXLNetModel

# Load your data from disaster_tweet.csv
df = pd.read_csv("disaster_tweet.csv")  # Update the filename here

# Extract text and label columns
text_data = df["text"]
label_data = df["target"]

# Load XLNet tokenizer and model
xlnet_tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
xlnet_model = TFXLNetModel.from_pretrained('xlnet-base-cased')

# Define custom object scope for TFXLNetModel
def load_model_with_custom_objects():
    with tf.keras.utils.custom_object_scope({"TFXLNetModel": TFXLNetModel}):
        model = tf.keras.models.load_model("xlnet_model.h5")
    return model

# Load the saved model within custom object scope
model = load_model_with_custom_objects()

# Define function to predict disaster tweet
def predict_disaster_tweet(text):
    input_ids = xlnet_tokenizer.encode(text, add_special_tokens=True, max_length=100, padding='max_length', return_tensors="tf")
    attention_masks = tf.ones_like(input_ids)  # Assuming all tokens are relevant
    pred = model.predict([input_ids, attention_masks])
    final_pred = np.where(pred >= 0.5, 1, 0)
    if final_pred == 1:
        return "Disaster"
    elif final_pred == 0:
        return "Non-Disaster"
    else:
        return "Uncertain"

# Define Gradio interface
iface = gr.Interface(
    fn=predict_disaster_tweet,
    inputs="text",
    outputs="text",
    title="Disaster Tweet Prediction",
    description="Enter a tweet and get prediction whether it's a (Disaster, Non-Disaster, or Uncertain)."
)

# Launch the Gradio interface
iface.launch(inline=False)