Upload 12 files
Browse files- app (10).py +109 -0
- browser (2).py +317 -0
- chat_with_bot (2).py +278 -0
- cookies (2).py +713 -0
- llm_engine (2).py +179 -0
- mdconvert (2).py +659 -0
- requirements (93).txt +32 -0
- rwkv_cpp_model (2).py +388 -0
- rwkv_cpp_shared_library (2).py +450 -0
- rwkv_world_tokenizer (2).py +126 -0
- sampling (2).py +52 -0
- tokenizer_util (2).py +38 -0
app (10).py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import login, InferenceClient
|
2 |
+
import os, gc, time, random, datetime, json, re
|
3 |
+
HF_TOKEN=os.getenv('HF_TOKEN')
|
4 |
+
SERP_API_KEY=os.getenv('SERP_KEY')
|
5 |
+
login(token=HF_TOKEN)
|
6 |
+
import gradio as gr
|
7 |
+
from transformers import CodeAgent, Tool, ToolCollection, load_tool, ReactCodeAgent, ReactJsonAgent
|
8 |
+
from transformers.agents import PythonInterpreterTool
|
9 |
+
from langchain.memory import ConversationBufferMemory
|
10 |
+
import bs4
|
11 |
+
import requests
|
12 |
+
from llm_engine import HfEngine
|
13 |
+
import datasets
|
14 |
+
import spaces
|
15 |
+
import tqdm
|
16 |
+
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
17 |
+
from langchain_community.vectorstores import FAISS
|
18 |
+
from langchain.docstore.document import Document
|
19 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
20 |
+
from langchain_core.vectorstores import VectorStore
|
21 |
+
from transformers.agents.prompts import DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT
|
22 |
+
from transformers.agents.default_tools import Tool, PythonInterpreterTool
|
23 |
+
from duckduckgo_search import DDGS
|
24 |
+
from web_surfer import (SearchInformationTool, NavigationalSearchTool, VisitTool, DownloadTool, PageUpTool, PageDownTool, FinderTool, FindNextTool, ArchiveSearchTool,)
|
25 |
+
from mdconvert import MarkdownConverter
|
26 |
+
from visual_qa import VisualQATool, VisualQAGPT4Tool
|
27 |
+
def search_ducky(query):
|
28 |
+
with DDGS() as ddgs:
|
29 |
+
results = list(ddgs.text(query, max_results=10))
|
30 |
+
content = ''
|
31 |
+
if results:
|
32 |
+
for result in results:
|
33 |
+
content += result['body']
|
34 |
+
return content
|
35 |
+
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
36 |
+
source_docs = [Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base]
|
37 |
+
docs_processed = RecursiveCharacterTextSplitter(chunk_size=500).split_documents(source_docs)[:1000]
|
38 |
+
embedding_model = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
|
39 |
+
vectordb = FAISS.from_documents(documents=docs_processed, embedding=embedding_model)
|
40 |
+
all_sources = list(set([doc.metadata["source"] for doc in docs_processed]))
|
41 |
+
print(all_sources)
|
42 |
+
class RetrieverTool(Tool):
|
43 |
+
name = "retriever"
|
44 |
+
description = "Retrieves some documents from the knowledge base that have the closest embeddings to the input query."
|
45 |
+
inputs = {
|
46 |
+
"query": {
|
47 |
+
"type": "text",
|
48 |
+
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
|
49 |
+
},
|
50 |
+
"source": {
|
51 |
+
"type": "text",
|
52 |
+
"description": ""
|
53 |
+
},
|
54 |
+
}
|
55 |
+
output_type = "text"
|
56 |
+
|
57 |
+
def __init__(self, vectordb: VectorStore, all_sources: str, **kwargs):
|
58 |
+
super().__init__(**kwargs)
|
59 |
+
self.vectordb = vectordb
|
60 |
+
self.inputs["source"]["description"] = (f"The source of the documents to search, as a str representation of a list. Possible values in the list are: {all_sources}. If this argument is not provided, all sources will be searched.")
|
61 |
+
|
62 |
+
def forward(self, query: str, source: str = None) -> str:
|
63 |
+
assert isinstance(query, str), "Your search query must be a string"
|
64 |
+
|
65 |
+
if source:
|
66 |
+
if isinstance(source, str) and "[" not in str(source): # if the source is not representing a list
|
67 |
+
source = [source]
|
68 |
+
source = json.loads(str(source).replace("'", '"'))
|
69 |
+
|
70 |
+
docs = self.vectordb.similarity_search(query, filter=({"source": source} if source else None), k=3)
|
71 |
+
|
72 |
+
if len(docs) == 0:
|
73 |
+
return "No documents found with this filtering. Try removing the source filter."
|
74 |
+
return "Retrieved documents:\n\n" + "\n===Document===\n".join([doc.page_content for doc in docs])
|
75 |
+
memory = ConversationBufferMemory(memory_key="chat_history")
|
76 |
+
llm_engine = HfEngine(model="Jopmt/JoPmt")
|
77 |
+
##gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
|
78 |
+
##prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
|
79 |
+
##tools = [StableDiffusionTool().langchain, ImageCaptioningTool().langchain, StableDiffusionPromptGeneratorTool().langchain, TextToVideoTool().langchain]
|
80 |
+
##tools=[prompt_generator_tool(), image_generation_tool(), PythonInterpreterTool()]
|
81 |
+
class SearchTool(Tool):
|
82 |
+
name = "ask_search_agent"
|
83 |
+
description = "A search agent that will browse the internet to answer a question. Use it to gather informations, not for problem-solving."
|
84 |
+
|
85 |
+
inputs = {
|
86 |
+
"question": {
|
87 |
+
"description": "Your question, as a natural language sentence. You are talking to an agent, so provide them with as much context as possible.",
|
88 |
+
"type": "text",
|
89 |
+
}
|
90 |
+
}
|
91 |
+
output_type = "text"
|
92 |
+
|
93 |
+
def forward(self, question: str) -> str:
|
94 |
+
return websurfer_agent.run(question)
|
95 |
+
tools=[PythonInterpreterTool(),SearchTool(),RetrieverTool(vectordb, all_sources)]
|
96 |
+
additional_authorized_imports=['requests', 'bs4', 'os', 'time', 'datetime', 'json', 're']
|
97 |
+
WEB_TOOLS = [SearchInformationTool(), NavigationalSearchTool(), VisitTool(), DownloadTool(), PageUpTool(), PageDownTool(), FinderTool(), FindNextTool(), ArchiveSearchTool(),]
|
98 |
+
websurfer_agent = ReactJsonAgent(tools=WEB_TOOLS,llm_engine=llm_engine, add_base_tools=True,max_iterations=1)
|
99 |
+
reagent = ReactCodeAgent(tools=tools, llm_engine=llm_engine, add_base_tools=True,max_iterations=1,additional_authorized_imports=additional_authorized_imports)
|
100 |
+
def plix(inut, progress=gr.Progress(track_tqdm=True)):
|
101 |
+
goose=reagent.run(inut)
|
102 |
+
return goose
|
103 |
+
with gr.Blocks(theme=random.choice([gr.themes.Monochrome(),gr.themes.Base.from_hub("gradio/seafoam"),gr.themes.Base.from_hub("freddyaboulton/dracula_revamped"),gr.themes.Glass(),gr.themes.Base(),]),analytics_enabled=False) as iface:
|
104 |
+
out=gr.Textbox(label="🤗Output",lines=5,interactive=False)
|
105 |
+
inut=gr.Textbox(label="Prompt")
|
106 |
+
btn=gr.Button("GENERATE")
|
107 |
+
btn.click(fn=plix,inputs=inut,outputs=out)
|
108 |
+
iface.queue(max_size=1,api_open=False)
|
109 |
+
iface.launch(max_threads=20,inline=False,show_api=False)
|
browser (2).py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Shamelessly stolen from Microsoft Autogen team: thanks to them for this great resource!
|
2 |
+
# https://github.com/microsoft/autogen/blob/gaia_multiagent_v01_march_1st/autogen/browser_utils.py
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import requests
|
6 |
+
import re
|
7 |
+
import io
|
8 |
+
import uuid
|
9 |
+
import mimetypes
|
10 |
+
import time
|
11 |
+
import pathlib
|
12 |
+
import pathvalidate
|
13 |
+
from urllib.parse import urljoin, urlparse, unquote, parse_qs
|
14 |
+
from urllib.request import url2pathname
|
15 |
+
from typing import Any, Dict, List, Optional, Union, Tuple
|
16 |
+
from mdconvert import MarkdownConverter, UnsupportedFormatException, FileConversionException
|
17 |
+
from serpapi import GoogleSearch
|
18 |
+
from cookies import COOKIES
|
19 |
+
from duckduckgo_search import DDGS
|
20 |
+
|
21 |
+
|
22 |
+
class SimpleTextBrowser:
|
23 |
+
"""(In preview) An extremely simple text-based web browser comparable to Lynx. Suitable for Agentic use."""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
start_page: Optional[str] = None,
|
28 |
+
viewport_size: Optional[int] = 1024 * 8,
|
29 |
+
downloads_folder: Optional[Union[str, None]] = None,
|
30 |
+
##serpapi_key: Optional[Union[str, None]] = None,
|
31 |
+
request_kwargs: Optional[Union[Dict[str, Any], None]] = None,
|
32 |
+
):
|
33 |
+
self.start_page: str = start_page if start_page else "about:blank"
|
34 |
+
self.viewport_size = viewport_size # Applies only to the standard uri types
|
35 |
+
self.downloads_folder = downloads_folder
|
36 |
+
self.history: List[Tuple[str, float]] = list()
|
37 |
+
self.page_title: Optional[str] = None
|
38 |
+
self.viewport_current_page = 0
|
39 |
+
self.viewport_pages: List[Tuple[int, int]] = list()
|
40 |
+
self.set_address(self.start_page)
|
41 |
+
##self.serpapi_key = serpapi_key
|
42 |
+
self.request_kwargs = request_kwargs
|
43 |
+
self.request_kwargs["cookies"] = COOKIES
|
44 |
+
self._mdconvert = MarkdownConverter()
|
45 |
+
self._page_content: str = ""
|
46 |
+
|
47 |
+
self._find_on_page_query: Union[str, None] = None
|
48 |
+
self._find_on_page_last_result: Union[int, None] = None # Location of the last result
|
49 |
+
|
50 |
+
@property
|
51 |
+
def address(self) -> str:
|
52 |
+
"""Return the address of the current page."""
|
53 |
+
return self.history[-1][0]
|
54 |
+
|
55 |
+
def set_address(self, uri_or_path: str) -> None:
|
56 |
+
# TODO: Handle anchors
|
57 |
+
self.history.append((uri_or_path, time.time()))
|
58 |
+
|
59 |
+
# Handle special URIs
|
60 |
+
if uri_or_path == "about:blank":
|
61 |
+
self._set_page_content("")
|
62 |
+
elif uri_or_path.startswith("google:"):
|
63 |
+
self._serpapi_search(uri_or_path[len("google:"):].strip())
|
64 |
+
else:
|
65 |
+
if (
|
66 |
+
not uri_or_path.startswith("http:")
|
67 |
+
and not uri_or_path.startswith("https:")
|
68 |
+
and not uri_or_path.startswith("file:")
|
69 |
+
):
|
70 |
+
if len(self.history) > 1:
|
71 |
+
prior_address = self.history[-2][0]
|
72 |
+
uri_or_path = urljoin(prior_address, uri_or_path)
|
73 |
+
# Update the address with the fully-qualified path
|
74 |
+
self.history[-1] = (uri_or_path, self.history[-1][1])
|
75 |
+
self._fetch_page(uri_or_path)
|
76 |
+
|
77 |
+
self.viewport_current_page = 0
|
78 |
+
self.find_on_page_query = None
|
79 |
+
self.find_on_page_viewport = None
|
80 |
+
|
81 |
+
@property
|
82 |
+
def viewport(self) -> str:
|
83 |
+
"""Return the content of the current viewport."""
|
84 |
+
bounds = self.viewport_pages[self.viewport_current_page]
|
85 |
+
return self.page_content[bounds[0] : bounds[1]]
|
86 |
+
|
87 |
+
@property
|
88 |
+
def page_content(self) -> str:
|
89 |
+
"""Return the full contents of the current page."""
|
90 |
+
return self._page_content
|
91 |
+
|
92 |
+
def _set_page_content(self, content: str) -> None:
|
93 |
+
"""Sets the text content of the current page."""
|
94 |
+
self._page_content = content
|
95 |
+
self._split_pages()
|
96 |
+
if self.viewport_current_page >= len(self.viewport_pages):
|
97 |
+
self.viewport_current_page = len(self.viewport_pages) - 1
|
98 |
+
|
99 |
+
def page_down(self) -> None:
|
100 |
+
self.viewport_current_page = min(self.viewport_current_page + 1, len(self.viewport_pages) - 1)
|
101 |
+
|
102 |
+
def page_up(self) -> None:
|
103 |
+
self.viewport_current_page = max(self.viewport_current_page - 1, 0)
|
104 |
+
|
105 |
+
def find_on_page(self, query: str) -> Union[str, None]:
|
106 |
+
"""Searches for the query from the current viewport forward, looping back to the start if necessary."""
|
107 |
+
|
108 |
+
# Did we get here via a previous find_on_page search with the same query?
|
109 |
+
# If so, map to find_next
|
110 |
+
if query == self._find_on_page_query and self.viewport_current_page == self._find_on_page_last_result:
|
111 |
+
return self.find_next()
|
112 |
+
|
113 |
+
# Ok it's a new search start from the current viewport
|
114 |
+
self._find_on_page_query = query
|
115 |
+
viewport_match = self._find_next_viewport(query, self.viewport_current_page)
|
116 |
+
if viewport_match is None:
|
117 |
+
self._find_on_page_last_result = None
|
118 |
+
return None
|
119 |
+
else:
|
120 |
+
self.viewport_current_page = viewport_match
|
121 |
+
self._find_on_page_last_result = viewport_match
|
122 |
+
return self.viewport
|
123 |
+
|
124 |
+
def find_next(self) -> None:
|
125 |
+
"""Scroll to the next viewport that matches the query"""
|
126 |
+
|
127 |
+
if self._find_on_page_query is None:
|
128 |
+
return None
|
129 |
+
|
130 |
+
starting_viewport = self._find_on_page_last_result
|
131 |
+
if starting_viewport is None:
|
132 |
+
starting_viewport = 0
|
133 |
+
else:
|
134 |
+
starting_viewport += 1
|
135 |
+
if starting_viewport >= len(self.viewport_pages):
|
136 |
+
starting_viewport = 0
|
137 |
+
|
138 |
+
viewport_match = self._find_next_viewport(self._find_on_page_query, starting_viewport)
|
139 |
+
if viewport_match is None:
|
140 |
+
self._find_on_page_last_result = None
|
141 |
+
return None
|
142 |
+
else:
|
143 |
+
self.viewport_current_page = viewport_match
|
144 |
+
self._find_on_page_last_result = viewport_match
|
145 |
+
return self.viewport
|
146 |
+
|
147 |
+
def _find_next_viewport(self, query: str, starting_viewport: int) -> Union[int, None]:
|
148 |
+
"""Search for matches between the starting viewport looping when reaching the end."""
|
149 |
+
|
150 |
+
if query is None:
|
151 |
+
return None
|
152 |
+
|
153 |
+
# Normalize the query, and convert to a regular expression
|
154 |
+
nquery = re.sub(r"\*", "__STAR__", query)
|
155 |
+
nquery = " " + (" ".join(re.split(r"\W+", nquery))).strip() + " "
|
156 |
+
nquery = nquery.replace(" __STAR__ ", "__STAR__ ") # Merge isolated stars with prior word
|
157 |
+
nquery = nquery.replace("__STAR__", ".*").lower()
|
158 |
+
|
159 |
+
if nquery.strip() == "":
|
160 |
+
return None
|
161 |
+
|
162 |
+
idxs = list()
|
163 |
+
idxs.extend(range(starting_viewport, len(self.viewport_pages)))
|
164 |
+
idxs.extend(range(0, starting_viewport))
|
165 |
+
|
166 |
+
for i in idxs:
|
167 |
+
bounds = self.viewport_pages[i]
|
168 |
+
content = self.page_content[bounds[0] : bounds[1]]
|
169 |
+
|
170 |
+
# TODO: Remove markdown links and images
|
171 |
+
ncontent = " " + (" ".join(re.split(r"\W+", content))).strip().lower() + " "
|
172 |
+
if re.search(nquery, ncontent):
|
173 |
+
return i
|
174 |
+
|
175 |
+
return None
|
176 |
+
|
177 |
+
def visit_page(self, path_or_uri: str) -> str:
|
178 |
+
self.set_address(path_or_uri)
|
179 |
+
return self.viewport
|
180 |
+
|
181 |
+
def _split_pages(self) -> None:
|
182 |
+
# Do not split search results
|
183 |
+
if self.address.startswith("google:"):
|
184 |
+
self.viewport_pages = [(0, len(self._page_content))]
|
185 |
+
return
|
186 |
+
|
187 |
+
# Handle empty pages
|
188 |
+
if len(self._page_content) == 0:
|
189 |
+
self.viewport_pages = [(0, 0)]
|
190 |
+
return
|
191 |
+
|
192 |
+
# Break the viewport into pages
|
193 |
+
self.viewport_pages = []
|
194 |
+
start_idx = 0
|
195 |
+
while start_idx < len(self._page_content):
|
196 |
+
end_idx = min(start_idx + self.viewport_size, len(self._page_content)) # type: ignore[operator]
|
197 |
+
# Adjust to end on a space
|
198 |
+
while end_idx < len(self._page_content) and self._page_content[end_idx - 1] not in [" ", "\t", "\r", "\n"]:
|
199 |
+
end_idx += 1
|
200 |
+
self.viewport_pages.append((start_idx, end_idx))
|
201 |
+
start_idx = end_idx
|
202 |
+
|
203 |
+
|
204 |
+
def _serpapi_search(self, query: str, filter_year: Optional[int] = None) -> None:
|
205 |
+
with DDGS() as ddgs:
|
206 |
+
results = list(ddgs.text(query, max_results=10))
|
207 |
+
|
208 |
+
self.page_title = f"{query} - Search"
|
209 |
+
|
210 |
+
if not results:
|
211 |
+
year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
|
212 |
+
self._set_page_content(f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter.")
|
213 |
+
return
|
214 |
+
|
215 |
+
web_snippets: List[str] = list()
|
216 |
+
for idx, page in enumerate(results, 1):
|
217 |
+
snippet = f"\n{page['body']}" if 'body' in page else ""
|
218 |
+
redacted_version = f"{idx}. [{page['title']}]({page['href']})\n{self._prev_visit(page['href'])}{snippet}"
|
219 |
+
web_snippets.append(redacted_version)
|
220 |
+
|
221 |
+
content = (f"A search for '{query}' found {len(web_snippets)} results:\n\n## Web Results\n" +
|
222 |
+
"\n\n".join(web_snippets))
|
223 |
+
self._set_page_content(content)
|
224 |
+
|
225 |
+
|
226 |
+
def _fetch_page(self, url: str) -> None:
|
227 |
+
download_path = ""
|
228 |
+
try:
|
229 |
+
if url.startswith("file://"):
|
230 |
+
download_path = os.path.normcase(os.path.normpath(unquote(url[7:])))
|
231 |
+
res = self._mdconvert.convert_local(download_path)
|
232 |
+
self.page_title = res.title
|
233 |
+
self._set_page_content(res.text_content)
|
234 |
+
else:
|
235 |
+
# Prepare the request parameters
|
236 |
+
request_kwargs = self.request_kwargs.copy() if self.request_kwargs is not None else {}
|
237 |
+
request_kwargs["stream"] = True
|
238 |
+
|
239 |
+
# Send a HTTP request to the URL
|
240 |
+
response = requests.get(url, **request_kwargs)
|
241 |
+
response.raise_for_status()
|
242 |
+
|
243 |
+
# If the HTTP request was successful
|
244 |
+
content_type = response.headers.get("content-type", "")
|
245 |
+
|
246 |
+
# Text or HTML
|
247 |
+
if "text/" in content_type.lower():
|
248 |
+
res = self._mdconvert.convert_response(response)
|
249 |
+
self.page_title = res.title
|
250 |
+
self._set_page_content(res.text_content)
|
251 |
+
# A download
|
252 |
+
else:
|
253 |
+
# Try producing a safe filename
|
254 |
+
fname = None
|
255 |
+
download_path = None
|
256 |
+
try:
|
257 |
+
fname = pathvalidate.sanitize_filename(os.path.basename(urlparse(url).path)).strip()
|
258 |
+
download_path = os.path.abspath(os.path.join(self.downloads_folder, fname))
|
259 |
+
|
260 |
+
suffix = 0
|
261 |
+
while os.path.exists(download_path) and suffix < 1000:
|
262 |
+
suffix += 1
|
263 |
+
base, ext = os.path.splitext(fname)
|
264 |
+
new_fname = f"{base}__{suffix}{ext}"
|
265 |
+
download_path = os.path.abspath(os.path.join(self.downloads_folder, new_fname))
|
266 |
+
|
267 |
+
except NameError:
|
268 |
+
pass
|
269 |
+
|
270 |
+
# No suitable name, so make one
|
271 |
+
if fname is None:
|
272 |
+
extension = mimetypes.guess_extension(content_type)
|
273 |
+
if extension is None:
|
274 |
+
extension = ".download"
|
275 |
+
fname = str(uuid.uuid4()) + extension
|
276 |
+
download_path = os.path.abspath(os.path.join(self.downloads_folder, fname))
|
277 |
+
|
278 |
+
# Open a file for writing
|
279 |
+
with open(download_path, "wb") as fh:
|
280 |
+
for chunk in response.iter_content(chunk_size=512):
|
281 |
+
fh.write(chunk)
|
282 |
+
|
283 |
+
# Render it
|
284 |
+
local_uri = pathlib.Path(download_path).as_uri()
|
285 |
+
self.set_address(local_uri)
|
286 |
+
|
287 |
+
|
288 |
+
except UnsupportedFormatException as e:
|
289 |
+
print(e)
|
290 |
+
self.page_title = ("Download complete.",)
|
291 |
+
self._set_page_content(f"# Download complete\n\nSaved file to '{download_path}'")
|
292 |
+
except FileConversionException as e:
|
293 |
+
print(e)
|
294 |
+
self.page_title = ("Download complete.",)
|
295 |
+
self._set_page_content(f"# Download complete\n\nSaved file to '{download_path}'")
|
296 |
+
except FileNotFoundError:
|
297 |
+
self.page_title = "Error 404"
|
298 |
+
self._set_page_content(f"## Error 404\n\nFile not found: {download_path}")
|
299 |
+
except requests.exceptions.RequestException as request_exception:
|
300 |
+
try:
|
301 |
+
self.page_title = f"Error {response.status_code}"
|
302 |
+
|
303 |
+
# If the error was rendered in HTML we might as well render it
|
304 |
+
content_type = response.headers.get("content-type", "")
|
305 |
+
if content_type is not None and "text/html" in content_type.lower():
|
306 |
+
res = self._mdconvert.convert(response)
|
307 |
+
self.page_title = f"Error {response.status_code}"
|
308 |
+
self._set_page_content(f"## Error {response.status_code}\n\n{res.text_content}")
|
309 |
+
else:
|
310 |
+
text = ""
|
311 |
+
for chunk in response.iter_content(chunk_size=512, decode_unicode=True):
|
312 |
+
text += chunk
|
313 |
+
self.page_title = f"Error {response.status_code}"
|
314 |
+
self._set_page_content(f"## Error {response.status_code}\n\n{text}")
|
315 |
+
except NameError:
|
316 |
+
self.page_title = f"Error"
|
317 |
+
self._set_page_content(f"## Error\n\n{str(request_exception)}")
|
chat_with_bot (2).py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Provides terminal-based chat interface for RWKV model.
|
2 |
+
# Usage: python chat_with_bot.py C:\rwkv.cpp-169M.bin
|
3 |
+
# Prompts and code adapted from https://github.com/BlinkDL/ChatRWKV/blob/9ca4cdba90efaee25cfec21a0bae72cbd48d8acd/chat.py
|
4 |
+
|
5 |
+
import os
|
6 |
+
import argparse
|
7 |
+
import pathlib
|
8 |
+
import copy
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
import sampling
|
12 |
+
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
|
13 |
+
from tokenizer_util import add_tokenizer_argument, get_tokenizer
|
14 |
+
from typing import List, Dict, Optional
|
15 |
+
|
16 |
+
# ======================================== Script settings ========================================
|
17 |
+
|
18 |
+
# English, Chinese, Japanese
|
19 |
+
LANGUAGE: str = 'English'
|
20 |
+
# QA: Question and Answer prompt to talk to an AI assistant.
|
21 |
+
# Chat: chat prompt (need a large model for adequate quality, 7B+).
|
22 |
+
PROMPT_TYPE: str = 'QA'
|
23 |
+
|
24 |
+
MAX_GENERATION_LENGTH: int = 250
|
25 |
+
|
26 |
+
# Sampling temperature. It could be a good idea to increase temperature when top_p is low.
|
27 |
+
TEMPERATURE: float = 0.8
|
28 |
+
# For better Q&A accuracy and less diversity, reduce top_p (to 0.5, 0.2, 0.1 etc.)
|
29 |
+
TOP_P: float = 0.5
|
30 |
+
# Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
31 |
+
PRESENCE_PENALTY: float = 0.2
|
32 |
+
# Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
|
33 |
+
FREQUENCY_PENALTY: float = 0.2
|
34 |
+
|
35 |
+
END_OF_LINE_TOKEN: int = 187
|
36 |
+
DOUBLE_END_OF_LINE_TOKEN: int = 535
|
37 |
+
END_OF_TEXT_TOKEN: int = 0
|
38 |
+
|
39 |
+
# =================================================================================================
|
40 |
+
|
41 |
+
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
|
42 |
+
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
|
43 |
+
add_tokenizer_argument(parser)
|
44 |
+
args = parser.parse_args()
|
45 |
+
|
46 |
+
script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
|
47 |
+
|
48 |
+
with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r', encoding='utf8') as json_file:
|
49 |
+
prompt_data = json.load(json_file)
|
50 |
+
|
51 |
+
user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt']
|
52 |
+
|
53 |
+
if init_prompt == '':
|
54 |
+
raise ValueError('Prompt must not be empty')
|
55 |
+
|
56 |
+
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
57 |
+
print(f'System info: {library.rwkv_get_system_info_string()}')
|
58 |
+
|
59 |
+
print('Loading RWKV model')
|
60 |
+
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
|
61 |
+
|
62 |
+
tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)
|
63 |
+
|
64 |
+
# =================================================================================================
|
65 |
+
|
66 |
+
processed_tokens: List[int] = []
|
67 |
+
logits: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None
|
68 |
+
state: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None
|
69 |
+
|
70 |
+
def process_tokens(_tokens: List[int], new_line_logit_bias: float = 0.0) -> None:
|
71 |
+
global processed_tokens, logits, state
|
72 |
+
|
73 |
+
logits, state = model.eval_sequence_in_chunks(_tokens, state, state, logits, use_numpy=True)
|
74 |
+
|
75 |
+
processed_tokens += _tokens
|
76 |
+
|
77 |
+
logits[END_OF_LINE_TOKEN] += new_line_logit_bias
|
78 |
+
|
79 |
+
state_by_thread: Dict[str, Dict] = {}
|
80 |
+
|
81 |
+
def save_thread_state(_thread: str) -> None:
|
82 |
+
state_by_thread[_thread] = {
|
83 |
+
'tokens': copy.deepcopy(processed_tokens),
|
84 |
+
'logits': copy.deepcopy(logits),
|
85 |
+
'state': copy.deepcopy(state)
|
86 |
+
}
|
87 |
+
|
88 |
+
def load_thread_state(_thread: str) -> None:
|
89 |
+
global processed_tokens, logits, state
|
90 |
+
|
91 |
+
thread_state = state_by_thread[_thread]
|
92 |
+
|
93 |
+
processed_tokens = copy.deepcopy(thread_state['tokens'])
|
94 |
+
logits = copy.deepcopy(thread_state['logits'])
|
95 |
+
state = copy.deepcopy(thread_state['state'])
|
96 |
+
|
97 |
+
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end.
|
98 |
+
# See https://github.com/BlinkDL/ChatRWKV/pull/110/files
|
99 |
+
def split_last_end_of_line(tokens: List[int]) -> List[int]:
|
100 |
+
if len(tokens) > 0 and tokens[-1] == DOUBLE_END_OF_LINE_TOKEN:
|
101 |
+
tokens = tokens[:-1] + [END_OF_LINE_TOKEN, END_OF_LINE_TOKEN]
|
102 |
+
|
103 |
+
return tokens
|
104 |
+
|
105 |
+
# =================================================================================================
|
106 |
+
|
107 |
+
processing_start: float = time.time()
|
108 |
+
|
109 |
+
prompt_tokens = tokenizer_encode(init_prompt)
|
110 |
+
prompt_token_count = len(prompt_tokens)
|
111 |
+
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
112 |
+
|
113 |
+
process_tokens(split_last_end_of_line(prompt_tokens))
|
114 |
+
|
115 |
+
processing_duration: float = time.time() - processing_start
|
116 |
+
|
117 |
+
print(f'Processed in {int(processing_duration)} s, {int(processing_duration / prompt_token_count * 1000)} ms per token')
|
118 |
+
|
119 |
+
save_thread_state('chat_init')
|
120 |
+
save_thread_state('chat')
|
121 |
+
|
122 |
+
print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.')
|
123 |
+
|
124 |
+
while True:
|
125 |
+
# Read user input
|
126 |
+
user_input: str = input(f'> {user}{separator} ')
|
127 |
+
msg: str = user_input.replace('\\n', '\n').strip()
|
128 |
+
|
129 |
+
temperature: float = TEMPERATURE
|
130 |
+
top_p: float = TOP_P
|
131 |
+
|
132 |
+
if '-temp=' in msg:
|
133 |
+
temperature = float(msg.split('-temp=')[1].split(' ')[0])
|
134 |
+
|
135 |
+
msg = msg.replace('-temp='+f'{temperature:g}', '')
|
136 |
+
|
137 |
+
if temperature <= 0.2:
|
138 |
+
temperature = 0.2
|
139 |
+
|
140 |
+
if temperature >= 5:
|
141 |
+
temperature = 5
|
142 |
+
|
143 |
+
if '-top_p=' in msg:
|
144 |
+
top_p = float(msg.split('-top_p=')[1].split(' ')[0])
|
145 |
+
|
146 |
+
msg = msg.replace('-top_p='+f'{top_p:g}', '')
|
147 |
+
|
148 |
+
if top_p <= 0:
|
149 |
+
top_p = 0
|
150 |
+
|
151 |
+
msg = msg.strip()
|
152 |
+
|
153 |
+
# + reset --> reset chat
|
154 |
+
if msg == '+reset':
|
155 |
+
load_thread_state('chat_init')
|
156 |
+
save_thread_state('chat')
|
157 |
+
print(f'{bot}{separator} Chat reset.\n')
|
158 |
+
continue
|
159 |
+
elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':
|
160 |
+
|
161 |
+
# +gen YOUR PROMPT --> free single-round generation with any prompt. Requires Novel model.
|
162 |
+
if msg[:5].lower() == '+gen ':
|
163 |
+
new = '\n' + msg[5:].strip()
|
164 |
+
state = None
|
165 |
+
processed_tokens = []
|
166 |
+
process_tokens(tokenizer_encode(new))
|
167 |
+
save_thread_state('gen_0')
|
168 |
+
|
169 |
+
# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
|
170 |
+
elif msg[:3].lower() == '+i ':
|
171 |
+
new = f'''
|
172 |
+
Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
173 |
+
|
174 |
+
# Instruction:
|
175 |
+
{msg[3:].strip()}
|
176 |
+
|
177 |
+
# Response:
|
178 |
+
'''
|
179 |
+
state = None
|
180 |
+
processed_tokens = []
|
181 |
+
process_tokens(tokenizer_encode(new))
|
182 |
+
save_thread_state('gen_0')
|
183 |
+
|
184 |
+
# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
|
185 |
+
elif msg[:4].lower() == '+qq ':
|
186 |
+
new = '\nQ: ' + msg[4:].strip() + '\nA:'
|
187 |
+
state = None
|
188 |
+
processed_tokens = []
|
189 |
+
process_tokens(tokenizer_encode(new))
|
190 |
+
save_thread_state('gen_0')
|
191 |
+
|
192 |
+
# +qa YOUR QUESTION --> answer an independent question (regardless of context).
|
193 |
+
elif msg[:4].lower() == '+qa ':
|
194 |
+
load_thread_state('chat_init')
|
195 |
+
|
196 |
+
real_msg = msg[4:].strip()
|
197 |
+
new = f'{user}{separator} {real_msg}\n\n{bot}{separator}'
|
198 |
+
|
199 |
+
process_tokens(tokenizer_encode(new))
|
200 |
+
save_thread_state('gen_0')
|
201 |
+
|
202 |
+
# +++ --> continue last free generation (only for +gen / +i)
|
203 |
+
elif msg.lower() == '+++':
|
204 |
+
try:
|
205 |
+
load_thread_state('gen_1')
|
206 |
+
save_thread_state('gen_0')
|
207 |
+
except Exception as e:
|
208 |
+
print(e)
|
209 |
+
continue
|
210 |
+
|
211 |
+
# ++ --> retry last free generation (only for +gen / +i)
|
212 |
+
elif msg.lower() == '++':
|
213 |
+
try:
|
214 |
+
load_thread_state('gen_0')
|
215 |
+
except Exception as e:
|
216 |
+
print(e)
|
217 |
+
continue
|
218 |
+
thread = 'gen_1'
|
219 |
+
|
220 |
+
else:
|
221 |
+
# + --> alternate chat reply
|
222 |
+
if msg.lower() == '+':
|
223 |
+
try:
|
224 |
+
load_thread_state('chat_pre')
|
225 |
+
except Exception as e:
|
226 |
+
print(e)
|
227 |
+
continue
|
228 |
+
# chat with bot
|
229 |
+
else:
|
230 |
+
load_thread_state('chat')
|
231 |
+
new = f'{user}{separator} {msg}\n\n{bot}{separator}'
|
232 |
+
process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999)
|
233 |
+
save_thread_state('chat_pre')
|
234 |
+
|
235 |
+
thread = 'chat'
|
236 |
+
|
237 |
+
# Print bot response
|
238 |
+
print(f'> {bot}{separator}', end='')
|
239 |
+
|
240 |
+
start_index: int = len(processed_tokens)
|
241 |
+
accumulated_tokens: List[int] = []
|
242 |
+
token_counts: Dict[int, int] = {}
|
243 |
+
|
244 |
+
for i in range(MAX_GENERATION_LENGTH):
|
245 |
+
for n in token_counts:
|
246 |
+
logits[n] -= PRESENCE_PENALTY + token_counts[n] * FREQUENCY_PENALTY
|
247 |
+
|
248 |
+
token: int = sampling.sample_logits(logits, temperature, top_p)
|
249 |
+
|
250 |
+
if token == END_OF_TEXT_TOKEN:
|
251 |
+
print()
|
252 |
+
break
|
253 |
+
|
254 |
+
if token not in token_counts:
|
255 |
+
token_counts[token] = 1
|
256 |
+
else:
|
257 |
+
token_counts[token] += 1
|
258 |
+
|
259 |
+
process_tokens([token])
|
260 |
+
|
261 |
+
# Avoid UTF-8 display issues
|
262 |
+
accumulated_tokens += [token]
|
263 |
+
|
264 |
+
decoded: str = tokenizer_decode(accumulated_tokens)
|
265 |
+
|
266 |
+
if '\uFFFD' not in decoded:
|
267 |
+
print(decoded, end='', flush=True)
|
268 |
+
|
269 |
+
accumulated_tokens = []
|
270 |
+
|
271 |
+
if thread == 'chat':
|
272 |
+
if '\n\n' in tokenizer_decode(processed_tokens[start_index:]):
|
273 |
+
break
|
274 |
+
|
275 |
+
if i == MAX_GENERATION_LENGTH - 1:
|
276 |
+
print()
|
277 |
+
|
278 |
+
save_thread_state(thread)
|
cookies (2).py
ADDED
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
COOKIES_LIST = [
|
2 |
+
{
|
3 |
+
"domain": ".youtube.com",
|
4 |
+
"expirationDate": 1718884961,
|
5 |
+
"hostOnly": False,
|
6 |
+
"httpOnly": False,
|
7 |
+
"name": "ST-xuwub9",
|
8 |
+
"path": "/",
|
9 |
+
"sameSite": None,
|
10 |
+
"secure": False,
|
11 |
+
"session": False,
|
12 |
+
"storeId": None,
|
13 |
+
"value": "session_logininfo=AFmmF2swRAIgf4gadACOuWOcipI1anW-dakEjtidNLkufnOC8uml7EECIDh2YisqWELDBJPTGUysCucJ3I0wjXxYjVHro1LHrdW0%3AQUQ3MjNmd2Jiajl3OWZYRnpFNnZlWWV5ZGJWZ0hpcmp4LVVPU280bk4zOS03Z0ozZG9fOFhWZ0dXaVo3NG1wTEg1b3hGaG10TFBlaFBnTlJfbER5bEp0aFhoNS1OLVhYNFRZT2F6ajgzOFpDbGhlUjZpMWRETlFFRjFfTTRiM0RnNTROSkdmMTFMVjFic1VuZ2trbGp4aktDa0JJUC1BWDh3"
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"domain": ".youtube.com",
|
17 |
+
"expirationDate": 1753004444.745411,
|
18 |
+
"hostOnly": False,
|
19 |
+
"httpOnly": True,
|
20 |
+
"name": "__Secure-YEC",
|
21 |
+
"path": "/",
|
22 |
+
"sameSite": "lax",
|
23 |
+
"secure": True,
|
24 |
+
"session": False,
|
25 |
+
"storeId": None,
|
26 |
+
"value": "CgtRVnI5LW1zRHlQVSjbtNCzBjIhCgJGUhIbEhcSFRMLFBUWFwwYGRobHB0eHw4PIBAREiAk"
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"domain": ".youtube.com",
|
30 |
+
"expirationDate": 1753434620.050824,
|
31 |
+
"hostOnly": False,
|
32 |
+
"httpOnly": True,
|
33 |
+
"name": "__Secure-3PSID",
|
34 |
+
"path": "/",
|
35 |
+
"sameSite": "no_restriction",
|
36 |
+
"secure": True,
|
37 |
+
"session": False,
|
38 |
+
"storeId": None,
|
39 |
+
"value": "g.a000kwibeLUu8Ea9Y-vLun7u3kU5VNJVuMAZl_jdfJaNm50JyDBB4ezJ_bdWu46a7YwObVn44wACgYKAakSARQSFQHGX2MicJcTzecTKH6bHzqU6TMbTxoVAUF8yKqQYK-MoI6Ql3vI2oYTB3E-0076"
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"domain": ".youtube.com",
|
43 |
+
"expirationDate": 1750420959.974642,
|
44 |
+
"hostOnly": False,
|
45 |
+
"httpOnly": False,
|
46 |
+
"name": "SIDCC",
|
47 |
+
"path": "/",
|
48 |
+
"sameSite": None,
|
49 |
+
"secure": False,
|
50 |
+
"session": False,
|
51 |
+
"storeId": None,
|
52 |
+
"value": "AKEyXzWQZauHKOo8t87zoEcjaVNIYUX54ohoWXT-tX4aAhEuZzIIptxZAcNkHuG2oDXYL6t-lw"
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"domain": ".youtube.com",
|
56 |
+
"expirationDate": 1753434620.050652,
|
57 |
+
"hostOnly": False,
|
58 |
+
"httpOnly": False,
|
59 |
+
"name": "SID",
|
60 |
+
"path": "/",
|
61 |
+
"sameSite": None,
|
62 |
+
"secure": False,
|
63 |
+
"session": False,
|
64 |
+
"storeId": None,
|
65 |
+
"value": "g.a000kwibeLUu8Ea9Y-vLun7u3kU5VNJVuMAZl_jdfJaNm50JyDBB6VHrZcC3gBAsFPbCQ0gF5AACgYKAYkSARQSFQHGX2Mi9kt0gHg5CxCYSkLQGHWaeBoVAUF8yKre_V6r3jZVak6JV4o2Q0FL0076"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"domain": ".youtube.com",
|
69 |
+
"expirationDate": 1750420958.397534,
|
70 |
+
"hostOnly": False,
|
71 |
+
"httpOnly": True,
|
72 |
+
"name": "__Secure-1PSIDTS",
|
73 |
+
"path": "/",
|
74 |
+
"sameSite": None,
|
75 |
+
"secure": True,
|
76 |
+
"session": False,
|
77 |
+
"storeId": None,
|
78 |
+
"value": "sidts-CjIB3EgAEkYL2L-GfrEzW5Dfy62S9oefGNLgst78S_986htCnGcfkxECch_9oz-qytSsZBAA"
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"domain": ".youtube.com",
|
82 |
+
"expirationDate": 1753433494.44729,
|
83 |
+
"hostOnly": False,
|
84 |
+
"httpOnly": False,
|
85 |
+
"name": "_ga_M0180HEFCY",
|
86 |
+
"path": "/",
|
87 |
+
"sameSite": None,
|
88 |
+
"secure": False,
|
89 |
+
"session": False,
|
90 |
+
"storeId": None,
|
91 |
+
"value": "GS1.1.1718871908.1.0.1718873494.0.0.0"
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"domain": ".youtube.com",
|
95 |
+
"expirationDate": 1753434620.050933,
|
96 |
+
"hostOnly": False,
|
97 |
+
"httpOnly": False,
|
98 |
+
"name": "SAPISID",
|
99 |
+
"path": "/",
|
100 |
+
"sameSite": None,
|
101 |
+
"secure": True,
|
102 |
+
"session": False,
|
103 |
+
"storeId": None,
|
104 |
+
"value": "mfeuiC-HraNJ-A03/ASXvCPNJSw7yTFgd6"
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"domain": ".youtube.com",
|
108 |
+
"expirationDate": 1750420959.974764,
|
109 |
+
"hostOnly": False,
|
110 |
+
"httpOnly": True,
|
111 |
+
"name": "__Secure-1PSIDCC",
|
112 |
+
"path": "/",
|
113 |
+
"sameSite": None,
|
114 |
+
"secure": True,
|
115 |
+
"session": False,
|
116 |
+
"storeId": None,
|
117 |
+
"value": "AKEyXzWHDSoXGCZpZhPxRrnC7B1s8zGIUjeMVyvgtQfsm1fs92lXPtFEI_td9LBUyqVUe0xK"
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"domain": ".youtube.com",
|
121 |
+
"expirationDate": 1753434620.050881,
|
122 |
+
"hostOnly": False,
|
123 |
+
"httpOnly": True,
|
124 |
+
"name": "SSID",
|
125 |
+
"path": "/",
|
126 |
+
"sameSite": None,
|
127 |
+
"secure": True,
|
128 |
+
"session": False,
|
129 |
+
"storeId": None,
|
130 |
+
"value": "AmlwXHnQvOQ10LVd-"
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"domain": ".youtube.com",
|
134 |
+
"expirationDate": 1753434620.050959,
|
135 |
+
"hostOnly": False,
|
136 |
+
"httpOnly": False,
|
137 |
+
"name": "__Secure-1PAPISID",
|
138 |
+
"path": "/",
|
139 |
+
"sameSite": None,
|
140 |
+
"secure": True,
|
141 |
+
"session": False,
|
142 |
+
"storeId": None,
|
143 |
+
"value": "mfeuiC-HraNJ-A03/ASXvCPNJSw7yTFgd6"
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"domain": ".youtube.com",
|
147 |
+
"expirationDate": 1753434620.050795,
|
148 |
+
"hostOnly": False,
|
149 |
+
"httpOnly": True,
|
150 |
+
"name": "__Secure-1PSID",
|
151 |
+
"path": "/",
|
152 |
+
"sameSite": None,
|
153 |
+
"secure": True,
|
154 |
+
"session": False,
|
155 |
+
"storeId": None,
|
156 |
+
"value": "g.a000kwibeLUu8Ea9Y-vLun7u3kU5VNJVuMAZl_jdfJaNm50JyDBBrlk7lRpKQGywAHEon7WGQAACgYKAQsSARQSFQHGX2MirAmnSRdZl6GPG6KLd4hOihoVAUF8yKoV17Tcj1a_OenIOkf2wBjO0076"
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"domain": ".youtube.com",
|
160 |
+
"expirationDate": 1753434620.050993,
|
161 |
+
"hostOnly": False,
|
162 |
+
"httpOnly": False,
|
163 |
+
"name": "__Secure-3PAPISID",
|
164 |
+
"path": "/",
|
165 |
+
"sameSite": "no_restriction",
|
166 |
+
"secure": True,
|
167 |
+
"session": False,
|
168 |
+
"storeId": None,
|
169 |
+
"value": "mfeuiC-HraNJ-A03/ASXvCPNJSw7yTFgd6"
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"domain": ".youtube.com",
|
173 |
+
"expirationDate": 1750420959.974815,
|
174 |
+
"hostOnly": False,
|
175 |
+
"httpOnly": True,
|
176 |
+
"name": "__Secure-3PSIDCC",
|
177 |
+
"path": "/",
|
178 |
+
"sameSite": "no_restriction",
|
179 |
+
"secure": True,
|
180 |
+
"session": False,
|
181 |
+
"storeId": None,
|
182 |
+
"value": "AKEyXzXM5UjKUEXwSHVmRAIo6hGHA4G63adj3EE1VdNriD0f38jZQbsUKiD4LQbA3BValmTFDg"
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"domain": ".youtube.com",
|
186 |
+
"expirationDate": 1750420958.397647,
|
187 |
+
"hostOnly": False,
|
188 |
+
"httpOnly": True,
|
189 |
+
"name": "__Secure-3PSIDTS",
|
190 |
+
"path": "/",
|
191 |
+
"sameSite": "no_restriction",
|
192 |
+
"secure": True,
|
193 |
+
"session": False,
|
194 |
+
"storeId": None,
|
195 |
+
"value": "sidts-CjIB3EgAEkYL2L-GfrEzW5Dfy62S9oefGNLgst78S_986htCnGcfkxECch_9oz-qytSsZBAA"
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"domain": ".youtube.com",
|
199 |
+
"expirationDate": 1753434620.050908,
|
200 |
+
"hostOnly": False,
|
201 |
+
"httpOnly": False,
|
202 |
+
"name": "APISID",
|
203 |
+
"path": "/",
|
204 |
+
"sameSite": None,
|
205 |
+
"secure": False,
|
206 |
+
"session": False,
|
207 |
+
"storeId": None,
|
208 |
+
"value": "IlQWLPjdNqziwCrV/ANG7Z4x5FF-IBxbZk"
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"domain": ".youtube.com",
|
212 |
+
"expirationDate": 1753434620.050855,
|
213 |
+
"hostOnly": False,
|
214 |
+
"httpOnly": True,
|
215 |
+
"name": "HSID",
|
216 |
+
"path": "/",
|
217 |
+
"sameSite": None,
|
218 |
+
"secure": False,
|
219 |
+
"session": False,
|
220 |
+
"storeId": None,
|
221 |
+
"value": "AasA7hmRuTFv7vjoq"
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"domain": ".youtube.com",
|
225 |
+
"expirationDate": 1753435873.577793,
|
226 |
+
"hostOnly": False,
|
227 |
+
"httpOnly": True,
|
228 |
+
"name": "LOGIN_INFO",
|
229 |
+
"path": "/",
|
230 |
+
"sameSite": "no_restriction",
|
231 |
+
"secure": True,
|
232 |
+
"session": False,
|
233 |
+
"storeId": None,
|
234 |
+
"value": "AFmmF2swRAIgf4gadACOuWOcipI1anW-dakEjtidNLkufnOC8uml7EECIDh2YisqWELDBJPTGUysCucJ3I0wjXxYjVHro1LHrdW0:QUQ3MjNmd2Jiajl3OWZYRnpFNnZlWWV5ZGJWZ0hpcmp4LVVPU280bk4zOS03Z0ozZG9fOFhWZ0dXaVo3NG1wTEg1b3hGaG10TFBlaFBnTlJfbER5bEp0aFhoNS1OLVhYNFRZT2F6ajgzOFpDbGhlUjZpMWRETlFFRjFfTTRiM0RnNTROSkdmMTFMVjFic1VuZ2trbGp4aktDa0JJUC1BWDh3"
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"domain": ".youtube.com",
|
238 |
+
"expirationDate": 1753444956.555608,
|
239 |
+
"hostOnly": False,
|
240 |
+
"httpOnly": False,
|
241 |
+
"name": "PREF",
|
242 |
+
"path": "/",
|
243 |
+
"sameSite": None,
|
244 |
+
"secure": True,
|
245 |
+
"session": False,
|
246 |
+
"storeId": None,
|
247 |
+
"value": "f4=4000000&f6=40000000&tz=Europe.Paris&f5=30000&f7=100"
|
248 |
+
}
|
249 |
+
]
|
250 |
+
|
251 |
+
COOKIES_LIST += [
|
252 |
+
{
|
253 |
+
"domain": ".www.researchgate.net",
|
254 |
+
"hostOnly": False,
|
255 |
+
"httpOnly": True,
|
256 |
+
"name": "isInstIp",
|
257 |
+
"path": "/",
|
258 |
+
"sameSite": None,
|
259 |
+
"secure": True,
|
260 |
+
"session": True,
|
261 |
+
"storeId": None,
|
262 |
+
"value": "False"
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"domain": ".researchgate.net",
|
266 |
+
"expirationDate": 1734423981,
|
267 |
+
"hostOnly": False,
|
268 |
+
"httpOnly": False,
|
269 |
+
"name": "__eoi",
|
270 |
+
"path": "/",
|
271 |
+
"sameSite": None,
|
272 |
+
"secure": False,
|
273 |
+
"session": False,
|
274 |
+
"storeId": None,
|
275 |
+
"value": "ID=c26f752377373146:T=1718871981:RT=1718884914:S=AA-AfjZw-T_OOX2kW2LLaFzXImgc"
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"domain": ".www.researchgate.net",
|
279 |
+
"expirationDate": 1753444909.646103,
|
280 |
+
"hostOnly": False,
|
281 |
+
"httpOnly": True,
|
282 |
+
"name": "ptc",
|
283 |
+
"path": "/",
|
284 |
+
"sameSite": None,
|
285 |
+
"secure": True,
|
286 |
+
"session": False,
|
287 |
+
"storeId": None,
|
288 |
+
"value": "RG1.8947708639250500550.1718872043"
|
289 |
+
},
|
290 |
+
{
|
291 |
+
"domain": ".researchgate.net",
|
292 |
+
"expirationDate": 1750507578,
|
293 |
+
"hostOnly": False,
|
294 |
+
"httpOnly": False,
|
295 |
+
"name": "euconsent-v2-didomi",
|
296 |
+
"path": "/",
|
297 |
+
"sameSite": "lax",
|
298 |
+
"secure": True,
|
299 |
+
"session": False,
|
300 |
+
"storeId": None,
|
301 |
+
"value": "CQAgmoAQAgmoAAHABBENA5EsAP_gAEPgAAYgJ2pB5G5UTWlBIG53YMskIAUFhFBoQEAgAACAAwIBSBIAIIwEAGAAIAgAICACAAIAIBIAIABAGAAAAAAAYIAAIAAIAAAQIAAKIAAAAAAAAgBQAAgIAgggEAAAgEBEABAAgAAAEIIAQNgACgAAACCAAAAAAAABAAAAAAAAQAAAAAAAYCQAAAJIAAAAACAIABAIAAAAAAAAAAAAAAAABBAAIJ2wPIAFAAXABQAFQALgAcAA8ACAAEgALwAZAA0ACIAEcAJgAUgAqgBcADEAGgAPQAfgBEACOAE4AMMAZYA0QBsgDkAHOAO4AfsBBwEIAItARwBHQC6gHUAO2Ae0A_4CHQEXgJ2AUOAo8BT4CpQFqALYAXmAwQBkgDLAGXANjAhCBG8CbAE3gJ1gTtAA.f_wACHwAAAAA"
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"domain": ".researchgate.net",
|
305 |
+
"expirationDate": 1718885236,
|
306 |
+
"hostOnly": False,
|
307 |
+
"httpOnly": False,
|
308 |
+
"name": "_gat",
|
309 |
+
"path": "/",
|
310 |
+
"sameSite": None,
|
311 |
+
"secure": False,
|
312 |
+
"session": False,
|
313 |
+
"storeId": None,
|
314 |
+
"value": "1"
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"domain": "www.researchgate.net",
|
318 |
+
"expirationDate": 1721477183,
|
319 |
+
"hostOnly": True,
|
320 |
+
"httpOnly": False,
|
321 |
+
"name": "_pbjs_userid_consent_data",
|
322 |
+
"path": "/",
|
323 |
+
"sameSite": "lax",
|
324 |
+
"secure": False,
|
325 |
+
"session": False,
|
326 |
+
"storeId": None,
|
327 |
+
"value": "3524755945110770"
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"domain": ".researchgate.net",
|
331 |
+
"expirationDate": 1752567981,
|
332 |
+
"hostOnly": False,
|
333 |
+
"httpOnly": False,
|
334 |
+
"name": "__gads",
|
335 |
+
"path": "/",
|
336 |
+
"sameSite": None,
|
337 |
+
"secure": False,
|
338 |
+
"session": False,
|
339 |
+
"storeId": None,
|
340 |
+
"value": "ID=eca2adb88969c830:T=1718871981:RT=1718884914:S=ALNI_MY2qZchynrhWX6hWMlaI87Pcj9riQ"
|
341 |
+
},
|
342 |
+
{
|
343 |
+
"domain": ".researchgate.net",
|
344 |
+
"expirationDate": 1718886709.646173,
|
345 |
+
"hostOnly": False,
|
346 |
+
"httpOnly": True,
|
347 |
+
"name": "__cf_bm",
|
348 |
+
"path": "/",
|
349 |
+
"sameSite": "no_restriction",
|
350 |
+
"secure": True,
|
351 |
+
"session": False,
|
352 |
+
"storeId": None,
|
353 |
+
"value": "IkQ_J4ciBzKQduRvjqsfSmQu8UygDWbHeROO5JVccfo-1718884909-1.0.1.1-qvNGEdbfI0HfhFP6kwe7R7mkTqODNhFuKhs72lLly6K2BOPMG3kbahpQFGvPK0U8FUfkznkq65gngd1sWj7sDA"
|
354 |
+
},
|
355 |
+
{
|
356 |
+
"domain": ".researchgate.net",
|
357 |
+
"expirationDate": 1752567981,
|
358 |
+
"hostOnly": False,
|
359 |
+
"httpOnly": False,
|
360 |
+
"name": "__gpi",
|
361 |
+
"path": "/",
|
362 |
+
"sameSite": None,
|
363 |
+
"secure": False,
|
364 |
+
"session": False,
|
365 |
+
"storeId": None,
|
366 |
+
"value": "UID=00000e4e9aa2e6f2:T=1718871981:RT=1718884914:S=ALNI_MYFNrgzkKn7K6Bd2y8hC6GJCvDiSg"
|
367 |
+
},
|
368 |
+
{
|
369 |
+
"domain": ".researchgate.net",
|
370 |
+
"hostOnly": False,
|
371 |
+
"httpOnly": True,
|
372 |
+
"name": "_cfuvid",
|
373 |
+
"path": "/",
|
374 |
+
"sameSite": "no_restriction",
|
375 |
+
"secure": True,
|
376 |
+
"session": True,
|
377 |
+
"storeId": None,
|
378 |
+
"value": "_GPmGZkBymiH3UiqTqzakEpi98br3nfFUWC2_u_wqkc-1718884909785-0.0.1.1-604800000"
|
379 |
+
},
|
380 |
+
{
|
381 |
+
"domain": ".researchgate.net",
|
382 |
+
"expirationDate": 1753445177.271667,
|
383 |
+
"hostOnly": False,
|
384 |
+
"httpOnly": False,
|
385 |
+
"name": "_ga",
|
386 |
+
"path": "/",
|
387 |
+
"sameSite": None,
|
388 |
+
"secure": False,
|
389 |
+
"session": False,
|
390 |
+
"storeId": None,
|
391 |
+
"value": "GA1.1.1525244793.1718885177"
|
392 |
+
},
|
393 |
+
{
|
394 |
+
"domain": ".researchgate.net",
|
395 |
+
"expirationDate": 1753445177.271482,
|
396 |
+
"hostOnly": False,
|
397 |
+
"httpOnly": False,
|
398 |
+
"name": "_ga_4P31SJ70EJ",
|
399 |
+
"path": "/",
|
400 |
+
"sameSite": None,
|
401 |
+
"secure": False,
|
402 |
+
"session": False,
|
403 |
+
"storeId": None,
|
404 |
+
"value": "GS1.1.1718885177.1.0.1718885177.0.0.0"
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"domain": ".researchgate.net",
|
408 |
+
"expirationDate": 1718971576,
|
409 |
+
"hostOnly": False,
|
410 |
+
"httpOnly": False,
|
411 |
+
"name": "_gid",
|
412 |
+
"path": "/",
|
413 |
+
"sameSite": None,
|
414 |
+
"secure": False,
|
415 |
+
"session": False,
|
416 |
+
"storeId": None,
|
417 |
+
"value": "GA1.2.854907463.1718885177"
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"domain": ".www.researchgate.net",
|
421 |
+
"expirationDate": 1750407982.506505,
|
422 |
+
"hostOnly": False,
|
423 |
+
"httpOnly": True,
|
424 |
+
"name": "did",
|
425 |
+
"path": "/",
|
426 |
+
"sameSite": None,
|
427 |
+
"secure": True,
|
428 |
+
"session": False,
|
429 |
+
"storeId": None,
|
430 |
+
"value": "1dWLO3C6am8l667Q4VUlBo0O1LI49Qi2Vw21SJEXHavBDYT56DI9007W5rYGVFVH"
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"domain": ".researchgate.net",
|
434 |
+
"expirationDate": 1750507578,
|
435 |
+
"hostOnly": False,
|
436 |
+
"httpOnly": False,
|
437 |
+
"name": "didomi_token",
|
438 |
+
"path": "/",
|
439 |
+
"sameSite": "lax",
|
440 |
+
"secure": True,
|
441 |
+
"session": False,
|
442 |
+
"storeId": None,
|
443 |
+
"value": "eyJ1c2VyX2lkIjoiMTkwMzU4YTUtNWU2My02Y2UzLWJlNzAtZGFjNzVmYjdiY2ExIiwiY3JlYXRlZCI6IjIwMjQtMDYtMjBUMTI6MDY6MTYuODA2WiIsInVwZGF0ZWQiOiIyMDI0LTA2LTIwVDEyOjA2OjE4Ljc4MVoiLCJ2ZW5kb3JzIjp7ImVuYWJsZWQiOlsidHdpdHRlciIsImdvb2dsZSIsImM6bGlua2VkaW4tbWFya2V0aW5nLXNvbHV0aW9ucyIsImM6b3duZXJpcSIsImM6b21uaXR1cmUtYWRvYmUtYW5hbHl0aWNzIiwiYzp0ZWNobm9yYXRpLW1lZGlhIiwiYzppbnRlcmNvbSIsImM6aW50ZW50LWlxIiwiYzppcHJvbSIsImM6bGlua2VkaW4iLCJjOmFtYXpvbmFkdi16Y1hGTEI2WCIsImM6bWVkaWFuZXQtY1V3YUtFNnoiLCJjOmluZGV4ZXhjaC1OWkNRTTY4UCIsImM6emVvdGFwZ21iLWQ3YndtdGp3IiwiYzp0cmlwbGVsaWYtZGRKSDM0clkiLCJjOnJ0YmhvdXNlLWI4Y2RIOHRNIiwiYzptZHByaW1pcy1lYU4yOVdjUCIsImM6bG9vcG1lbGktVGRhWXRCUHEiLCJjOm1hZ25pdGVpbi05d1RZTHFSRCIsImM6Ymlkc3dpdGNoLWQ2N0V3N1c5IiwiYzpvcmFjbGVhZHYtcUhlREptQUwiLCJjOmdvb2dsZWFuYS00VFhuSmlnUiIsImM6bG90YW1lc29sLURIaTdMUmpNIiwiYzpuZXh0bWlsbGUtR0pyZlg4VWMiLCJjOm5yaWNodGVjLXFVVlEyUlFxIiwiYzpicml0ZXBvb2wtQldWeVdHeVUiLCJjOnRhcGFkaW5jLXFxY2tVN1BXIiwiYzppZDV0ZWNobi16Tk1KNGR3ZiIsImM6bWljcm9zb2Z0IiwiYzpwZXJtdXRpdmUtSjdpaHJlTWsiLCJjOm9wZXJhc29mdC1CY1hjRFZKTSIsImM6cG9zdGhvZy1Cakp4RmRGOSJdfSwicHVycG9zZXMiOnsiZW5hYmxlZCI6WyJnZW9sb2NhdGlvbl9kYXRhIiwiZGV2aWNlX2NoYXJhY3RlcmlzdGljcyJdfSwidmVuZG9yc19saSI6eyJlbmFibGVkIjpbImdvb2dsZSIsImM6b3BlcmFzb2Z0LUJjWGNEVkpNIl19LCJ2ZXJzaW9uIjoyLCJhYyI6IkRIU0FvQUZrQWNnQTVnSHFnUUhBeGdCNndEMTRJR0FRTkFqMEJJd0NTY0VyQUtCd1YtZ3MxQmgwREc0R09nQUEuREhTQW9BRmtBY2dBNWdIcWdRSEF4Z0I2d0QxNElHQVFOQWowQkl3Q1NjRXJBS0J3Vi1nczFCaDBERzRHT2dBQSJ9"
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"domain": ".www.researchgate.net",
|
447 |
+
"hostOnly": False,
|
448 |
+
"httpOnly": True,
|
449 |
+
"name": "hasPdpNext",
|
450 |
+
"path": "/",
|
451 |
+
"sameSite": None,
|
452 |
+
"secure": True,
|
453 |
+
"session": True,
|
454 |
+
"storeId": None,
|
455 |
+
"value": "False"
|
456 |
+
},
|
457 |
+
{
|
458 |
+
"domain": ".researchgate.net",
|
459 |
+
"expirationDate": 1750421183,
|
460 |
+
"hostOnly": False,
|
461 |
+
"httpOnly": False,
|
462 |
+
"name": "ph_phc_ma1XTQyee96N1GML6qUTgLQRiDifnRcE9STiHTZ0CfZ_posthog",
|
463 |
+
"path": "/",
|
464 |
+
"sameSite": "lax",
|
465 |
+
"secure": True,
|
466 |
+
"session": False,
|
467 |
+
"storeId": None,
|
468 |
+
"value": "%7B%22distinct_id%22%3A%220190358a-56a1-7313-83b0-d13dddeac787%22%2C%22%24sesid%22%3A%5B1718885183223%2C%220190358a-56a1-7313-83b0-d13b2b87778d%22%2C1718885176993%5D%2C%22%24session_is_sampled%22%3Atrue%7D"
|
469 |
+
},
|
470 |
+
{
|
471 |
+
"domain": ".www.researchgate.net",
|
472 |
+
"hostOnly": False,
|
473 |
+
"httpOnly": True,
|
474 |
+
"name": "sid",
|
475 |
+
"path": "/",
|
476 |
+
"sameSite": None,
|
477 |
+
"secure": True,
|
478 |
+
"session": True,
|
479 |
+
"storeId": None,
|
480 |
+
"value": "qmH5Lc4f0CUJ3zeaxORcV0S8I8V1MuCFZtcIQqPYtv1XPejrbSLAQRbT50PL40TqeKQ1XsQDWt9gtYVzuL80bRmPjw6jn3cQ0ikNqW40maHcQ3JL2Vfa8ZZf0j7p35eJ"
|
481 |
+
}
|
482 |
+
]
|
483 |
+
|
484 |
+
COOKIES_LIST += [
|
485 |
+
{
|
486 |
+
"domain": "github.com",
|
487 |
+
"hostOnly": True,
|
488 |
+
"httpOnly": True,
|
489 |
+
"name": "_gh_sess",
|
490 |
+
"path": "/",
|
491 |
+
"sameSite": "lax",
|
492 |
+
"secure": True,
|
493 |
+
"session": True,
|
494 |
+
"storeId": None,
|
495 |
+
"value": "P%2Fmof1avuqwHaUQUIJR%2FZYn7jqbT7lgGuTGjp1BGAFIG5UpNDusEE3b8dRjz0eATE5xPdPjLYFqMs%2FI9AOalKX4YuYfSEEnxCMawU01099b4o9Xzzcv%2BmecrmO0Q8q%2Bdq1h8SIv6nvPP7HzlFesl8ysafb9b%2F0q6dTArKdSOurasza8UgLSYD08ofA50Pcm0IG7CTzF8ZCizrGgGTMi%2F%2B7L3E17jav5PM1Sf2vQKg15Gbg1QIOppJJHzlufgQoZigqFv%2BWznaws0Tt7Y2lSFCw%3D%3D--CJRhqMXJnwOaJgk4--DhUErlL4GdROikEjKD4O9g%3D%3D"
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"domain": ".github.com",
|
499 |
+
"expirationDate": 1750408875.763785,
|
500 |
+
"hostOnly": False,
|
501 |
+
"httpOnly": False,
|
502 |
+
"name": "_octo",
|
503 |
+
"path": "/",
|
504 |
+
"sameSite": "lax",
|
505 |
+
"secure": True,
|
506 |
+
"session": False,
|
507 |
+
"storeId": None,
|
508 |
+
"value": "GH1.1.728652011.1718872875"
|
509 |
+
},
|
510 |
+
{
|
511 |
+
"domain": ".github.com",
|
512 |
+
"expirationDate": 1750408875.763926,
|
513 |
+
"hostOnly": False,
|
514 |
+
"httpOnly": True,
|
515 |
+
"name": "logged_in",
|
516 |
+
"path": "/",
|
517 |
+
"sameSite": "lax",
|
518 |
+
"secure": True,
|
519 |
+
"session": False,
|
520 |
+
"storeId": None,
|
521 |
+
"value": "no"
|
522 |
+
},
|
523 |
+
{
|
524 |
+
"domain": ".github.com",
|
525 |
+
"hostOnly": False,
|
526 |
+
"httpOnly": False,
|
527 |
+
"name": "preferred_color_mode",
|
528 |
+
"path": "/",
|
529 |
+
"sameSite": "lax",
|
530 |
+
"secure": True,
|
531 |
+
"session": True,
|
532 |
+
"storeId": None,
|
533 |
+
"value": "dark"
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"domain": ".github.com",
|
537 |
+
"hostOnly": False,
|
538 |
+
"httpOnly": False,
|
539 |
+
"name": "tz",
|
540 |
+
"path": "/",
|
541 |
+
"sameSite": "lax",
|
542 |
+
"secure": True,
|
543 |
+
"session": True,
|
544 |
+
"storeId": None,
|
545 |
+
"value": "Europe%2FParis"
|
546 |
+
}
|
547 |
+
]
|
548 |
+
|
549 |
+
COOKIES_LIST += [
|
550 |
+
{
|
551 |
+
"domain": ".web.archive.org",
|
552 |
+
"expirationDate": 1718886430,
|
553 |
+
"hostOnly": False,
|
554 |
+
"httpOnly": False,
|
555 |
+
"name": "_gat",
|
556 |
+
"path": "/web/20201123221659/http://orcid.org/",
|
557 |
+
"sameSite": None,
|
558 |
+
"secure": False,
|
559 |
+
"session": False,
|
560 |
+
"storeId": None,
|
561 |
+
"value": "1"
|
562 |
+
},
|
563 |
+
{
|
564 |
+
"domain": ".web.archive.org",
|
565 |
+
"expirationDate": 1718972770,
|
566 |
+
"hostOnly": False,
|
567 |
+
"httpOnly": False,
|
568 |
+
"name": "_gid",
|
569 |
+
"path": "/web/20201123221659/http://orcid.org/",
|
570 |
+
"sameSite": None,
|
571 |
+
"secure": False,
|
572 |
+
"session": False,
|
573 |
+
"storeId": None,
|
574 |
+
"value": "GA1.2.402246368.1606169825"
|
575 |
+
},
|
576 |
+
{
|
577 |
+
"domain": ".web.archive.org",
|
578 |
+
"expirationDate": 1753446370.315621,
|
579 |
+
"hostOnly": False,
|
580 |
+
"httpOnly": False,
|
581 |
+
"name": "_ga",
|
582 |
+
"path": "/web/20201123221659/http://orcid.org/",
|
583 |
+
"sameSite": None,
|
584 |
+
"secure": False,
|
585 |
+
"session": False,
|
586 |
+
"storeId": None,
|
587 |
+
"value": "GA1.2.1301409987.1606169825"
|
588 |
+
},
|
589 |
+
{
|
590 |
+
"domain": ".web.archive.org",
|
591 |
+
"expirationDate": 1750422367,
|
592 |
+
"hostOnly": False,
|
593 |
+
"httpOnly": False,
|
594 |
+
"name": "_hjid",
|
595 |
+
"path": "/web/20201123221659/http://orcid.org/",
|
596 |
+
"sameSite": "lax",
|
597 |
+
"secure": False,
|
598 |
+
"session": False,
|
599 |
+
"storeId": None,
|
600 |
+
"value": "07f80263-a631-4bf4-8ffd-8fc8912085e2"
|
601 |
+
},
|
602 |
+
{
|
603 |
+
"domain": ".web.archive.org",
|
604 |
+
"expirationDate": 1718888167,
|
605 |
+
"hostOnly": False,
|
606 |
+
"httpOnly": False,
|
607 |
+
"name": "_hjFirstSeen",
|
608 |
+
"path": "/web/20201123221659/http://orcid.org/",
|
609 |
+
"sameSite": "lax",
|
610 |
+
"secure": False,
|
611 |
+
"session": False,
|
612 |
+
"storeId": None,
|
613 |
+
"value": "1"
|
614 |
+
}
|
615 |
+
]
|
616 |
+
COOKIES_LIST += [
|
617 |
+
{
|
618 |
+
"domain": "orcid.org",
|
619 |
+
"hostOnly": True,
|
620 |
+
"httpOnly": False,
|
621 |
+
"name": "AWSELBCORS",
|
622 |
+
"path": "/",
|
623 |
+
"sameSite": "no_restriction",
|
624 |
+
"secure": True,
|
625 |
+
"session": True,
|
626 |
+
"storeId": None,
|
627 |
+
"value": "CBD1D7FF1216388FA48838CBCA4774FD22800B8FB548A40EF92BB0994D5B77A8410307CDEAA69C52236663F2BF89B252C17BC0FCDF790FD59771BDDF6EA8CA4CFD29D8733F"
|
628 |
+
},
|
629 |
+
{
|
630 |
+
"domain": ".orcid.org",
|
631 |
+
"expirationDate": 1753452454.637671,
|
632 |
+
"hostOnly": False,
|
633 |
+
"httpOnly": False,
|
634 |
+
"name": "_ga_9R61FWK9H5",
|
635 |
+
"path": "/",
|
636 |
+
"sameSite": None,
|
637 |
+
"secure": False,
|
638 |
+
"session": False,
|
639 |
+
"storeId": None,
|
640 |
+
"value": "GS1.1.1718892454.1.0.1718892454.0.0.0"
|
641 |
+
},
|
642 |
+
{
|
643 |
+
"domain": ".orcid.org",
|
644 |
+
"expirationDate": 1753452454.63421,
|
645 |
+
"hostOnly": False,
|
646 |
+
"httpOnly": False,
|
647 |
+
"name": "_ga",
|
648 |
+
"path": "/",
|
649 |
+
"sameSite": None,
|
650 |
+
"secure": False,
|
651 |
+
"session": False,
|
652 |
+
"storeId": None,
|
653 |
+
"value": "GA1.1.2021310691.1718892455"
|
654 |
+
},
|
655 |
+
{
|
656 |
+
"domain": "orcid.org",
|
657 |
+
"hostOnly": True,
|
658 |
+
"httpOnly": False,
|
659 |
+
"name": "AWSELB",
|
660 |
+
"path": "/",
|
661 |
+
"sameSite": None,
|
662 |
+
"secure": False,
|
663 |
+
"session": True,
|
664 |
+
"storeId": None,
|
665 |
+
"value": "CBD1D7FF1216388FA48838CBCA4774FD22800B8FB548A40EF92BB0994D5B77A8410307CDEAA69C52236663F2BF89B252C17BC0FCDF790FD59771BDDF6EA8CA4CFD29D8733F"
|
666 |
+
},
|
667 |
+
{
|
668 |
+
"domain": ".orcid.org",
|
669 |
+
"expirationDate": 1750428454,
|
670 |
+
"hostOnly": False,
|
671 |
+
"httpOnly": False,
|
672 |
+
"name": "OptanonAlertBoxClosed",
|
673 |
+
"path": "/",
|
674 |
+
"sameSite": "lax",
|
675 |
+
"secure": False,
|
676 |
+
"session": False,
|
677 |
+
"storeId": None,
|
678 |
+
"value": "2024-06-20T14:07:34.583Z"
|
679 |
+
},
|
680 |
+
{
|
681 |
+
"domain": ".orcid.org",
|
682 |
+
"expirationDate": 1750428454,
|
683 |
+
"hostOnly": False,
|
684 |
+
"httpOnly": False,
|
685 |
+
"name": "OptanonConsent",
|
686 |
+
"path": "/",
|
687 |
+
"sameSite": "lax",
|
688 |
+
"secure": False,
|
689 |
+
"session": False,
|
690 |
+
"storeId": None,
|
691 |
+
"value": "isGpcEnabled=0&datestamp=Thu+Jun+20+2024+16%3A07%3A34+GMT%2B0200+(heure+d%E2%80%99%C3%A9t%C3%A9+d%E2%80%99Europe+centrale)&version=202310.2.0&browserGpcFlag=0&isIABGlobal=False&hosts=&landingPath=NotLandingPage&groups=C0001%3A1%2CC0003%3A1%2CC0002%3A1%2CC0004%3A1"
|
692 |
+
},
|
693 |
+
{
|
694 |
+
"domain": "orcid.org",
|
695 |
+
"hostOnly": True,
|
696 |
+
"httpOnly": False,
|
697 |
+
"name": "XSRF-TOKEN",
|
698 |
+
"path": "/",
|
699 |
+
"sameSite": None,
|
700 |
+
"secure": True,
|
701 |
+
"session": True,
|
702 |
+
"storeId": None,
|
703 |
+
"value": "6957be7a-bcb4-4d59-a522-ea9b6b210ed9"
|
704 |
+
}
|
705 |
+
]
|
706 |
+
from requests.cookies import RequestsCookieJar
|
707 |
+
|
708 |
+
# Create a RequestsCookieJar instance
|
709 |
+
COOKIES = RequestsCookieJar()
|
710 |
+
|
711 |
+
# Add cookies to the jar
|
712 |
+
for cookie in COOKIES_LIST:
|
713 |
+
COOKIES.set(cookie['name'], cookie['value'], domain=cookie['domain'], path=cookie['path'])
|
llm_engine (2).py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
import os, time, json, re, gc, subprocess
|
18 |
+
import gradio as gr
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
import argparse
|
22 |
+
import time
|
23 |
+
import sampling
|
24 |
+
import copy
|
25 |
+
from datetime import datetime
|
26 |
+
from huggingface_hub import hf_hub_download
|
27 |
+
from pynvml import *
|
28 |
+
from tokenizer_util import add_tokenizer_argument, get_tokenizer
|
29 |
+
import rwkv_world_tokenizer
|
30 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
31 |
+
hf_hub_download(repo_id="JoPmt/RWKV-5-3B-V2-Quant", filename="rwkv-5-world-3b-v2-20231118-ctx16k.Q4_0.bin", local_dir='~/app/Downloads')
|
32 |
+
model_path='~/app/Downloads/rwkv-5-world-3b-v2-20231118-ctx16k.Q4_0.bin'
|
33 |
+
from copy import deepcopy
|
34 |
+
from enum import Enum
|
35 |
+
from typing import Dict, List
|
36 |
+
from huggingface_hub import InferenceClient
|
37 |
+
from transformers.agents import PythonInterpreterTool
|
38 |
+
from transformers import AutoTokenizer
|
39 |
+
tokenizer=AutoTokenizer.from_pretrained("NousResearch/Hermes-2-Pro-Llama-3-8B",revision="pr/13")
|
40 |
+
tools=[PythonInterpreterTool()]
|
41 |
+
os.system("apt-get update && apt-get install cmake gcc g++")
|
42 |
+
os.system("git clone --recursive https://github.com/JoPmt/rwkv.cpp.git && cd rwkv.cpp && mkdir build && cd build && cmake .. -DRWKV_CUBLAS=ON -DRWKV_BUILD_SHARED_LIBRARY=ON -DGGML_CUDA=ON -DRWKV_BUILD_PYTHON_MODULE=ON -DRWKV_BUILD_TOOLS=ON -DRWKV_BUILD_EXTRAS=ON && cmake --build . --config Release && make RWKV_CUBLAS=1 GGML_CUDA=1")
|
43 |
+
import rwkv_cpp_model
|
44 |
+
import rwkv_cpp_shared_library
|
45 |
+
|
46 |
+
def find_lib():
|
47 |
+
for root, dirs, files in os.walk("/"):
|
48 |
+
for file in files:
|
49 |
+
if file == "librwkv.so":
|
50 |
+
return os.path.join(root, file)
|
51 |
+
return None
|
52 |
+
library_path = find_lib()
|
53 |
+
rwkv_lib = rwkv_cpp_shared_library.RWKVSharedLibrary(library_path)
|
54 |
+
modal = rwkv_cpp_model.RWKVModel(rwkv_lib,model_path,thread_count=2)
|
55 |
+
print('Loading RWKV model')
|
56 |
+
tokenizer_decode, tokenizer_encode = get_tokenizer('auto', modal.n_vocab)
|
57 |
+
out_str = ''
|
58 |
+
prompt = out_str
|
59 |
+
token_count = 1200
|
60 |
+
temperature = 1.0
|
61 |
+
top_p = 0.7
|
62 |
+
presence_penalty = 0.1
|
63 |
+
count_penalty = 0.4
|
64 |
+
def generate_prompt(instruction, zput=""):
|
65 |
+
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
66 |
+
zput = zput.strip().replace('\r\n','\n').replace('\n\n','\n')
|
67 |
+
if zput:
|
68 |
+
return f"""Instruction: {instruction}
|
69 |
+
Input: {zput}
|
70 |
+
Response:"""
|
71 |
+
else:
|
72 |
+
return f"""User: hi
|
73 |
+
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
74 |
+
User: {instruction}
|
75 |
+
Assistant:"""
|
76 |
+
class MessageRole(str, Enum):
|
77 |
+
USER = "user"
|
78 |
+
ASSISTANT = "assistant"
|
79 |
+
SYSTEM = "system"
|
80 |
+
TOOL_CALL = "tool-call"
|
81 |
+
TOOL_RESPONSE = "tool-response"
|
82 |
+
@classmethod
|
83 |
+
def roles(cls):
|
84 |
+
return [r.value for r in cls]
|
85 |
+
def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
|
86 |
+
"""
|
87 |
+
Subsequent messages with the same role will be concatenated to a single message.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
message_list (`List[Dict[str, str]]`): List of chat messages.
|
91 |
+
"""
|
92 |
+
final_message_list = []
|
93 |
+
message_list = deepcopy(message_list) # Avoid modifying the original list
|
94 |
+
for message in message_list:
|
95 |
+
if not set(message.keys()) == {"role", "content"}:
|
96 |
+
raise ValueError("Message should contain only 'role' and 'content' keys!")
|
97 |
+
|
98 |
+
role = message["role"]
|
99 |
+
if role not in MessageRole.roles():
|
100 |
+
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
|
101 |
+
|
102 |
+
if role in role_conversions:
|
103 |
+
message["role"] = role_conversions[role]
|
104 |
+
|
105 |
+
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
|
106 |
+
final_message_list[-1]["content"] = "\n=======\n" + message["content"]
|
107 |
+
else:
|
108 |
+
final_message_list.append(message)
|
109 |
+
return final_message_list
|
110 |
+
llama_role_conversions = {
|
111 |
+
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
112 |
+
MessageRole.TOOL_CALL: MessageRole.USER,
|
113 |
+
}
|
114 |
+
class HfEngine:
|
115 |
+
def __init__(self, model: str = "JoPmt/JoPmt"):
|
116 |
+
self.model = model
|
117 |
+
self.client = modal
|
118 |
+
def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str:
|
119 |
+
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
120 |
+
print(messages)
|
121 |
+
pret=''
|
122 |
+
prut=''
|
123 |
+
for message in messages:
|
124 |
+
print(message['content'])
|
125 |
+
if message['role'].lower() == 'system':
|
126 |
+
pret+=''+message['content']+''
|
127 |
+
if message['role'].lower() == 'user':
|
128 |
+
prut+=''+message['content']+''
|
129 |
+
##prompt = ins.format(question=''+pret+''+prut+'', system=pret)
|
130 |
+
prompt=tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True,)
|
131 |
+
print(prompt)
|
132 |
+
token_count=1200
|
133 |
+
temperature=1.0
|
134 |
+
top_p=0.7
|
135 |
+
presencePenalty = 0.1
|
136 |
+
countPenalty = 0.4
|
137 |
+
token_ban=[]
|
138 |
+
stop_token=[0]
|
139 |
+
ctx=pret
|
140 |
+
prompt=prut
|
141 |
+
all_tokens = []
|
142 |
+
out_last = 0
|
143 |
+
out_str = ''
|
144 |
+
occurrence = {}
|
145 |
+
state = None
|
146 |
+
ctx=generate_prompt(ctx,prompt)
|
147 |
+
prompt_tokens = tokenizer_encode(ctx)
|
148 |
+
prompt_token_count = len(prompt_tokens)
|
149 |
+
init_logits, init_state = modal.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True)
|
150 |
+
logits, state = init_logits.copy(), init_state.copy()
|
151 |
+
out_str = ''
|
152 |
+
occurrence = {}
|
153 |
+
bof=[]
|
154 |
+
for i in range(token_count):
|
155 |
+
for n in occurrence:
|
156 |
+
logits[n] -= (presencePenalty + occurrence[n] * countPenalty)
|
157 |
+
token = sampling.sample_logits(logits, temperature, top_p)
|
158 |
+
|
159 |
+
if token in stop_token:
|
160 |
+
break
|
161 |
+
all_tokens += [token]
|
162 |
+
|
163 |
+
for xxx in occurrence:
|
164 |
+
occurrence[xxx] *= 0.996
|
165 |
+
|
166 |
+
if token not in occurrence:
|
167 |
+
occurrence[token] = 1
|
168 |
+
else:
|
169 |
+
occurrence[token] += 1
|
170 |
+
|
171 |
+
tmp = tokenizer_decode(all_tokens[out_last:])
|
172 |
+
if '\ufffd' not in tmp:
|
173 |
+
out_str += tmp
|
174 |
+
out_last = i + 1
|
175 |
+
##yield out_str.strip()
|
176 |
+
logits, state = modal.eval(token, state, state, logits, use_numpy=True)
|
177 |
+
del state
|
178 |
+
gc.collect()
|
179 |
+
return out_str.strip()
|
mdconvert (2).py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ruff: noqa: E722
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import requests
|
5 |
+
import re
|
6 |
+
import markdownify
|
7 |
+
import mimetypes
|
8 |
+
import html
|
9 |
+
import puremagic
|
10 |
+
import tempfile
|
11 |
+
import copy
|
12 |
+
import mammoth
|
13 |
+
import pptx
|
14 |
+
import pandas as pd
|
15 |
+
import traceback
|
16 |
+
|
17 |
+
from urllib.parse import urlparse, parse_qs
|
18 |
+
from bs4 import BeautifulSoup
|
19 |
+
from typing import Any, Dict, List, Optional, Union
|
20 |
+
import pdfminer
|
21 |
+
import pdfminer.high_level
|
22 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
class DocumentConverterResult:
|
27 |
+
"""The result of converting a document to text."""
|
28 |
+
|
29 |
+
def __init__(self, title: Union[str, None] = None, text_content: str = ""):
|
30 |
+
self.title = title
|
31 |
+
self.text_content = text_content
|
32 |
+
|
33 |
+
|
34 |
+
class DocumentConverter:
|
35 |
+
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
|
36 |
+
raise NotImplementedError()
|
37 |
+
|
38 |
+
|
39 |
+
class PlainTextConverter(DocumentConverter):
|
40 |
+
"""Anything with content type text/plain"""
|
41 |
+
|
42 |
+
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
|
43 |
+
extension = kwargs.get("file_extension", "")
|
44 |
+
if extension == "":
|
45 |
+
return None
|
46 |
+
|
47 |
+
content_type, encoding = mimetypes.guess_type("__placeholder" + extension)
|
48 |
+
|
49 |
+
text_content = ""
|
50 |
+
with open(local_path, "rt") as fh:
|
51 |
+
text_content = fh.read()
|
52 |
+
|
53 |
+
return DocumentConverterResult(
|
54 |
+
title=None,
|
55 |
+
text_content=text_content,
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
class HtmlConverter(DocumentConverter):
|
60 |
+
"""Anything with content type text/html"""
|
61 |
+
|
62 |
+
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
|
63 |
+
# Bail if not html
|
64 |
+
extension = kwargs.get("file_extension", "")
|
65 |
+
if extension.lower() not in [".html", ".htm"]:
|
66 |
+
return None
|
67 |
+
|
68 |
+
result = None
|
69 |
+
with open(local_path, "rt") as fh:
|
70 |
+
result = self._convert(fh.read())
|
71 |
+
|
72 |
+
return result
|
73 |
+
|
74 |
+
def _convert(self, html_content) -> Union[None, DocumentConverterResult]:
|
75 |
+
"""Helper function that converts and HTML string."""
|
76 |
+
|
77 |
+
# Parse the string
|
78 |
+
soup = BeautifulSoup(html_content, "html.parser")
|
79 |
+
|
80 |
+
# Remove javascript and style blocks
|
81 |
+
for script in soup(["script", "style"]):
|
82 |
+
script.extract()
|
83 |
+
|
84 |
+
# Print only the main content
|
85 |
+
body_elm = soup.find("body")
|
86 |
+
webpage_text = ""
|
87 |
+
if body_elm:
|
88 |
+
webpage_text = markdownify.MarkdownConverter().convert_soup(body_elm)
|
89 |
+
else:
|
90 |
+
webpage_text = markdownify.MarkdownConverter().convert_soup(soup)
|
91 |
+
|
92 |
+
return DocumentConverterResult(
|
93 |
+
title=None if soup.title is None else soup.title.string,
|
94 |
+
text_content=webpage_text,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
class WikipediaConverter(DocumentConverter):
|
99 |
+
"""Handle Wikipedia pages separately, focusing only on the main document content."""
|
100 |
+
|
101 |
+
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
|
102 |
+
# Bail if not Wikipedia
|
103 |
+
extension = kwargs.get("file_extension", "")
|
104 |
+
if extension.lower() not in [".html", ".htm"]:
|
105 |
+
return None
|
106 |
+
url = kwargs.get("url", "")
|
107 |
+
if not re.search(r"^https?:\/\/[a-zA-Z]{2,3}\.wikipedia.org\/", url):
|
108 |
+
return None
|
109 |
+
|
110 |
+
# Parse the file
|
111 |
+
soup = None
|
112 |
+
with open(local_path, "rt") as fh:
|
113 |
+
soup = BeautifulSoup(fh.read(), "html.parser")
|
114 |
+
|
115 |
+
# Remove javascript and style blocks
|
116 |
+
for script in soup(["script", "style"]):
|
117 |
+
script.extract()
|
118 |
+
|
119 |
+
# Print only the main content
|
120 |
+
body_elm = soup.find("div", {"id": "mw-content-text"})
|
121 |
+
title_elm = soup.find("span", {"class": "mw-page-title-main"})
|
122 |
+
|
123 |
+
webpage_text = ""
|
124 |
+
if body_elm:
|
125 |
+
# What's the title
|
126 |
+
main_title = soup.title.string
|
127 |
+
if title_elm and len(title_elm) > 0:
|
128 |
+
main_title = title_elm.string
|
129 |
+
|
130 |
+
# Convert the page
|
131 |
+
webpage_text = "# " + main_title + "\n\n" + markdownify.MarkdownConverter().convert_soup(body_elm)
|
132 |
+
else:
|
133 |
+
webpage_text = markdownify.MarkdownConverter().convert_soup(soup)
|
134 |
+
|
135 |
+
return DocumentConverterResult(
|
136 |
+
title=soup.title.string,
|
137 |
+
text_content=webpage_text,
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
class YouTubeConverter(DocumentConverter):
|
142 |
+
"""Handle YouTube specially, focusing on the video title, description, and transcript."""
|
143 |
+
|
144 |
+
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
|
145 |
+
# Bail if not YouTube
|
146 |
+
extension = kwargs.get("file_extension", "")
|
147 |
+
if extension.lower() not in [".html", ".htm"]:
|
148 |
+
return None
|
149 |
+
url = kwargs.get("url", "")
|
150 |
+
if not url.startswith("https://www.youtube.com/watch?"):
|
151 |
+
return None
|
152 |
+
|
153 |
+
# Parse the file
|
154 |
+
soup = None
|
155 |
+
with open(local_path, "rt") as fh:
|
156 |
+
soup = BeautifulSoup(fh.read(), "html.parser")
|
157 |
+
|
158 |
+
# Read the meta tags
|
159 |
+
metadata = {"title": soup.title.string}
|
160 |
+
for meta in soup(["meta"]):
|
161 |
+
for a in meta.attrs:
|
162 |
+
if a in ["itemprop", "property", "name"]:
|
163 |
+
metadata[meta[a]] = meta.get("content", "")
|
164 |
+
break
|
165 |
+
|
166 |
+
# We can also try to read the full description. This is more prone to breaking, since it reaches into the page implementation
|
167 |
+
try:
|
168 |
+
for script in soup(["script"]):
|
169 |
+
content = script.text
|
170 |
+
if "ytInitialData" in content:
|
171 |
+
lines = re.split(r"\r?\n", content)
|
172 |
+
obj_start = lines[0].find("{")
|
173 |
+
obj_end = lines[0].rfind("}")
|
174 |
+
if obj_start >= 0 and obj_end >= 0:
|
175 |
+
data = json.loads(lines[0][obj_start : obj_end + 1])
|
176 |
+
attrdesc = self._findKey(data, "attributedDescriptionBodyText")
|
177 |
+
if attrdesc:
|
178 |
+
metadata["description"] = attrdesc["content"]
|
179 |
+
break
|
180 |
+
except:
|
181 |
+
pass
|
182 |
+
|
183 |
+
# Start preparing the page
|
184 |
+
webpage_text = "# YouTube\n"
|
185 |
+
|
186 |
+
title = self._get(metadata, ["title", "og:title", "name"])
|
187 |
+
if title:
|
188 |
+
webpage_text += f"\n## {title}\n"
|
189 |
+
|
190 |
+
stats = ""
|
191 |
+
views = self._get(metadata, ["interactionCount"])
|
192 |
+
if views:
|
193 |
+
stats += f"- **Views:** {views}\n"
|
194 |
+
|
195 |
+
keywords = self._get(metadata, ["keywords"])
|
196 |
+
if keywords:
|
197 |
+
stats += f"- **Keywords:** {keywords}\n"
|
198 |
+
|
199 |
+
runtime = self._get(metadata, ["duration"])
|
200 |
+
if runtime:
|
201 |
+
stats += f"- **Runtime:** {runtime}\n"
|
202 |
+
|
203 |
+
if len(stats) > 0:
|
204 |
+
webpage_text += f"\n### Video Metadata\n{stats}\n"
|
205 |
+
|
206 |
+
description = self._get(metadata, ["description", "og:description"])
|
207 |
+
if description:
|
208 |
+
webpage_text += f"\n### Description\n{description}\n"
|
209 |
+
|
210 |
+
transcript_text = ""
|
211 |
+
parsed_url = urlparse(url)
|
212 |
+
params = parse_qs(parsed_url.query)
|
213 |
+
|
214 |
+
video_id = params["v"][0]
|
215 |
+
# Must be a single transcript.
|
216 |
+
print("VIDDDD ID:", video_id)
|
217 |
+
transcript = YouTubeTranscriptApi.get_transcript(video_id)
|
218 |
+
transcript_text = " ".join([part["text"] for part in transcript])
|
219 |
+
# Alternative formatting:
|
220 |
+
# formatter = TextFormatter()
|
221 |
+
# formatter.format_transcript(transcript)
|
222 |
+
if transcript_text:
|
223 |
+
webpage_text += f"\n### Transcript\n{transcript_text}\n"
|
224 |
+
|
225 |
+
return DocumentConverterResult(
|
226 |
+
title=title if title else soup.title.string,
|
227 |
+
text_content=webpage_text,
|
228 |
+
)
|
229 |
+
|
230 |
+
def _get(self, json, keys, default=None):
|
231 |
+
for k in keys:
|
232 |
+
if k in json:
|
233 |
+
return json[k]
|
234 |
+
return default
|
235 |
+
|
236 |
+
def _findKey(self, json, key):
|
237 |
+
if isinstance(json, list):
|
238 |
+
for elm in json:
|
239 |
+
ret = self._findKey(elm, key)
|
240 |
+
if ret is not None:
|
241 |
+
return ret
|
242 |
+
elif isinstance(json, dict):
|
243 |
+
for k in json:
|
244 |
+
if k == key:
|
245 |
+
return json[k]
|
246 |
+
else:
|
247 |
+
ret = self._findKey(json[k], key)
|
248 |
+
if ret is not None:
|
249 |
+
return ret
|
250 |
+
return None
|
251 |
+
|
252 |
+
|
253 |
+
class PdfConverter(DocumentConverter):
|
254 |
+
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
|
255 |
+
# Bail if not a PDF
|
256 |
+
extension = kwargs.get("file_extension", "")
|
257 |
+
if extension.lower() != ".pdf":
|
258 |
+
return None
|
259 |
+
|
260 |
+
return DocumentConverterResult(
|
261 |
+
title=None,
|
262 |
+
text_content=pdfminer.high_level.extract_text(local_path),
|
263 |
+
)
|
264 |
+
|
265 |
+
from huggingface_hub import InferenceClient
|
266 |
+
class AudioConverter(DocumentConverter):
|
267 |
+
def __init__(self):
|
268 |
+
super().__init__()
|
269 |
+
self.client = InferenceClient("distil-whisper/distil-large-v3")
|
270 |
+
|
271 |
+
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
|
272 |
+
# Bail if not an audio file
|
273 |
+
extension = kwargs.get("file_extension", "")
|
274 |
+
if extension.lower() not in [".wav", ".mp3", ".flac", ".m4a"]:
|
275 |
+
return None
|
276 |
+
try:
|
277 |
+
result = self.client.automatic_speech_recognition(audio=local_path).text
|
278 |
+
except Exception as e:
|
279 |
+
print("Exception in decoding audio:", e)
|
280 |
+
from openai import OpenAI
|
281 |
+
oai_client = OpenAI()
|
282 |
+
from pathlib import Path
|
283 |
+
result = oai_client.audio.transcriptions.create(
|
284 |
+
model="whisper-1",
|
285 |
+
file=Path(local_path)
|
286 |
+
).text
|
287 |
+
|
288 |
+
return DocumentConverterResult(
|
289 |
+
title=None,
|
290 |
+
text_content=result,
|
291 |
+
)
|
292 |
+
|
293 |
+
|
294 |
+
class DocxConverter(HtmlConverter):
|
295 |
+
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
|
296 |
+
# Bail if not a DOCX
|
297 |
+
extension = kwargs.get("file_extension", "")
|
298 |
+
if extension.lower() != ".docx":
|
299 |
+
return None
|
300 |
+
|
301 |
+
result = None
|
302 |
+
with open(local_path, "rb") as docx_file:
|
303 |
+
result = mammoth.convert_to_html(docx_file)
|
304 |
+
html_content = result.value
|
305 |
+
result = self._convert(html_content)
|
306 |
+
|
307 |
+
return result
|
308 |
+
|
309 |
+
|
310 |
+
class XlsxConverter(HtmlConverter):
|
311 |
+
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
|
312 |
+
# Bail if not a XLSX
|
313 |
+
extension = kwargs.get("file_extension", "")
|
314 |
+
|
315 |
+
if extension.lower() not in [".xlsx", ".xls"]:
|
316 |
+
return None
|
317 |
+
|
318 |
+
sheets = pd.read_excel(local_path, sheet_name=None)
|
319 |
+
md_content = ""
|
320 |
+
for s in sheets:
|
321 |
+
md_content += f"## {s}\n"
|
322 |
+
html_content = sheets[s].to_html(index=False)
|
323 |
+
md_content += self._convert(html_content).text_content.strip() + "\n\n"
|
324 |
+
|
325 |
+
return DocumentConverterResult(
|
326 |
+
title=None,
|
327 |
+
text_content=md_content.strip(),
|
328 |
+
)
|
329 |
+
|
330 |
+
|
331 |
+
import xml.etree.ElementTree as ET
|
332 |
+
class XmlConverter(DocumentConverter):
|
333 |
+
def convert(self, local_path, **kwargs) -> None | DocumentConverterResult:
|
334 |
+
# Parse the XML string
|
335 |
+
extension = kwargs.get("file_extension", "")
|
336 |
+
|
337 |
+
if extension.lower() not in [".xml"]:
|
338 |
+
return None
|
339 |
+
|
340 |
+
xml_string = ""
|
341 |
+
with open(local_path, "rt") as fh:
|
342 |
+
xml_string = fh.read()
|
343 |
+
|
344 |
+
def extract_table_from_html_like(xml_root):
|
345 |
+
table = xml_root.find('.//table')
|
346 |
+
if table is None:
|
347 |
+
raise ValueError("No table found in the XML")
|
348 |
+
|
349 |
+
headers = [th.text for th in table.find('thead').findall('th')]
|
350 |
+
rows = [[td.text for td in tr.findall('td')] for tr in table.find('tbody').findall('tr')]
|
351 |
+
|
352 |
+
# Create markdown table
|
353 |
+
markdown = '| ' + ' | '.join(headers) + ' |\n'
|
354 |
+
markdown += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n'
|
355 |
+
for row in rows:
|
356 |
+
markdown += '| ' + ' | '.join(row) + ' |\n'
|
357 |
+
|
358 |
+
def extract_table_from_wordml(xml_root, namespaces):
|
359 |
+
# Parse the XML content
|
360 |
+
root = xml_root
|
361 |
+
namespace = {'w': 'http://schemas.microsoft.com/office/word/2003/wordml'}
|
362 |
+
|
363 |
+
# Extract text content
|
364 |
+
body = root.find('w:body', namespace)
|
365 |
+
paragraphs = body.findall('.//w:p', namespace)
|
366 |
+
text_content = []
|
367 |
+
for para in paragraphs:
|
368 |
+
texts = para.findall('.//w:t', namespace)
|
369 |
+
for text in texts:
|
370 |
+
text_content.append(text.text)
|
371 |
+
|
372 |
+
return '\n'.join(text_content)
|
373 |
+
|
374 |
+
# Parse the XML string
|
375 |
+
root = ET.fromstring(xml_string)
|
376 |
+
namespaces = {'w': 'http://schemas.microsoft.com/office/word/2003/wordml'}
|
377 |
+
|
378 |
+
if root.tag.endswith('wordDocument'):
|
379 |
+
markdown = extract_table_from_wordml(root, namespaces)
|
380 |
+
else:
|
381 |
+
markdown = extract_table_from_html_like(root)
|
382 |
+
|
383 |
+
return DocumentConverterResult(
|
384 |
+
title=None,
|
385 |
+
text_content=markdown.strip(),
|
386 |
+
)
|
387 |
+
|
388 |
+
class PptxConverter(HtmlConverter):
|
389 |
+
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
|
390 |
+
# Bail if not a PPTX
|
391 |
+
extension = kwargs.get("file_extension", "")
|
392 |
+
if extension.lower() != ".pptx":
|
393 |
+
return None
|
394 |
+
|
395 |
+
md_content = ""
|
396 |
+
|
397 |
+
presentation = pptx.Presentation(local_path)
|
398 |
+
slide_num = 0
|
399 |
+
for slide in presentation.slides:
|
400 |
+
slide_num += 1
|
401 |
+
|
402 |
+
md_content += f"\n\n<!-- Slide number: {slide_num} -->\n"
|
403 |
+
|
404 |
+
title = slide.shapes.title
|
405 |
+
for shape in slide.shapes:
|
406 |
+
# Pictures
|
407 |
+
if self._is_picture(shape):
|
408 |
+
# https://github.com/scanny/python-pptx/pull/512#issuecomment-1713100069
|
409 |
+
alt_text = ""
|
410 |
+
try:
|
411 |
+
alt_text = shape._element._nvXxPr.cNvPr.attrib.get("descr", "")
|
412 |
+
except:
|
413 |
+
pass
|
414 |
+
|
415 |
+
# A placeholder name
|
416 |
+
filename = re.sub(r"\W", "", shape.name) + ".jpg"
|
417 |
+
# try:
|
418 |
+
# filename = shape.image.filename
|
419 |
+
# except:
|
420 |
+
# pass
|
421 |
+
|
422 |
+
md_content += "\n![" + (alt_text if alt_text else shape.name) + "](" + filename + ")\n"
|
423 |
+
|
424 |
+
# Tables
|
425 |
+
if self._is_table(shape):
|
426 |
+
html_table = "<html><body><table>"
|
427 |
+
first_row = True
|
428 |
+
for row in shape.table.rows:
|
429 |
+
html_table += "<tr>"
|
430 |
+
for cell in row.cells:
|
431 |
+
if first_row:
|
432 |
+
html_table += "<th>" + html.escape(cell.text) + "</th>"
|
433 |
+
else:
|
434 |
+
html_table += "<td>" + html.escape(cell.text) + "</td>"
|
435 |
+
html_table += "</tr>"
|
436 |
+
first_row = False
|
437 |
+
html_table += "</table></body></html>"
|
438 |
+
md_content += "\n" + self._convert(html_table).text_content.strip() + "\n"
|
439 |
+
|
440 |
+
# Text areas
|
441 |
+
elif shape.has_text_frame:
|
442 |
+
if shape == title:
|
443 |
+
md_content += "# " + shape.text.lstrip() + " "
|
444 |
+
else:
|
445 |
+
md_content += shape.text + " "
|
446 |
+
|
447 |
+
md_content = md_content.strip()
|
448 |
+
|
449 |
+
if slide.has_notes_slide:
|
450 |
+
md_content += "\n\n### Notes:\n"
|
451 |
+
notes_frame = slide.notes_slide.notes_text_frame
|
452 |
+
if notes_frame is not None:
|
453 |
+
md_content += notes_frame.text
|
454 |
+
md_content = md_content.strip()
|
455 |
+
|
456 |
+
return DocumentConverterResult(
|
457 |
+
title=None,
|
458 |
+
text_content=md_content.strip(),
|
459 |
+
)
|
460 |
+
|
461 |
+
def _is_picture(self, shape):
|
462 |
+
if shape.shape_type == pptx.enum.shapes.MSO_SHAPE_TYPE.PICTURE:
|
463 |
+
return True
|
464 |
+
if shape.shape_type == pptx.enum.shapes.MSO_SHAPE_TYPE.PLACEHOLDER:
|
465 |
+
if hasattr(shape, "image"):
|
466 |
+
return True
|
467 |
+
return False
|
468 |
+
|
469 |
+
def _is_table(self, shape):
|
470 |
+
if shape.shape_type == pptx.enum.shapes.MSO_SHAPE_TYPE.TABLE:
|
471 |
+
return True
|
472 |
+
return False
|
473 |
+
|
474 |
+
class FileConversionException(Exception):
|
475 |
+
pass
|
476 |
+
|
477 |
+
class UnsupportedFormatException(Exception):
|
478 |
+
pass
|
479 |
+
|
480 |
+
class MarkdownConverter:
|
481 |
+
"""(In preview) An extremely simple text-based document reader, suitable for LLM use.
|
482 |
+
This reader will convert common file-types or webpages to Markdown."""
|
483 |
+
|
484 |
+
def __init__(
|
485 |
+
self,
|
486 |
+
requests_session: Optional[requests.Session] = None,
|
487 |
+
):
|
488 |
+
if requests_session is None:
|
489 |
+
self._requests_session = requests.Session()
|
490 |
+
else:
|
491 |
+
self._requests_session = requests_session
|
492 |
+
|
493 |
+
|
494 |
+
self._page_converters: List[DocumentConverter] = []
|
495 |
+
|
496 |
+
# Register converters for successful browsing operations
|
497 |
+
# Later registrations are tried first / take higher priority than earlier registrations
|
498 |
+
# To this end, the most specific converters should appear below the most generic converters
|
499 |
+
self.register_page_converter(WikipediaConverter())
|
500 |
+
self.register_page_converter(XmlConverter())
|
501 |
+
self.register_page_converter(YouTubeConverter())
|
502 |
+
self.register_page_converter(DocxConverter())
|
503 |
+
self.register_page_converter(XlsxConverter())
|
504 |
+
self.register_page_converter(PptxConverter())
|
505 |
+
# self.register_page_converter(ImageConverter())
|
506 |
+
self.register_page_converter(PdfConverter())
|
507 |
+
self.register_page_converter(AudioConverter())
|
508 |
+
self.register_page_converter(HtmlConverter())
|
509 |
+
self.register_page_converter(PlainTextConverter())
|
510 |
+
|
511 |
+
def convert(self, source, **kwargs):
|
512 |
+
"""
|
513 |
+
Args:
|
514 |
+
- source: can be a string representing a path or url, or a requests.response object
|
515 |
+
- extension: specifies the file extension to use when interpreting the file. If None, infer from source (path, uri, content-type, etc.)
|
516 |
+
"""
|
517 |
+
|
518 |
+
# Local path or url
|
519 |
+
if isinstance(source, str):
|
520 |
+
if source.startswith("http://") or source.startswith("https://") or source.startswith("file://"):
|
521 |
+
return self.convert_url(source, **kwargs)
|
522 |
+
else:
|
523 |
+
return self.convert_local(source, **kwargs)
|
524 |
+
# Request response
|
525 |
+
elif isinstance(source, requests.Response):
|
526 |
+
return self.convert_response(source, **kwargs)
|
527 |
+
|
528 |
+
def convert_local(self, path, **kwargs):
|
529 |
+
# Prepare a list of extensions to try (in order of priority)
|
530 |
+
ext = kwargs.get("file_extension")
|
531 |
+
extensions = [ext] if ext is not None else []
|
532 |
+
|
533 |
+
# Get extension alternatives from the path and puremagic
|
534 |
+
base, ext = os.path.splitext(path)
|
535 |
+
self._append_ext(extensions, ext)
|
536 |
+
self._append_ext(extensions, self._guess_ext_magic(path))
|
537 |
+
|
538 |
+
# Convert
|
539 |
+
return self._convert(path, extensions, **kwargs)
|
540 |
+
|
541 |
+
def convert_url(self, url, **kwargs):
|
542 |
+
# Send a HTTP request to the URL
|
543 |
+
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
|
544 |
+
response = self._requests_session.get(url, stream=True, headers={"User-Agent": user_agent})
|
545 |
+
response.raise_for_status()
|
546 |
+
return self.convert_response(response, **kwargs)
|
547 |
+
|
548 |
+
def convert_response(self, response, **kwargs):
|
549 |
+
# Prepare a list of extensions to try (in order of priority)
|
550 |
+
ext = kwargs.get("file_extension")
|
551 |
+
extensions = [ext] if ext is not None else []
|
552 |
+
|
553 |
+
# Guess from the mimetype
|
554 |
+
content_type = response.headers.get("content-type", "").split(";")[0]
|
555 |
+
self._append_ext(extensions, mimetypes.guess_extension(content_type))
|
556 |
+
|
557 |
+
# Read the content disposition if there is one
|
558 |
+
content_disposition = response.headers.get("content-disposition", "")
|
559 |
+
m = re.search(r"filename=([^;]+)", content_disposition)
|
560 |
+
if m:
|
561 |
+
base, ext = os.path.splitext(m.group(1).strip("\"'"))
|
562 |
+
self._append_ext(extensions, ext)
|
563 |
+
|
564 |
+
# Read from the extension from the path
|
565 |
+
base, ext = os.path.splitext(urlparse(response.url).path)
|
566 |
+
self._append_ext(extensions, ext)
|
567 |
+
|
568 |
+
# Save the file locally to a temporary file. It will be deleted before this method exits
|
569 |
+
handle, temp_path = tempfile.mkstemp()
|
570 |
+
fh = os.fdopen(handle, "wb")
|
571 |
+
result = None
|
572 |
+
try:
|
573 |
+
# Download the file
|
574 |
+
for chunk in response.iter_content(chunk_size=512):
|
575 |
+
fh.write(chunk)
|
576 |
+
fh.close()
|
577 |
+
|
578 |
+
# Use puremagic to check for more extension options
|
579 |
+
self._append_ext(extensions, self._guess_ext_magic(temp_path))
|
580 |
+
|
581 |
+
# Convert
|
582 |
+
result = self._convert(temp_path, extensions, url=response.url)
|
583 |
+
except Exception as e:
|
584 |
+
print(f"Error in converting: {e}")
|
585 |
+
|
586 |
+
# Clean up
|
587 |
+
finally:
|
588 |
+
try:
|
589 |
+
fh.close()
|
590 |
+
except:
|
591 |
+
pass
|
592 |
+
os.unlink(temp_path)
|
593 |
+
|
594 |
+
return result
|
595 |
+
|
596 |
+
def _convert(self, local_path, extensions, **kwargs):
|
597 |
+
error_trace = ""
|
598 |
+
for ext in extensions:
|
599 |
+
for converter in self._page_converters:
|
600 |
+
_kwargs = copy.deepcopy(kwargs)
|
601 |
+
_kwargs.update({"file_extension": ext})
|
602 |
+
# If we hit an error log it and keep trying
|
603 |
+
try:
|
604 |
+
res = converter.convert(local_path, **_kwargs)
|
605 |
+
if res is not None:
|
606 |
+
# Normalize the content
|
607 |
+
res.text_content = "\n".join([line.rstrip() for line in re.split(r"\r?\n", res.text_content)])
|
608 |
+
res.text_content = re.sub(r"\n{3,}", "\n\n", res.text_content)
|
609 |
+
|
610 |
+
# Todo
|
611 |
+
return res
|
612 |
+
except Exception as e:
|
613 |
+
error_trace = ("\n\n" + traceback.format_exc()).strip()
|
614 |
+
|
615 |
+
|
616 |
+
# If we got this far without success, report any exceptions
|
617 |
+
if len(error_trace) > 0:
|
618 |
+
raise FileConversionException(
|
619 |
+
f"Could not convert '{local_path}' to Markdown. File type was recognized as {extensions}. While converting the file, the following error was encountered:\n\n{error_trace}"
|
620 |
+
)
|
621 |
+
|
622 |
+
# Nothing can handle it!
|
623 |
+
# raise UnsupportedFormatException(
|
624 |
+
# f"Could not convert '{local_path}' to Markdown. The formats {extensions} are not supported."
|
625 |
+
# )
|
626 |
+
res = PlainTextConverter().convert(local_path, **kwargs)
|
627 |
+
return res
|
628 |
+
|
629 |
+
def _append_ext(self, extensions, ext):
|
630 |
+
"""Append a unique non-None, non-empty extension to a list of extensions."""
|
631 |
+
if ext is None:
|
632 |
+
return
|
633 |
+
ext = ext.strip()
|
634 |
+
if ext == "":
|
635 |
+
return
|
636 |
+
# if ext not in extensions:
|
637 |
+
if True:
|
638 |
+
extensions.append(ext)
|
639 |
+
|
640 |
+
def _guess_ext_magic(self, path):
|
641 |
+
"""Use puremagic (a Python implementation of libmagic) to guess a file's extension based on the first few bytes."""
|
642 |
+
# Use puremagic to guess
|
643 |
+
try:
|
644 |
+
guesses = puremagic.magic_file(path)
|
645 |
+
if len(guesses) > 0:
|
646 |
+
ext = guesses[0].extension.strip()
|
647 |
+
if len(ext) > 0:
|
648 |
+
return ext
|
649 |
+
except FileNotFoundError:
|
650 |
+
pass
|
651 |
+
except IsADirectoryError:
|
652 |
+
pass
|
653 |
+
except PermissionError:
|
654 |
+
pass
|
655 |
+
return None
|
656 |
+
|
657 |
+
def register_page_converter(self, converter: DocumentConverter) -> None:
|
658 |
+
"""Register a page text converter."""
|
659 |
+
self._page_converters.append(converter)
|
requirements (93).txt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.43.0
|
2 |
+
torch
|
3 |
+
gradio
|
4 |
+
huggingface_hub
|
5 |
+
beautifulsoup4
|
6 |
+
requests
|
7 |
+
gradio_tools
|
8 |
+
accelerate
|
9 |
+
langchain
|
10 |
+
sentence-transformers
|
11 |
+
faiss-cpu
|
12 |
+
langchain_community
|
13 |
+
langchain-huggingface
|
14 |
+
pypdf
|
15 |
+
markdownify
|
16 |
+
urllib3
|
17 |
+
pathvalidate
|
18 |
+
pdfminer.six
|
19 |
+
pdfminer
|
20 |
+
mammoth
|
21 |
+
python-pptx
|
22 |
+
pandas
|
23 |
+
puremagic
|
24 |
+
youtube_transcript_api
|
25 |
+
google-search-results
|
26 |
+
duckduckgo_search
|
27 |
+
cmake
|
28 |
+
numpy
|
29 |
+
pynvml
|
30 |
+
argparse
|
31 |
+
typing
|
32 |
+
tqdm
|
rwkv_cpp_model (2).py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import multiprocessing
|
3 |
+
|
4 |
+
# Pre-import PyTorch, if available.
|
5 |
+
# This fixes "OSError: [WinError 127] The specified procedure could not be found".
|
6 |
+
try:
|
7 |
+
import torch
|
8 |
+
except ModuleNotFoundError:
|
9 |
+
pass
|
10 |
+
|
11 |
+
# I'm sure this is not strictly correct, but let's keep this crutch for now.
|
12 |
+
try:
|
13 |
+
import rwkv_cpp_shared_library
|
14 |
+
except ModuleNotFoundError:
|
15 |
+
from . import rwkv_cpp_shared_library
|
16 |
+
|
17 |
+
from typing import TypeVar, Optional, Tuple, List
|
18 |
+
|
19 |
+
# A value of this type is either a numpy's ndarray or a PyTorch's Tensor.
|
20 |
+
NumpyArrayOrPyTorchTensor: TypeVar = TypeVar('NumpyArrayOrPyTorchTensor')
|
21 |
+
|
22 |
+
class RWKVModel:
|
23 |
+
"""
|
24 |
+
An RWKV model managed by rwkv.cpp library.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
|
30 |
+
model_path: str,
|
31 |
+
thread_count: int = max(1, multiprocessing.cpu_count() // 2),
|
32 |
+
gpu_layer_count: int = 0,
|
33 |
+
**kwargs
|
34 |
+
) -> None:
|
35 |
+
"""
|
36 |
+
Loads the model and prepares it for inference.
|
37 |
+
In case of any error, this method will throw an exception.
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
shared_library : RWKVSharedLibrary
|
42 |
+
rwkv.cpp shared library.
|
43 |
+
model_path : str
|
44 |
+
Path to RWKV model file in ggml format.
|
45 |
+
thread_count : int
|
46 |
+
Thread count to use. If not set, defaults to CPU count / 2.
|
47 |
+
gpu_layer_count : int
|
48 |
+
Count of layers to offload onto the GPU, must be >= 0.
|
49 |
+
See documentation of `gpu_offload_layers` for details about layer offloading.
|
50 |
+
"""
|
51 |
+
|
52 |
+
if 'gpu_layers_count' in kwargs:
|
53 |
+
gpu_layer_count = kwargs['gpu_layers_count']
|
54 |
+
|
55 |
+
if not os.path.isfile(model_path):
|
56 |
+
raise ValueError(f'{model_path} is not a file')
|
57 |
+
|
58 |
+
if not (thread_count > 0):
|
59 |
+
raise ValueError('Thread count must be > 0')
|
60 |
+
|
61 |
+
if not (gpu_layer_count >= 0):
|
62 |
+
raise ValueError('GPU layer count must be >= 0')
|
63 |
+
|
64 |
+
self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library
|
65 |
+
|
66 |
+
self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count)
|
67 |
+
|
68 |
+
if gpu_layer_count > 0:
|
69 |
+
self.gpu_offload_layers(gpu_layer_count)
|
70 |
+
|
71 |
+
self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx)
|
72 |
+
self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
|
73 |
+
|
74 |
+
self._valid: bool = True
|
75 |
+
|
76 |
+
def gpu_offload_layers(self, layer_count: int) -> bool:
|
77 |
+
"""
|
78 |
+
Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
|
79 |
+
For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
|
80 |
+
- pass `model.n_layer` to offload all layers except model head
|
81 |
+
- pass `model.n_layer + 1` to offload all layers, including model head
|
82 |
+
|
83 |
+
Returns true if at least one layer was offloaded.
|
84 |
+
If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
|
85 |
+
|
86 |
+
Parameters
|
87 |
+
----------
|
88 |
+
layer_count : int
|
89 |
+
Count of layers to offload onto the GPU, must be >= 0.
|
90 |
+
"""
|
91 |
+
|
92 |
+
if not (layer_count >= 0):
|
93 |
+
raise ValueError('Layer count must be >= 0')
|
94 |
+
|
95 |
+
return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)
|
96 |
+
|
97 |
+
@property
|
98 |
+
def n_vocab(self) -> int:
|
99 |
+
return self._library.rwkv_get_n_vocab(self._ctx)
|
100 |
+
|
101 |
+
@property
|
102 |
+
def n_embed(self) -> int:
|
103 |
+
return self._library.rwkv_get_n_embed(self._ctx)
|
104 |
+
|
105 |
+
@property
|
106 |
+
def n_layer(self) -> int:
|
107 |
+
return self._library.rwkv_get_n_layer(self._ctx)
|
108 |
+
|
109 |
+
def eval(
|
110 |
+
self,
|
111 |
+
token: int,
|
112 |
+
state_in: Optional[NumpyArrayOrPyTorchTensor],
|
113 |
+
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
114 |
+
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
115 |
+
use_numpy: bool = False
|
116 |
+
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
|
117 |
+
"""
|
118 |
+
Evaluates the model for a single token.
|
119 |
+
In case of any error, this method will throw an exception.
|
120 |
+
|
121 |
+
Parameters
|
122 |
+
----------
|
123 |
+
token : int
|
124 |
+
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
|
125 |
+
state_in : Optional[NumpyArrayOrTorchTensor]
|
126 |
+
State from previous call of this method. If this is a first pass, set it to None.
|
127 |
+
state_out : Optional[NumpyArrayOrTorchTensor]
|
128 |
+
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
|
129 |
+
logits_out : Optional[NumpyArrayOrTorchTensor]
|
130 |
+
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
|
131 |
+
use_numpy : bool
|
132 |
+
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
|
133 |
+
This parameter is ignored if any tensor parameter is not None; in such case,
|
134 |
+
type of returned tensors will match the type of received tensors.
|
135 |
+
|
136 |
+
Returns
|
137 |
+
-------
|
138 |
+
logits, state
|
139 |
+
Logits vector of shape (n_vocab); state for the next step.
|
140 |
+
"""
|
141 |
+
|
142 |
+
if not self._valid:
|
143 |
+
raise ValueError('Model was freed')
|
144 |
+
|
145 |
+
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
146 |
+
|
147 |
+
if state_in is not None:
|
148 |
+
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
|
149 |
+
|
150 |
+
state_in_ptr = self._get_data_ptr(state_in)
|
151 |
+
else:
|
152 |
+
state_in_ptr = 0
|
153 |
+
|
154 |
+
if state_out is not None:
|
155 |
+
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
|
156 |
+
else:
|
157 |
+
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
|
158 |
+
|
159 |
+
if logits_out is not None:
|
160 |
+
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
|
161 |
+
else:
|
162 |
+
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
|
163 |
+
|
164 |
+
self._library.rwkv_eval(
|
165 |
+
self._ctx,
|
166 |
+
token,
|
167 |
+
state_in_ptr,
|
168 |
+
self._get_data_ptr(state_out),
|
169 |
+
self._get_data_ptr(logits_out)
|
170 |
+
)
|
171 |
+
|
172 |
+
return logits_out, state_out
|
173 |
+
|
174 |
+
def eval_sequence(
|
175 |
+
self,
|
176 |
+
tokens: List[int],
|
177 |
+
state_in: Optional[NumpyArrayOrPyTorchTensor],
|
178 |
+
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
179 |
+
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
180 |
+
use_numpy: bool = False
|
181 |
+
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
|
182 |
+
"""
|
183 |
+
Evaluates the model for a sequence of tokens.
|
184 |
+
|
185 |
+
NOTE ON GGML NODE LIMIT
|
186 |
+
|
187 |
+
ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
|
188 |
+
this limit when using large models and/or large sequence lengths.
|
189 |
+
Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
|
190 |
+
|
191 |
+
If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
|
192 |
+
To get rid of the assertion failure, reduce the model size and/or sequence length.
|
193 |
+
|
194 |
+
In case of any error, this method will throw an exception.
|
195 |
+
|
196 |
+
Parameters
|
197 |
+
----------
|
198 |
+
tokens : List[int]
|
199 |
+
Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
|
200 |
+
state_in : Optional[NumpyArrayOrTorchTensor]
|
201 |
+
State from previous call of this method. If this is a first pass, set it to None.
|
202 |
+
state_out : Optional[NumpyArrayOrTorchTensor]
|
203 |
+
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
|
204 |
+
logits_out : Optional[NumpyArrayOrTorchTensor]
|
205 |
+
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
|
206 |
+
use_numpy : bool
|
207 |
+
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
|
208 |
+
This parameter is ignored if any tensor parameter is not None; in such case,
|
209 |
+
type of returned tensors will match the type of received tensors.
|
210 |
+
|
211 |
+
Returns
|
212 |
+
-------
|
213 |
+
logits, state
|
214 |
+
Logits vector of shape (n_vocab); state for the next step.
|
215 |
+
"""
|
216 |
+
|
217 |
+
if not self._valid:
|
218 |
+
raise ValueError('Model was freed')
|
219 |
+
|
220 |
+
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
221 |
+
|
222 |
+
if state_in is not None:
|
223 |
+
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
|
224 |
+
|
225 |
+
state_in_ptr = self._get_data_ptr(state_in)
|
226 |
+
else:
|
227 |
+
state_in_ptr = 0
|
228 |
+
|
229 |
+
if state_out is not None:
|
230 |
+
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
|
231 |
+
else:
|
232 |
+
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
|
233 |
+
|
234 |
+
if logits_out is not None:
|
235 |
+
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
|
236 |
+
else:
|
237 |
+
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
|
238 |
+
|
239 |
+
self._library.rwkv_eval_sequence(
|
240 |
+
self._ctx,
|
241 |
+
tokens,
|
242 |
+
state_in_ptr,
|
243 |
+
self._get_data_ptr(state_out),
|
244 |
+
self._get_data_ptr(logits_out)
|
245 |
+
)
|
246 |
+
|
247 |
+
return logits_out, state_out
|
248 |
+
|
249 |
+
def eval_sequence_in_chunks(
|
250 |
+
self,
|
251 |
+
tokens: List[int],
|
252 |
+
state_in: Optional[NumpyArrayOrPyTorchTensor],
|
253 |
+
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
254 |
+
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
255 |
+
chunk_size: int = 16,
|
256 |
+
use_numpy: bool = False
|
257 |
+
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
|
258 |
+
"""
|
259 |
+
Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
|
260 |
+
This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
|
261 |
+
It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance.
|
262 |
+
|
263 |
+
Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
|
264 |
+
A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
|
265 |
+
and choose one that works the best in your use case.
|
266 |
+
|
267 |
+
In case of any error, this method will throw an exception.
|
268 |
+
|
269 |
+
Parameters
|
270 |
+
----------
|
271 |
+
tokens : List[int]
|
272 |
+
Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
|
273 |
+
chunk_size : int
|
274 |
+
Size of each chunk in tokens, must be positive.
|
275 |
+
state_in : Optional[NumpyArrayOrTorchTensor]
|
276 |
+
State from previous call of this method. If this is a first pass, set it to None.
|
277 |
+
state_out : Optional[NumpyArrayOrTorchTensor]
|
278 |
+
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
|
279 |
+
logits_out : Optional[NumpyArrayOrTorchTensor]
|
280 |
+
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
|
281 |
+
use_numpy : bool
|
282 |
+
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
|
283 |
+
This parameter is ignored if any tensor parameter is not None; in such case,
|
284 |
+
type of returned tensors will match the type of received tensors.
|
285 |
+
|
286 |
+
Returns
|
287 |
+
-------
|
288 |
+
logits, state
|
289 |
+
Logits vector of shape (n_vocab); state for the next step.
|
290 |
+
"""
|
291 |
+
|
292 |
+
if not self._valid:
|
293 |
+
raise ValueError('Model was freed')
|
294 |
+
|
295 |
+
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
296 |
+
|
297 |
+
if state_in is not None:
|
298 |
+
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
|
299 |
+
|
300 |
+
state_in_ptr = self._get_data_ptr(state_in)
|
301 |
+
else:
|
302 |
+
state_in_ptr = 0
|
303 |
+
|
304 |
+
if state_out is not None:
|
305 |
+
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
|
306 |
+
else:
|
307 |
+
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
|
308 |
+
|
309 |
+
if logits_out is not None:
|
310 |
+
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
|
311 |
+
else:
|
312 |
+
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
|
313 |
+
|
314 |
+
self._library.rwkv_eval_sequence_in_chunks(
|
315 |
+
self._ctx,
|
316 |
+
tokens,
|
317 |
+
chunk_size,
|
318 |
+
state_in_ptr,
|
319 |
+
self._get_data_ptr(state_out),
|
320 |
+
self._get_data_ptr(logits_out)
|
321 |
+
)
|
322 |
+
|
323 |
+
return logits_out, state_out
|
324 |
+
|
325 |
+
def free(self) -> None:
|
326 |
+
"""
|
327 |
+
Frees all allocated resources.
|
328 |
+
In case of any error, this method will throw an exception.
|
329 |
+
The object must not be used anymore after calling this method.
|
330 |
+
"""
|
331 |
+
|
332 |
+
if not self._valid:
|
333 |
+
raise ValueError('Already freed')
|
334 |
+
|
335 |
+
self._valid = False
|
336 |
+
|
337 |
+
self._library.rwkv_free(self._ctx)
|
338 |
+
|
339 |
+
def __del__(self) -> None:
|
340 |
+
# Free the context on GC in case user forgot to call free() explicitly.
|
341 |
+
if hasattr(self, '_valid') and self._valid:
|
342 |
+
self.free()
|
343 |
+
|
344 |
+
def _is_pytorch_tensor(self, tensor: NumpyArrayOrPyTorchTensor) -> bool:
|
345 |
+
return hasattr(tensor, '__module__') and tensor.__module__ == 'torch'
|
346 |
+
|
347 |
+
def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]], use_numpy_by_default: bool) -> bool:
|
348 |
+
for tensor in tensors:
|
349 |
+
if tensor is not None:
|
350 |
+
return False if self._is_pytorch_tensor(tensor) else True
|
351 |
+
|
352 |
+
return use_numpy_by_default
|
353 |
+
|
354 |
+
def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
|
355 |
+
if self._is_pytorch_tensor(tensor):
|
356 |
+
tensor: torch.Tensor = tensor
|
357 |
+
|
358 |
+
if tensor.device != torch.device('cpu'):
|
359 |
+
raise ValueError(f'{name} is not on CPU')
|
360 |
+
if tensor.dtype != torch.float32:
|
361 |
+
raise ValueError(f'{name} is not of type float32')
|
362 |
+
if tensor.shape != (size,):
|
363 |
+
raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
|
364 |
+
if not tensor.is_contiguous():
|
365 |
+
raise ValueError(f'{name} is not contiguous')
|
366 |
+
else:
|
367 |
+
import numpy as np
|
368 |
+
tensor: np.ndarray = tensor
|
369 |
+
|
370 |
+
if tensor.dtype != np.float32:
|
371 |
+
raise ValueError(f'{name} is not of type float32')
|
372 |
+
if tensor.shape != (size,):
|
373 |
+
raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
|
374 |
+
if not tensor.data.contiguous:
|
375 |
+
raise ValueError(f'{name} is not contiguous')
|
376 |
+
|
377 |
+
def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
|
378 |
+
if self._is_pytorch_tensor(tensor):
|
379 |
+
return tensor.data_ptr()
|
380 |
+
else:
|
381 |
+
return tensor.ctypes.data
|
382 |
+
|
383 |
+
def _zeros_float32(self, element_count: int, use_numpy: bool) -> NumpyArrayOrPyTorchTensor:
|
384 |
+
if use_numpy:
|
385 |
+
import numpy as np
|
386 |
+
return np.zeros(element_count, dtype=np.float32)
|
387 |
+
else:
|
388 |
+
return torch.zeros(element_count, dtype=torch.float32, device='cpu')
|
rwkv_cpp_shared_library (2).py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import ctypes
|
4 |
+
import pathlib
|
5 |
+
import platform
|
6 |
+
from typing import Optional, List, Tuple, Callable
|
7 |
+
|
8 |
+
QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = (
|
9 |
+
'Q4_0',
|
10 |
+
'Q4_1',
|
11 |
+
'Q5_0',
|
12 |
+
'Q5_1',
|
13 |
+
'Q8_0'
|
14 |
+
)
|
15 |
+
|
16 |
+
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
17 |
+
P_INT = ctypes.POINTER(ctypes.c_int32)
|
18 |
+
|
19 |
+
class RWKVContext:
|
20 |
+
|
21 |
+
def __init__(self, ptr: ctypes.pointer) -> None:
|
22 |
+
self.ptr: ctypes.pointer = ptr
|
23 |
+
|
24 |
+
class RWKVSharedLibrary:
|
25 |
+
"""
|
26 |
+
Python wrapper around rwkv.cpp shared library.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, shared_library_path: str) -> None:
|
30 |
+
"""
|
31 |
+
Loads the shared library from specified file.
|
32 |
+
In case of any error, this method will throw an exception.
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
shared_library_path : str
|
37 |
+
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
|
38 |
+
"""
|
39 |
+
# When Python is greater than 3.8, we need to reprocess the custom dll
|
40 |
+
# according to the documentation to prevent loading failure errors.
|
41 |
+
# https://docs.python.org/3/whatsnew/3.8.html#ctypes
|
42 |
+
if platform.system().lower() == 'windows':
|
43 |
+
self.library = ctypes.CDLL(shared_library_path, winmode=0)
|
44 |
+
else:
|
45 |
+
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
46 |
+
|
47 |
+
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
|
48 |
+
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
49 |
+
|
50 |
+
self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
|
51 |
+
self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
|
52 |
+
|
53 |
+
self.library.rwkv_eval.argtypes = [
|
54 |
+
ctypes.c_void_p, # ctx
|
55 |
+
ctypes.c_int32, # token
|
56 |
+
P_FLOAT, # state_in
|
57 |
+
P_FLOAT, # state_out
|
58 |
+
P_FLOAT # logits_out
|
59 |
+
]
|
60 |
+
self.library.rwkv_eval.restype = ctypes.c_bool
|
61 |
+
|
62 |
+
self.library.rwkv_eval_sequence.argtypes = [
|
63 |
+
ctypes.c_void_p, # ctx
|
64 |
+
P_INT, # tokens
|
65 |
+
ctypes.c_size_t, # token count
|
66 |
+
P_FLOAT, # state_in
|
67 |
+
P_FLOAT, # state_out
|
68 |
+
P_FLOAT # logits_out
|
69 |
+
]
|
70 |
+
self.library.rwkv_eval_sequence.restype = ctypes.c_bool
|
71 |
+
|
72 |
+
self.library.rwkv_eval_sequence_in_chunks.argtypes = [
|
73 |
+
ctypes.c_void_p, # ctx
|
74 |
+
P_INT, # tokens
|
75 |
+
ctypes.c_size_t, # token count
|
76 |
+
ctypes.c_size_t, # chunk size
|
77 |
+
P_FLOAT, # state_in
|
78 |
+
P_FLOAT, # state_out
|
79 |
+
P_FLOAT # logits_out
|
80 |
+
]
|
81 |
+
self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool
|
82 |
+
|
83 |
+
self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
|
84 |
+
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t
|
85 |
+
|
86 |
+
self.library.rwkv_get_n_embed.argtypes = [ctypes.c_void_p]
|
87 |
+
self.library.rwkv_get_n_embed.restype = ctypes.c_size_t
|
88 |
+
|
89 |
+
self.library.rwkv_get_n_layer.argtypes = [ctypes.c_void_p]
|
90 |
+
self.library.rwkv_get_n_layer.restype = ctypes.c_size_t
|
91 |
+
|
92 |
+
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
|
93 |
+
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32
|
94 |
+
|
95 |
+
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
|
96 |
+
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32
|
97 |
+
|
98 |
+
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
99 |
+
self.library.rwkv_free.restype = None
|
100 |
+
|
101 |
+
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
102 |
+
self.library.rwkv_free.restype = None
|
103 |
+
|
104 |
+
self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p]
|
105 |
+
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
106 |
+
|
107 |
+
self.library.rwkv_get_system_info_string.argtypes = []
|
108 |
+
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
|
109 |
+
|
110 |
+
self.nullptr = ctypes.cast(0, ctypes.c_void_p)
|
111 |
+
|
112 |
+
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
|
113 |
+
"""
|
114 |
+
Loads the model from a file and prepares it for inference.
|
115 |
+
Throws an exception in case of any error. Error messages would be printed to stderr.
|
116 |
+
|
117 |
+
Parameters
|
118 |
+
----------
|
119 |
+
model_file_path : str
|
120 |
+
Path to model file in ggml format.
|
121 |
+
thread_count : int
|
122 |
+
Count of threads to use, must be positive.
|
123 |
+
"""
|
124 |
+
|
125 |
+
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
|
126 |
+
|
127 |
+
if ptr is None:
|
128 |
+
raise ValueError('rwkv_init_from_file failed, check stderr')
|
129 |
+
|
130 |
+
return RWKVContext(ptr)
|
131 |
+
|
132 |
+
def rwkv_gpu_offload_layers(self, ctx: RWKVContext, layer_count: int) -> bool:
|
133 |
+
"""
|
134 |
+
Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
|
135 |
+
For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
|
136 |
+
- pass `rwkv_get_n_layer(ctx)` to offload all layers except model head
|
137 |
+
- pass `rwkv_get_n_layer(ctx) + 1` to offload all layers, including model head
|
138 |
+
Returns true if at least one layer was offloaded.
|
139 |
+
If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
|
140 |
+
|
141 |
+
Parameters
|
142 |
+
----------
|
143 |
+
ctx : RWKVContext
|
144 |
+
RWKV context obtained from rwkv_init_from_file.
|
145 |
+
layer_count : int
|
146 |
+
Count of layers to offload onto the GPU, must be >= 0.
|
147 |
+
"""
|
148 |
+
|
149 |
+
if not (layer_count >= 0):
|
150 |
+
raise ValueError('Layer count must be >= 0')
|
151 |
+
|
152 |
+
return self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(layer_count))
|
153 |
+
|
154 |
+
def rwkv_eval(
|
155 |
+
self,
|
156 |
+
ctx: RWKVContext,
|
157 |
+
token: int,
|
158 |
+
state_in_address: Optional[int],
|
159 |
+
state_out_address: int,
|
160 |
+
logits_out_address: int
|
161 |
+
) -> None:
|
162 |
+
"""
|
163 |
+
Evaluates the model for a single token.
|
164 |
+
Throws an exception in case of any error. Error messages would be printed to stderr.
|
165 |
+
Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
166 |
+
|
167 |
+
Parameters
|
168 |
+
----------
|
169 |
+
ctx : RWKVContext
|
170 |
+
RWKV context obtained from rwkv_init_from_file.
|
171 |
+
token : int
|
172 |
+
Next token index, in range 0 <= token < n_vocab.
|
173 |
+
state_in_address : int
|
174 |
+
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
175 |
+
state_out_address : int
|
176 |
+
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
177 |
+
logits_out_address : int
|
178 |
+
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
179 |
+
"""
|
180 |
+
|
181 |
+
if not self.library.rwkv_eval(
|
182 |
+
ctx.ptr,
|
183 |
+
ctypes.c_int32(token),
|
184 |
+
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
185 |
+
ctypes.cast(state_out_address, P_FLOAT),
|
186 |
+
ctypes.cast(logits_out_address, P_FLOAT)
|
187 |
+
):
|
188 |
+
raise ValueError('rwkv_eval failed, check stderr')
|
189 |
+
|
190 |
+
def rwkv_eval_sequence(
|
191 |
+
self,
|
192 |
+
ctx: RWKVContext,
|
193 |
+
tokens: List[int],
|
194 |
+
state_in_address: Optional[int],
|
195 |
+
state_out_address: int,
|
196 |
+
logits_out_address: int
|
197 |
+
) -> None:
|
198 |
+
"""
|
199 |
+
Evaluates the model for a sequence of tokens.
|
200 |
+
Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so.
|
201 |
+
Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
|
202 |
+
|
203 |
+
NOTE ON GGML NODE LIMIT
|
204 |
+
|
205 |
+
ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
|
206 |
+
this limit when using large models and/or large sequence lengths.
|
207 |
+
Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
|
208 |
+
|
209 |
+
If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
|
210 |
+
To get rid of the assertion failure, reduce the model size and/or sequence length.
|
211 |
+
|
212 |
+
Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
|
213 |
+
Throws an exception in case of any error. Error messages would be printed to stderr.
|
214 |
+
|
215 |
+
Parameters
|
216 |
+
----------
|
217 |
+
ctx : RWKVContext
|
218 |
+
RWKV context obtained from rwkv_init_from_file.
|
219 |
+
tokens : List[int]
|
220 |
+
Next token indices, in range 0 <= token < n_vocab.
|
221 |
+
state_in_address : int
|
222 |
+
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
223 |
+
state_out_address : int
|
224 |
+
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
225 |
+
logits_out_address : int
|
226 |
+
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
227 |
+
"""
|
228 |
+
|
229 |
+
if not self.library.rwkv_eval_sequence(
|
230 |
+
ctx.ptr,
|
231 |
+
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
|
232 |
+
ctypes.c_size_t(len(tokens)),
|
233 |
+
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
234 |
+
ctypes.cast(state_out_address, P_FLOAT),
|
235 |
+
ctypes.cast(logits_out_address, P_FLOAT)
|
236 |
+
):
|
237 |
+
raise ValueError('rwkv_eval_sequence failed, check stderr')
|
238 |
+
|
239 |
+
def rwkv_eval_sequence_in_chunks(
|
240 |
+
self,
|
241 |
+
ctx: RWKVContext,
|
242 |
+
tokens: List[int],
|
243 |
+
chunk_size: int,
|
244 |
+
state_in_address: Optional[int],
|
245 |
+
state_out_address: int,
|
246 |
+
logits_out_address: int
|
247 |
+
) -> None:
|
248 |
+
"""
|
249 |
+
Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
|
250 |
+
This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
|
251 |
+
It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance.
|
252 |
+
|
253 |
+
Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
|
254 |
+
A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
|
255 |
+
and choose one that works the best in your use case.
|
256 |
+
|
257 |
+
Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
|
258 |
+
Throws an exception in case of any error. Error messages would be printed to stderr.
|
259 |
+
|
260 |
+
Parameters
|
261 |
+
----------
|
262 |
+
ctx : RWKVContext
|
263 |
+
RWKV context obtained from rwkv_init_from_file.
|
264 |
+
tokens : List[int]
|
265 |
+
Next token indices, in range 0 <= token < n_vocab.
|
266 |
+
chunk_size : int
|
267 |
+
Size of each chunk in tokens, must be positive.
|
268 |
+
state_in_address : int
|
269 |
+
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
270 |
+
state_out_address : int
|
271 |
+
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
272 |
+
logits_out_address : int
|
273 |
+
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
274 |
+
"""
|
275 |
+
|
276 |
+
if not self.library.rwkv_eval_sequence_in_chunks(
|
277 |
+
ctx.ptr,
|
278 |
+
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
|
279 |
+
ctypes.c_size_t(len(tokens)),
|
280 |
+
ctypes.c_size_t(chunk_size),
|
281 |
+
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
282 |
+
ctypes.cast(state_out_address, P_FLOAT),
|
283 |
+
ctypes.cast(logits_out_address, P_FLOAT)
|
284 |
+
):
|
285 |
+
raise ValueError('rwkv_eval_sequence_in_chunks failed, check stderr')
|
286 |
+
|
287 |
+
def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
|
288 |
+
"""
|
289 |
+
Returns the number of tokens in the given model's vocabulary.
|
290 |
+
Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
|
291 |
+
|
292 |
+
Parameters
|
293 |
+
----------
|
294 |
+
ctx : RWKVContext
|
295 |
+
RWKV context obtained from rwkv_init_from_file.
|
296 |
+
"""
|
297 |
+
|
298 |
+
return self.library.rwkv_get_n_vocab(ctx.ptr)
|
299 |
+
|
300 |
+
def rwkv_get_n_embed(self, ctx: RWKVContext) -> int:
|
301 |
+
"""
|
302 |
+
Returns the number of elements in the given model's embedding.
|
303 |
+
Useful for reading individual fields of a model's hidden state.
|
304 |
+
|
305 |
+
Parameters
|
306 |
+
----------
|
307 |
+
ctx : RWKVContext
|
308 |
+
RWKV context obtained from rwkv_init_from_file.
|
309 |
+
"""
|
310 |
+
|
311 |
+
return self.library.rwkv_get_n_embed(ctx.ptr)
|
312 |
+
|
313 |
+
def rwkv_get_n_layer(self, ctx: RWKVContext) -> int:
|
314 |
+
"""
|
315 |
+
Returns the number of layers in the given model.
|
316 |
+
A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model.
|
317 |
+
Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`.
|
318 |
+
Useful for always offloading the entire model to GPU.
|
319 |
+
|
320 |
+
Parameters
|
321 |
+
----------
|
322 |
+
ctx : RWKVContext
|
323 |
+
RWKV context obtained from rwkv_init_from_file.
|
324 |
+
"""
|
325 |
+
|
326 |
+
return self.library.rwkv_get_n_layer(ctx.ptr)
|
327 |
+
|
328 |
+
def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int:
|
329 |
+
"""
|
330 |
+
Returns count of FP32 elements in state buffer.
|
331 |
+
|
332 |
+
Parameters
|
333 |
+
----------
|
334 |
+
ctx : RWKVContext
|
335 |
+
RWKV context obtained from rwkv_init_from_file.
|
336 |
+
"""
|
337 |
+
|
338 |
+
return self.library.rwkv_get_state_buffer_element_count(ctx.ptr)
|
339 |
+
|
340 |
+
def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int:
|
341 |
+
"""
|
342 |
+
Returns count of FP32 elements in logits buffer.
|
343 |
+
|
344 |
+
Parameters
|
345 |
+
----------
|
346 |
+
ctx : RWKVContext
|
347 |
+
RWKV context obtained from rwkv_init_from_file.
|
348 |
+
"""
|
349 |
+
|
350 |
+
return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr)
|
351 |
+
|
352 |
+
def rwkv_free(self, ctx: RWKVContext) -> None:
|
353 |
+
"""
|
354 |
+
Frees all allocated memory and the context.
|
355 |
+
|
356 |
+
Parameters
|
357 |
+
----------
|
358 |
+
ctx : RWKVContext
|
359 |
+
RWKV context obtained from rwkv_init_from_file.
|
360 |
+
"""
|
361 |
+
|
362 |
+
self.library.rwkv_free(ctx.ptr)
|
363 |
+
|
364 |
+
ctx.ptr = self.nullptr
|
365 |
+
|
366 |
+
def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None:
|
367 |
+
"""
|
368 |
+
Quantizes FP32 or FP16 model to one of INT4 formats.
|
369 |
+
Throws an exception in case of any error. Error messages would be printed to stderr.
|
370 |
+
|
371 |
+
Parameters
|
372 |
+
----------
|
373 |
+
model_file_path_in : str
|
374 |
+
Path to model file in ggml format, must be either FP32 or FP16.
|
375 |
+
model_file_path_out : str
|
376 |
+
Quantized model will be written here.
|
377 |
+
format_name : str
|
378 |
+
One of QUANTIZED_FORMAT_NAMES.
|
379 |
+
"""
|
380 |
+
|
381 |
+
if format_name not in QUANTIZED_FORMAT_NAMES:
|
382 |
+
raise ValueError(f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}')
|
383 |
+
|
384 |
+
if not self.library.rwkv_quantize_model_file(
|
385 |
+
model_file_path_in.encode('utf-8'),
|
386 |
+
model_file_path_out.encode('utf-8'),
|
387 |
+
format_name.encode('utf-8')
|
388 |
+
):
|
389 |
+
raise ValueError('rwkv_quantize_model_file failed, check stderr')
|
390 |
+
|
391 |
+
def rwkv_get_system_info_string(self) -> str:
|
392 |
+
"""
|
393 |
+
Returns system information string.
|
394 |
+
"""
|
395 |
+
|
396 |
+
return self.library.rwkv_get_system_info_string().decode('utf-8')
|
397 |
+
|
398 |
+
def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
399 |
+
"""
|
400 |
+
Attempts to find rwkv.cpp shared library and load it.
|
401 |
+
To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly.
|
402 |
+
"""
|
403 |
+
|
404 |
+
file_name: str
|
405 |
+
|
406 |
+
if 'win32' in sys.platform or 'cygwin' in sys.platform:
|
407 |
+
file_name = 'rwkv.dll'
|
408 |
+
elif 'darwin' in sys.platform:
|
409 |
+
file_name = 'librwkv.dylib'
|
410 |
+
else:
|
411 |
+
file_name = 'librwkv.so'
|
412 |
+
|
413 |
+
# Possible sub-paths to the library relative to the repo dir.
|
414 |
+
child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [
|
415 |
+
# No lookup for Debug config here.
|
416 |
+
# I assume that if a user wants to debug the library,
|
417 |
+
# they will be able to find the library and set the exact path explicitly.
|
418 |
+
lambda p: p / 'bin' / 'Release' / file_name,
|
419 |
+
lambda p: p / 'bin' / file_name,
|
420 |
+
# Some people prefer to build in the "build" subdirectory.
|
421 |
+
lambda p: p / 'build' / 'bin' / 'Release' / file_name,
|
422 |
+
lambda p: p / 'build' / 'bin' / file_name,
|
423 |
+
lambda p: p / 'build' / file_name,
|
424 |
+
# Fallback.
|
425 |
+
lambda p: p / file_name
|
426 |
+
]
|
427 |
+
|
428 |
+
working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd()))
|
429 |
+
|
430 |
+
parent_paths: List[pathlib.Path] = [
|
431 |
+
# Possible repo dirs relative to the working dir.
|
432 |
+
# ./python/rwkv_cpp
|
433 |
+
working_dir.parent.parent,
|
434 |
+
# ./python
|
435 |
+
working_dir.parent,
|
436 |
+
# .
|
437 |
+
working_dir,
|
438 |
+
# Repo dir relative to this Python file.
|
439 |
+
pathlib.Path(os.path.abspath(__file__)).parent.parent.parent
|
440 |
+
]
|
441 |
+
|
442 |
+
for parent_path in parent_paths:
|
443 |
+
for child_path in child_paths:
|
444 |
+
full_path: pathlib.Path = child_path(parent_path)
|
445 |
+
|
446 |
+
if os.path.isfile(full_path):
|
447 |
+
return RWKVSharedLibrary(str(full_path))
|
448 |
+
|
449 |
+
raise ValueError(f'Failed to find {file_name} automatically; '
|
450 |
+
f'you need to find the library and create RWKVSharedLibrary specifying the path to it')
|
rwkv_world_tokenizer (2).py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
from typing import List, Set, Tuple, Callable
|
4 |
+
|
5 |
+
# Taken from https://github.com/BlinkDL/ChatRWKV/tree/main/tokenizer/rwkv_tokenizer.py
|
6 |
+
|
7 |
+
class Trie:
|
8 |
+
__slots__ = ('ch', 'to', 'values', 'front')
|
9 |
+
|
10 |
+
def __init__(self, front=None, ch=None) -> None:
|
11 |
+
self.ch = ch
|
12 |
+
self.to: List = [None for _ in range(256)]
|
13 |
+
self.values: Set = set()
|
14 |
+
self.front = front
|
15 |
+
|
16 |
+
def add(self, key: bytes, idx: int = 0, val=None) -> 'Trie':
|
17 |
+
if idx == len(key):
|
18 |
+
if val is None:
|
19 |
+
val = key
|
20 |
+
|
21 |
+
self.values.add(val)
|
22 |
+
|
23 |
+
return self
|
24 |
+
|
25 |
+
ch = key[idx]
|
26 |
+
|
27 |
+
if self.to[ch] is None:
|
28 |
+
self.to[ch] = Trie(front=self, ch=ch)
|
29 |
+
|
30 |
+
return self.to[ch].add(key, idx=idx + 1, val=val)
|
31 |
+
|
32 |
+
def find_longest(self, key: bytes, idx: int = 0) -> Tuple[int, 'Trie', set]:
|
33 |
+
u: Trie = self
|
34 |
+
ch: int = key[idx]
|
35 |
+
ret = None
|
36 |
+
|
37 |
+
while u.to[ch] is not None:
|
38 |
+
u = u.to[ch]
|
39 |
+
idx += 1
|
40 |
+
|
41 |
+
if u.values:
|
42 |
+
ret = idx, u, u.values
|
43 |
+
|
44 |
+
if idx == len(key):
|
45 |
+
break
|
46 |
+
|
47 |
+
ch = key[idx]
|
48 |
+
|
49 |
+
if ret is None:
|
50 |
+
raise ValueError('Entry not found')
|
51 |
+
|
52 |
+
return ret
|
53 |
+
|
54 |
+
def __repr__(self) -> str:
|
55 |
+
fr = self
|
56 |
+
ret = []
|
57 |
+
|
58 |
+
while fr is not None:
|
59 |
+
if fr.ch is not None:
|
60 |
+
ret.append(fr.ch)
|
61 |
+
|
62 |
+
fr = fr.front
|
63 |
+
|
64 |
+
return '<TRIE %s %s>' % (ret[::-1], self.values)
|
65 |
+
|
66 |
+
class WorldTokenizer:
|
67 |
+
|
68 |
+
def __init__(self, file_path) -> None:
|
69 |
+
self.index_to_token = {}
|
70 |
+
|
71 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
72 |
+
lines = f.readlines()
|
73 |
+
|
74 |
+
for line in lines:
|
75 |
+
idx = int(line[:line.index(' ')])
|
76 |
+
x = eval(line[line.index(' '):line.rindex(' ')])
|
77 |
+
x = x.encode('utf-8') if isinstance(x, str) else x
|
78 |
+
assert isinstance(x, bytes)
|
79 |
+
assert len(x) == int(line[line.rindex(' '):])
|
80 |
+
self.index_to_token[idx] = x
|
81 |
+
|
82 |
+
self.token_to_index = {}
|
83 |
+
|
84 |
+
for k, v in self.index_to_token.items():
|
85 |
+
self.token_to_index[v] = int(k)
|
86 |
+
|
87 |
+
self.root = Trie()
|
88 |
+
|
89 |
+
for t, i in self.token_to_index.items():
|
90 |
+
_ = self.root.add(t, val=(t, i))
|
91 |
+
|
92 |
+
def encode_bytes(self, src: bytes) -> List[int]:
|
93 |
+
idx: int = 0
|
94 |
+
tokens: List[int] = []
|
95 |
+
|
96 |
+
while idx < len(src):
|
97 |
+
_idx: int = idx
|
98 |
+
idx, _, values = self.root.find_longest(src, idx)
|
99 |
+
assert (idx != _idx)
|
100 |
+
_, token = next(iter(values))
|
101 |
+
tokens.append(token)
|
102 |
+
|
103 |
+
return tokens
|
104 |
+
|
105 |
+
def decode_bytes(self, tokens: List[int]) -> bytes:
|
106 |
+
return b''.join(map(lambda i: self.index_to_token[i], tokens))
|
107 |
+
|
108 |
+
def encode(self, src: str) -> List[int]:
|
109 |
+
return self.encode_bytes(src.encode('utf-8'))
|
110 |
+
|
111 |
+
def decode(self, tokens: List[int]) -> str:
|
112 |
+
# 'replace' error handling mode will insert \uFFFD characters in place of malformed/partial UTF-8 sequences.
|
113 |
+
# Downstream code needs to detect \uFFFD and attempt to postpone decoding until more tokens arrive and UTF-8 sequences are complete.
|
114 |
+
return self.decode_bytes(tokens).decode('utf-8', errors='replace')
|
115 |
+
|
116 |
+
def get_world_tokenizer_v20230424() -> Tuple[
|
117 |
+
Callable[[List[int]], str],
|
118 |
+
Callable[[str], List[int]]
|
119 |
+
]:
|
120 |
+
"""
|
121 |
+
Loads the default World tokenizer, commonly used in RWKV v4 World models.
|
122 |
+
Returns a tuple of `decode` and `encode` functions.
|
123 |
+
"""
|
124 |
+
parent: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
|
125 |
+
tokenizer: WorldTokenizer = WorldTokenizer(parent / 'rwkv_vocab_v20230424.txt')
|
126 |
+
return tokenizer.decode, tokenizer.encode
|
sampling (2).py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
# https://stackoverflow.com/a/50425683
|
5 |
+
def softmax(x: np.ndarray, axis: int):
|
6 |
+
x -= x.max(axis=axis, keepdims=True)
|
7 |
+
e: np.ndarray = np.exp(x)
|
8 |
+
return e / e.sum(axis=axis, keepdims=True)
|
9 |
+
|
10 |
+
def sample_logits(out, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int:
|
11 |
+
if hasattr(out, '__module__') and out.__module__ == 'torch':
|
12 |
+
out = out.cpu().numpy()
|
13 |
+
|
14 |
+
probs: np.ndarray = softmax(out, axis=-1)
|
15 |
+
|
16 |
+
return sample_probs(probs, temperature, top_p, logit_bias)
|
17 |
+
|
18 |
+
def sample_probs(probs: np.ndarray, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int:
|
19 |
+
if not (0.0 <= temperature):
|
20 |
+
raise ValueError('temperature')
|
21 |
+
if not (0.0 <= top_p <= 1.0):
|
22 |
+
raise ValueError('top_p')
|
23 |
+
|
24 |
+
if top_p == 0.0:
|
25 |
+
top_p = 1.0
|
26 |
+
|
27 |
+
if logit_bias is not None and len(logit_bias) > 0:
|
28 |
+
logits: np.ndarray = np.log(probs)
|
29 |
+
|
30 |
+
ids, values = zip(*logit_bias.items())
|
31 |
+
logits[list(ids)] += values
|
32 |
+
|
33 |
+
# Makes calculation more numerically stable, does not change the result
|
34 |
+
logits -= logits.max(axis=-1, keepdims=True)
|
35 |
+
|
36 |
+
probs = np.exp(logits) / np.sum(np.exp(logits))
|
37 |
+
|
38 |
+
if temperature == 0.0:
|
39 |
+
return np.argmax(probs).item()
|
40 |
+
|
41 |
+
if top_p < 1.0:
|
42 |
+
sorted_probs = np.sort(probs)[::-1]
|
43 |
+
cumulative_probs = np.cumsum(sorted_probs)
|
44 |
+
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
45 |
+
probs[probs < cutoff] = 0
|
46 |
+
|
47 |
+
if temperature != 1.0:
|
48 |
+
probs = np.power(probs, 1.0 / temperature)
|
49 |
+
|
50 |
+
probs = probs / np.sum(probs)
|
51 |
+
|
52 |
+
return np.random.choice(a=len(probs), p=probs)
|
tokenizer_util (2).py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
import rwkv_world_tokenizer
|
4 |
+
from typing import List, Tuple, Callable
|
5 |
+
|
6 |
+
def add_tokenizer_argument(parser) -> None:
|
7 |
+
parser.add_argument(
|
8 |
+
'tokenizer',
|
9 |
+
help='Tokenizer to use; supported tokenizers: auto (guess from n_vocab), 20B, world',
|
10 |
+
nargs='?',
|
11 |
+
type=str,
|
12 |
+
default='auto'
|
13 |
+
)
|
14 |
+
|
15 |
+
def get_tokenizer(tokenizer_name: str, n_vocab: int) -> Tuple[
|
16 |
+
Callable[[List[int]], str],
|
17 |
+
Callable[[str], List[int]]
|
18 |
+
]:
|
19 |
+
if tokenizer_name == 'auto':
|
20 |
+
if n_vocab == 50277:
|
21 |
+
tokenizer_name = '20B'
|
22 |
+
elif n_vocab == 65536:
|
23 |
+
tokenizer_name = 'world'
|
24 |
+
else:
|
25 |
+
raise ValueError(f'Can not guess the tokenizer from n_vocab value of {n_vocab}')
|
26 |
+
|
27 |
+
parent: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
|
28 |
+
|
29 |
+
if tokenizer_name == 'world':
|
30 |
+
print('Loading World v20230424 tokenizer')
|
31 |
+
return rwkv_world_tokenizer.get_world_tokenizer_v20230424()
|
32 |
+
elif tokenizer_name == '20B':
|
33 |
+
print('Loading 20B tokenizer')
|
34 |
+
import tokenizers
|
35 |
+
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(parent / '20B_tokenizer.json'))
|
36 |
+
return tokenizer.decode, lambda x: tokenizer.encode(x).ids
|
37 |
+
else:
|
38 |
+
raise ValueError(f'Unknown tokenizer {tokenizer_name}')
|