File size: 9,863 Bytes
523a4a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import traceback
from datetime import datetime
from pathlib import Path
import os
import random
import string
import tempfile
import re
import io
import PyPDF2
import docx
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.enums import TA_JUSTIFY
from ai_config import n_of_questions, load_model, openai_api_key, convert_text_to_speech
from knowledge_retrieval import setup_knowledge_retrieval, generate_report
from ai_config import n_of_questions, openai_api_key, load_model
# Initialize settings
n_of_questions = n_of_questions()
current_datetime = datetime.now()
human_readable_datetime = current_datetime.strftime("%B %d, %Y at %H:%M")
current_date = current_datetime.strftime("%Y-%m-%d")
# Initialize the model and retrieval chain
try:
llm = load_model(openai_api_key)
interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(llm)
knowledge_base_connected = True
print("Successfully connected to the knowledge base.")
except Exception as e:
print(f"Error initializing the model or retrieval chain: {str(e)}")
knowledge_base_connected = False
print("Falling back to basic mode without knowledge base.")
question_count = 0
interview_history = []
last_audio_path = None # Variable to store the path of the last audio file
initial_audio_path = None # Variable to store the path of the initial audio file
language = None
def generate_random_string(length=5):
return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
def respond(message, history):
global question_count, interview_history, combined_retriever, last_audio_path, initial_audio_path, language, interview_retrieval_chain, report_retrieval_chain
if not isinstance(history, list):
history = []
if not history or not history[-1]:
history.append(["", ""])
# Extract the actual message text
if isinstance(message, list):
message = message[-1][0] if message and isinstance(message[-1], list) else message[-1]
question_count += 1
interview_history.append(f"Q{question_count}: {message}")
history_str = "\n".join(interview_history)
try:
if knowledge_base_connected:
if question_count == 1:
# Capture the language from the first response
language = message.strip().lower()
# Reinitialize the interview chain with the new language
interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(
llm, language)
if question_count < n_of_questions:
result = interview_retrieval_chain.invoke({
"input": f"Based on the patient's statement: '{message}', what should be the next question?",
"history": history_str,
"question_number": question_count + 1,
"language": language
})
question = result.get("answer", f"Can you tell me more about that? (in {language})")
else:
result = generate_report(report_retrieval_chain, interview_history, language)
question = result
speech_file_path = None # Skip audio generation for the report
if question:
random_suffix = generate_random_string()
speech_file_path = Path(__file__).parent / f"question_{question_count}_{random_suffix}.mp3"
convert_text_to_speech(question, speech_file_path)
print(f"Question {question_count} saved as audio at {speech_file_path}")
# Remove the last audio file if it exists
if last_audio_path and os.path.exists(last_audio_path):
os.remove(last_audio_path)
last_audio_path = speech_file_path
else:
speech_file_path = None # Skip audio generation for the report
else:
# Fallback mode without knowledge base
question = f"Can you elaborate on that? (in {language})"
if question_count < n_of_questions:
speech_file_path = Path(__file__).parent / f"question_{question_count}.mp3"
convert_text_to_speech(question, speech_file_path)
print(f"Question {question_count} saved as audio at {speech_file_path}")
if last_audio_path and os.path.exists(last_audio_path):
os.remove(last_audio_path)
last_audio_path = speech_file_path
else:
speech_file_path = None
history[-1][1] = f"{question}"
# Remove the initial question audio file after the first user response
if initial_audio_path and os.path.exists(initial_audio_path):
os.remove(initial_audio_path)
initial_audio_path = None
return history, str(speech_file_path) if speech_file_path else None
except Exception as e:
print(f"Error in retrieval chain: {str(e)}")
print(traceback.format_exc())
return history, None
def reset_interview():
"""Reset the interview state."""
global question_count, interview_history, last_audio_path, initial_audio_path
question_count = 0
interview_history = []
if last_audio_path and os.path.exists(last_audio_path):
os.remove(last_audio_path)
last_audio_path = None
initial_audio_path = None
def read_file(file):
if file is None:
return "No file uploaded"
if isinstance(file, str):
with open(file, 'r', encoding='utf-8') as f:
return f.read()
if hasattr(file, 'name'): # Check if it's a file-like object
if file.name.endswith('.txt'):
return file.content
elif file.name.endswith('.pdf'):
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file.content))
return "\n".join(page.extract_text() for page in pdf_reader.pages)
elif file.name.endswith('.docx'):
doc = docx.Document(io.BytesIO(file.content))
return "\n".join(paragraph.text for paragraph in doc.paragraphs)
else:
return "Unsupported file format"
return "Unable to read file"
def generate_report_from_file(file, language):
try:
file_content = read_file(file)
if file_content == "No file uploaded" or file_content == "Unsupported file format" or file_content == "Unable to read file":
return file_content
report_language = language.strip().lower() if language else "english"
print('preferred language:', report_language)
print(f"Generating report in language: {report_language}") # For debugging
# Reinitialize the report chain with the new language
_, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
result = report_retrieval_chain.invoke({
"input": "Please provide a clinical report based on the following content:",
"history": file_content,
"language": report_language
})
report_content = result.get("answer", "Unable to generate report due to insufficient information.")
pdf_path = create_pdf(report_content)
return report_content, pdf_path
except Exception as e:
return f"An error occurred while processing the file: {str(e)}", None
def generate_interview_report(interview_history, language):
try:
report_language = language.strip().lower() if language else "english"
print('preferred report_language language:', report_language)
_, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
result = report_retrieval_chain.invoke({
"input": "Please provide a clinical report based on the following interview:",
"history": "\n".join(interview_history),
"language": report_language
})
report_content = result.get("answer", "Unable to generate report due to insufficient information.")
pdf_path = create_pdf(report_content)
return report_content, pdf_path
except Exception as e:
return f"An error occurred while generating the report: {str(e)}", None
def create_pdf(content):
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf')
doc = SimpleDocTemplate(temp_file.name, pagesize=letter)
styles = getSampleStyleSheet()
# Create a custom style for bold text
bold_style = ParagraphStyle('Bold', parent=styles['Normal'], fontName='Helvetica-Bold', fontSize=10)
# Create a custom style for normal text with justification
normal_style = ParagraphStyle('Normal', parent=styles['Normal'], alignment=TA_JUSTIFY)
flowables = []
for line in content.split('\n'):
# Use regex to find words surrounded by **
parts = re.split(r'(\*\*.*?\*\*)', line)
paragraph_parts = []
for part in parts:
if part.startswith('**') and part.endswith('**'):
# Bold text
bold_text = part.strip('**')
paragraph_parts.append(Paragraph(bold_text, bold_style))
else:
# Normal text
paragraph_parts.append(Paragraph(part, normal_style))
flowables.extend(paragraph_parts)
flowables.append(Spacer(1, 12)) # Add space between paragraphs
doc.build(flowables)
return temp_file.name |