Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
from langchain.prompts import PromptTemplate | |
from langchain_groq import ChatGroq | |
from typing import Literal | |
# Load environment variables | |
load_dotenv() | |
# Initialize LLMs | |
def initialize_llms(): | |
"""Initialize and return the LLM instances""" | |
groq_api_key = os.getenv("GROQ_API_KEY") | |
return { | |
"rewrite_llm": ChatGroq( | |
temperature=0.1, | |
model="llama-3.3-70b-versatile", | |
api_key=groq_api_key | |
), | |
"step_back_llm": ChatGroq( | |
temperature=0, | |
model="Gemma2-9B-IT", | |
api_key=groq_api_key | |
) | |
} | |
# Certification classification | |
def classify_certification( | |
query: str, | |
llm: ChatGroq, | |
certs_dir: str = "docs/processed" | |
) -> str: | |
""" | |
Classify which certification a query is referring to. | |
Returns certification name or 'no certification mentioned'. | |
""" | |
available_certs = "2BSvs, CertifHy - National Green Certificate (NGC), CertifHy - RFNBO, Certified_Hydrogen_Producer, GH2_Standard, Green_Hydrogen_Certification, ISCC CORSIA, ISCC EU (International Sustainability & Carbon Certification), ISCC PLUS, ISO_19880_Hydrogen_Quality, REDcert-EU, RSB, Scottish Quality Farm Assured Combinable Crops (SQC), TUV Rheinland H2.21, UK RTFO_regulation" | |
template = """ | |
You are an AI assistant classifying user queries based on the certification they are asking for in a RAG system. | |
Classify the given query into one of the following certifications: | |
- {available_certifications} | |
Don't need any explanation, just return the name of the certification. | |
Use the exact name of the certification as it appears in the directory. | |
If the query refers to multiple certifications, return the most relevant one. | |
If the query doesn't mention any certification, respond with "no certification mentioned". | |
Original query: {original_query} | |
Classification: | |
""" | |
prompt = PromptTemplate( | |
input_variables=["original_query", "available_certifications"], | |
template=template | |
) | |
chain = prompt | llm | |
response = chain.invoke({ | |
"original_query": query, | |
"available_certifications": available_certs | |
}).content.strip() | |
return response | |
# Query specificity classification | |
def classify_query_specificity( | |
query: str, | |
llm: ChatGroq | |
) -> Literal["specific", "general", "too narrow"]: | |
""" | |
Classify query specificity. | |
Returns one of: 'specific', 'general', or 'too narrow'. | |
""" | |
template = """ | |
You are an AI assistant classifying user queries based on their specificity for a RAG system. | |
Classify the given query into one of: | |
- "specific" → If it asks for exact values, certifications, or well-defined facts. | |
- "general" → If it is broad and needs refinement for better retrieval. | |
- "too narrow" → If it is very specific and might need broader context. | |
DO NOT output explanations, only return one of: "specific", "general", or "too narrow". | |
Original query: {original_query} | |
Classification: | |
""" | |
prompt = PromptTemplate( | |
input_variables=["original_query"], | |
template=template | |
) | |
chain = prompt | llm | |
response = chain.invoke({"original_query": query}).content.strip().lower() | |
return response.split("\n")[0].strip() # type: ignore | |
# Query refinement | |
def refine_query( | |
query: str, | |
llm: ChatGroq | |
) -> str: | |
"""Rewrite a query to be clearer and more detailed while keeping the original intent""" | |
template = """ | |
You are an AI assistant that improves queries for retrieving precise certification and compliance data. | |
Rewrite the query to be clearer while keeping the intent unchanged. | |
Original query: {original_query} | |
Refined query: | |
""" | |
prompt = PromptTemplate( | |
input_variables=["original_query"], | |
template=template | |
) | |
chain = prompt | llm | |
return chain.invoke({"original_query": query}).content | |
# Step-back query generation | |
def generate_step_back_query( | |
query: str, | |
llm: ChatGroq | |
) -> str: | |
"""Generate a broader step-back query to retrieve relevant background information""" | |
template = """ | |
You are an AI assistant generating broader queries to improve retrieval context. | |
Given the original query, generate a more general step-back query to retrieve relevant background information. | |
Original query: {original_query} | |
Step-back query: | |
""" | |
prompt = PromptTemplate( | |
input_variables=["original_query"], | |
template=template | |
) | |
chain = prompt | llm | |
return chain.invoke({"original_query": query}).content | |
# Main query processing pipeline | |
def process_query( | |
original_query: str, | |
llms: dict | |
) -> str: | |
""" | |
Process a query through the full pipeline: | |
1. Classify specificity | |
2. Apply appropriate refinement | |
""" | |
specificity = classify_query_specificity(original_query, llms["rewrite_llm"]) | |
if specificity == "specific": | |
return refine_query(original_query, llms["rewrite_llm"]) | |
elif specificity == "general": | |
return refine_query(original_query, llms["rewrite_llm"]) | |
elif specificity == "too narrow": | |
return generate_step_back_query(original_query, llms["step_back_llm"]) | |
return original_query | |
# Test setup | |
def test_hydrogen_certification_functions(): | |
# Initialize LLMs | |
llms = initialize_llms() | |
# Create a test directory with hydrogen certifications | |
test_certs_dir = "docs/processed" | |
os.makedirs(test_certs_dir, exist_ok=True) | |
# Create some dummy certification folders | |
hydrogen_certifications = [ | |
"GH2_Standard", | |
"Certified_Hydrogen_Producer", | |
"Green_Hydrogen_Certification", | |
"ISO_19880_Hydrogen_Quality" | |
] | |
for cert in hydrogen_certifications: | |
os.makedirs(os.path.join(test_certs_dir, cert), exist_ok=True) | |
# Test queries | |
test_queries = [ | |
("What are the purity requirements in GH2 Standard?", "specific"), | |
("How does hydrogen certification work?", "general"), | |
("What's the exact ppm of CO2 allowed in ISO_19880_Hydrogen_Quality section 4.2?", "too narrow"), | |
("What safety protocols exist for hydrogen storage?", "general") | |
] | |
print("=== Testing Certification Classification ===") | |
for query, _ in test_queries: | |
cert = classify_certification(query, llms["rewrite_llm"], test_certs_dir) | |
print(f"Query: {query}\nClassification: {cert}\n") | |
print("\n=== Testing Specificity Classification ===") | |
for query, expected_type in test_queries: | |
specificity = classify_query_specificity(query, llms["rewrite_llm"]) | |
print(f"Query: {query}\nExpected: {expected_type}, Got: {specificity}\n") | |
print("\n=== Testing Full Query Processing ===") | |
for query, _ in test_queries: | |
processed = process_query(query, llms) | |
print(f"Original: {query}\nProcessed: {processed}\n") | |
# Run the tests | |
if __name__ == "__main__": | |
test_hydrogen_certification_functions() |