Spaces:
Runtime error
Runtime error
SidneyChen
commited on
Commit
•
727e077
1
Parent(s):
a8b4549
Upload demo_0113.py
Browse files- demo_0113.py +228 -0
demo_0113.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""demo_0113.ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/1ge4fiA7yDzLAH4vl1LN4_3NxkbLGdKhz
|
8 |
+
"""
|
9 |
+
|
10 |
+
!pip install -qq transformers
|
11 |
+
|
12 |
+
import pandas as pd
|
13 |
+
# from catboost import CatBoostClassifier
|
14 |
+
from sklearn.preprocessing import LabelEncoder
|
15 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
16 |
+
from wordcloud import WordCloud
|
17 |
+
from tqdm import tqdm
|
18 |
+
import nltk
|
19 |
+
from nltk.stem import WordNetLemmatizer
|
20 |
+
from nltk.corpus import stopwords
|
21 |
+
import re
|
22 |
+
from sklearn.model_selection import train_test_split
|
23 |
+
from sklearn.svm import SVC,LinearSVC
|
24 |
+
from sklearn.ensemble import RandomForestClassifier,GradientBoostingClassifier
|
25 |
+
from xgboost import XGBClassifier
|
26 |
+
import matplotlib.pyplot as plt
|
27 |
+
import seaborn as sns
|
28 |
+
from sklearn.metrics import accuracy_score
|
29 |
+
from sklearn.naive_bayes import MultinomialNB
|
30 |
+
from sklearn.experimental import enable_hist_gradient_boosting
|
31 |
+
from sklearn.ensemble import HistGradientBoostingClassifier
|
32 |
+
from imblearn.over_sampling import SMOTE
|
33 |
+
import plotly.express as px
|
34 |
+
import warnings
|
35 |
+
import torch
|
36 |
+
torch.backends.cudnn.benchmark = True
|
37 |
+
from torchvision import transforms, utils
|
38 |
+
import math
|
39 |
+
import random
|
40 |
+
import numpy as np
|
41 |
+
from torch import nn, autograd, optim
|
42 |
+
import numpy as np
|
43 |
+
import random
|
44 |
+
|
45 |
+
warnings.filterwarnings('ignore')
|
46 |
+
|
47 |
+
!pip install openai
|
48 |
+
|
49 |
+
!pip install gradio
|
50 |
+
|
51 |
+
import os
|
52 |
+
import openai
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
# Commented out IPython magic to ensure Python compatibility.
|
57 |
+
from google.colab import drive
|
58 |
+
drive.mount("/content/drive", force_remount=True)
|
59 |
+
FOLDERNAME="Colab\ Notebooks/finalproject_test"
|
60 |
+
# %cd drive/MyDrive/$FOLDERNAME
|
61 |
+
|
62 |
+
import time
|
63 |
+
import pandas as pd
|
64 |
+
import numpy as np
|
65 |
+
import matplotlib.pyplot as plt
|
66 |
+
import seaborn as sns
|
67 |
+
from sklearn.model_selection import train_test_split
|
68 |
+
from sklearn.metrics import f1_score, accuracy_score
|
69 |
+
import os,re
|
70 |
+
import warnings
|
71 |
+
warnings.filterwarnings('ignore')
|
72 |
+
import nltk
|
73 |
+
from nltk.corpus import stopwords
|
74 |
+
from nltk.stem import PorterStemmer
|
75 |
+
from wordcloud import WordCloud
|
76 |
+
from tqdm import tqdm, trange
|
77 |
+
import torch
|
78 |
+
from torch.nn import BCEWithLogitsLoss
|
79 |
+
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
|
80 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
81 |
+
|
82 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
83 |
+
device = 'cuda'
|
84 |
+
|
85 |
+
model = torch.load('mbti_model.pt')
|
86 |
+
max_length = 512
|
87 |
+
threshold = 0.50
|
88 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
|
89 |
+
|
90 |
+
# def clean_text(posts):
|
91 |
+
# clean = []
|
92 |
+
# # lower case
|
93 |
+
# posts = posts.lower()
|
94 |
+
# # remove emali
|
95 |
+
# posts = re.sub(re.compile(r'\S+@\S+'), "", posts)
|
96 |
+
# # remove tag
|
97 |
+
# posts = re.sub(re.compile(r'@\S+'), "", posts)
|
98 |
+
# # remove '
|
99 |
+
# posts = re.sub(re.compile(r'\''), "", posts)
|
100 |
+
# # posts(|||)->list
|
101 |
+
# posts = posts.split('|||')
|
102 |
+
# # removing links and len(posts) > 5
|
103 |
+
# posts = [s for s in posts if not re.search(r'https?:\/\/[^\s<>"]+|www\.[^\s<>"]+', s) if len(s)>5]
|
104 |
+
# posts = [re.sub(r'\'', '', s) for s in posts]
|
105 |
+
# return posts
|
106 |
+
|
107 |
+
sentence = "Share some fun facts to break the ice"
|
108 |
+
|
109 |
+
# sentence = clean_text(sentence)
|
110 |
+
|
111 |
+
def data_preprocess(sentence):
|
112 |
+
test_encodings = tokenizer.encode_plus(sentence,max_length=max_length,pad_to_max_length=True,truncation=True)
|
113 |
+
test_input_ids = test_encodings['input_ids']
|
114 |
+
test_token_type_ids = test_encodings['token_type_ids']
|
115 |
+
test_attention_masks = test_encodings['attention_mask']
|
116 |
+
|
117 |
+
test_inputs = torch.tensor(test_input_ids).reshape(512,1).T
|
118 |
+
test_masks = torch.tensor(test_attention_masks).reshape(512,1).T
|
119 |
+
test_token_types = torch.tensor(test_token_type_ids).reshape(512,1).T
|
120 |
+
|
121 |
+
return test_inputs, test_masks
|
122 |
+
|
123 |
+
a, b = data_preprocess(sentence)
|
124 |
+
|
125 |
+
def predict(test_inputs, test_masks):
|
126 |
+
model.eval()
|
127 |
+
logit_preds,pred_labels = [],[]
|
128 |
+
with torch.no_grad():
|
129 |
+
# forward pass
|
130 |
+
test_inputs = test_inputs.to(device)
|
131 |
+
test_masks = test_masks.to(device)
|
132 |
+
outs = model(test_inputs , token_type_ids=None, attention_mask=test_masks)
|
133 |
+
b_logit_pred = outs[0]
|
134 |
+
pred_label = torch.sigmoid(b_logit_pred)
|
135 |
+
# print(pred_label)
|
136 |
+
|
137 |
+
# converting into numpy arrays
|
138 |
+
b_logit_pred = b_logit_pred.detach().cpu().numpy()
|
139 |
+
pred_label = pred_label.to('cpu').numpy()
|
140 |
+
# print(pred_label.tolist())
|
141 |
+
|
142 |
+
|
143 |
+
# flatten output variables
|
144 |
+
|
145 |
+
|
146 |
+
# converting flattened binary values to boolean values
|
147 |
+
pred_bools = [pl>threshold for pl in pred_label]
|
148 |
+
# print(pred_bools)
|
149 |
+
|
150 |
+
mbti = ''
|
151 |
+
for i in range(4):
|
152 |
+
if i == 0:
|
153 |
+
mbti += 'E' if pred_bools[0][i] else 'I'
|
154 |
+
if i == 1:
|
155 |
+
mbti += 'S' if pred_bools[0][i] else 'N'
|
156 |
+
if i == 2:
|
157 |
+
mbti += 'T' if pred_bools[0][i] else 'F'
|
158 |
+
if i == 3:
|
159 |
+
mbti += 'J' if pred_bools[0][i] else 'P'
|
160 |
+
return mbti
|
161 |
+
|
162 |
+
predict(a, b)
|
163 |
+
|
164 |
+
import os
|
165 |
+
import openai
|
166 |
+
import gradio as gr
|
167 |
+
import random
|
168 |
+
|
169 |
+
openai.api_key = ("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")
|
170 |
+
|
171 |
+
def translation(text):
|
172 |
+
response = openai.Completion.create(
|
173 |
+
model="text-davinci-003",
|
174 |
+
# translation = '中翻英'
|
175 |
+
# text = "你好"
|
176 |
+
prompt=f"中翻英{text}",
|
177 |
+
max_tokens=500,
|
178 |
+
top_p=1,
|
179 |
+
frequency_penalty=0,
|
180 |
+
presence_penalty=0
|
181 |
+
)
|
182 |
+
return response['choices'][0]['text'].strip()
|
183 |
+
|
184 |
+
def predict_mbti(description):
|
185 |
+
text = translation(description)
|
186 |
+
text, text_masks = data_preprocess(text)
|
187 |
+
mbti = predict(text, text_masks)
|
188 |
+
return mbti
|
189 |
+
|
190 |
+
# with gr.Blocks(css=".gradio-container {background-color: red}") as demo
|
191 |
+
# demo = gr.Interface(fn=predict_mbti, #callable function
|
192 |
+
# inputs=gr.inputs.Textbox(label = '讓我來分析你最近的人格><', placeholder = '個性描述、自己的故事或是曾經發過的文章'), #input format
|
193 |
+
# outputs=gr.outputs.Textbox(label = '只有我最了解你,你是一位...'),
|
194 |
+
# # outputs = [gr.outputs.Textbox(label = '只有我最了解你,你是一位...'), gr.outputs.Textbox(label = '專屬推薦給你的電影🍿')],
|
195 |
+
# title = "AI-MBTI knows U.",
|
196 |
+
# description = 'Come on. Let us predict your MBTI type !!! We will tell you what kind of movie should you watch !',
|
197 |
+
# theme = 'grass',
|
198 |
+
|
199 |
+
|
200 |
+
# ) #output format
|
201 |
+
|
202 |
+
blocks = gr.Blocks()
|
203 |
+
|
204 |
+
with blocks as demo:
|
205 |
+
desc = gr.Textbox(label = '讓我來分析你最近的人格📝', placeholder= '個性描述、自己的故事或是曾經發過的文章')
|
206 |
+
# verb = gr.Radio(label = '請問有聽過16型人格測驗(16pernalities)嗎 /n https://www.16personalities.com/free-personality-test', ["有", "沒有"])
|
207 |
+
survey = gr.Radio(["⭕️有聽過👂16型人格測驗(16pernalities)", "❌沒有聽過👂16型人格測驗(16pernalities)"],
|
208 |
+
label = '民意調查中...')
|
209 |
+
survey2 = gr.Radio(["✅曾經做過✏️16型人格測驗(16pernalities)", "❎沒有做過✏️16型人格測驗(16pernalities)"],
|
210 |
+
label = '搜集民意中...')
|
211 |
+
object = gr.Textbox(placeholder="object")
|
212 |
+
|
213 |
+
with gr.Row():
|
214 |
+
type_btn = gr.Button("16型人格類型👨👧👦")
|
215 |
+
movie_btn = gr.Button("推薦專屬電影🍿")
|
216 |
+
|
217 |
+
|
218 |
+
output1 = gr.Textbox(label="👉根據這段描述,你的16型人格類型🪢會是...")
|
219 |
+
output2 = gr.Textbox(label="👉由你的描述與人格特質,適合你的電影🎦有...")
|
220 |
+
|
221 |
+
type_btn.click(predict_mbti, desc, output1)
|
222 |
+
# movie_btn.click(None, [subject, verb, object], output2, _js="(s, v, o) => o + ' ' + v + ' ' + s")
|
223 |
+
# # verb.change(lambda x: x, verb, output3, _js="(x) => [...x].reverse().join('')")
|
224 |
+
# foo_bar_btn.click(None, [], subject, _js="(x) => x + ' foo'")
|
225 |
+
|
226 |
+
#display the interface
|
227 |
+
demo.launch(share=True, debug=True)
|
228 |
+
|