SonFox2920 commited on
Commit
2b377c2
1 Parent(s): ee41e0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -34
app.py CHANGED
@@ -1,35 +1,91 @@
1
  import streamlit as st
2
- import pandas as pd
3
- import matplotlib.pyplot as plt
4
- import seaborn as sns
5
-
6
- # Tạo sidebar cho upload file
7
- st.sidebar.title("Upload Dataset")
8
- uploaded_file = st.sidebar.file_uploader("Chọn một file CSV", type=["csv"])
9
-
10
- # Kiểm tra xem đã upload file chưa
11
- if uploaded_file is not None:
12
- # Đọc dữ liệu từ file CSV
13
- df = pd.read_csv(uploaded_file)
14
-
15
- # Hiển thị dữ liệu
16
- st.subheader("Dữ liệu từ file CSV")
17
- st.write(df)
18
-
19
- # Thống kê số lượng nhãn
20
- st.subheader("Thống kê số lượng nhãn")
21
- label_counts = df['label_id'].value_counts()
22
-
23
- # Hiển thị số lượng free_text của mỗi nhãn
24
- st.write("Số lượng nhãn OFFENSIVE (2):", label_counts.get(2, 0))
25
- st.write("Số lượng nhãn Clean (0):", label_counts.get(0, 0))
26
- st.write("Số lượng nhãn OFFENSIVE (1):", label_counts.get(1, 0))
27
-
28
- # Hiển thị biểu đồ thống
29
- st.subheader("Biểu đồ thống kê")
30
- fig, ax = plt.subplots(figsize=(8, 5))
31
- sns.countplot(x='label_id', data=df, ax=ax)
32
- st.pyplot(fig)
33
-
34
- else:
35
- st.warning("Vui lòng upload file CSV.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from pyspark.sql import SparkSession
3
+ from pyspark.ml.pipeline import Pipeline, PipelineModel
4
+ from pyspark.sql.types import *
5
+ from pyspark.sql.functions import *
6
+
7
+ from pyspark.sql import DataFrame
8
+ from pyspark import keyword_only
9
+ from pyspark.ml import Transformer
10
+ from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params, TypeConverters
11
+ from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
12
+
13
+ import re
14
+ import string
15
+
16
+ # Load Spark session
17
+ spark = SparkSession.builder\
18
+ .appName("HateSpeechDetection")\
19
+ .master('local[*]')\
20
+ .getOrCreate()
21
+
22
+ # Load the pre-trained model
23
+ loaded_model = PipelineModel.load('LogisticRegression')
24
+
25
+ # Define the TextTransformer class (as in your code)
26
+ class TextTransformer(Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable):
27
+ input_col = Param(Params._dummy(), "input_col", "input column name.", typeConverter=TypeConverters.toString)
28
+ output_col = Param(Params._dummy(), "output_col", "output column name.", typeConverter=TypeConverters.toString)
29
+
30
+ @keyword_only
31
+ def __init__(self, input_col: str = "input", output_col: str = "output", ):
32
+ super(TextTransformer, self).__init__()
33
+ self._setDefault(input_col=None, output_col=None)
34
+ kwargs = self._input_kwargs
35
+ self.set_params(**kwargs)
36
+
37
+
38
+ @keyword_only
39
+ def set_params(self, input_col: str = "input", output_col: str = "output"):
40
+ kwargs = self._input_kwargs
41
+ self._set(**kwargs)
42
+
43
+ def get_input_col(self):
44
+ return self.getOrDefault(self.input_col)
45
+
46
+ def get_output_col(self):
47
+ return self.getOrDefault(self.output_col)
48
+
49
+
50
+ def _transform(self, df: DataFrame):
51
+ def preprocess_text(text, ) -> str:
52
+ text = re.sub(r'\d+', '', str(text)).translate(str.maketrans( string.punctuation, ' '*len(string.punctuation)),).strip().lower()
53
+ return text
54
+ input_col = self.get_input_col()
55
+ output_col = self.get_output_col()
56
+ # The custom action: concatenate the integer form of the doubles from the Vector
57
+ transform_udf = udf(preprocess_text, StringType())
58
+ new_df = df.withColumn(output_col, transform_udf(input_col))
59
+ return new_df
60
+
61
+ # Create a Streamlit app
62
+ def main():
63
+ st.title("Text Classification App")
64
+
65
+ # User input text
66
+ user_input = st.text_area("Enter text here:")
67
+
68
+ if st.button("Predict"):
69
+ if user_input:
70
+ # Create a DataFrame with a single column 'free_text' containing the input text
71
+ data = [(user_input,)]
72
+ columns = ['free_text']
73
+ input_df = spark.createDataFrame(data, columns)
74
+
75
+ # Use the loaded model to make predictions
76
+ predictions = loaded_model.transform(input_df)
77
+
78
+ # Extract the prediction result
79
+ result = predictions.select("prediction").collect()[0]["prediction"]
80
+
81
+ # Map the prediction result to corresponding labels
82
+ labels = {0: "CLEAN", 1: "OFFENSIVE", 2: "HATE"}
83
+ predicted_class = labels.get(result, "UNKNOWN")
84
+
85
+ # Display the result
86
+ st.success(f"Predicted class: {predicted_class}")
87
+ else:
88
+ st.warning("Please enter some text.")
89
+
90
+ if __name__ == "__main__":
91
+ main()