Spaces:
Sleeping
Sleeping
import time | |
import random | |
import ujson as json | |
from typing import List | |
from dataclasses import dataclass | |
import gradio as gr | |
class Paper: | |
paper_id: str | |
title: str | |
abstract: str | |
authors: List[str] = None | |
year: int = None | |
doi: str = None | |
def load_database(filename): | |
database = [] | |
with open(filename, "r", encoding="utf-8") as f: | |
for line in f: | |
json_data = json.loads(line) | |
data_point = Paper( | |
paper_id=json_data["paper_id"], | |
title=json_data["title"], | |
abstract=json_data["abstract"], | |
authors=json_data.get("authors", []), | |
year=json_data.get("year", None), | |
doi=json_data.get("doi", None) | |
) | |
database.append(data_point) | |
return database | |
class S2ORCRAGPipeline: | |
def __init__( | |
self, | |
s2orc_filename, | |
model=lambda x: x, | |
): | |
self.s2orc_filename = s2orc_filename | |
self.database = load_database(s2orc_filename) | |
self.model = model | |
def retrieve_top_k( | |
self, | |
query: str, | |
topk=5 | |
): | |
# Fake | |
random.seed(len(query) + topk) | |
return random.sample(self.database, topk) | |
# Real | |
# TODO: DB-team | |
def generate_response( | |
self, | |
query, | |
retrieved_papers, | |
): | |
# Fake | |
response = f"{query}... わかった!こちらはあなたの質問に関連する論文です:\n" | |
for paper in retrieved_papers: | |
response += f"- {paper.title}: {paper.abstract}\n" | |
response += "\nどう思いますか?\n" | |
response = self.model(response) | |
return response | |
# Real | |
# TODO: Generation-team | |
def __call__( | |
self, | |
query | |
): | |
# Firstly, retrieve papers from database | |
retrieved_papers = self.retrieve_top_k(query, topk=3) | |
# Secondly, generate response based on query and the retrieved papers | |
response = self.generate_response(query, retrieved_papers) | |
return response | |
def slow_echo(self, message, history): | |
output = self.__call__(query=message) | |
for i in range(len(output)): | |
time.sleep(0.001) | |
yield output[: i + 1] | |
if __name__ == "__main__": | |
# load from S2ORC | |
example_filename = "sample.jsonl" | |
pipeline = S2ORCRAGPipeline( | |
s2orc_filename=example_filename, | |
model=lambda x: x | |
) | |
initial_messages = [{"role": "assistant", "content": "こんにちは〜今日は何の論文を探したいですか?"}] | |
demo = gr.ChatInterface( | |
pipeline.slow_echo, | |
chatbot=gr.Chatbot( | |
value=initial_messages, | |
type="messages", | |
resizable=True, height=700, | |
placeholder="こんにちは〜今日は何の論文を探したいですか?" | |
), | |
type="messages", | |
flagging_mode="manual", | |
flagging_options=["Like", "Spam", "Inappropriate", "Other"], | |
title="LLMC S2ORC 論文検索 (+RAG)", | |
description="", | |
save_history=True, | |
examples=["こんにちは", "LLM関連の論文を探したい", "Find Suzuki's papers on graphene from 2019 to 2021 in Surface Science Journal."], | |
) | |
demo.launch(debug=True, share=True) # Share=True is failed when using NII Network | |