import time import random import ujson as json from typing import List from dataclasses import dataclass import gradio as gr @dataclass class Paper: paper_id: str title: str abstract: str authors: List[str] = None year: int = None doi: str = None def load_database(filename): database = [] with open(filename, "r", encoding="utf-8") as f: for line in f: json_data = json.loads(line) data_point = Paper( paper_id=json_data["paper_id"], title=json_data["title"], abstract=json_data["abstract"], authors=json_data.get("authors", []), year=json_data.get("year", None), doi=json_data.get("doi", None) ) database.append(data_point) return database class S2ORCRAGPipeline: def __init__( self, s2orc_filename, model=lambda x: x, ): self.s2orc_filename = s2orc_filename self.database = load_database(s2orc_filename) self.model = model def retrieve_top_k( self, query: str, topk=5 ): # Fake random.seed(len(query) + topk) return random.sample(self.database, topk) # Real # TODO: DB-team def generate_response( self, query, retrieved_papers, ): # Fake response = f"{query}... わかった!こちらはあなたの質問に関連する論文です:\n" for paper in retrieved_papers: response += f"- {paper.title}: {paper.abstract}\n" response += "\nどう思いますか?\n" response = self.model(response) return response # Real # TODO: Generation-team def __call__( self, query ): # Firstly, retrieve papers from database retrieved_papers = self.retrieve_top_k(query, topk=3) # Secondly, generate response based on query and the retrieved papers response = self.generate_response(query, retrieved_papers) return response def slow_echo(self, message, history): output = self.__call__(query=message) for i in range(len(output)): time.sleep(0.001) yield output[: i + 1] if __name__ == "__main__": # load from S2ORC example_filename = "sample.jsonl" pipeline = S2ORCRAGPipeline( s2orc_filename=example_filename, model=lambda x: x ) initial_messages = [{"role": "assistant", "content": "こんにちは〜今日は何の論文を探したいですか?"}] demo = gr.ChatInterface( pipeline.slow_echo, chatbot=gr.Chatbot( value=initial_messages, type="messages", resizable=True, height=700, placeholder="こんにちは〜今日は何の論文を探したいですか?" ), type="messages", flagging_mode="manual", flagging_options=["Like", "Spam", "Inappropriate", "Other"], title="LLMC S2ORC 論文検索 (+RAG)", description="", save_history=True, examples=["こんにちは", "LLM関連の論文を探したい", "Find Suzuki's papers on graphene from 2019 to 2021 in Surface Science Journal."], ) demo.launch(debug=True, share=True) # Share=True is failed when using NII Network