Spaces:
Build error
Build error
| import gradio as gr | |
| import pandas as pd | |
| from OpenAITools.FetchTools import fetch_clinical_trials, fetch_clinical_trials_jp | |
| from langchain_openai import ChatOpenAI | |
| from langchain_groq import ChatGroq | |
| from OpenAITools.CrinicalTrialTools import QuestionModifierEnglish, TumorNameExtractor, SimpleClinicalTrialAgent, GraderAgent | |
| # モデルとエージェントの初期化 | |
| groq = ChatGroq(model_name="llama3-70b-8192", temperature=0) | |
| modifier = QuestionModifierEnglish(groq) | |
| extractor = TumorNameExtractor(groq) | |
| CriteriaCheckAgent = SimpleClinicalTrialAgent(groq) | |
| grader_agent = GraderAgent(groq) | |
| # データフレームを生成する関数 | |
| def generate_dataframe_from_question(ex_question): | |
| # Modify and extract tumor name | |
| modified_question = modifier.modify_question(ex_question) | |
| tumor_name = extractor.extract_tumor_name(ex_question) | |
| # Get clinical trials data based on tumor name | |
| df = fetch_clinical_trials(tumor_name) | |
| df['AgentJudgment'] = None | |
| df['AgentGrade'] = None | |
| # NCTIDのリストを作成し、プログレスバーを表示 | |
| NCTIDs = list(df['NCTID']) | |
| progress = gr.Progress(track_tqdm=True) | |
| for i, nct_id in enumerate(NCTIDs): | |
| target_criteria = df.loc[df['NCTID'] == nct_id, 'Eligibility Criteria'].values[0] | |
| agent_judgment = CriteriaCheckAgent.evaluate_eligibility(target_criteria, modified_question) | |
| agent_grade = grader_agent.evaluate_eligibility(agent_judgment) | |
| # Update DataFrame | |
| df.loc[df['NCTID'] == nct_id, 'AgentJudgment'] = agent_judgment | |
| df.loc[df['NCTID'] == nct_id, 'AgentGrade'] = agent_grade | |
| # プログレスバーを更新(進行状況を浮動小数点数で渡す) | |
| progress((i + 1) / len(NCTIDs)) | |
| # 列を指定した順に並び替え | |
| columns_order = ['NCTID', 'AgentGrade', 'Title', 'AgentJudgment', 'Japanes Locations', | |
| 'Primary Completion Date', 'Cancer', 'Summary', 'Eligibility Criteria'] | |
| df = df[columns_order] | |
| return df, df # フィルタ用と表示用にデータフレームを返す | |
| # AgentGradeが特定の値(yes, no, unclear)の行だけを選択する関数 | |
| def filter_rows_by_grade(original_df, grade): | |
| df_filtered = original_df[original_df['AgentGrade'] == grade] | |
| return df_filtered, df_filtered # フィルタした結果を2つ返す | |
| # CSVとして保存しダウンロードする関数 | |
| def download_filtered_csv(df): | |
| file_path = "filtered_data.csv" # 現在の作業ディレクトリに保存 | |
| df.to_csv(file_path, index=False) # CSVファイルとして保存 | |
| return file_path | |
| # Gradioインターフェースの作成 | |
| with gr.Blocks() as demo: | |
| # 説明 | |
| gr.Markdown("## 質問を入力して、患者さんが参加可能な臨床治験の情報を収集。参加可能か否かを判断根拠も含めて提示します。結果はcsvとしてダウンロード可能です") | |
| # 質問入力ボックス | |
| question_input = gr.Textbox(label="質問を入力してください", placeholder="例: 65歳男性でBRCA遺伝子の変異がある前立腺癌患者さんが参加できる臨床治験を教えて下さい。") | |
| # データフレーム表示エリア | |
| dataframe_output = gr.DataFrame() | |
| # データの元となるDataFrameを保存するためのstate | |
| original_df = gr.State() | |
| filtered_df = gr.State() | |
| # データフレームを生成するボタン | |
| generate_button = gr.Button("日本で行われている患者さんの癌腫の臨床治験を全て取得する") | |
| # ボタンでAgentGradeがyes, no, unclearの行のみ表示 | |
| yes_button = gr.Button("AI Agentが患者さんが参加可能であると判断した臨床治験のみを表示") | |
| no_button = gr.Button("I Agentが患者さんが参加不可であると判断した臨床治験のみを表示") | |
| unclear_button = gr.Button("AI Agentが与えられた情報だけでは判断不可能とした臨床治験のみを表示") | |
| # フィルタ結果をダウンロードするボタン | |
| download_button = gr.Button("フィルタ結果をCSVとしてダウンロード") | |
| download_output = gr.File() # ダウンロード用の出力エリア | |
| # データフレームを生成して保存 | |
| generate_button.click(fn=generate_dataframe_from_question, inputs=question_input, outputs=[dataframe_output, original_df]) | |
| # yesボタン、noボタン、unclearボタンが押されたらフィルタしたデータを表示 | |
| yes_button.click(fn=filter_rows_by_grade, inputs=[original_df, gr.State("yes")], outputs=[dataframe_output, filtered_df]) | |
| no_button.click(fn=filter_rows_by_grade, inputs=[original_df, gr.State("no")], outputs=[dataframe_output, filtered_df]) | |
| unclear_button.click(fn=filter_rows_by_grade, inputs=[original_df, gr.State("unclear")], outputs=[dataframe_output, filtered_df]) | |
| # ダウンロードボタンを押すとフィルタ結果のCSVをダウンロード | |
| download_button.click(fn=download_filtered_csv, inputs=filtered_df, outputs=download_output) | |
| if __name__ == "__main__": | |
| demo.launch() |