amiguel commited on
Commit
bdb3fd1
1 Parent(s): bfc8095

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import GPT2Tokenizer
4
+ import pandas as pd
5
+
6
+ # Load the tokenizer
7
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
8
+
9
+ # Define the classification function
10
+ def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):
11
+ model.eval()
12
+
13
+ # Prepare inputs to the model
14
+ input_ids = tokenizer.encode(text)
15
+ supported_context_length = model.pos_emb.weight.shape[1]
16
+
17
+ # Truncate sequences if they are too long
18
+ input_ids = input_ids[:min(max_length, supported_context_length)]
19
+
20
+ # Pad sequences to the longest sequence
21
+ input_ids += [pad_token_id] * (max_length - len(input_ids))
22
+ input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension
23
+
24
+ # Model inference
25
+ with torch.no_grad():
26
+ logits = model(input_tensor)[:, -1, :] # Logits of the last output token
27
+ predicted_label = torch.argmax(logits, dim=-1).item()
28
+
29
+ # Return the classified result
30
+ return "Proper Naming Notfcn" if predicted_label == 1 else "Wrong Naming Notificn"
31
+
32
+ # Load the trained model from the local directory
33
+ model_path = "clv__classifier_774M.pth"
34
+ model = torch.load(model_path)
35
+ model.eval()
36
+
37
+ # Set the device to run the model on (GPU if available, else CPU)
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ model.to(device)
40
+
41
+ # Streamlit app
42
+ def main():
43
+ st.title("Text Classification App")
44
+
45
+ # Input options
46
+ input_option = st.radio("Select input option", ("Single Text Query", "Upload Table"))
47
+
48
+ if input_option == "Single Text Query":
49
+ # Single text query input
50
+ text_query = st.text_input("Enter text query")
51
+ if st.button("Classify"):
52
+ if text_query:
53
+ # Classify the text query
54
+ predicted_label = classify_review(text_query, model, tokenizer, device, max_length=train_dataset.max_length)
55
+ st.write("Predicted Label:")
56
+ st.write(predicted_label)
57
+ else:
58
+ st.warning("Please enter a text query.")
59
+
60
+ elif input_option == "Upload Table":
61
+ # Table upload
62
+ uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
63
+ if uploaded_file is not None:
64
+ # Read the uploaded file
65
+ if uploaded_file.name.endswith(".csv"):
66
+ df = pd.read_csv(uploaded_file)
67
+ else:
68
+ df = pd.read_excel(uploaded_file)
69
+
70
+ # Select the text column
71
+ text_column = st.selectbox("Select the text column", df.columns)
72
+
73
+ # Classify the texts in the selected column
74
+ predicted_labels = []
75
+ for text in df[text_column]:
76
+ predicted_label = classify_review(text, model, tokenizer, device, max_length=train_dataset.max_length)
77
+ predicted_labels.append(predicted_label)
78
+
79
+ # Add the predicted labels to the DataFrame
80
+ df["Predicted Label"] = predicted_labels
81
+
82
+ # Display the DataFrame with predicted labels
83
+ st.write(df)
84
+
85
+ if __name__ == "__main__":
86
+ main()