Spaces:
Build error
Build error
File size: 6,736 Bytes
b8522d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
from datasets import load_dataset
import pandas as pd
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForQuestionAnswering, BertConfig
from pymongo import MongoClient
class Database:
@staticmethod
def get_mongodb():
# MongoDB bağlantı bilgilerini döndürecek şekilde tanımlanmalıdır.
return 'mongodb://localhost:27017/', 'yeniDatabase', 'train'
@staticmethod
def get_input_texts():
# MongoDB bağlantı bilgilerini alma
mongo_url, db_name, collection_name = Database.get_mongodb()
# MongoDB'ye bağlanma
client = MongoClient(mongo_url)
db = client[db_name]
collection = db[collection_name]
# Sorguyu tanımlama
query = {"Prompt": {"$exists": True}}
# Sorguyu çalıştırma ve dökümanları çekme
cursor = collection.find(query, {"Prompt": 1, "_id": 0})
# Cursor'ı döküman listesine dönüştürme
input_texts_from_db = list(cursor)
# Input text'leri döndürme
return input_texts_from_db
@staticmethod
def get_output_texts():
# MongoDB bağlantı bilgilerini alma
mongo_url, db_name, collection_name = Database.get_mongodb()
# MongoDB'ye bağlanma
client = MongoClient(mongo_url)
db = client[db_name]
collection = db[collection_name]
# Sorguyu tanımlama
query = {"Response": {"$exists": True}}
# Sorguyu çalıştırma ve dökümanları çekme
cursor = collection.find(query, {"Response": 1, "_id": 0})
# Cursor'ı döküman listesine dönüştürme
output_texts_from_db = list(cursor)
# Input text'leri döndürme
return output_texts_from_db
@staticmethod
def get_average_prompt_token_length():
# MongoDB bağlantı bilgilerini alma
mongo_url, db_name, collection_name = Database.get_mongodb()
# MongoDB'ye bağlanma
client = MongoClient(mongo_url)
db = client[db_name]
collection = db[collection_name]
# Tüm dökümanları çekme ve 'prompt_token_length' alanını alma
docs = collection.find({}, {'Prompt_token_length': 1})
# 'prompt_token_length' değerlerini toplama ve sayma
total_length = 0
count = 0
for doc in docs:
if 'Prompt_token_length' in doc:
total_length += doc['Prompt_token_length']
count += 1
# Ortalama hesaplama
average_length = total_length / count if count > 0 else 0
return int(average_length)
# Tokenizer ve Modeli yükleme
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Encode işlemi
def tokenize_and_encode(train_df,doc):
input_texts_from_db = Database.get_input_texts()
output_texts_from_db= Database.get_output_texts()
input_texts = [doc["Prompt"] for doc in input_texts_from_db]
output_texts= [doc["Response"] for doc in output_texts_from_db]
encoded = tokenizer.batch_encode_plus(
#doc['Prompt'].tolist(),
#text_pair= doc['Response'].tolist(),
input_texts,
output_texts,
padding=True,
truncation=True,
max_length=100,
return_attention_mask=True,
return_tensors='pt'
)
return encoded
encoded_data=tokenize_and_encode()
class QA:
#buradaki verilerin değeri değiştirilmeli
def __init__(self, model_path: str):
self.max_seq_length = 384
self.doc_stride = 128
self.do_lower_case = False
self.max_query_length = 30
self.n_best_size = 3
self.max_answer_length = 30
self.version_2_with_negative = False
self.model, self.tokenizer = self.load_model(model_path)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
self.model.eval()
def load_model(self, model_path: str, do_lower_case=False):
config = BertConfig.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=do_lower_case)
model = BertForQuestionAnswering.from_pretrained(model_path, from_tf=False, config=config)
return model, tokenizer
def extract_features_from_dataset(self, train_df):
def get_max_length(examples):
return {
'max_seq_length': max(len(e) for e in examples),
'max_query_length': max(len(q) for q in examples)
}
# Örnek bir kullanım
features = get_max_length(train_df)
return features
# Ortalama prompt token uzunluğunu al ve yazdır
average_length = Database.get_average_prompt_token_length()
print(f"Ortalama prompt token uzunluğu: {average_length}")
# QA sınıfını oluştur
qa = QA(model_path='bert-base-uncased')
#tensor veri setini koda entegre etme
"""# Tensor veri kümesi oluşturma
input_ids = encoded_data['input_ids']
attention_mask = encoded_data['attention_mask']
token_type_ids = encoded_data['token_type_ids']
labels = torch.tensor(data['Response'].tolist()) # Cevapları etiket olarak kullanın
# TensorDataset oluşturma
dataset = TensorDataset(input_ids, attention_mask, token_type_ids, labels)
# DataLoader oluşturma
batch_size = 16
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset),
batch_size=batch_size
)"""
#modelin için epoch sayısının tanımlaması
"""# Eğitim için optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)
# Eğitim döngüsü
model.train()
for epoch in range(3): # Örnek olarak 3 epoch
for batch in dataloader:
input_ids, attention_mask, token_type_ids, labels = [t.to(device) for t in batch]
optimizer.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, start_positions=labels, end_positions=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1} loss: {loss.item()}")"""
#sonuçların sınıflandırılması
"""# Modeli değerlendirme aşamasına getirme
model.eval()
# Örnek tahmin
with torch.no_grad():
for batch in dataloader:
input_ids, attention_mask, token_type_ids, _ = [t.to(device) for t in batch]
outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
# Çıktıları kullanarak başlık, alt başlık ve anahtar kelimeler belirleyebilirsiniz
""" |