Project / app.py
sshenai's picture
Update app.py
bb78694 verified
# app.py
from datasets import load_dataset
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, pipeline
from sklearn.metrics.pairwise import cosine_similarity
import gradio as gr
import re
# 全局配置
MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" # 更强大的语义模型
SUMMARIZER_NAME = "facebook/bart-large-cnn"
DATASET_NAME = "bookcorpus"
CACHE_DIR = "./data-cache"
# 预加载资源
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
summarizer = pipeline("summarization", SUMMARIZER_NAME)
# 加载并预处理书籍数据
def load_books():
dataset = load_dataset(DATASET_NAME, split='train', streaming=True)
books = []
for book in dataset.take(50000): # 取5万本书
text = book['text'].strip()
if len(text) > 500: # 过滤短文本
title = re.findall(r'"([^"]*)"', text[:200]) # 尝试提取标题
books.append({
"text": text,
"title": title[0] if title else "Untitled Book"
})
return books
# 生成语义嵌入
def get_embeddings(texts):
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
with torch.no_grad():
outputs = model(**inputs)
embeddings = mean_pooling(outputs, inputs['attention_mask'])
return F.normalize(embeddings, p=2, dim=1)
# 平均池化
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embedding * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# 智能摘要生成
def generate_summary(text):
inputs = tokenizer(
"summarize: " + text,
max_length=1024,
truncation=True,
return_tensors="pt"
)
summary_ids = summarizer.model.generate(
inputs.input_ids,
max_length=150,
min_length=50,
length_penalty=2.0,
num_beams=4,
early_stopping=True
)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# 核心推荐逻辑
def recommend_books(keywords, top_k=5):
# 清洗输入
keywords = re.sub(r'[^\w\s,]', '', keywords).lower()
keywords = [k.strip() for k in keywords.split(',') if k.strip()]
if len(keywords) < 2:
return "❗ Please enter at least 2 keywords (e.g. 'fantasy, magic')"
# 获取嵌入
keyword_emb = get_embeddings([" ".join(keywords)]).mean(dim=0)
book_embs = get_embeddings([f"{b['title']} {b['text']}" for b in books])
# 计算相似度
sim_scores = cosine_similarity(keyword_emb.reshape(1,-1), book_embs)[0]
top_indices = np.argsort(sim_scores)[-top_k:][::-1]
# 生成结果
results = []
for idx in top_indices:
book = books[idx]
summary = generate_summary(book['text'])
results.append({
"title": book['title'],
"summary": summary,
"score": f"{sim_scores[idx]:.2f}"
})
return results
# Gradio界面
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 📚 智能图书推荐系统")
with gr.Row():
inputs = gr.Textbox(label="输入关键词(用逗号分隔)", placeholder="例如:sci-fi, time travel")
outputs = gr.JSON(label="推荐结果")
examples = gr.Examples(
examples=[
["romance, paris"],
["mystery, detective"],
["science fiction, space opera"]
],
inputs=[inputs]
)
inputs.submit(
fn=recommend_books,
inputs=inputs,
outputs=outputs
)
# 初始化数据
print("Loading book data...")
books = load_books()
print(f"Loaded {len(books)} books")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)