YU-XI commited on
Commit
17633c5
1 Parent(s): 8b06ac5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -35
app.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import gradio as gr
3
  import asyncio
4
  from langchain_core.prompts import PromptTemplate
5
- from langchain_community.output_parsers.rail_parser import GuardrailsOutputParser
6
  from langchain_community.document_loaders import PyPDFLoader
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  import google.generativeai as genai
@@ -11,9 +10,8 @@ import torch
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
 
13
  # Gemini PDF QA System
14
- async def initialize(file_path, question):
15
  genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
16
- model = genai.GenerativeModel('gemini-pro')
17
  model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3)
18
  prompt_template = """Answer the question as precise as possible using the provided context. If the answer is
19
  not contained in the context, say "answer not available in context" \n\n
@@ -27,49 +25,48 @@ async def initialize(file_path, question):
27
  pages = pdf_loader.load_and_split()
28
  context = "\n".join(str(page.page_content) for page in pages[:30])
29
  stuff_chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
30
- stuff_answer = await stuff_chain({"input_documents": pages, "question": question, "context": context}, return_only_outputs=True)
31
  return stuff_answer['output_text']
32
  else:
33
  return "Error: Unable to process the document. Please ensure the PDF file is valid."
34
 
35
- async def pdf_qa(file, question):
36
- answer = await initialize(file.name, question)
37
- return answer
38
-
39
  # Mistral Text Completion
40
- def load_mistral_model():
41
- model_path = "nvidia/Mistral-NeMo-Minitron-8B-Base"
42
- tokenizer = AutoTokenizer.from_pretrained(model_path)
43
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
44
- dtype = torch.bfloat16
45
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, device_map=device)
46
- return tokenizer, model
47
 
48
- def generate_text(prompt, max_length=50):
49
- tokenizer, model = load_mistral_model()
50
- inputs = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
51
- outputs = model.generate(inputs, max_length=max_length)
52
- return tokenizer.decode(outputs[0])
53
 
54
- # Gradio Interface
55
- def pdf_qa_wrapper(file, question):
56
- return asyncio.run(pdf_qa(file, question))
57
 
 
 
 
 
 
 
 
58
  with gr.Blocks() as demo:
59
- gr.Markdown("# Combined PDF QA and Text Completion System")
60
 
61
- with gr.Tab("PDF Question Answering"):
62
- input_file = gr.File(label="Upload PDF File")
63
- input_question = gr.Textbox(label="Ask about the document")
64
- output_text_gemini = gr.Textbox(label="Answer - GeminiPro")
65
- pdf_qa_button = gr.Button("Ask Question")
66
 
67
- with gr.Tab("Text Completion"):
68
- input_prompt = gr.Textbox(label="Enter prompt for text completion")
69
- output_text_mistral = gr.Textbox(label="Completed Text - Mistral")
70
- complete_text_button = gr.Button("Complete Text")
71
 
72
- pdf_qa_button.click(pdf_qa_wrapper, inputs=[input_file, input_question], outputs=output_text_gemini)
73
- complete_text_button.click(generate_text, inputs=input_prompt, outputs=output_text_mistral)
 
 
 
74
 
75
  demo.launch()
 
2
  import gradio as gr
3
  import asyncio
4
  from langchain_core.prompts import PromptTemplate
 
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  import google.generativeai as genai
 
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
 
12
  # Gemini PDF QA System
13
+ async def initialize_gemini(file_path, question):
14
  genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
 
15
  model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3)
16
  prompt_template = """Answer the question as precise as possible using the provided context. If the answer is
17
  not contained in the context, say "answer not available in context" \n\n
 
25
  pages = pdf_loader.load_and_split()
26
  context = "\n".join(str(page.page_content) for page in pages[:30])
27
  stuff_chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
28
+ stuff_answer = await stuff_chain.acall({"input_documents": pages, "question": question, "context": context}, return_only_outputs=True)
29
  return stuff_answer['output_text']
30
  else:
31
  return "Error: Unable to process the document. Please ensure the PDF file is valid."
32
 
 
 
 
 
33
  # Mistral Text Completion
34
+ class MistralModel:
35
+ def __init__(self):
36
+ self.model_path = "nvidia/Mistral-NeMo-Minitron-8B-Base"
37
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
38
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
39
+ self.dtype = torch.bfloat16
40
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_path, torch_dtype=self.dtype, device_map=self.device)
41
 
42
+ def generate_text(self, prompt, max_length=50):
43
+ inputs = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
44
+ outputs = self.model.generate(inputs, max_length=max_length)
45
+ return self.tokenizer.decode(outputs[0])
 
46
 
47
+ mistral_model = MistralModel()
 
 
48
 
49
+ # Combined function for both models
50
+ async def process_input(file, question):
51
+ gemini_answer = await initialize_gemini(file.name, question)
52
+ mistral_answer = mistral_model.generate_text(question)
53
+ return gemini_answer, mistral_answer
54
+
55
+ # Gradio Interface
56
  with gr.Blocks() as demo:
57
+ gr.Markdown("# PDF Question Answering and Text Completion System")
58
 
59
+ input_file = gr.File(label="Upload PDF File")
60
+ input_question = gr.Textbox(label="Ask a question or provide a prompt")
61
+ process_button = gr.Button("Process")
 
 
62
 
63
+ output_text_gemini = gr.Textbox(label="Answer - Gemini")
64
+ output_text_mistral = gr.Textbox(label="Answer - Mistral")
 
 
65
 
66
+ process_button.click(
67
+ fn=process_input,
68
+ inputs=[input_file, input_question],
69
+ outputs=[output_text_gemini, output_text_mistral]
70
+ )
71
 
72
  demo.launch()