|
|
|
|
|
from datasets import load_dataset |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer, AutoModel, pipeline |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
import gradio as gr |
|
|
import re |
|
|
|
|
|
|
|
|
MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" |
|
|
SUMMARIZER_NAME = "facebook/bart-large-cnn" |
|
|
DATASET_NAME = "bookcorpus" |
|
|
CACHE_DIR = "./data-cache" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModel.from_pretrained(MODEL_NAME) |
|
|
summarizer = pipeline("summarization", SUMMARIZER_NAME) |
|
|
|
|
|
|
|
|
def load_books(): |
|
|
dataset = load_dataset(DATASET_NAME, split='train', streaming=True) |
|
|
books = [] |
|
|
for book in dataset.take(50000): |
|
|
text = book['text'].strip() |
|
|
if len(text) > 500: |
|
|
title = re.findall(r'"([^"]*)"', text[:200]) |
|
|
books.append({ |
|
|
"text": text, |
|
|
"title": title[0] if title else "Untitled Book" |
|
|
}) |
|
|
return books |
|
|
|
|
|
|
|
|
def get_embeddings(texts): |
|
|
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
embeddings = mean_pooling(outputs, inputs['attention_mask']) |
|
|
return F.normalize(embeddings, p=2, dim=1) |
|
|
|
|
|
|
|
|
def mean_pooling(model_output, attention_mask): |
|
|
token_embeddings = model_output.last_hidden_state |
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
return torch.sum(token_embedding * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
|
|
|
|
|
def generate_summary(text): |
|
|
inputs = tokenizer( |
|
|
"summarize: " + text, |
|
|
max_length=1024, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
summary_ids = summarizer.model.generate( |
|
|
inputs.input_ids, |
|
|
max_length=150, |
|
|
min_length=50, |
|
|
length_penalty=2.0, |
|
|
num_beams=4, |
|
|
early_stopping=True |
|
|
) |
|
|
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
def recommend_books(keywords, top_k=5): |
|
|
|
|
|
keywords = re.sub(r'[^\w\s,]', '', keywords).lower() |
|
|
keywords = [k.strip() for k in keywords.split(',') if k.strip()] |
|
|
|
|
|
if len(keywords) < 2: |
|
|
return "❗ Please enter at least 2 keywords (e.g. 'fantasy, magic')" |
|
|
|
|
|
|
|
|
keyword_emb = get_embeddings([" ".join(keywords)]).mean(dim=0) |
|
|
book_embs = get_embeddings([f"{b['title']} {b['text']}" for b in books]) |
|
|
|
|
|
|
|
|
sim_scores = cosine_similarity(keyword_emb.reshape(1,-1), book_embs)[0] |
|
|
top_indices = np.argsort(sim_scores)[-top_k:][::-1] |
|
|
|
|
|
|
|
|
results = [] |
|
|
for idx in top_indices: |
|
|
book = books[idx] |
|
|
summary = generate_summary(book['text']) |
|
|
results.append({ |
|
|
"title": book['title'], |
|
|
"summary": summary, |
|
|
"score": f"{sim_scores[idx]:.2f}" |
|
|
}) |
|
|
return results |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 📚 智能图书推荐系统") |
|
|
|
|
|
with gr.Row(): |
|
|
inputs = gr.Textbox(label="输入关键词(用逗号分隔)", placeholder="例如:sci-fi, time travel") |
|
|
outputs = gr.JSON(label="推荐结果") |
|
|
|
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
["romance, paris"], |
|
|
["mystery, detective"], |
|
|
["science fiction, space opera"] |
|
|
], |
|
|
inputs=[inputs] |
|
|
) |
|
|
|
|
|
inputs.submit( |
|
|
fn=recommend_books, |
|
|
inputs=inputs, |
|
|
outputs=outputs |
|
|
) |
|
|
|
|
|
|
|
|
print("Loading book data...") |
|
|
books = load_books() |
|
|
print(f"Loaded {len(books)} books") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |