hanbin commited on
Commit
703f11a
·
1 Parent(s): a31e3cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -2
app.py CHANGED
@@ -1,4 +1,152 @@
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
1
+ # This file is .....
2
+ # Author: Hanbin Wang
3
+ # Date: 2023/4/18
4
+ import transformers
5
  import streamlit as st
6
+ from PIL import Image
7
+
8
+ from transformers import RobertaTokenizer, T5ForConditionalGeneration
9
+ from transformers import pipeline
10
+
11
+ @st.cache_resource
12
+ def get_model(model_path):
13
+ tokenizer = RobertaTokenizer.from_pretrained(model_path)
14
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
15
+ model.eval()
16
+ return tokenizer, model
17
+
18
+
19
+ def main():
20
+ # `st.set_page_config` is used to display the default layout width, the title of the app, and the emoticon in the browser tab.
21
+
22
+ st.set_page_config(
23
+ layout="centered", page_title="MaMaL-Sum Demo(代码摘要)", page_icon="❄️"
24
+ )
25
+
26
+ c1, c2 ,c3 = st.columns([0.32, 2,0.5])
27
+
28
+ # The snowflake logo will be displayed in the first column, on the left.
29
+
30
+ with c1:
31
+ st.image(
32
+ "./panda27.png",
33
+ width=100,
34
+ )
35
+
36
+ # The heading will be on the right.
37
+
38
+ with c2:
39
+ st.caption("")
40
+ st.title("MaMaL-Sum(代码摘要)")
41
+
42
+
43
+ ############ SIDEBAR CONTENT ############
44
+
45
+ st.sidebar.image("./panda27.png",width=270)
46
+
47
+ st.sidebar.markdown("---")
48
+
49
+ st.sidebar.write(
50
+ """
51
+ ## 使用方法:
52
+ 在【输入】文本框输入想要解释的代码,点击【摘要】按钮,即会显示代码的自然语言描述。
53
+ """
54
+ )
55
+
56
+ st.sidebar.write(
57
+ """
58
+ ## 注意事项:
59
+ 1)APP托管在外网上,请确保您可以全局科学上网。
60
+
61
+ 2)您可以下载[MaMaL-Sum](https://huggingface.co/hanbin/MaMaL-Sum)模型,本地测试。(无需科学上网)
62
+ """
63
+ )
64
+ # For elements to be displayed in the sidebar, we need to add the sidebar element in the widget.
65
+
66
+ # We create a text input field for users to enter their API key.
67
+
68
+ # API_KEY = st.sidebar.text_input(
69
+ # "Enter your HuggingFace API key",
70
+ # help="Once you created you HuggingFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens",
71
+ # type="password",
72
+ # )
73
+ #
74
+ # # Adding the HuggingFace API inference URL.
75
+ # API_URL = "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"
76
+ #
77
+ # # Now, let's create a Python dictionary to store the API headers.
78
+ # headers = {"Authorization": f"Bearer {API_KEY}"}
79
+
80
+
81
+ st.sidebar.markdown("---")
82
+
83
+
84
+ # Let's add some info about the app to the sidebar.
85
+
86
+ st.write(
87
+ "> **Tip:** 首次运行需要加载模型,可能需要一定的时间!"
88
+ )
89
+
90
+ st.write(
91
+ "> **Tip:** 左侧栏给出了一些good case 和 bad case,you can try it!"
92
+ )
93
+
94
+ st.sidebar.write(
95
+ "> **Good case:**"
96
+ )
97
+ code_good = """def svg_to_image(string, size=None):
98
+ if isinstance(string, unicode):
99
+ string = string.encode('utf-8')
100
+ renderer = QtSvg.QSvgRenderer(QtCore.QByteArray(string))
101
+ if not renderer.isValid():
102
+ raise ValueError('Invalid SVG data.')
103
+ if size is None:
104
+ size = renderer.defaultSize()
105
+ image = QtGui.QImage(size, QtGui.QImage.Format_ARGB32)
106
+ painter = QtGui.QPainter(image)
107
+ renderer.render(painter)
108
+ return image"""
109
+ st.sidebar.code(code_good, language='python')
110
+
111
+
112
+ st.sidebar.write(
113
+ "> **Bad cases:**"
114
+ )
115
+ code_bad = """from transformers import RobertaTokenizer, T5ForConditionalGeneration
116
+ from transformers import pipeline"""
117
+ st.sidebar.code(code_bad, language='python')
118
+
119
+ st.sidebar.write(
120
+ """
121
+ App 由 东北大学NLP课小组成员创建, 使用 [Streamlit](https://streamlit.io/)🎈 和 [HuggingFace](https://huggingface.co/inference-api)'s [MaMaL-Sum](https://huggingface.co/hanbin/MaMaL-Sum) 模型.
122
+ """
123
+ )
124
+
125
+ # model, tokenizer = load_model("hanbin/MaMaL-Gen")
126
+ st.write("### 输入:")
127
+ input = st.text_area("", height=200)
128
+ button = st.button('摘要')
129
+
130
+ tokenizer,model = get_model("hanbin/MaMaL-Sum")
131
+
132
+ input_ids = tokenizer(input, return_tensors="pt").input_ids
133
+ generated_ids = model.generate(input_ids, max_length=100)
134
+ output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
135
+ # generator = pipeline('text-generation', model="E:\DenseRetrievalGroup\CodeT5-base")
136
+ # output = generator(input)
137
+ # code = '''def hello():
138
+ # print("Hello, Streamlit!")'''
139
+ if button:
140
+ st.write("### 输出:")
141
+ st.code(output, language='python')
142
+ else:
143
+ st.write('')
144
+
145
+
146
+
147
+
148
+ if __name__ == '__main__':
149
+
150
+ main()
151
+
152