rafaaa2105 commited on
Commit
74d06fe
1 Parent(s): b2a42bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -23
app.py CHANGED
@@ -1,30 +1,129 @@
1
- import gradio as gr
2
- from autollm import AutoQueryEngine
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import spaces
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  @spaces.GPU
7
- def query_engine(llm_model, document, query):
8
- api_key = os.getenv("HUGGINGFACE_API_KEY")
9
- llm_api_base = "https://api-inference.huggingface.co/models/"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- content = str(document)
 
 
 
 
 
 
 
 
 
 
12
 
13
- query_engine = AutoQueryEngine.from_defaults(
14
- documents=[content],
15
- llm_model=llm_model,
16
- llm_api_base=llm_api_base,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
- response = query_engine.query(query)
19
- return response
20
-
21
- interface = gr.Blocks()
22
- with interface:
23
- gr.Markdown("# AutoQueryEngine Interface")
24
- llm_model = gr.Textbox(label="LLM Model", value="mistralai/Mixtral-8x7B-Instruct-v0.1")
25
- document = gr.File(label="Documents")
26
- query = gr.Textbox(label="Query")
27
- output = gr.Textbox(label="Output")
28
- query_btn = gr.Button("Query")
29
- query_btn.click(fn=query_engine, inputs=[llm_model, document, query], outputs=output)
30
- interface.launch()
 
 
 
1
  import os
2
+ import dotenv
3
+ import gradio as gr
4
+ import lancedb
5
+ import logging
6
+ from langchain.embeddings.cohere import CohereEmbeddings
7
+ from langchain.llms import Cohere
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.vectorstores import LanceDB
11
+ from langchain.document_loaders import TextLoader
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain_community.document_loaders import PyPDFLoader
14
+ import argostranslate.package
15
+ import argostranslate.translate
16
  import spaces
17
 
18
+
19
+ # Configuration Management
20
+ dotenv.load_dotenv(".env")
21
+ DB_PATH = "/tmp/lancedb"
22
+
23
+ COHERE_MODEL_NAME = "multilingual-22-12"
24
+ LANGUAGE_ISO_CODES = {
25
+ "English": "en",
26
+ "Hindi": "hi",
27
+ "Turkish": "tr",
28
+ "French": "fr",
29
+ }
30
+
31
+ # Logging Configuration
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger = logging.getLogger(__name__)
34
+
35
  @spaces.GPU
36
+ def initialize_documents_and_embeddings(input_file_path):
37
+ file_extension = os.path.splitext(input_file_path)[1]
38
+ if file_extension == '.txt':
39
+ logger.info("txt file processing")
40
+ # Handle text file
41
+ loader = TextLoader(input_file_path)
42
+ documents = loader.load()
43
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
44
+ texts = text_splitter.split_documents(documents)
45
+ elif file_extension == '.pdf':
46
+ logger.info("pdf file processing")
47
+ # Handle PDF file
48
+ loader = PyPDFLoader(input_file_path)
49
+ texts = loader.load_and_split()
50
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
51
+ texts = text_splitter.split_documents(texts)
52
+ else:
53
+ raise ValueError("Unsupported file type. Supported files are .txt and .pdf only.")
54
+
55
+ embeddings = CohereEmbeddings(model=COHERE_MODEL_NAME)
56
+ return texts, embeddings
57
+
58
+ # Database Initialization
59
+ def initialize_database(texts, embeddings):
60
+ db = lancedb.connect(DB_PATH)
61
+ table = db.create_table(
62
+ "multiling-rag",
63
+ data=[{"vector": embeddings.embed_query("Hello World"), "text": "Hello World", "id": "1"}],
64
+ mode="overwrite",
65
+ )
66
+ return LanceDB.from_documents(texts, embeddings, connection=table)
67
 
68
+ # Translation Function
69
+ def translate_text(text, from_code, to_code):
70
+ try:
71
+ argostranslate.package.update_package_index()
72
+ available_packages = argostranslate.package.get_available_packages()
73
+ package_to_install = next(filter(lambda x: x.from_code == from_code and x.to_code == to_code, available_packages))
74
+ argostranslate.package.install_from_path(package_to_install.download())
75
+ return argostranslate.translate.translate(text, from_code, to_code)
76
+ except Exception as e:
77
+ logger.error(f"Error in translate_text: {str(e)}")
78
+ return "Translation error"
79
 
80
+
81
+ prompt_template = """Text: {context}
82
+
83
+ Question: {question}
84
+
85
+ Answer the question based on the text provided. If the text doesn't contain the answer, reply that the answer is not available."""
86
+ PROMPT = PromptTemplate(
87
+ template=prompt_template, input_variables=["context", "question"])
88
+
89
+ # Question Answering Function
90
+ def answer_question(question, input_language, output_language, db):
91
+ try:
92
+ input_lang_code = LANGUAGE_ISO_CODES[input_language]
93
+ output_lang_code = LANGUAGE_ISO_CODES[output_language]
94
+
95
+ question_in_english = translate_text(question, from_code=input_lang_code, to_code="en") if input_language != "English" else question
96
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
97
+ qa = RetrievalQA.from_chain_type(llm=Cohere(model="command", temperature=0), chain_type="stuff", retriever=db.as_retriever(), chain_type_kwargs={"prompt": prompt}, return_source_documents=True)
98
+
99
+ answer = qa({"query": question_in_english})
100
+ result_in_english = answer["result"].replace("\n", "").replace("Answer:", "")
101
+
102
+ return translate_text(result_in_english, from_code="en", to_code=output_lang_code) if output_language != "English" else result_in_english
103
+ except Exception as e:
104
+ logger.error(f"Error in answer_question: {str(e)}")
105
+ return "An error occurred while processing your question. Please try again."
106
+
107
+ def setup_gradio_interface(db):
108
+ return gr.Interface(
109
+ fn=lambda question, input_language, output_language: answer_question(question, input_language, output_language, db),
110
+ inputs=[
111
+ gr.Textbox(lines=2, placeholder="Type your question here..."),
112
+ gr.Dropdown(list(LANGUAGE_ISO_CODES.keys()), label="Input Language"),
113
+ gr.Dropdown(list(LANGUAGE_ISO_CODES.keys()), label="Output Language")
114
+ ],
115
+ outputs="text",
116
+ title="Multilingual Chatbot",
117
+ description="Ask any question in your chosen language and get an answer in the language of your choice."
118
  )
119
+
120
+ # Main Function
121
+ def main():
122
+ INPUT_FILE_PATH = "healthy-diet-fact-sheet-394.pdf"
123
+ texts, embeddings = initialize_documents_and_embeddings(INPUT_FILE_PATH)
124
+ db = initialize_database(texts, embeddings)
125
+ iface = setup_gradio_interface(db)
126
+ iface.launch(share=True, debug=True)
127
+
128
+ if __name__ == "__main__":
129
+ main()