yenniejun commited on
Commit
d3dbd6c
Β·
1 Parent(s): 644276e

Adding plotly plot

Browse files
Files changed (2) hide show
  1. app.py +74 -24
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,33 +1,83 @@
1
- import gradio as gr
 
 
 
 
 
2
 
3
- title = "RoBERTa"
 
4
 
5
- description = "Gradio Demo for RoBERTa. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
6
 
7
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1907.11692' target='_blank'>RoBERTa: A Robustly Optimized BERT Pretraining Approach</a></p>"
 
 
 
 
 
 
8
 
9
- examples = [
10
- ['The goal of life is <mask>.','roberta-base']
11
- ]
12
 
13
- io1 = gr.Interface.load("huggingface/roberta-base")
 
14
 
15
- io2 = gr.Interface.load("huggingface/roberta-large")
 
 
 
16
 
 
 
17
 
18
- def inference(inputtext, model):
19
- if model == "roberta-base":
20
- outlabel = io1(inputtext)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  else:
22
- outlabel = io2(inputtext)
23
- return outlabel
24
-
25
-
26
- gr.Interface(
27
- inference,
28
- [gr.Textbox(label="Context",lines=10),gr.Dropdown(choices=["roberta-base","roberta-large"], type="value", label="model")],
29
- [gr.Label(label="Output")],
30
- examples=examples,
31
- article=article,
32
- title=title,
33
- description=description).launch(enable_queue=True)
 
1
+ """
2
+ HuggingFace Spaces that:
3
+ - loads in HanmunRoBERTa model https://huggingface.co/bdsl/HanmunRoBERTa
4
+ - optionally strips text of punctuation and unwanted charactesr
5
+ - predicts century for the input text
6
+ - Visualizes prediction scores for each century
7
 
8
+ # https://huggingface.co/blog/streamlit-spaces
9
+ # https://huggingface.co/docs/hub/en/spaces-sdks-streamlit
10
 
11
+ """
12
 
13
+ import streamlit as st
14
+ from transformers import pipeline
15
+ from string import punctuation
16
+ import pandas as pd
17
+ import plotly.express as px
18
+ import plotly.graph_objects as go
19
+ colors = px.colors.qualitative.Plotly
20
 
21
+ # from huggingface_hub import InferenceClient
22
+ # client = InferenceClient(model="bdsl/HanmunRoBERTa")
 
23
 
24
+ # Load the pipeline with the HanmunRoBERTa model
25
+ model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa")
26
 
27
+ # Streamlit app layout
28
+ title = "HanmunRoBERTa Century Classifier"
29
+ st.set_page_config(page_title=title, page_icon="πŸ“š")
30
+ st.title(title)
31
 
32
+ # Checkbox to remove punctuation
33
+ remove_punct = st.checkbox(label="Remove punctuation", value=True)
34
 
35
+ # Text area for user input
36
+ input_str = st.text_area("Input text", height=275)
37
+
38
+ # Remove punctuation if checkbox is selected
39
+ if remove_punct and input_str:
40
+ # Specify the characters to remove
41
+ characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
42
+ translating = str.maketrans('', '', characters_to_remove)
43
+ input_str = input_str.translate(translating)
44
+
45
+ # Display the input text after processing
46
+ st.write("Processed input:", input_str)
47
+
48
+ # Predict and display the classification scores if input is provided
49
+ if st.button("Classify"):
50
+ if input_str:
51
+ predictions = model_pipeline(input_str)
52
+ data = pd.DataFrame(predictions)
53
+ data=data.sort_values(by='score', ascending=True)
54
+ data.label = data.label.astype(str)
55
+
56
+
57
+ # Displaying predictions as a bar chart
58
+ fig = go.Figure(
59
+ go.Bar(
60
+ x=data.score.values,
61
+ y=[f'{i}th Century' for i in data.label.values],
62
+ orientation='h',
63
+ text=[f'{score:.3f}' for score in data['score'].values], # Format text with 2 decimal points
64
+ textposition='outside', # Position the text outside the bars
65
+ hoverinfo='text', # Use custom text for hover info
66
+ hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])], # Custom hover text
67
+ marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]), # Cycle through colors
68
+
69
+ ))
70
+ fig.update_traces(width=0.4)
71
+
72
+ fig.update_layout(
73
+ height=300, # Custom height
74
+ xaxis_title='Score',
75
+ yaxis_title='',
76
+ title='Model predictions and scores',
77
+ margin=dict(l=100, r=200, t=50, b=50),
78
+ uniformtext_minsize=8,
79
+ uniformtext_mode='hide',
80
+ )
81
+ st.pyplot(fig=fig)
82
  else:
83
+ st.write("Please enter some text to classify.")
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  streamlit
2
  torch
3
  transformers
4
- pandas
 
 
1
  streamlit
2
  torch
3
  transformers
4
+ pandas
5
+ plotly