JStage_RAG / app.py
Coldog2333's picture
Upload folder using huggingface_hub
9cf9ef9 verified
import time
import random
import ujson as json
from typing import List
from dataclasses import dataclass
import gradio as gr
@dataclass
class Paper:
paper_id: str
title: str
abstract: str
authors: List[str] = None
year: int = None
doi: str = None
def load_database(filename):
database = []
with open(filename, "r", encoding="utf-8") as f:
for line in f:
json_data = json.loads(line)
data_point = Paper(
paper_id=json_data["paper_id"],
title=json_data["title"],
abstract=json_data["abstract"],
authors=json_data.get("authors", []),
year=json_data.get("year", None),
doi=json_data.get("doi", None)
)
database.append(data_point)
return database
class S2ORCRAGPipeline:
def __init__(
self,
s2orc_filename,
model=lambda x: x,
):
self.s2orc_filename = s2orc_filename
self.database = load_database(s2orc_filename)
self.model = model
def retrieve_top_k(
self,
query: str,
topk=5
):
# Fake
random.seed(len(query) + topk)
return random.sample(self.database, topk)
# Real
# TODO: DB-team
def generate_response(
self,
query,
retrieved_papers,
):
# Fake
response = f"{query}... わかった!こちらはあなたの質問に関連する論文です:\n"
for paper in retrieved_papers:
response += f"- {paper.title}: {paper.abstract}\n"
response += "\nどう思いますか?\n"
response = self.model(response)
return response
# Real
# TODO: Generation-team
def __call__(
self,
query
):
# Firstly, retrieve papers from database
retrieved_papers = self.retrieve_top_k(query, topk=3)
# Secondly, generate response based on query and the retrieved papers
response = self.generate_response(query, retrieved_papers)
return response
def slow_echo(self, message, history):
output = self.__call__(query=message)
for i in range(len(output)):
time.sleep(0.001)
yield output[: i + 1]
if __name__ == "__main__":
# load from S2ORC
example_filename = "sample.jsonl"
pipeline = S2ORCRAGPipeline(
s2orc_filename=example_filename,
model=lambda x: x
)
initial_messages = [{"role": "assistant", "content": "こんにちは〜今日は何の論文を探したいですか?"}]
demo = gr.ChatInterface(
pipeline.slow_echo,
chatbot=gr.Chatbot(
value=initial_messages,
type="messages",
resizable=True, height=700,
placeholder="こんにちは〜今日は何の論文を探したいですか?"
),
type="messages",
flagging_mode="manual",
flagging_options=["Like", "Spam", "Inappropriate", "Other"],
title="LLMC S2ORC 論文検索 (+RAG)",
description="",
save_history=True,
examples=["こんにちは", "LLM関連の論文を探したい", "Find Suzuki's papers on graphene from 2019 to 2021 in Surface Science Journal."],
)
demo.launch(debug=True, share=True) # Share=True is failed when using NII Network