daspartho commited on
Commit
8815850
1 Parent(s): 2da7cfd
Files changed (3) hide show
  1. app.py +24 -0
  2. labels.bin +3 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline
3
+ import pickle
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained("daspartho/subreddit-predictor")
6
+ model = AutoModelForSequenceClassification.from_pretrained("daspartho/subreddit-predictor") # i've uploaded the model on HuggingFace :)
7
+
8
+ with open('labels.bin', 'rb') as f:
9
+ label_map = pickle.load(f)
10
+
11
+ pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, top_k=3)
12
+
13
+ def classify_text(plot):
14
+ predictions = pipe(plot)[0]
15
+ return {label_map[pred['label']]: float(pred['score']) for pred in predictions}
16
+
17
+ iface = gr.Interface(
18
+ description = "Enter a title for a reddit post, and the model will attempt to predict the subreddit.",
19
+ article = "<p style='text-align: center'><a href='https://github.com/daspartho/predict-subreddit' target='_blank'>Github</a></p>",
20
+ fn=classify_text,
21
+ inputs=gr.inputs.Textbox(label="Type the title here"),
22
+ outputs=gr.outputs.Label(label='What the model thinks'),
23
+ )
24
+ iface.launch()
labels.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7c51566ebf1c3c795393b317d2105a51d75e81e761dc28a04430e1e4c6063ec
3
+ size 4003
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ gradio