Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -30,7 +30,7 @@ async def initialize_gemini(file_path, question):
|
|
30 |
else:
|
31 |
return "Error: Unable to process the document. Please ensure the PDF file is valid."
|
32 |
|
33 |
-
# Mistral Text Completion
|
34 |
class MistralModel:
|
35 |
def __init__(self):
|
36 |
self.model_path = "nvidia/Mistral-NeMo-Minitron-8B-Base"
|
@@ -39,10 +39,24 @@ class MistralModel:
|
|
39 |
self.dtype = torch.bfloat16
|
40 |
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, torch_dtype=self.dtype, device_map=self.device)
|
41 |
|
42 |
-
def generate_text(self, prompt, max_length=
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
mistral_model = MistralModel()
|
48 |
|
@@ -54,14 +68,14 @@ async def process_input(file, question):
|
|
54 |
|
55 |
# Gradio Interface
|
56 |
with gr.Blocks() as demo:
|
57 |
-
gr.Markdown("# PDF Question Answering and Text Completion System")
|
58 |
|
59 |
-
input_file = gr.File(label="Upload PDF File")
|
60 |
input_question = gr.Textbox(label="Ask a question or provide a prompt")
|
61 |
process_button = gr.Button("Process")
|
62 |
|
63 |
-
output_text_gemini = gr.Textbox(label="Answer - Gemini")
|
64 |
-
output_text_mistral = gr.Textbox(label="Answer - Mistral")
|
65 |
|
66 |
process_button.click(
|
67 |
fn=process_input,
|
|
|
30 |
else:
|
31 |
return "Error: Unable to process the document. Please ensure the PDF file is valid."
|
32 |
|
33 |
+
# Improved Mistral Text Completion
|
34 |
class MistralModel:
|
35 |
def __init__(self):
|
36 |
self.model_path = "nvidia/Mistral-NeMo-Minitron-8B-Base"
|
|
|
39 |
self.dtype = torch.bfloat16
|
40 |
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, torch_dtype=self.dtype, device_map=self.device)
|
41 |
|
42 |
+
def generate_text(self, prompt, max_length=200):
|
43 |
+
# Improve the prompt for better context
|
44 |
+
enhanced_prompt = f"Question: {prompt}\n\nAnswer: Let's approach this step-by-step:\n1."
|
45 |
+
inputs = self.tokenizer.encode(enhanced_prompt, return_tensors='pt').to(self.model.device)
|
46 |
+
|
47 |
+
# Generate with more nuanced parameters
|
48 |
+
outputs = self.model.generate(
|
49 |
+
inputs,
|
50 |
+
max_length=max_length,
|
51 |
+
num_return_sequences=1,
|
52 |
+
no_repeat_ngram_size=3,
|
53 |
+
top_k=50,
|
54 |
+
top_p=0.95,
|
55 |
+
temperature=0.7,
|
56 |
+
do_sample=True
|
57 |
+
)
|
58 |
+
|
59 |
+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
60 |
|
61 |
mistral_model = MistralModel()
|
62 |
|
|
|
68 |
|
69 |
# Gradio Interface
|
70 |
with gr.Blocks() as demo:
|
71 |
+
gr.Markdown("# Enhanced PDF Question Answering and Text Completion System")
|
72 |
|
73 |
+
input_file = gr.File(label="Upload PDF File (Optional)")
|
74 |
input_question = gr.Textbox(label="Ask a question or provide a prompt")
|
75 |
process_button = gr.Button("Process")
|
76 |
|
77 |
+
output_text_gemini = gr.Textbox(label="Answer - Gemini (PDF-based if file uploaded)")
|
78 |
+
output_text_mistral = gr.Textbox(label="Answer - Mistral (General knowledge)")
|
79 |
|
80 |
process_button.click(
|
81 |
fn=process_input,
|