|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
from typing import List, Dict, Optional, Union |
|
import logging |
|
from enum import Enum, auto |
|
import torch |
|
from transformers import AutoTokenizer, pipeline |
|
import spaces |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
LOCAL = "local" |
|
INFERENCE_API = "api" |
|
|
|
|
|
TEXT_GENERATION_MODELS = [ |
|
{ |
|
"name": "Llama-2-7b-chat-hf", |
|
"description": "Llama-2-7b-chat-hf", |
|
"chat_model": True, |
|
"type": INFERENCE_API, |
|
"model_id": "meta-llama/Llama-2-7b-chat-hf" |
|
}, |
|
{ |
|
"name": "TinyLlama-1.1B-Chat-v1.0", |
|
"description": "TinyLlama-1.1B-Chat-v1.0", |
|
"chat_model": True, |
|
"type": INFERENCE_API, |
|
"model_id": "tinyllama/TinyLlama-1.1B-Chat-v1.0" |
|
}, |
|
{ |
|
"name": "Mistral-7B-v0.1", |
|
"description": "Mistral-7B-v0.1", |
|
"chat_model": False, |
|
"type": LOCAL, |
|
"model_path": "mistralai/Mistral-7B-v0.1" |
|
} |
|
] |
|
|
|
CLASSIFICATION_MODELS = [ |
|
{ |
|
"name": "Toxic-BERT", |
|
"description": "Fine-tuned for toxic content detection", |
|
"type": LOCAL, |
|
"model_path": "unitary/toxic-bert" |
|
} |
|
] |
|
|
|
|
|
tokenizers = {} |
|
pipelines = {} |
|
api_clients = {} |
|
|
|
|
|
def initialize_api_clients(): |
|
"""Inference APIクライアントの初期化""" |
|
for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS: |
|
if model["type"] == INFERENCE_API and "model_id" in model: |
|
logger.info(f"Initializing API client for {model['name']}") |
|
api_clients[model["model_id"]] = InferenceClient( |
|
model["model_id"], |
|
token=True |
|
) |
|
logger.info("API clients initialized") |
|
|
|
|
|
def preload_local_models(): |
|
"""ローカルモデルを事前ロード""" |
|
logger.info("Preloading local models at application startup...") |
|
|
|
|
|
for model in TEXT_GENERATION_MODELS: |
|
if model["type"] == LOCAL and "model_path" in model: |
|
model_path = model["model_path"] |
|
try: |
|
logger.info(f"Preloading text generation model: {model_path}") |
|
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) |
|
pipelines[model_path] = pipeline( |
|
"text-generation", |
|
model=model_path, |
|
tokenizer=tokenizers[model_path], |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
logger.info(f"Model preloaded successfully: {model_path}") |
|
except Exception as e: |
|
logger.error(f"Error preloading model {model_path}: {str(e)}") |
|
|
|
|
|
for model in CLASSIFICATION_MODELS: |
|
if model["type"] == LOCAL and "model_path" in model: |
|
model_path = model["model_path"] |
|
try: |
|
logger.info(f"Preloading classification model: {model_path}") |
|
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) |
|
pipelines[model_path] = pipeline( |
|
"text-classification", |
|
model=model_path, |
|
tokenizer=tokenizers[model_path], |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
logger.info(f"Model preloaded successfully: {model_path}") |
|
except Exception as e: |
|
logger.error(f"Error preloading model {model_path}: {str(e)}") |
|
|
|
@spaces.GPU |
|
def generate_text_local(model_path, chat_model, text): |
|
"""ローカルモデルでのテキスト生成""" |
|
try: |
|
logger.info(f"Running local text generation with {model_path}") |
|
pipeline = pipelines[model_path] |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pipeline.model = pipeline.model.to(device) |
|
|
|
|
|
if hasattr(pipeline, "device"): |
|
pipeline.device = device |
|
|
|
|
|
device_info = next(pipeline.model.parameters()).device |
|
logger.info(f"Model {model_path} is running on device: {device_info}") |
|
|
|
if chat_model: |
|
outputs = pipeline( |
|
[{"role": "user", "content": text}], |
|
max_new_tokens=40, |
|
do_sample=False, |
|
num_return_sequences=1 |
|
) |
|
else: |
|
outputs = pipeline( |
|
text, |
|
max_new_tokens=40, |
|
do_sample=False, |
|
num_return_sequences=1 |
|
) |
|
|
|
pipeline.model = pipeline.model.to("cpu") |
|
if hasattr(pipeline, "device"): |
|
pipeline.device = torch.device("cpu") |
|
|
|
return outputs[0]["generated_text"] |
|
except Exception as e: |
|
logger.error(f"Error in local text generation with {model_path}: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
def generate_text_api(model_id, chat_model, text): |
|
"""API経由でのテキスト生成""" |
|
try: |
|
logger.info(f"Running API text generation with {model_id}") |
|
if chat_model: |
|
response = api_clients[model_id].chat.completions.create( |
|
messages=[{"role": "user", "content": text}], |
|
max_tokens=512 |
|
) |
|
response = response.choices[0].message.content |
|
else: |
|
response = api_clients[model_id].text_generation( |
|
text, |
|
max_new_tokens=40, |
|
temperature=0.7) |
|
return response |
|
except Exception as e: |
|
logger.error(f"Error in API text generation with {model_id}: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
@spaces.GPU |
|
def classify_text_local(model_path, text): |
|
"""ローカルモデルでのテキスト分類""" |
|
try: |
|
logger.info(f"Running local classification with {model_path}") |
|
pipeline = pipelines[model_path] |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pipeline.model = pipeline.model.to(device) |
|
|
|
|
|
if hasattr(pipeline, "device"): |
|
pipeline.device = device |
|
|
|
|
|
device_info = next(pipeline.model.parameters()).device |
|
logger.info(f"Model {model_path} is running on device: {device_info}") |
|
|
|
result = pipeline(text) |
|
|
|
|
|
pipeline.model = pipeline.model.to("cpu") |
|
if hasattr(pipeline, "device"): |
|
pipeline.device = torch.device("cpu") |
|
|
|
return str(result) |
|
except Exception as e: |
|
logger.error(f"Error in local classification with {model_path}: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
def classify_text_api(model_id, text): |
|
"""API経由でのテキスト分類""" |
|
try: |
|
logger.info(f"Running API classification with {model_id}") |
|
response = api_clients[model_id].text_classification(text) |
|
return str(response) |
|
except Exception as e: |
|
logger.error(f"Error in API classification with {model_id}: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
|
|
def handle_invoke(text, selected_types): |
|
"""Invokeボタンのハンドラ""" |
|
results = [] |
|
futures_to_model = {} |
|
|
|
with ThreadPoolExecutor(max_workers=len([x for x in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS if x["type"] in selected_types])) as executor: |
|
futures = [] |
|
|
|
|
|
for model in TEXT_GENERATION_MODELS: |
|
if model["type"] in selected_types: |
|
if model["type"] == LOCAL: |
|
future = executor.submit(generate_text_local, model["model_path"], model["chat_model"], text) |
|
futures.append(future) |
|
futures_to_model[future] = model |
|
else: |
|
future = executor.submit(generate_text_api, model["model_id"], model["chat_model"], text) |
|
futures.append(future) |
|
futures_to_model[future] = model |
|
|
|
|
|
for model in CLASSIFICATION_MODELS: |
|
if model["type"] in selected_types: |
|
if model["type"] == LOCAL: |
|
future = executor.submit(classify_text_local, model["model_path"], text) |
|
futures.append(future) |
|
futures_to_model[future] = model |
|
else: |
|
future = executor.submit(classify_text_api, model["model_id"], text) |
|
futures.append(future) |
|
futures_to_model[future] = model |
|
|
|
|
|
all_models = TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS |
|
results = [""] * len(all_models) |
|
|
|
for future in as_completed(futures): |
|
model = futures_to_model[future] |
|
model_index = all_models.index(model) |
|
results[model_index] = future.result() |
|
|
|
return results |
|
|
|
|
|
def update_model_visibility(selected_types): |
|
"""モデルの表示状態を更新""" |
|
logger.info(f"Updating visibility for types: {selected_types}") |
|
|
|
updates = [] |
|
for model_outputs in [gen_model_outputs, class_model_outputs]: |
|
for output in model_outputs: |
|
visible = output["type"] in selected_types |
|
logger.info(f"Model {output['name']} (type: {output['type']}): visible = {visible}") |
|
updates.append(gr.update(visible=visible)) |
|
return updates |
|
|
|
|
|
def load_models_and_update_ui(): |
|
"""モデルをロードしUIを更新する""" |
|
try: |
|
return gr.update(visible=False), gr.update(visible=True) |
|
except Exception as e: |
|
logger.error(f"Error loading models: {e}") |
|
return gr.update(value=f"Error loading models: {e}"), gr.update(visible=False) |
|
|
|
|
|
def create_model_grid(models): |
|
"""モデルグリッドの作成""" |
|
outputs = [] |
|
with gr.Column() as container: |
|
for i in range(0, len(models), 2): |
|
with gr.Row() as row: |
|
for j in range(min(2, len(models) - i)): |
|
model = models[i + j] |
|
with gr.Column(): |
|
with gr.Group() as group: |
|
gr.Markdown(f"### {model['name']}") |
|
gr.Markdown(f"Type: {model['type']}") |
|
output = gr.Textbox( |
|
label="Model Output", |
|
lines=5, |
|
interactive=False, |
|
info=model['description'] |
|
) |
|
outputs.append({ |
|
"type": model["type"], |
|
"name": model["name"], |
|
"output": output, |
|
"group": group |
|
}) |
|
return outputs |
|
|
|
|
|
input_text = None |
|
filter_checkboxes = None |
|
invoke_button = None |
|
gen_model_outputs = [] |
|
class_model_outputs = [] |
|
community_output = None |
|
|
|
|
|
def create_ui(): |
|
"""UIの作成""" |
|
global input_text, filter_checkboxes, invoke_button, gen_model_outputs, class_model_outputs, community_output |
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Group(visible=True) as loading_group: |
|
gr.Markdown(""" |
|
# Toxic Eye |
|
|
|
### Loading models... This may take a few minutes. |
|
|
|
The application is initializing and preloading all models. |
|
Please wait while the models are being loaded... |
|
""") |
|
|
|
|
|
with gr.Group(visible=False) as main_ui_group: |
|
|
|
gr.Markdown(""" |
|
# Toxic Eye |
|
This system evaluates the toxicity level of input text using multiple approaches. |
|
""") |
|
|
|
|
|
with gr.Row(): |
|
input_text = gr.Textbox( |
|
label="Input Text", |
|
placeholder="Enter text to analyze...", |
|
lines=3 |
|
) |
|
|
|
|
|
with gr.Row(): |
|
filter_checkboxes = gr.CheckboxGroup( |
|
choices=[LOCAL, INFERENCE_API], |
|
value=[LOCAL, INFERENCE_API], |
|
label="Filter Models", |
|
info="Choose which types of models to display", |
|
interactive=True |
|
) |
|
|
|
|
|
with gr.Row(): |
|
invoke_button = gr.Button( |
|
"Invoke Selected Models", |
|
variant="primary", |
|
size="lg" |
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Text Generation LLM"): |
|
gen_model_outputs = create_model_grid(TEXT_GENERATION_MODELS) |
|
with gr.Tab("Classification LLM"): |
|
class_model_outputs = create_model_grid(CLASSIFICATION_MODELS) |
|
with gr.Tab("Community (Not implemented)"): |
|
with gr.Column(): |
|
community_output = gr.Textbox( |
|
label="Related Community Topics", |
|
lines=5, |
|
interactive=False |
|
) |
|
|
|
|
|
filter_checkboxes.change( |
|
fn=update_model_visibility, |
|
inputs=[filter_checkboxes], |
|
outputs=[ |
|
output["group"] |
|
for outputs in [gen_model_outputs, class_model_outputs] |
|
for output in outputs |
|
] |
|
) |
|
|
|
invoke_button.click( |
|
fn=handle_invoke, |
|
inputs=[input_text, filter_checkboxes], |
|
outputs=[ |
|
output["output"] |
|
for outputs in [gen_model_outputs, class_model_outputs] |
|
for output in outputs |
|
] |
|
) |
|
|
|
|
|
demo.load( |
|
fn=load_models_and_update_ui, |
|
inputs=None, |
|
outputs=[loading_group, main_ui_group] |
|
) |
|
|
|
return demo |
|
|
|
|
|
def main(): |
|
logger.info("Starting Toxic Eye application") |
|
initialize_api_clients() |
|
|
|
preload_local_models() |
|
logger.info("Models loaded successfully") |
|
demo = create_ui() |
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |