test1 / Study_RAG.py
ZagreusMiura's picture
added cpu support
5991cda verified
import gradio as gr
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import os
# This model supports two prompts: "s2p_query" and "s2s_query" for sentence-to-passage and sentence-to-sentence tasks, respectively.
# They are defined in `config_sentence_transformers.json`
def answer_pipeline(question, progress=gr.Progress()):
# %%
device = "cuda" if torch.cuda.is_available() else "cpu"
# print("Loading Model")
progress(0, desc="Loading Model...")
model = SentenceTransformer("dunzhang/stella_en_1.5B_v5", trust_remote_code=True).to(device)
# print("Loading Data")
progress(0.5, desc="Loading Data...")
corpus = pd.read_pickle('corpus.pkl')
docs = corpus["Text"].tolist()
ids = corpus["Document ID"].tolist()
doc_embeddings = torch.Tensor(corpus["Embedding"].tolist())
# while question.lower() != 'q':
# question = input("Enter your question (or type 'q' to quit): ").strip()
question = question.strip()
# question = """What is a significant push factor for entrepreneurship in the UAE?
# A. Strict intellectual property regulations
# B. High taxation
# C. Supportive legislation and policies
# D. Limited access to funding
# """
querry_embedding = torch.Tensor(model.encode(question, query_prompt_name = "s2p_query"))
cos_dists = F.cosine_similarity(doc_embeddings, querry_embedding, dim = 1)
sorted_indices = torch.argsort(cos_dists, descending=True)
# np.array(sorted_indices.cpu().reshape(-1))
# sorted_indices =sorted_indices[0]
sorted_documents_with_scores = [(docs[idx.item()], cos_dists[idx].item(), ids[idx]) for idx in sorted_indices]
# # Output the sorted documents with their scores
# for doc, score in sorted_documents_with_scores[:3]:
# print(f"\nScore: {score:.4f}")
# print(f"Full Text: {doc}") # Optionally, show the full text if needed
top_n = 10
context = "\n---------------\n".join([x[0] for x in sorted_documents_with_scores[:top_n]])
# print(context)
# # %%
# # %%
from groq import Groq
client = Groq(api_key=os.getenv('Groq'))
#not strictly based on the information available or your internal knowledge. --> based on the information available or your internal knowledge
system_instruction = """You are an expert research assistant specializing in educational programs. Your task is to generate concise, informative, and contextually relevant responses based on the provided contexts.
Guidelines:
- Analyze the given context(s) carefully and provide responses strictly based on the information available or your internal knowledge.
- If the answer is not clear from the provided context or yours internal knowledge, respond with "I don't know."
- **Provide Detailed and Informative Answers**: Avoid brief answers. Include key details to make the response comprehensive and useful. For example, for "What study program offers XYZ?", provide both the program name and a brief description.
- Provide a direct response without any prelude or introductory text.
- Maintain the exact formatting provided below:
The input will contain contexts followed by a query, separated by a line of dashes ("\n---------------\n"):
Context:
{
---Context 1---
"\n---------------\n"
---Context 2---
"\n---------------\n"
...
---Context n---
"\n---------------\n"
}
Query:
{Question}
Response:
[Start your response directly here without any prelude. Adhere strictly to this structure and guidelines]
"""
content = f"""
Context:
{context}
Query:
{question}
Response:
"""
progress(0.7, desc="Generating Output...")
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": system_instruction
},
{
"role": "user",
"content": content,
}
],
model="llama3-70b-8192",
temperature=1,
)
# print(sorted_documents_with_scores[:top_n])
# print("==========================================================")
return chat_completion.choices[0].message.content
# output_file = 'response.txt'
# # Overwrite the file with new content
# with open(output_file, 'w') as file:
# file.write(answer)
# # %%