yenniejun commited on
Commit
c9a9ab8
Β·
1 Parent(s): 8ffb5c6

Adding plotly plot

Browse files
Files changed (1) hide show
  1. app.py +45 -51
app.py CHANGED
@@ -29,61 +29,55 @@ title = "HanmunRoBERTa Century Classifier"
29
  st.set_page_config(page_title=title, page_icon="πŸ“š")
30
  st.title(title)
31
 
32
- # Create a two-column layout
33
- col1, col2 = st.columns([2, 3]) # Adjust the width ratio as needed
34
 
35
- with col1:
36
- # Checkbox to remove punctuation
37
- remove_punct = st.checkbox(label="Remove punctuation", value=True)
38
 
39
- # Text area for user input
40
- input_str = st.text_area("Input text", height=275)
 
 
 
 
41
 
42
- # Remove punctuation if checkbox is selected
43
- if remove_punct and input_str:
44
- # Specify the characters to remove
45
- characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
46
- translating = str.maketrans('', '', characters_to_remove)
47
- input_str = input_str.translate(translating)
48
 
49
- # Display the input text after processing
50
- st.write("Processed input:", input_str)
 
 
 
 
 
51
 
52
- # Button for prediction
53
- classify_button = st.button("Classify")
 
 
 
 
 
 
 
 
 
 
54
 
55
- # Predict and display the classification scores if input is provided and button is clicked
56
- if classify_button and input_str:
57
- predictions = model_pipeline(input_str)
58
- data = pd.DataFrame(predictions)
59
- data = data.sort_values(by='score', ascending=True)
60
- data.label = data.label.astype(str)
61
-
62
- # Displaying predictions as a bar chart
63
- fig = go.Figure(
64
- go.Bar(
65
- x=data.score.values,
66
- y=[f'{i}th Century' for i in data.label.values],
67
- orientation='h',
68
- text=[f'{score:.3f}' for score in data['score'].values],
69
- textposition='outside',
70
- hoverinfo='text',
71
- hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])],
72
- marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]),
73
- ))
74
- fig.update_traces(width=0.4)
75
- fig.update_layout(
76
- height=300, # Custom height
77
- xaxis_title='Score',
78
- yaxis_title='',
79
- title='Model predictions and scores',
80
- margin=dict(l=100, r=200, t=50, b=50),
81
- uniformtext_minsize=8,
82
- uniformtext_mode='hide',
83
- )
84
-
85
- with col2:
86
  st.plotly_chart(figure_or_data=fig, use_container_width=True)
87
- else:
88
- with col2:
89
- st.write("Please enter some text to classify and click 'Classify'.")
 
29
  st.set_page_config(page_title=title, page_icon="πŸ“š")
30
  st.title(title)
31
 
32
+ # Checkbox to remove punctuation
33
+ remove_punct = st.checkbox(label="Remove punctuation", value=True)
34
 
35
+ # Text area for user input
36
+ input_str = st.text_area("Input text", height=275)
 
37
 
38
+ # Remove punctuation if checkbox is selected
39
+ if remove_punct and input_str:
40
+ # Specify the characters to remove
41
+ characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
42
+ translating = str.maketrans('', '', characters_to_remove)
43
+ input_str = input_str.translate(translating)
44
 
45
+ # Display the input text after processing
46
+ st.write("Processed input:", input_str)
 
 
 
 
47
 
48
+ # Predict and display the classification scores if input is provided
49
+ if st.button("Classify"):
50
+ if input_str:
51
+ predictions = model_pipeline(input_str)
52
+ data = pd.DataFrame(predictions)
53
+ data=data.sort_values(by='score', ascending=True)
54
+ data.label = data.label.astype(str)
55
 
56
+
57
+ # Displaying predictions as a bar chart
58
+ fig = go.Figure(
59
+ go.Bar(
60
+ x=data.score.values,
61
+ y=[f'{i}th Century' for i in data.label.values],
62
+ orientation='h',
63
+ text=[f'{score:.3f}' for score in data['score'].values], # Format text with 2 decimal points
64
+ textposition='outside', # Position the text outside the bars
65
+ hoverinfo='text', # Use custom text for hover info
66
+ hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])], # Custom hover text
67
+ marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]), # Cycle through colors
68
 
69
+ ))
70
+ fig.update_traces(width=0.4)
71
+
72
+ fig.update_layout(
73
+ height=300, # Custom height
74
+ xaxis_title='Score',
75
+ yaxis_title='',
76
+ title='Model predictions and scores',
77
+ margin=dict(l=100, r=200, t=50, b=50),
78
+ uniformtext_minsize=8,
79
+ uniformtext_mode='hide',
80
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  st.plotly_chart(figure_or_data=fig, use_container_width=True)
82
+ else:
83
+ st.write("Please enter some text to classify.")