File size: 5,406 Bytes
4c1d173 25ff072 4c1d173 b967ca3 6b023f2 b967ca3 6b023f2 b967ca3 6b023f2 b967ca3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import subprocess
import sys
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
# Install required libraries
for package in ["litellm", "gradio", "datasets", "rank_bm25", "sentence-transformers","typing"]:
try:
__import__(package)
except ImportError:
install(package)
from litellm import completion
import os
os.environ['GROQ_API_KEY'] = "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"
response = completion(
model="groq/llama3-8b-8192",
messages=[
{"role": "user", "content": "hello from litellm"}
],
)
from datasets import load_dataset
dataset = load_dataset("hugginglearners/russia-ukraine-conflict-articles")
docs = [item['articles'] for item in dataset['train'].select(range(10))]
def chunk_document(doc: str, doc_id: int, desired_chunk_size: int = 100, max_chunk_size: int = 3000):
chunk = ''
chunk_number = 0
for line in doc.splitlines():
chunk += line + '\n'
if len(chunk) >= desired_chunk_size:
yield (doc_id, chunk_number, chunk[:max_chunk_size])
chunk = ''
chunk_number += 1
if chunk:
yield (doc_id, chunk_number, chunk)
def chunk_documents(docs: list[str], desired_chunk_size: int = 100, max_chunk_size: int = 3000):
chunks = []
for doc_id, doc in enumerate(docs):
chunks.extend(chunk_document(doc, doc_id, desired_chunk_size, max_chunk_size))
return chunks
#from typing import list
import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
import torch
class Retriever:
def __init__(self, docs: list[str]):
self.chunks = chunk_documents(docs)
self.docs = [chunk[2] for chunk in self.chunks]
tokenized_docs = [doc.lower().split(" ") for doc in self.docs]
self.bm25 = BM25Okapi(tokenized_docs)
self.sbert = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
self.doc_embeddings = self.sbert.encode(self.docs)
def get_docs(self, query, method="bm25", n=3):
if method == "bm25":
scores = self._get_bm25_scores(query)
elif method == "sbert":
scores = self._get_semantic_scores(query)
elif method == "hybrid":
bm25_scores = self._get_bm25_scores(query)
semantic_scores = self._get_semantic_scores(query)
scores = 0.3 * bm25_scores + 0.7 * semantic_scores
else:
raise ValueError("Invalid method. Choose 'bm25', 'sbert', or 'hybrid'.")
sorted_indices = np.argsort(scores)[::-1]
# Повертаємо перші n документів із інформацією про джерело
return [(self.chunks[i][0], self.chunks[i][1], self.docs[i]) for i in sorted_indices[:n]]
def _get_bm25_scores(self, query):
tokenized_query = query.lower().split(" ")
return self.bm25.get_scores(tokenized_query)
def _get_semantic_scores(self, query):
query_embedding = self.sbert.encode(query)
scores = torch.cosine_similarity(
torch.tensor(query_embedding).unsqueeze(0),
torch.tensor(self.doc_embeddings),
dim=1
)
return scores.numpy()
class QuestionAnsweringBot:
PROMPT = '''\
You are a helpful assistant that can answer questions.
Rules:
-Reply with the answer only and nothing but the answer.
-Say 'I don't know(((' if you don't know the answer.
-Use the provided context.
'''
def __init__(self, docs):
self.retriever = Retriever(docs)
def answer_question(self, question: str, method: str = "bm25") -> str:
context_with_indices = self.retriever.get_docs(question, method=method)
if not context_with_indices:
return "I don't know((("
# контекст для моделі
context = "\n".join([f"Doc {doc_id}, Chunk {chunk_id}: {text}" for doc_id, chunk_id, text in context_with_indices])
messages = [
{"role": "system", "content": self.PROMPT},
{"role": "user", "content": f"Context: {context}\nQuestion: {question}"}
]
try:
completionn = completion(
model="groq/llama3-8b-8192",
messages=messages,
)
# Відповідь
answer = completionn['choices'][0]['message']['content']
# джерела
sources = [f"Doc {doc_id}: Chunk {chunk_id}; " for doc_id, chunk_id, _ in context_with_indices]
return f"{answer} [{', '.join(sources)}]"
except Exception as e:
return f"Error: {str(e)}"
# question = "Tell about war"
docs = docs
# bot = QuestionAnsweringBot(docs)
# answer = bot.answer_question(question)
# print(f'Q: {question}')
# print(f'A: {answer}')
import gradio as gr
def answer_question_with_method(query, method):
bot = QuestionAnsweringBot(docs)
return bot.answer_question(query, method=method)
# Створення інтерфейсу
demo = gr.Interface(
fn=answer_question_with_method,
inputs=[
gr.Textbox(label="Your Question"),
gr.Dropdown(
choices=["bm25", "sbert", "hybrid"],
value="hybrid",
label="Select Retrieval Method"
)
],
outputs="text"
)
demo.launch()
|