Spaces:
Running
Running
""" | |
License: | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
In no event shall the authors or copyright holders be liable | |
for any claim, damages or other liability, whether in an action of contract,otherwise, | |
arising from, out of or in connection with the software or the use or | |
other dealings in the software. | |
Copyright (c) 2024 pi19404. All rights reserved. | |
Authors: | |
pi19404 <pi19404@gmail.com> | |
""" | |
""" | |
Gradio Interface for Shield Gemma LLM Evaluator | |
This module provides a Gradio interface to interact with the Shield Gemma LLM Evaluator. | |
It allows users to input JSON data and select various options to evaluate the content | |
for policy violations. | |
Functions: | |
my_inference_function: The main inference function to process input data and return results. | |
""" | |
import gradio as gr | |
from gradio_client import Client | |
import torch | |
import json | |
import threading | |
import os | |
API_TOKEN=os.getenv("API_TOKEN") | |
lock = threading.Lock() | |
client = Client("pi19404/ai-worker",hf_token=API_TOKEN) | |
def my_inference_function(input_data, output_data,mode, max_length, max_new_tokens, model_size): | |
""" | |
The main inference function to process input data and return results. | |
Args: | |
input_data (str or dict): The input data in JSON format. | |
mode (str): The mode of operation ("scoring" or "generative"). | |
max_length (int): The maximum length of the input prompt. | |
max_new_tokens (int): The maximum number of new tokens to generate. | |
model_size (str): The size of the model to be used. | |
Returns: | |
str: The output data in JSON format. | |
""" | |
with lock: | |
try: | |
result = client.predict( | |
input_data=input_data, | |
output_data=output_data, | |
mode=mode, | |
max_length=max_length, | |
max_new_tokens=max_new_tokens, | |
model_size=model_size, | |
api_name="/my_inference_function" | |
) | |
print(result) | |
print("entering return",result) | |
return result # Pretty-print the JSON | |
except json.JSONDecodeError: | |
return json.dumps({"error": "Invalid JSON input"}) | |
except KeyError: | |
return json.dumps({"error": "Missing 'input' key in JSON"}) | |
except ValueError as e: | |
return json.dumps({"error": str(e)}) | |
with gr.Blocks() as demo: | |
gr.Markdown("## LLM Safety Evaluation") | |
with gr.Tab("ShieldGemma2"): | |
input_text = gr.Textbox(label="Input Text") | |
output_text = gr.Textbox( | |
label="Response Text", | |
lines=5, | |
max_lines=10, | |
show_copy_button=True, | |
elem_classes=["wrap-text"] | |
) | |
mode_input = gr.Dropdown(choices=["scoring", "generative"], label="Prediction Mode") | |
max_length_input = gr.Number(label="Max Length", value=150) | |
max_new_tokens_input = gr.Number(label="Max New Tokens", value=1024) | |
model_size_input = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size") | |
response_text = gr.Textbox( | |
label="Output Text", | |
lines=10, | |
max_lines=20, | |
show_copy_button=True, | |
elem_classes=["wrap-text"] | |
) | |
text_button = gr.Button("Submit") | |
text_button.click(fn=my_inference_function, inputs=[input_text, output_text, mode_input, max_length_input, max_new_tokens_input, model_size_input], outputs=response_text) | |
# with gr.Tab("API Input"): | |
# api_input = gr.JSON(label="Input JSON") | |
# mode_input_api = gr.Dropdown(choices=["scoring", "generative"], label="Mode") | |
# max_length_input_api = gr.Number(label="Max Length", value=150) | |
# max_new_tokens_input_api = gr.Number(label="Max New Tokens", value=None) | |
# model_size_input_api = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size") | |
# api_output = gr.JSON(label="Output JSON") | |
# api_button = gr.Button("Submit") | |
# api_button.click(fn=my_inference_function, inputs=[api_input, api_output,mode_input_api, max_length_input_api, max_new_tokens_input_api, model_size_input_api], outputs=api_output) | |
demo.launch(share=True) | |