Malikosama1 commited on
Commit
893f310
·
1 Parent(s): 0ad240f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+ import google.generativeai as genai
5
+
6
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
7
+ genai.configure(api_key=GOOGLE_API_KEY)
8
+
9
+ # Set up the model
10
+ generation_config = {
11
+ "temperature": 0.9,
12
+ "top_p": 1,
13
+ "top_k": 1,
14
+ "max_output_tokens": 2048,
15
+ }
16
+
17
+ safety_settings = [
18
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
19
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
20
+ {
21
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
22
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE",
23
+ },
24
+ {
25
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
26
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE",
27
+ },
28
+ ]
29
+
30
+ model = genai.GenerativeModel(
31
+ model_name="gemini-pro",
32
+ generation_config=generation_config,
33
+ safety_settings=safety_settings,
34
+ )
35
+
36
+ task_description = " You are an SMS (Short Message Service) reader who reads every message that the short message service centre receives and you need to classify each message among the following categories: {}<div>Let the output be a softmax function output giving the probability of message belonging to each category.</div><div>The sum of the probabilities should be 1</div><div>The output must be in JSON format</div>"
37
+
38
+
39
+ def classify_msg(categories, message):
40
+ prompt_parts = [
41
+ task_description.format(categories),
42
+ f"Message: {message}",
43
+ "Category: ",
44
+ ]
45
+
46
+ response = model.generate_content(prompt_parts)
47
+
48
+ json_response = json.loads(
49
+ response.text[response.text.find("{") : response.text.rfind("}") + 1]
50
+ )
51
+
52
+ return gr.Label(json_response)
53
+
54
+
55
+ def clear_inputs_and_outputs():
56
+ return [None, None, None]
57
+
58
+
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown(
61
+ """
62
+ <h1 align="center">Multi-language Text Classifier using Gemini Pro</h1> \
63
+ This space uses Gemini Pro in order to classify texts.<br> \
64
+ Depending on the list of categories that you specify, you can have text classifier, a SPAM detector, a sentiment classifier, ... <br><br> \
65
+ <b>For the categories, enter a list of words separated by commas</b><br><br>
66
+ """
67
+ )
68
+ with gr.Row():
69
+ with gr.Column():
70
+ with gr.Row():
71
+ categories = gr.Textbox(
72
+ label="Categories",
73
+ placeholder="Input the list of categories as comma separated words",
74
+ )
75
+ with gr.Row():
76
+ message = gr.Textbox(label="Message", placeholder="Enter Message")
77
+ with gr.Row():
78
+ clr_btn = gr.Button(value="Clear", variant="secondary")
79
+ csf_btn = gr.Button(value="Classify")
80
+ with gr.Column():
81
+ lbl_output = gr.Label(label="Prediction")
82
+
83
+ clr_btn.click(
84
+ fn=clear_inputs_and_outputs,
85
+ inputs=[],
86
+ outputs=[categories, message, lbl_output],
87
+ )
88
+ csf_btn.click(
89
+ fn=classify_msg,
90
+ inputs=[categories, message],
91
+ outputs=[lbl_output],
92
+ )
93
+
94
+ gr.Examples(
95
+ examples=[
96
+ ["Normal, Promotional, Urgent", "Will you be passing by?"],
97
+ ["Spam, Ham", "Plus de 300 % de perte de poids pendant le régime."],
98
+ ["Χαρούμενος, Δυστυχισμένος", "Η εξυπηρέτηση σας ήταν απαίσια"],
99
+ ["مهم، أقل أهمية ", "خبر عاجل"],
100
+ ],
101
+ inputs=[categories, message],
102
+ outputs=lbl_output,
103
+ fn=classify_msg,
104
+ cache_examples=True,
105
+ )
106
+
107
+ demo.queue(api_open=False)
108
+ demo.launch(debug=True, share=True, show_api=False)