md-qa-test / question_generator.py
ambrosfitz's picture
Update question_generator.py
afaec0f verified
raw
history blame
3.55 kB
import random
import csv
import os
import logging
import hashlib
import json
import re
from typing import List, Dict
from datetime import datetime
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Get the Mistral API key from environment variables
api_key = os.environ.get("MISTRAL_API_KEY")
if not api_key:
logging.error("MISTRAL_API_KEY environment variable is not set.")
raise ValueError("MISTRAL_API_KEY environment variable is not set.")
model = "mistral-large-latest"
# Initialize Mistral client
client = MistralClient(api_key=api_key)
# ... (previous functions remain the same)
def extract_json_from_markdown(markdown_text: str) -> str:
"""Extract JSON content from Markdown-formatted text."""
json_match = re.search(r'```json\s*(.*?)\s*```', markdown_text, re.DOTALL)
if json_match:
return json_match.group(1)
else:
raise ValueError("No JSON content found in the Markdown text")
def generate_microbiology_question() -> Dict[str, str]:
"""Generate a microbiology question."""
# ... (previous code remains the same)
try:
chat_response = client.chat(
model=model,
messages=[
ChatMessage(role="system", content="You are a medical educator creating unique microbiology questions for the NBME exam. Ensure each question is distinct from previously generated ones and follows the specified template."),
ChatMessage(role="user", content=prompt)
]
)
response_content = chat_response.choices[0].message.content
logging.info(f"Received response from Mistral API: {response_content[:100]}...") # Log first 100 characters
# Extract JSON from Markdown if necessary
try:
json_content = extract_json_from_markdown(response_content)
except ValueError:
json_content = response_content # If not in Markdown, use the original content
# Parse the JSON response
question_data = json.loads(json_content)
# Validate the structure of the parsed JSON
required_keys = ["question", "options", "correct_answer", "explanation", "medical_reasoning"]
if not all(key in question_data for key in required_keys):
raise ValueError("Response is missing required keys")
if not all(key in question_data["options"] for key in ["A", "B", "C", "D", "E"]):
raise ValueError("Response is missing required option keys")
# Save the question hash
question_hash = hash_question(question_data['question'])
if question_hash not in generated_questions:
generated_questions.add(question_hash)
save_generated_question(question_hash)
return question_data
except json.JSONDecodeError as e:
logging.error(f"Failed to parse JSON response: {e}")
logging.error(f"Response content: {response_content}")
raise
except ValueError as e:
logging.error(f"Invalid response structure: {e}")
logging.error(f"Response content: {response_content}")
raise
except Exception as e:
logging.error(f"An unexpected error occurred: {e}")
raise
# Example usage
if __name__ == "__main__":
question = generate_microbiology_question()
print(json.dumps(question, indent=2))