BuckLakeAI / preprocess.py
parkerjj's picture
更新 Dockerfile,安装 Spacy 及其模型;在 preprocess.py 中添加模型下载处理
8bf0955
raw
history blame
20.3 kB
import re
import sys
import os
import numpy as np
from collections import defaultdict
import pandas as pd
import time
# 如果使用 spaCy 进行 NLP 处理
import spacy
# 如果使用某种情感分析工具,比如 Hugging Face 的模型
from transformers import pipeline
# 还需要导入 pickle 模块(如果你在代码的其他部分使用了它来处理序列化/反序列化)
import pickle
from gensim.models import KeyedVectors
import akshare as ak
from gensim.models import Word2Vec
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from us_stock import *
# 强制使用 GPU
#spacy.require_gpu()
# 加载模型
try:
nlp = spacy.load("en_core_web_md")
except OSError:
print("Downloading model 'en_core_web_md'...")
from spacy.cli import download
download("en_core_web_md")
nlp = spacy.load("en_core_web_md")
# 检查是否使用 GPU
print("Is NPL GPU used Preprocessing.py:", spacy.prefer_gpu())
# 使用合适的模型和tokenizer
model_name = "ProsusAI/finbert" # 选择合适的预训练模型
tokenizer = AutoTokenizer.from_pretrained(model_name)
sa_model = AutoModelForSequenceClassification.from_pretrained(model_name)
# 初始化情感分析器
sentiment_analyzer = pipeline('sentiment-analysis', model=sa_model, tokenizer=tokenizer)
index_us_stock_index_INX = ak.index_us_stock_sina(symbol=".INX")
index_us_stock_index_DJI = ak.index_us_stock_sina(symbol=".DJI")
index_us_stock_index_IXIC = ak.index_us_stock_sina(symbol=".IXIC")
index_us_stock_index_NDX = ak.index_us_stock_sina(symbol=".NDX")
class LazyWord2Vec:
def __init__(self, model_path):
self.model_path = model_path
self._model = None
@property
def model(self):
if self._model is None:
print("Loading Word2Vec model...")
self._model = KeyedVectors.load(self.model_path, mmap='r')
return self._model
@property
def vector_size(self):
self.load_model()
return self.model.vector_size # 现在你可以正确访问 vector_size 属性
def __getitem__(self, key):
return self.model[key]
def __contains__(self, key):
return key in self.model
# 加载预训练的 Google News Word2Vec 模型
# 定义路径列表
search_paths = ["/BuckLake/Model/",
"/Users/parker/Development/Server/BuckLake/Model/",
"/Users/liuyue/Work/BuckLake/Model/"]
# 获取当前文件所在目录的路径
current_directory = os.getcwd()
print(f"Current directory: {current_directory}")
current_directory = os.path.dirname(os.path.abspath(__file__))
# 添加相对于当前项目的路径
# search_paths.insert(0, os.path.join(current_directory, 'model'))
search_paths.insert(1, os.path.join(current_directory, '..', 'Model'))
# 定义相对路径
filename = 'word2vec-google-news-300.model'
# 初始化word2vec_path为None
word2vec_path = None
# 遍历路径列表
for path in search_paths:
potential_path = os.path.join(path, filename)
if os.path.exists(potential_path):
word2vec_path = potential_path
break
else:
print(f"{potential_path} not found.")
# 如果找到路径,加载模型
if word2vec_path:
print(f"Loading Word2Vec model from {word2vec_path}...")
word2vec_model = LazyWord2Vec(word2vec_path)
else:
raise FileNotFoundError(f"{filename} not found in any of the search paths: {search_paths}")
def pos_tagging(text):
try:
doc = nlp(text)
tokens, pos_tags, tags = [], [], []
for token in doc:
if token.is_punct or token.is_stop:
continue
tokens.append(token.text)
pos_tags.append(token.pos_)
tags.append(token.tag_)
except Exception as e:
print(f"Error in pos_tagging for text: {text[:50]}... Error: {str(e)}")
return "", "", ""
return tokens, pos_tags, tags
# 命名实体识别函数
def named_entity_recognition(text):
try:
doc = nlp(text)
entities = [(ent.text, ent.label_) for ent in doc.ents]
except Exception as e:
print(f"Error in named_entity_recognition for text: {text[:50]}... Error: {str(e)}")
entities = []
return entities or [("", "")]
# 处理命名实体识别结果
def process_entities(entities):
entity_counts = defaultdict(int)
try:
for entity in entities:
etype = entity[1] # 取出实体类型
entity_counts[etype] += 1 # 直接对实体类型进行计数
# 将字典转化为有序的数组
entity_types = sorted(entity_counts.keys())
counts = np.array([entity_counts[etype] for etype in entity_types])
except Exception as e:
print(f"Error in process_entities: {str(e)}")
counts = np.zeros(len(entities))
entity_types = []
return counts, entity_types
# 处理词性标注结果
def process_pos_tags(pos_tags):
pos_counts = defaultdict(int)
try:
for pos in pos_tags:
pos_counts[pos[1]] += 1 # 使用POS标签(如NN, VB等)
# 将字典转化为有序的数组
pos_types = sorted(pos_counts.keys())
counts = np.array([pos_counts[pos] for pos in pos_types])
except Exception as e:
print(f"Error in process_pos_tags: {str(e)}")
counts = np.zeros(len(pos_tags))
pos_types = []
return counts, pos_types
# 函数:获取文档向量
def get_document_vector(words, model = word2vec_model):
try:
# 获取每个词的词向量,如果词不在模型中则跳过
word_vectors = [model[word] for word in words if word in model]
# 对词向量进行平均,得到文档向量;如果没有词在模型中则返回零向量
document_vector = np.mean(word_vectors, axis=0) if word_vectors else np.zeros(model.vector_size)
except Exception as e:
print(f"Error in get_document_vector for words: {words[:5]}... Error: {str(e)}")
document_vector = np.zeros(model.vector_size)
return document_vector
# 函数:获取情感得分
def get_sentiment_score(text):
try:
# 直接将原始文本传递给 sentiment_analyzer,它会自动处理 tokenization
result = sentiment_analyzer(text, truncation=True, max_length=512)[0]
score = result['score'] if result['label'] == 'positive' else -result['score']
except Exception as e:
print(f"Error in get_sentiment_score for text: {text[:50]}... Error: {str(e)}")
score = 0.0
return score
def get_stock_info(stock_codes, history_days=30):
# 获取股票代码和新闻日期
stock_codes = stock_codes
news_date = datetime.now().strftime('%Y%m%d')
# print(f"Getting stock info for {stock_codes} on {news_date}")
previous_stock_history = []
following_stock_history = []
previous_stock_inx_index_history = []
previous_stock_dj_index_history = []
previous_stock_ixic_index_history = []
previous_stock_ndx_index_history = []
following_stock_inx_index_history = []
following_stock_dj_index_history = []
following_stock_ixic_index_history = []
following_stock_ndx_index_history = []
def process_history(stock_history, target_date, history_days=history_days, following_days = 3):
# 如果数据为空,创建一个空的 DataFrame 并填充为 0
if stock_history.empty:
empty_data_previous = pd.DataFrame({
'开盘': [-1] * history_days,
'收盘': [-1] * history_days,
'最高': [-1] * history_days,
'最低': [-1] * history_days,
'成交量': [-1] * history_days,
'成交额': [-1] * history_days
})
empty_data_following = pd.DataFrame({
'开盘': [-1] * following_days,
'收盘': [-1] * following_days,
'最高': [-1] * following_days,
'最低': [-1] * following_days,
'成交量': [-1] * following_days,
'成交额': [-1] * following_days
})
return empty_data_previous, empty_data_following
# 确保 'date' 列存在
if 'date' not in stock_history.columns:
print(f"'date' column not found in stock history. Returning empty data.")
return pd.DataFrame([[-1] * 6] * history_days), pd.DataFrame([[-1] * 6] * following_days)
# 将日期转换为 datetime 格式,便于比较
stock_history['date'] = pd.to_datetime(stock_history['date'])
target_date = pd.to_datetime(target_date)
# 找到目标日期的索引
target_row = stock_history[stock_history['date'] == target_date]
if target_row.empty:
# 如果目标日期找不到,找到离目标日期最近的日期
closest_date_index = (stock_history['date'] - target_date).abs().idxmin()
target_date = stock_history.loc[closest_date_index, 'date']
target_row = stock_history[stock_history['date'] == target_date]
# 确保找到的目标日期有数据
if target_row.empty:
return pd.DataFrame([[-1] * 6] * history_days), pd.DataFrame([[-1] * 6] * following_days)
target_index = target_row.index[0]
target_pos = stock_history.index.get_loc(target_index)
# 取出目标日期及其前history_days条记录
previous_rows = stock_history.iloc[max(0, target_pos - history_days):target_pos + 1]
# 取出目标日期及其后3条记录
following_rows = stock_history.iloc[target_pos + 1:target_pos + 4]
# 删除日期列
previous_rows = previous_rows.drop(columns=['date'])
following_rows = following_rows.drop(columns=['date'])
# 如果 previous_rows 或 following_rows 的行数不足 history_days,则填充至 history_days 行
if len(previous_rows) < history_days:
previous_rows = previous_rows.reindex(range(history_days), fill_value=-1)
if len(following_rows) < 3:
following_rows = following_rows.reindex(range(3), fill_value=-1)
# 只返回前history_days行,并只返回前6列(开盘、收盘、最高、最低、成交量、成交额)
previous_rows = previous_rows.iloc[:history_days, :6]
following_rows = following_rows.iloc[:following_days, :6]
return previous_rows, following_rows
if not stock_codes or stock_codes == ['']:
# 如果 stock_codes 为空,直接获取并返回大盘数据
stock_index_ndx_history = get_stock_index_history("", news_date, 1)
stock_index_dj_history = get_stock_index_history("", news_date, 2)
stock_index_inx_history = get_stock_index_history("", news_date, 3)
stock_index_ixic_history = get_stock_index_history("", news_date, 4)
previous_ndx_rows, following_ndx_rows = process_history(stock_index_ndx_history, news_date, history_days)
previous_dj_rows, following_dj_rows = process_history(stock_index_dj_history, news_date, history_days)
previous_inx_rows, following_inx_rows = process_history(stock_index_inx_history, news_date, history_days)
previous_ixic_rows, following_ixic_rows = process_history(stock_index_ixic_history, news_date, history_days)
previous_stock_inx_index_history.append(previous_inx_rows.values.tolist())
previous_stock_dj_index_history.append(previous_dj_rows.values.tolist())
previous_stock_ixic_index_history.append(previous_ixic_rows.values.tolist())
previous_stock_ndx_index_history.append(previous_ndx_rows.values.tolist())
following_stock_inx_index_history.append(following_inx_rows.values.tolist())
following_stock_dj_index_history.append(following_dj_rows.values.tolist())
following_stock_ixic_index_history.append(following_ixic_rows.values.tolist())
following_stock_ndx_index_history.append(following_ndx_rows.values.tolist())
# 个股补零逻辑
previous_stock_history.append([[-1] * 6] * history_days)
following_stock_history.append([[-1] * 6] * 3)
else:
for stock_code in stock_codes:
stock_code = stock_code.strip()
stock_history = get_stock_history(stock_code, news_date)
# 处理个股数据
previous_rows, following_rows = process_history(stock_history, news_date)
previous_stock_history.append(previous_rows.values.tolist())
following_stock_history.append(following_rows.values.tolist())
# 处理大盘数据
stock_index_ndx_history = get_stock_index_history("", news_date, 1)
stock_index_dj_history = get_stock_index_history("", news_date, 2)
stock_index_inx_history = get_stock_index_history("", news_date, 3)
stock_index_ixic_history = get_stock_index_history("", news_date, 4)
previous_ndx_rows, following_ndx_rows = process_history(stock_index_ndx_history, news_date, history_days)
previous_dj_rows, following_dj_rows = process_history(stock_index_dj_history, news_date, history_days)
previous_inx_rows, following_inx_rows = process_history(stock_index_inx_history, news_date, history_days)
previous_ixic_rows, following_ixic_rows = process_history(stock_index_ixic_history, news_date, history_days)
previous_stock_inx_index_history.append(previous_inx_rows.values.tolist())
previous_stock_dj_index_history.append(previous_dj_rows.values.tolist())
previous_stock_ixic_index_history.append(previous_ixic_rows.values.tolist())
previous_stock_ndx_index_history.append(previous_ndx_rows.values.tolist())
following_stock_inx_index_history.append(following_inx_rows.values.tolist())
following_stock_dj_index_history.append(following_dj_rows.values.tolist())
following_stock_ixic_index_history.append(following_ixic_rows.values.tolist())
following_stock_ndx_index_history.append(following_ndx_rows.values.tolist())
# 只返回第一支股票的数据
break
return previous_stock_history, following_stock_history, \
previous_stock_inx_index_history, previous_stock_dj_index_history, previous_stock_ixic_index_history, previous_stock_ndx_index_history, \
following_stock_inx_index_history, following_stock_dj_index_history, following_stock_ixic_index_history, following_stock_ndx_index_history,
def lemmatized_entry(entry):
entry_start_time = time.time()
# Step 1 - 条目聚合
lemmatized_text = preprocessing_entry(entry)
return lemmatized_text
# 1. 数据清理
# 1.1 合并数据
# 1.2 去除噪声
# 1.3 大小写转换
# 1.4 去除停用词
# 1.5 词汇矫正与拼写检查
# 1.6 词干提取与词形还原
# 强制使用 GPU
# spacy.require_gpu()
# 加载模型
nlp = spacy.load("en_core_web_md")
# 检查是否使用 GPU
print("Is NPL GPU used Lemmatized:", spacy.prefer_gpu())
def preprocessing_entry(news_entry):
"""数据清理启动函数
Args:
text (str): preprocessing后的文本
Returns:
[str]]: 词干提取后的String列表
"""
# 1.1 合并数据
text = merge_text(news_entry)
# 1.2 去除噪声
text = disposal_noise(text)
# 1.3 大小写转换
text = text.lower()
# 1.4 去除停用词
text = remove_stopwords(text)
# 1.5 拼写检查
#text = correct_spelling(text)
#print(f"1.5 拼写检查后的文本:{text}")
# 1.6 词干提取与词形还原
lemmatized_text_list = lemmatize_text(text)
#print(f"1.6 词干提取与词形还原后的文本:{lemmatized_text_list}")
return lemmatized_text_list
# 1.1 合并数据
def merge_text(news_entry):
return news_entry
# 1.2 去除噪声
def disposal_noise(text):
# 移除HTML标签
text = re.sub(r'<.*?>', '', text)
# 移除URLs
text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
# 移除方括号内的内容
# text = re.sub(r'\[.*?\]', '', text)
# 移除标点符号
# text = re.sub(r'[^\w\s]', '', text)
# 移除多余的空格
text = re.sub(r'\s+', ' ', text).strip()
# 或者选择性地过滤,例如移除表情符号
# text = re.sub(r'[^\w\s.,!?]', '', text)
# 移除换行符和制表符
text = re.sub(r'[\n\t\r]', ' ', text)
return text
# 1.4 去除停用词
def remove_stopwords(text):
# 使用 spaCy 处理文本
doc = nlp(text)
# 去除停用词,并且仅保留标识为“词”(Token.is_alpha)类型的标记
filtered_sentence = [token.text for token in doc if not token.is_stop and (token.is_alpha or token.like_num)]
return ' '.join(filtered_sentence)
# 1.5 拼写检查
# 该函数用于检查输入文本的拼写错误,并修正
# def correct_spelling(text):
# corrected_text = []
# doc = nlp(text)
# for token in doc:
# if token.is_alpha: # 仅检查字母构成的单词
# corrected_word = spell.correction(token.text)
# if corrected_word is None:
# # 如果拼写检查没有建议,保留原始单词
# corrected_word = token.text
# corrected_text.append(corrected_word)
# else:
# corrected_text.append(token.text)
# return " ".join(corrected_text)
# 1.6 词干提取与词形还原
# 该函数用于对输入文本进行词形还原,返回一个包含词形还原后单词
def lemmatize_text(text):
# 提取词干化后的词
lemmatized_words = []
doc = nlp(text) # 需要在这里处理输入文本
for token in doc:
# 忽略标点符号和空格
if not token.is_punct and not token.is_space and (token.is_alpha or token.like_num):
lemmatized_words.append(token.lemma_)
return lemmatized_words
# 2. 数据增强和特征提取
# 2.1 词性标注(Part-of-Speech Tagging)
# 为每个词标注其词性(如名词、动词、形容词等),这有助于后续的句法分析和信息提取。
# 工具:spaCy 或 NLTK
# 2.2 命名实体识别(NER)
# 识别文本中的命名实体,如人名、地名、组织机构等,提取出这些实体信息。
# 工具:spaCy 或 Stanford NER
# 2.3 句法分析与依存分析
# 分析句子结构,理解单词之间的关系(如主谓宾结构)。
# 工具:spaCy 或 NLTK
# 2 特征提取
# 强制使用 GPU
#spacy.require_gpu()
# 加载模型
nlp = spacy.load("en_core_web_md")
# 检查是否使用 GPU
print("Is NPL GPU used Enchance_text.py:", spacy.prefer_gpu())
# 2.3 句法分析与依存分析
def dependency_parsing(text):
doc = nlp(text)
dependencies = []
for token in doc:
# 过滤标点符号和停用词,或其他不需要的词性
if token.is_punct or token.is_stop:
continue
# 可以进一步根据特定的依存关系类型过滤结果
# 常见的依存关系类型: 'nsubj' (名词主语), 'dobj' (直接宾语), 等等
# if token.dep_ not in {'nsubj', 'dobj', ...}:
# continue
dependencies.append((token.text, token.dep_, token.head.text))
return dependencies
def processing_entry(entry):
# print(f"processing_entry: {entry}")
lemmatized_entry = preprocessing_entry(entry)
# print(f"lemmatized_entry: {lemmatized_entry}")
cleaned_text = disposal_noise(entry)
# print(f"disposal_noise: {cleaned_text}")
pos_tag = pos_tagging(cleaned_text)
# print(f"pos_tagging: {db_pos_tag}")
ner = named_entity_recognition(cleaned_text)
# print(f"named_entity_recognition: {db_ner}")
dependency_parsed = dependency_parsing(cleaned_text)
# print(f"dependency_parsing: {db_dependency_parsing}")
sentiment_score = get_sentiment_score(cleaned_text)
# print(f"sentiment_score: {sentiment_score}")
return (lemmatized_entry, pos_tag, ner, dependency_parsed, sentiment_score)