File size: 5,451 Bytes
232f6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import re
from openai import OpenAI
from langchain_openai import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI
from pydantic import BaseModel
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import time

def clean_response(response):
    # Remove any leading/trailing whitespace, including newlines
    cleaned = response.strip()
    
    # Remove any enclosing quotation marks
    cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
    
    # Replace multiple newlines with a single newline
    cleaned = re.sub(r'\n+', '\n', cleaned)
    
    # Remove any remaining '\n' characters
    cleaned = cleaned.replace('\\n', '')
    
    return cleaned

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

openai_api_key = os.environ.get('OPENAI_API_KEY')
llm = ChatOpenAI(
    api_key=openai_api_key,
    model_name="gpt-4-turbo-preview",  # or "gpt-3.5-turbo" for a more economical option
    temperature=0.7
)

@app.get("/")
def read_root():
    return {"Hello": "World"}

class Query(BaseModel):
    query_text: str

prompt = ChatPromptTemplate.from_template(
"""
You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
You may elaborate on your answers slightly to provide more information, but avoid sounding boastful or exaggerating. Stay focused on the context provided.
If the query is not related to TIET or falls outside the context of education, respond with:
        "Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology. 
        For more information, please contact at our toll-free number: 18002024100 or E-mail us at admissions@thapar.edu
<context>
{context}
</context>
Question: {input}  
"""
)

def vector_embedding():
    try:
        file_path = "./data/Data.docx"
        if not os.path.exists(file_path):
            print(f"The file {file_path} does not exist.")
            return {"response": "Error: Data file not found"}

        loader = DocxLoader(file_path)
        documents = loader.load()

        print(f"Loaded document: {file_path}")

        text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
        chunks = text_splitter.split_documents(documents)
        
        print(f"Created {len(chunks)} chunks.")

        model_name = "BAAI/bge-base-en"
        encode_kwargs = {'normalize_embeddings': True}
        model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
        
        db = FAISS.from_documents(chunks, model_norm)
        db.save_local("./vectors_db")
        
        print("Vector store created and saved successfully.")
        return {"response": "Vector Store DB Is Ready"}

    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return {"response": f"Error: {str(e)}"}

def get_embeddings():
    model_name = "BAAI/bge-base-en"
    encode_kwargs = {'normalize_embeddings': True}
    model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
    return model_norm

@app.post("/chat")  # Changed from /anthropic to /chat
def read_item(query: Query):
    try:
        embeddings = get_embeddings()
        vectors = FAISS.load_local("./vectors_db", embeddings, allow_dangerous_deserialization=True)
    except Exception as e:
        print(f"Error loading vector store: {str(e)}")
        return {"response": "Vector Store Not Found or Error Loading. Please run /setup first."}
    
    prompt1 = query.query_text
    if prompt1:
        start = time.process_time()
        document_chain = create_stuff_documents_chain(llm, prompt)
        retriever = vectors.as_retriever()
        retrieval_chain = create_retrieval_chain(retriever, document_chain)
        response = retrieval_chain.invoke({'input': prompt1})
        print("Response time:", time.process_time() - start)
        
        # Apply the cleaning function to the response
        cleaned_response = clean_response(response['answer'])
        
        # For debugging, print the cleaned response
        print("Cleaned response:", repr(cleaned_response))
        
        return cleaned_response
    else:
        return "No Query Found"

@app.get("/setup")
def setup():
    return vector_embedding()

# Uncomment this to check if the API key is set
# print(f"API key set: {'Yes' if os.environ.get('OPENAI_API_KEY') else 'No'}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)