SonyaX20
new
822abee
import gradio as gr
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# 加载预训练模型和分词器
MODEL_NAME = "bert-base-chinese"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
# 加载tnews数据集
dataset = load_dataset("clue", "tnews")
# 数据预处理函数
def preprocess_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
return inputs
# 特征提取函数
def extract_features(text):
inputs = preprocess_text(text)
with torch.no_grad():
outputs = model(**inputs)
# 使用[CLS] token的表示作为特征
cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
return cls_embedding
# 余弦相似度计算
def cosine_similarity(vec1, vec2):
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
# 预定义相似性对比文本
predefined_texts = [
"今天的天气很好,我想去散步。",
"股票市场今天表现不错。",
"人工智能正在改变我们的生活。"
]
predefined_features = [extract_features(text) for text in predefined_texts]
# 绘制降维可视化
def plot_features(features):
# 用 t-SNE 进行降维
tsne = TSNE(n_components=2, random_state=42)
reduced_features = tsne.fit_transform([features] + predefined_features)
colors = ['red'] + ['blue'] * len(predefined_texts)
# 绘制图像
plt.figure(figsize=(8, 6))
for i, point in enumerate(reduced_features):
label = "Input" if i == 0 else f"Text {i}"
plt.scatter(point[0], point[1], c=colors[i], label=label)
plt.legend()
plt.title("Feature Vector Visualization (t-SNE)")
plt.xlabel("Dimension 1")
plt.ylabel("Dimension 2")
plt.grid()
# 保存图像为字符串
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
img_str = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
return f'<img src="data:image/png;base64,{img_str}" />'
# Gradio接口函数
def predict(text):
# 提取特征
features = extract_features(text)
# 计算相似性
similarities = [
(predefined_texts[i], cosine_similarity(features, predefined_features[i]))
for i in range(len(predefined_texts))
]
# 构造相似性结果文本
similarity_text = "\n".join([f"与 \"{t}\" 的相似度: {s:.2f}" for t, s in similarities])
# 降维图
tsne_plot = plot_features(features)
return f"特征维度: {features.shape}\n特征向量(部分展示): {features[:10]}\n\n相似性结果:\n{similarity_text}\n", tsne_plot
# 定义Gradio界面
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="输入中文文本..."),
outputs=[
"text", # 文本输出
"html", # 图像输出
],
title="中文特征提取与分析",
description="基于BERT的中文文本特征提取,支持相似性分析与降维可视化。",
)
# 运行Gradio应用
if __name__ == "__main__":
demo.launch()