import os import json import gradio as gr import google.generativeai as genai GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") genai.configure(api_key=GOOGLE_API_KEY) # Set up the model generation_config = { "temperature": 0.9, "top_p": 1, "top_k": 1, "max_output_tokens": 2048, } safety_settings = [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE", }, { "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE", }, ] model = genai.GenerativeModel( model_name="gemini-pro", generation_config=generation_config, safety_settings=safety_settings, ) task_description = " You are an SMS (Short Message Service) reader who reads every message that the short message service centre receives and you need to classify each message among the following categories: {}
Let the output be a softmax function output giving the probability of message belonging to each category.
The sum of the probabilities should be 1
The output must be in JSON format
" def classify_msg(categories, message): prompt_parts = [ task_description.format(categories), f"Message: {message}", "Category: ", ] response = model.generate_content(prompt_parts) json_response = json.loads( response.text[response.text.find("{") : response.text.rfind("}") + 1] ) return gr.Label(json_response) def clear_inputs_and_outputs(): return [None, None, None] with gr.Blocks() as demo: gr.Markdown( """

Multi-language Text Classifier using Gemini Pro

\ This space uses Gemini Pro in order to classify texts.
\ Depending on the list of categories that you specify, you can have text classifier, a SPAM detector, a sentiment classifier, ...

\ For the categories, enter a list of words separated by commas

""" ) with gr.Row(): with gr.Column(): with gr.Row(): categories = gr.Textbox( label="Categories", placeholder="Input the list of categories as comma separated words", ) with gr.Row(): message = gr.Textbox(label="Message", placeholder="Enter Message") with gr.Row(): clr_btn = gr.Button(value="Clear", variant="secondary") csf_btn = gr.Button(value="Classify") with gr.Column(): lbl_output = gr.Label(label="Prediction") clr_btn.click( fn=clear_inputs_and_outputs, inputs=[], outputs=[categories, message, lbl_output], ) csf_btn.click( fn=classify_msg, inputs=[categories, message], outputs=[lbl_output], ) gr.Examples( examples=[ ["Normal, Promotional, Urgent", "Will you be passing by?"], ["Spam, Ham", "Plus de 300 % de perte de poids pendant le régime."], ["Χαρούμενος, Δυστυχισμένος", "Η εξυπηρέτηση σας ήταν απαίσια"], ["مهم، أقل أهمية ", "خبر عاجل"], ], inputs=[categories, message], outputs=lbl_output, fn=classify_msg, cache_examples=True, ) demo.queue(api_open=False) demo.launch(debug=True, share=True, show_api=False)