File size: 3,090 Bytes
ee0f4cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74d362a
 
 
 
 
 
ee0f4cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
from transformers import (
    pipeline
)

# albert
# albert_base_chinese_cluecorpussmall=pipeline(task="fill-mask", model="uer/albert-base-chinese-cluecorpussmall")
# roberta
# xlm_roberta_base=pipeline(task="fill-mask", model="xlm-roberta-base")
# xlm_roberta_large=pipeline(task="fill-mask", model="xlm-roberta-large")
# bert
pipe=pipeline(
    # model="rjx/chinese-new-text-classification-10200-albert-base-chinese-cluecorpussmall", 
    model="rjx/rjxai-xlm-roberta-longformer-1024-and-dataset-en-0523",
    use_auth_token=st.secrets["read_key"]
)

if 'type' not in st.session_state:
    st.session_state.type=""

if 'article' not in st.session_state:
    st.session_state.article=""

if 'result' not in st.session_state:
    st.session_state.result=""

# def charlength():
#     if len(st.session_state.article)>=512:
#         st.warning("article length is 512")

def form_article_click():
    # st.warning(st.session_state.model_key)
    # st.warning(st.session_state.article_key)

    if st.session_state.article_key == "":
        st.warning("Need to enter content")
    else:
        article=st.session_state.article_key
        # 截取大于510字符的内容
        if len(article)>=1024:
            article=article[0: 1024]
        # st.info(article)
        print(article)
        result = pipe(article)
        print(result)
        # st.session_state.type=result[0]["label"]
        if result:
            if result[0]["label"]=="人类写的":
                st.session_state.type="Human writing"
            elif result[0]["label"]=="ai写的":
                st.session_state.type="AI writing"
        st.session_state.result=result

st.title("AI write or Human write v2")

col1, col2 = st.columns(2)

with col1:
   st.header("input")
   with st.form(key="article_form"):
        # 模型选择
        # option_input = st.selectbox(
        #     'select rjx model:(Other model making)',
        #     ('chinese-new-text-classification-10200-albert-base-chinese-cluecorpussmall', 'other'),
        #     disabled=True,
        #     key='model_key',
        #     # on_change=charlength
        # )
        # 文本输入
        article_input=st.text_area(
            'content',
            height=270,
            key='article_key'
        )
        # 提交按钮
        submit_button=st.form_submit_button(label='submit', on_click=form_article_click)

with col2:
   st.header("output")
   st.text_input('classification', st.session_state.type, disabled=True)
#    st.text_area(
#        'result', 
#        st.session_state.result, 
#        height=270,
#        disabled=True
#     )


# st.write('selected classification model:', option)


# file_name = st.file_uploader("Upload a hot dog candidate image")

# if file_name is not None:
#     col1, col2 = st.columns(2)

#     image = Image.open(file_name)
#     col1.image(image, use_column_width=True)
#     predictions = pipeline(image)

#     col2.header("Probabilities")
#     for p in predictions:
#         col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")