from __future__ import annotations import os from pathlib import Path import yaml from langchain_core.prompts import ChatPromptTemplate from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables import RunnableSequence from langgraph.prebuilt import ValidationNode from config.settings import settings from forms.schemas import ExtractedNotes, SOAPNote, DAPNote, BIRPNote, PIRPNote, GIRPNote, SIRPNote, FAIRFDARPNote, DARENote, PIENote, SOAPIERNote, SOAPIENote, POMRNote, NarrativeNote, CBENote, SBARNote from utils.youtube import download_transcript from utils.text_processing import chunk_text from models.llm_provider import get_llm from langchain.globals import set_llm_cache from langchain.cache import SQLiteCache set_llm_cache(SQLiteCache(database_path=".langchain.db")) from dotenv import load_dotenv load_dotenv() # Set environment for LangSmith tracing/logging os.environ["LANGCHAIN_TRACING_V2"] = "true" if settings.LANGCHAIN_API_KEY: os.environ["LANGCHAIN_API_KEY"] = settings.LANGCHAIN_API_KEY def load_prompt(note_type: str) -> tuple[str, str]: """Load the prompt template from YAML for the specified note type.""" prompt_path = Path("langhub/prompts/therapy_extraction_prompt.yaml") with open(prompt_path, "r") as f: data = yaml.safe_load(f) note_prompts = data.get("prompts", {}).get(note_type.lower()) if not note_prompts: raise ValueError(f"No prompt template found for note type: {note_type}") return note_prompts["system"], note_prompts["human"] def create_extraction_chain(note_type: str = "soap") -> RunnableSequence: """Create a chain for extracting structured notes.""" print(f"Creating extraction chain for {note_type.upper()} notes...") print("Initializing LLM...") llm = get_llm() print("Setting up schema mapping...") # Select the appropriate schema based on note type schema_map = { "soap": SOAPNote, "dap": DAPNote, "birp": BIRPNote, "birp_raw": BIRPNote, "pirp": PIRPNote, "girp": GIRPNote, "sirp": SIRPNote, "fair_fdarp": FAIRFDARPNote, "dare": DARENote, "pie": PIENote, "soapier": SOAPIERNote, "soapie": SOAPIENote, "pomr": POMRNote, "narrative": NarrativeNote, "cbe": CBENote, "sbar": SBARNote } schema = schema_map.get(note_type.lower()) if not schema: raise ValueError(f"Unsupported note type: {note_type}") print("Creating structured LLM output...") # Create structured LLM structured_llm = llm.with_structured_output(schema=schema, include_raw=True) print("Loading system prompt...") # Load system prompt and human prompt for the specific note type system_prompt, human_prompt = load_prompt(note_type) print("Creating prompt template...") # Create prompt template prompt_template = ChatPromptTemplate.from_messages([ ("system", system_prompt), ("human", human_prompt) ]) print("Building extraction chain...") # Create the chain chain = prompt_template | structured_llm print("Extraction chain created successfully") return chain def process_session(url: str, note_type: str = "soap") -> dict: """Process a single therapy session.""" try: # Download transcript print(f"Downloading transcript from {url}...") transcript = download_transcript(url) # Create extraction chain chain = create_extraction_chain(note_type) # Process transcript print("Extracting structured notes...") result = chain.invoke({ "note_type": note_type.upper(), "text": transcript }) return result.model_dump() except Exception as e: print(f"Error processing session: {str(e)}") return {} def main(): # Example YouTube sessions sessions = [ { "title": "CBT Role-Play – Complete Session – Part 6", "url": "https://www.youtube.com/watch?v=KuHLL2AE-SE" }, { "title": "CBT Role-Play – Complete Session – Part 7", "url": "https://www.youtube.com/watch?v=jS1KE3_Pqlc" } ] for session in sessions: print(f"\nProcessing session: {session['title']}") # Extract notes in different formats note_types = ["soap", "dap", "birp"] results = {} for note_type in note_types: print(f"\nExtracting {note_type.upper()} notes...") result = process_session(session["url"], note_type) results[note_type] = result # Print results print(f"\nResults for '{session['title']}':") for note_type, notes in results.items(): print(f"\n{note_type.upper()} Notes:") print(yaml.dump(notes, default_flow_style=False)) if __name__ == "__main__": main()