yenniejun commited on
Commit
6fc21a4
Β·
1 Parent(s): a686b95

Going back to streamlit

Browse files
Files changed (2) hide show
  1. app.py +39 -38
  2. requirements.txt +4 -1
app.py CHANGED
@@ -8,49 +8,50 @@ HuggingFace Spaces that:
8
  # https://huggingface.co/blog/streamlit-spaces
9
  # https://huggingface.co/docs/hub/en/spaces-sdks-streamlit
10
 
11
- # https://www.gradio.app/docs/interface
12
- # https://huggingface.co/spaces/docs-demos/roberta-base/blob/main/app.py
13
-
14
  """
15
- import gradio as gr
 
 
16
  from string import punctuation
 
 
 
 
 
 
17
 
 
18
  title = "HanmunRoBERTa Century Classifier"
19
- description = "Century classifier for classical Chinese and Korean texts"
 
20
 
21
- # Load the HanmunRoBERTa model
22
- hanmun_roberta = gr.load("huggingface/bdsl/HanmunRoBERTa")
23
 
24
- def strip_text(inputtext):
 
 
 
 
 
25
  characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
26
  translating = str.maketrans('', '', characters_to_remove)
27
- return inputtext.translate(translating)
28
-
29
- def inference(inputtext, model, strip_text_flag):
30
- if strip_text_flag:
31
- inputtext = strip_text(inputtext)
32
- if model == "HanmunRoBERTa":
33
- outlabel = hanmun_roberta(inputtext)
34
- return outlabel
35
-
36
- # Define some example inputs for your interface
37
- examples = [["Example text 1", "HanmunRoBERTa", True],
38
- ["Example text 2", "HanmunRoBERTa", True]]
39
-
40
- # Set up the Gradio interface
41
- gr.Interface(
42
- inference,
43
- [gr.inputs.Textbox(label="Input text", lines=10),
44
- gr.inputs.Dropdown(choices=["HanmunRoBERTa"],
45
- type="value",
46
- default="HanmunRoBERTa",
47
- label="Model"),
48
- gr.inputs.Checkbox(label="Remove punctuation")],
49
- [gr.outputs.Label(label="Output")],
50
- examples=examples,
51
- title=title,
52
- description=description).launch(enable_queue=True)
53
-
54
-
55
-
56
-
 
8
  # https://huggingface.co/blog/streamlit-spaces
9
  # https://huggingface.co/docs/hub/en/spaces-sdks-streamlit
10
 
 
 
 
11
  """
12
+
13
+ import streamlit as st
14
+ from transformers import pipeline
15
  from string import punctuation
16
+ import pandas as pd
17
+ # from huggingface_hub import InferenceClient
18
+ # client = InferenceClient(model="bdsl/HanmunRoBERTa")
19
+
20
+ # Load the pipeline with the HanmunRoBERTa model
21
+ model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa")
22
 
23
+ # Streamlit app layout
24
  title = "HanmunRoBERTa Century Classifier"
25
+ st.title(title)
26
+ st.set_page_config(layout=layout, page_title=title, page_icon="πŸ“š")
27
 
28
+ # Checkbox to remove punctuation
29
+ remove_punct = st.checkbox(label="Remove punctuation", value=True)
30
 
31
+ # Text area for user input
32
+ input_str = st.text_area("Input text", height=275)
33
+
34
+ # Remove punctuation if checkbox is selected
35
+ if remove_punct and input_str:
36
+ # Specify the characters to remove
37
  characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
38
  translating = str.maketrans('', '', characters_to_remove)
39
+ input_str = input_str.translate(translating)
40
+
41
+ # Display the input text after processing
42
+ st.write("Processed input:", input_str)
43
+
44
+ # Predict and display the classification scores if input is provided
45
+ if st.button("Classify"):
46
+ if input_str:
47
+ predictions = model_pipeline(input_str)
48
+
49
+ # Prepare the data for plotting
50
+ labels = [prediction['label'] for prediction in predictions]
51
+ scores = [prediction['score'] for prediction in predictions]
52
+ data = pd.DataFrame({"Label": labels, "Score": scores})
53
+
54
+ # Displaying predictions as a bar chart
55
+ st.bar_chart(data.set_index('Label'))
56
+ else:
57
+ st.write("Please enter some text to classify.")
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1 +1,4 @@
1
- gradio==3.50
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ pandas