File size: 5,848 Bytes
3ac45a5
b99fbc5
3ac45a5
ee324f3
 
 
3ac45a5
 
 
 
 
 
57957da
 
 
 
 
 
 
 
 
3ac45a5
 
 
 
57957da
 
 
 
 
 
 
 
3ac45a5
 
 
 
57957da
13a0823
57957da
13a0823
57957da
 
 
3ac45a5
57957da
3ac45a5
 
 
 
57957da
3ac45a5
 
13a0823
57957da
 
 
ee324f3
3ac45a5
 
 
57957da
 
 
 
 
 
 
 
3ac45a5
 
 
 
 
 
 
 
 
57957da
 
 
3ac45a5
 
 
 
 
 
57957da
 
 
 
 
3ac45a5
 
b99fbc5
 
3ac45a5
57957da
 
 
 
 
 
 
 
 
 
 
 
 
3ac45a5
 
57957da
3ac45a5
f143400
3ac45a5
57957da
3ac45a5
 
 
57957da
 
 
 
 
 
 
13a0823
c6f3317
13a0823
57957da
 
3ac45a5
 
13a0823
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
import pytz

sentense_transformers_model = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
ranker_model = 'hotchpotch/japanese-reranker-cross-encoder-base-v1' 

from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.joiners import DocumentJoiner
from haystack.components.rankers import TransformersSimilarityRanker
from haystack import Pipeline,component,Document
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
from haystack.components.builders import ChatPromptBuilder
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
from haystack.dataclasses import ChatMessage
from itertools import chain
from typing import Any,List
from haystack.core.component.types import Variadic

document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
print('document_store loaded' ,document_store.count_documents())

@component
class ListJoiner:
    def __init__(self, _type: Any):
        component.set_output_types(self, values=_type)
    def run(self, values: Variadic[Any]):
        result = list(chain(*values))
        return {"values": result}

class Niwa_rag :
    def __init__(self):
        self.createPipe()
    def createPipe(self):
        user_message_template = """
            γƒγƒ£γƒƒγƒˆε±₯ζ­΄γ¨ζδΎ›γ•γ‚ŒγŸθ³‡ζ–™γ«εŸΊγ₯いて、θ³ͺε•γ«η­”γˆγ¦γγ γ•γ„γ€‚
        
            γƒγƒ£γƒƒγƒˆε±₯ζ­΄:
            {% for memory in memories %}
                {{ memory.content }}
            {% endfor %}
        
            資料:
            {% for document in documents %}
                {{ document.content }}
            {% endfor %}
        
            θ³ͺ問: {{query}}
            ε›žη­”:
        """
        system_message = ChatMessage.from_system("あγͺγŸγ―γ€ζδΎ›γ•γ‚ŒγŸθ³‡ζ–™γ¨γƒγƒ£γƒƒγƒˆε±₯歴を使用して人間を支援するAIγ‚’γ‚·γ‚Ήγ‚Ώγƒ³γƒˆγ§γ™")
        user_message = ChatMessage.from_user(user_message_template)
        messages = [system_message, user_message]
        
        text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model)
        embedding_retriever = InMemoryEmbeddingRetriever(document_store)
        bm25_retriever = InMemoryBM25Retriever(document_store)
        document_joiner = DocumentJoiner()
        ranker = TransformersSimilarityRanker(model=ranker_model,top_k=8)
        prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
        gemini = GoogleAIGeminiChatGenerator(model="models/gemini-1.0-pro")
        
        memory_store = InMemoryChatMessageStore()
        memory_joiner = ListJoiner(List[ChatMessage])
        memory_retriever = ChatMessageRetriever(memory_store)
        memory_writer = ChatMessageWriter(memory_store)
        
        pipe = Pipeline()
        pipe.add_component("text_embedder", text_embedder)
        pipe.add_component("embedding_retriever", embedding_retriever)
        pipe.add_component("bm25_retriever", bm25_retriever)
        pipe.add_component("document_joiner", document_joiner)
        pipe.add_component("ranker", ranker)
        pipe.add_component("prompt_builder", prompt_builder)
        pipe.add_component("llm", gemini)
        pipe.add_component("memory_retriever", memory_retriever)
        pipe.add_component("memory_writer", memory_writer)
        pipe.add_component("memory_joiner", memory_joiner)
        
        pipe.connect("text_embedder", "embedding_retriever")
        pipe.connect("bm25_retriever", "document_joiner")
        pipe.connect("embedding_retriever", "document_joiner")
        pipe.connect("document_joiner", "ranker")
        pipe.connect("ranker.documents", "prompt_builder.documents")
        pipe.connect("prompt_builder.prompt", "llm.messages")
        pipe.connect("llm.replies", "memory_joiner")
        pipe.connect("memory_joiner", "memory_writer")
        pipe.connect("memory_retriever", "prompt_builder.memories")
        
        self.pipe = pipe
    def run(self,q):
        now = datetime.datetime.now(pytz.timezone('Asia/Tokyo'))
        print('q:',q,now)
        if not q :
            return {'reply':'','sources':''}
        result = self.pipe.run({
            "text_embedder": {"text": q},
            "bm25_retriever": {"query": q},
            "ranker": {"query": q},
            "prompt_builder": { "query": q},
            "memory_joiner": {"values": [ChatMessage.from_user(q)]},
        },include_outputs_from=["llm",'ranker'])
        reply = result['llm']['replies'][0]
        docs = result['ranker']['documents']
        print('reply:',reply)
        html = '<div class="ref-title">ε‚θ€ƒθ¨˜δΊ‹</div><div class="ref">'
        for doc in docs :
            title = doc.meta['title']
            link = doc.meta['link']
            row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
            html += row
            print('',title,link,doc.meta['type'],doc.score)
        html += '</div>'
        return {'reply':reply.content,'sources':html}

rag = Niwa_rag()

def fn(q,history):
    result = rag.run(q)
    return result['reply'] + result['sources']

app = gr.ChatInterface(
    fn, 
    type="messages",
    title='庭フゑン Chatbot',
    textbox=gr.Textbox(placeholder='θ³ͺε•γ‚’θ¨˜ε…₯して下さい',submit_btn=True),
    css_paths = './app.css'
)


if __name__ == "__main__":
    app.launch()