|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import os |
|
|
|
|
|
model_dir = "models/Miwa-Keita/zenz-v2.5-medium" |
|
|
|
|
|
snapshot_download( |
|
repo_id="Miwa-Keita/zenz-v2.5-medium", |
|
local_dir=model_dir, |
|
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.model"], |
|
ignore_patterns=["optimizer.pt", "checkpoint*"], |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_dir, |
|
torch_dtype=torch.float32, |
|
use_safetensors=True |
|
) |
|
|
|
|
|
def preprocess_input(user_input): |
|
prefix = "\uEE00" |
|
suffix = "\uEE01" |
|
processed_input = prefix + user_input + suffix |
|
return processed_input |
|
|
|
|
|
def postprocess_output(model_output): |
|
suffix = "\uEE01" |
|
|
|
if suffix in model_output: |
|
return model_output.split(suffix)[1] |
|
return model_output |
|
|
|
|
|
def generate_text(user_input): |
|
processed_input = preprocess_input(user_input) |
|
|
|
|
|
inputs = tokenizer(processed_input, return_tensors="pt") |
|
|
|
|
|
outputs = model.generate(**inputs, max_length=100) |
|
|
|
|
|
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
return postprocess_output(decoded_output) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_text, |
|
inputs=gr.Textbox(label="変換する文字列(カタカナ)"), |
|
outputs=gr.Textbox(label="変換結果"), |
|
title="ニューラルかな漢字変換モデル zenz-v2.5 のデモ", |
|
description="変換したい文字列をカタカナを入力してください" |
|
) |
|
|
|
|
|
iface.launch() |