import gradio as gr import torch import requests import re import emoji import nltk import lxml import os from bs4 import BeautifulSoup from markdown import markdown from nltk.corpus import stopwords from datasets import load_dataset from sentence_transformers import SentenceTransformer, util from retry import retry from transformers import pipeline pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-en-es") # 确保已下载 nltk 的停用词 nltk.download('stopwords') # 从环境变量中获取 hf_token hf_token = os.getenv('HF_TOKEN') model_id = "BAAI/bge-large-en-v1.5" feature_extraction_pipeline = pipeline("feature-extraction", model=model_id) # model_id = "BAAI/bge-large-en-v1.5" # api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" # headers = {"Authorization": f"Bearer {hf_token}"} @retry(tries=3, delay=10) def query(texts): # 使用特征提取管道获取特征 features = feature_extraction_pipeline(texts) # 将特征降维成二维张量(如果它们不是) # 假设 features 是一个列表,每个元素是一个句子的特征 embeddings = [torch.tensor(f).mean(dim=0) for f in features] embeddings = torch.stack(embeddings) return embeddings # def query(texts): # response = requests.post(api_url, headers=headers, json={"inputs": texts}) # if response.status_code == 200: # result = response.json() # if isinstance(result, list): # return result # elif 'error' in result: # raise RuntimeError("Error from Hugging Face API: " + result['error']) # else: # raise RuntimeError("Failed to get response from Hugging Face API, status code: " + str(response.status_code)) # 加载嵌入向量数据集 faqs_embeddings_dataset = load_dataset('chenglu/hf-blogs-baai-embeddings') df = faqs_embeddings_dataset["train"].to_pandas() embeddings_array = df.T.to_numpy() dataset_embeddings = torch.from_numpy(embeddings_array).to(torch.float) # 加载原始数据集 original_dataset = load_dataset("chenglu/hf-blogs")['train'] # 定义英语停用词集 stop_words = set(stopwords.words('english')) def remove_stopwords(text): return ' '.join([word for word in text.split() if word.lower() not in stop_words]) def clean_content(content): content = re.sub(r"(```.*?```|`.*?`)", "", content, flags=re.DOTALL) content = BeautifulSoup(content, "html.parser").get_text() content = emoji.replace_emoji(content, replace='') content = re.sub(r"[^a-zA-Z\s]", "", content) content = re.sub(r"http\S+|www\S+|https\S+", '', content, flags=re.MULTILINE) content = markdown(content) content = ''.join(BeautifulSoup(content, 'lxml').findAll(text=True)) content = re.sub(r'\s+', ' ', content) return content def get_tags_for_local(dataset, local_value): entry = next((item for item in dataset if item['local'] == local_value), None) if entry: return entry['tags'] else: return None def gradio_query_interface(input_text): cleaned_text = clean_content(input_text) no_stopwords_text = remove_stopwords(cleaned_text) new_embedding = query(no_stopwords_text) # new_embedding = feature_extraction_pipeline(input_text) query_embeddings = torch.FloatTensor(new_embedding) hits = util.semantic_search(query_embeddings, dataset_embeddings, top_k=5) if all(hit['score'] < 0.6 for hit in hits[0]): return "Content Not related" else: highest_score_result = max(hits[0], key=lambda x: x['score']) highest_score_corpus_id = highest_score_result['corpus_id'] local = df.columns[highest_score_corpus_id] recommended_tags = get_tags_for_local(original_dataset, local) return f"Recommended category tags: {recommended_tags}" iface = gr.Interface( fn=gradio_query_interface, inputs="text", outputs="label" ) iface.launch()