JoPmt commited on
Commit
a600684
1 Parent(s): 44540ae

Upload 12 files

Browse files
app (10).py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import login, InferenceClient
2
+ import os, gc, time, random, datetime, json, re
3
+ HF_TOKEN=os.getenv('HF_TOKEN')
4
+ SERP_API_KEY=os.getenv('SERP_KEY')
5
+ login(token=HF_TOKEN)
6
+ import gradio as gr
7
+ from transformers import CodeAgent, Tool, ToolCollection, load_tool, ReactCodeAgent, ReactJsonAgent
8
+ from transformers.agents import PythonInterpreterTool
9
+ from langchain.memory import ConversationBufferMemory
10
+ import bs4
11
+ import requests
12
+ from llm_engine import HfEngine
13
+ import datasets
14
+ import spaces
15
+ import tqdm
16
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
17
+ from langchain_community.vectorstores import FAISS
18
+ from langchain.docstore.document import Document
19
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
20
+ from langchain_core.vectorstores import VectorStore
21
+ from transformers.agents.prompts import DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT
22
+ from transformers.agents.default_tools import Tool, PythonInterpreterTool
23
+ from duckduckgo_search import DDGS
24
+ from web_surfer import (SearchInformationTool, NavigationalSearchTool, VisitTool, DownloadTool, PageUpTool, PageDownTool, FinderTool, FindNextTool, ArchiveSearchTool,)
25
+ from mdconvert import MarkdownConverter
26
+ from visual_qa import VisualQATool, VisualQAGPT4Tool
27
+ def search_ducky(query):
28
+ with DDGS() as ddgs:
29
+ results = list(ddgs.text(query, max_results=10))
30
+ content = ''
31
+ if results:
32
+ for result in results:
33
+ content += result['body']
34
+ return content
35
+ knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
36
+ source_docs = [Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base]
37
+ docs_processed = RecursiveCharacterTextSplitter(chunk_size=500).split_documents(source_docs)[:1000]
38
+ embedding_model = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
39
+ vectordb = FAISS.from_documents(documents=docs_processed, embedding=embedding_model)
40
+ all_sources = list(set([doc.metadata["source"] for doc in docs_processed]))
41
+ print(all_sources)
42
+ class RetrieverTool(Tool):
43
+ name = "retriever"
44
+ description = "Retrieves some documents from the knowledge base that have the closest embeddings to the input query."
45
+ inputs = {
46
+ "query": {
47
+ "type": "text",
48
+ "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
49
+ },
50
+ "source": {
51
+ "type": "text",
52
+ "description": ""
53
+ },
54
+ }
55
+ output_type = "text"
56
+
57
+ def __init__(self, vectordb: VectorStore, all_sources: str, **kwargs):
58
+ super().__init__(**kwargs)
59
+ self.vectordb = vectordb
60
+ self.inputs["source"]["description"] = (f"The source of the documents to search, as a str representation of a list. Possible values in the list are: {all_sources}. If this argument is not provided, all sources will be searched.")
61
+
62
+ def forward(self, query: str, source: str = None) -> str:
63
+ assert isinstance(query, str), "Your search query must be a string"
64
+
65
+ if source:
66
+ if isinstance(source, str) and "[" not in str(source): # if the source is not representing a list
67
+ source = [source]
68
+ source = json.loads(str(source).replace("'", '"'))
69
+
70
+ docs = self.vectordb.similarity_search(query, filter=({"source": source} if source else None), k=3)
71
+
72
+ if len(docs) == 0:
73
+ return "No documents found with this filtering. Try removing the source filter."
74
+ return "Retrieved documents:\n\n" + "\n===Document===\n".join([doc.page_content for doc in docs])
75
+ memory = ConversationBufferMemory(memory_key="chat_history")
76
+ llm_engine = HfEngine(model="Jopmt/JoPmt")
77
+ ##gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
78
+ ##prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
79
+ ##tools = [StableDiffusionTool().langchain, ImageCaptioningTool().langchain, StableDiffusionPromptGeneratorTool().langchain, TextToVideoTool().langchain]
80
+ ##tools=[prompt_generator_tool(), image_generation_tool(), PythonInterpreterTool()]
81
+ class SearchTool(Tool):
82
+ name = "ask_search_agent"
83
+ description = "A search agent that will browse the internet to answer a question. Use it to gather informations, not for problem-solving."
84
+
85
+ inputs = {
86
+ "question": {
87
+ "description": "Your question, as a natural language sentence. You are talking to an agent, so provide them with as much context as possible.",
88
+ "type": "text",
89
+ }
90
+ }
91
+ output_type = "text"
92
+
93
+ def forward(self, question: str) -> str:
94
+ return websurfer_agent.run(question)
95
+ tools=[PythonInterpreterTool(),SearchTool(),RetrieverTool(vectordb, all_sources)]
96
+ additional_authorized_imports=['requests', 'bs4', 'os', 'time', 'datetime', 'json', 're']
97
+ WEB_TOOLS = [SearchInformationTool(), NavigationalSearchTool(), VisitTool(), DownloadTool(), PageUpTool(), PageDownTool(), FinderTool(), FindNextTool(), ArchiveSearchTool(),]
98
+ websurfer_agent = ReactJsonAgent(tools=WEB_TOOLS,llm_engine=llm_engine, add_base_tools=True,max_iterations=1)
99
+ reagent = ReactCodeAgent(tools=tools, llm_engine=llm_engine, add_base_tools=True,max_iterations=1,additional_authorized_imports=additional_authorized_imports)
100
+ def plix(inut, progress=gr.Progress(track_tqdm=True)):
101
+ goose=reagent.run(inut)
102
+ return goose
103
+ with gr.Blocks(theme=random.choice([gr.themes.Monochrome(),gr.themes.Base.from_hub("gradio/seafoam"),gr.themes.Base.from_hub("freddyaboulton/dracula_revamped"),gr.themes.Glass(),gr.themes.Base(),]),analytics_enabled=False) as iface:
104
+ out=gr.Textbox(label="🤗Output",lines=5,interactive=False)
105
+ inut=gr.Textbox(label="Prompt")
106
+ btn=gr.Button("GENERATE")
107
+ btn.click(fn=plix,inputs=inut,outputs=out)
108
+ iface.queue(max_size=1,api_open=False)
109
+ iface.launch(max_threads=20,inline=False,show_api=False)
browser (2).py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Shamelessly stolen from Microsoft Autogen team: thanks to them for this great resource!
2
+ # https://github.com/microsoft/autogen/blob/gaia_multiagent_v01_march_1st/autogen/browser_utils.py
3
+ import json
4
+ import os
5
+ import requests
6
+ import re
7
+ import io
8
+ import uuid
9
+ import mimetypes
10
+ import time
11
+ import pathlib
12
+ import pathvalidate
13
+ from urllib.parse import urljoin, urlparse, unquote, parse_qs
14
+ from urllib.request import url2pathname
15
+ from typing import Any, Dict, List, Optional, Union, Tuple
16
+ from mdconvert import MarkdownConverter, UnsupportedFormatException, FileConversionException
17
+ from serpapi import GoogleSearch
18
+ from cookies import COOKIES
19
+ from duckduckgo_search import DDGS
20
+
21
+
22
+ class SimpleTextBrowser:
23
+ """(In preview) An extremely simple text-based web browser comparable to Lynx. Suitable for Agentic use."""
24
+
25
+ def __init__(
26
+ self,
27
+ start_page: Optional[str] = None,
28
+ viewport_size: Optional[int] = 1024 * 8,
29
+ downloads_folder: Optional[Union[str, None]] = None,
30
+ ##serpapi_key: Optional[Union[str, None]] = None,
31
+ request_kwargs: Optional[Union[Dict[str, Any], None]] = None,
32
+ ):
33
+ self.start_page: str = start_page if start_page else "about:blank"
34
+ self.viewport_size = viewport_size # Applies only to the standard uri types
35
+ self.downloads_folder = downloads_folder
36
+ self.history: List[Tuple[str, float]] = list()
37
+ self.page_title: Optional[str] = None
38
+ self.viewport_current_page = 0
39
+ self.viewport_pages: List[Tuple[int, int]] = list()
40
+ self.set_address(self.start_page)
41
+ ##self.serpapi_key = serpapi_key
42
+ self.request_kwargs = request_kwargs
43
+ self.request_kwargs["cookies"] = COOKIES
44
+ self._mdconvert = MarkdownConverter()
45
+ self._page_content: str = ""
46
+
47
+ self._find_on_page_query: Union[str, None] = None
48
+ self._find_on_page_last_result: Union[int, None] = None # Location of the last result
49
+
50
+ @property
51
+ def address(self) -> str:
52
+ """Return the address of the current page."""
53
+ return self.history[-1][0]
54
+
55
+ def set_address(self, uri_or_path: str) -> None:
56
+ # TODO: Handle anchors
57
+ self.history.append((uri_or_path, time.time()))
58
+
59
+ # Handle special URIs
60
+ if uri_or_path == "about:blank":
61
+ self._set_page_content("")
62
+ elif uri_or_path.startswith("google:"):
63
+ self._serpapi_search(uri_or_path[len("google:"):].strip())
64
+ else:
65
+ if (
66
+ not uri_or_path.startswith("http:")
67
+ and not uri_or_path.startswith("https:")
68
+ and not uri_or_path.startswith("file:")
69
+ ):
70
+ if len(self.history) > 1:
71
+ prior_address = self.history[-2][0]
72
+ uri_or_path = urljoin(prior_address, uri_or_path)
73
+ # Update the address with the fully-qualified path
74
+ self.history[-1] = (uri_or_path, self.history[-1][1])
75
+ self._fetch_page(uri_or_path)
76
+
77
+ self.viewport_current_page = 0
78
+ self.find_on_page_query = None
79
+ self.find_on_page_viewport = None
80
+
81
+ @property
82
+ def viewport(self) -> str:
83
+ """Return the content of the current viewport."""
84
+ bounds = self.viewport_pages[self.viewport_current_page]
85
+ return self.page_content[bounds[0] : bounds[1]]
86
+
87
+ @property
88
+ def page_content(self) -> str:
89
+ """Return the full contents of the current page."""
90
+ return self._page_content
91
+
92
+ def _set_page_content(self, content: str) -> None:
93
+ """Sets the text content of the current page."""
94
+ self._page_content = content
95
+ self._split_pages()
96
+ if self.viewport_current_page >= len(self.viewport_pages):
97
+ self.viewport_current_page = len(self.viewport_pages) - 1
98
+
99
+ def page_down(self) -> None:
100
+ self.viewport_current_page = min(self.viewport_current_page + 1, len(self.viewport_pages) - 1)
101
+
102
+ def page_up(self) -> None:
103
+ self.viewport_current_page = max(self.viewport_current_page - 1, 0)
104
+
105
+ def find_on_page(self, query: str) -> Union[str, None]:
106
+ """Searches for the query from the current viewport forward, looping back to the start if necessary."""
107
+
108
+ # Did we get here via a previous find_on_page search with the same query?
109
+ # If so, map to find_next
110
+ if query == self._find_on_page_query and self.viewport_current_page == self._find_on_page_last_result:
111
+ return self.find_next()
112
+
113
+ # Ok it's a new search start from the current viewport
114
+ self._find_on_page_query = query
115
+ viewport_match = self._find_next_viewport(query, self.viewport_current_page)
116
+ if viewport_match is None:
117
+ self._find_on_page_last_result = None
118
+ return None
119
+ else:
120
+ self.viewport_current_page = viewport_match
121
+ self._find_on_page_last_result = viewport_match
122
+ return self.viewport
123
+
124
+ def find_next(self) -> None:
125
+ """Scroll to the next viewport that matches the query"""
126
+
127
+ if self._find_on_page_query is None:
128
+ return None
129
+
130
+ starting_viewport = self._find_on_page_last_result
131
+ if starting_viewport is None:
132
+ starting_viewport = 0
133
+ else:
134
+ starting_viewport += 1
135
+ if starting_viewport >= len(self.viewport_pages):
136
+ starting_viewport = 0
137
+
138
+ viewport_match = self._find_next_viewport(self._find_on_page_query, starting_viewport)
139
+ if viewport_match is None:
140
+ self._find_on_page_last_result = None
141
+ return None
142
+ else:
143
+ self.viewport_current_page = viewport_match
144
+ self._find_on_page_last_result = viewport_match
145
+ return self.viewport
146
+
147
+ def _find_next_viewport(self, query: str, starting_viewport: int) -> Union[int, None]:
148
+ """Search for matches between the starting viewport looping when reaching the end."""
149
+
150
+ if query is None:
151
+ return None
152
+
153
+ # Normalize the query, and convert to a regular expression
154
+ nquery = re.sub(r"\*", "__STAR__", query)
155
+ nquery = " " + (" ".join(re.split(r"\W+", nquery))).strip() + " "
156
+ nquery = nquery.replace(" __STAR__ ", "__STAR__ ") # Merge isolated stars with prior word
157
+ nquery = nquery.replace("__STAR__", ".*").lower()
158
+
159
+ if nquery.strip() == "":
160
+ return None
161
+
162
+ idxs = list()
163
+ idxs.extend(range(starting_viewport, len(self.viewport_pages)))
164
+ idxs.extend(range(0, starting_viewport))
165
+
166
+ for i in idxs:
167
+ bounds = self.viewport_pages[i]
168
+ content = self.page_content[bounds[0] : bounds[1]]
169
+
170
+ # TODO: Remove markdown links and images
171
+ ncontent = " " + (" ".join(re.split(r"\W+", content))).strip().lower() + " "
172
+ if re.search(nquery, ncontent):
173
+ return i
174
+
175
+ return None
176
+
177
+ def visit_page(self, path_or_uri: str) -> str:
178
+ self.set_address(path_or_uri)
179
+ return self.viewport
180
+
181
+ def _split_pages(self) -> None:
182
+ # Do not split search results
183
+ if self.address.startswith("google:"):
184
+ self.viewport_pages = [(0, len(self._page_content))]
185
+ return
186
+
187
+ # Handle empty pages
188
+ if len(self._page_content) == 0:
189
+ self.viewport_pages = [(0, 0)]
190
+ return
191
+
192
+ # Break the viewport into pages
193
+ self.viewport_pages = []
194
+ start_idx = 0
195
+ while start_idx < len(self._page_content):
196
+ end_idx = min(start_idx + self.viewport_size, len(self._page_content)) # type: ignore[operator]
197
+ # Adjust to end on a space
198
+ while end_idx < len(self._page_content) and self._page_content[end_idx - 1] not in [" ", "\t", "\r", "\n"]:
199
+ end_idx += 1
200
+ self.viewport_pages.append((start_idx, end_idx))
201
+ start_idx = end_idx
202
+
203
+
204
+ def _serpapi_search(self, query: str, filter_year: Optional[int] = None) -> None:
205
+ with DDGS() as ddgs:
206
+ results = list(ddgs.text(query, max_results=10))
207
+
208
+ self.page_title = f"{query} - Search"
209
+
210
+ if not results:
211
+ year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
212
+ self._set_page_content(f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter.")
213
+ return
214
+
215
+ web_snippets: List[str] = list()
216
+ for idx, page in enumerate(results, 1):
217
+ snippet = f"\n{page['body']}" if 'body' in page else ""
218
+ redacted_version = f"{idx}. [{page['title']}]({page['href']})\n{self._prev_visit(page['href'])}{snippet}"
219
+ web_snippets.append(redacted_version)
220
+
221
+ content = (f"A search for '{query}' found {len(web_snippets)} results:\n\n## Web Results\n" +
222
+ "\n\n".join(web_snippets))
223
+ self._set_page_content(content)
224
+
225
+
226
+ def _fetch_page(self, url: str) -> None:
227
+ download_path = ""
228
+ try:
229
+ if url.startswith("file://"):
230
+ download_path = os.path.normcase(os.path.normpath(unquote(url[7:])))
231
+ res = self._mdconvert.convert_local(download_path)
232
+ self.page_title = res.title
233
+ self._set_page_content(res.text_content)
234
+ else:
235
+ # Prepare the request parameters
236
+ request_kwargs = self.request_kwargs.copy() if self.request_kwargs is not None else {}
237
+ request_kwargs["stream"] = True
238
+
239
+ # Send a HTTP request to the URL
240
+ response = requests.get(url, **request_kwargs)
241
+ response.raise_for_status()
242
+
243
+ # If the HTTP request was successful
244
+ content_type = response.headers.get("content-type", "")
245
+
246
+ # Text or HTML
247
+ if "text/" in content_type.lower():
248
+ res = self._mdconvert.convert_response(response)
249
+ self.page_title = res.title
250
+ self._set_page_content(res.text_content)
251
+ # A download
252
+ else:
253
+ # Try producing a safe filename
254
+ fname = None
255
+ download_path = None
256
+ try:
257
+ fname = pathvalidate.sanitize_filename(os.path.basename(urlparse(url).path)).strip()
258
+ download_path = os.path.abspath(os.path.join(self.downloads_folder, fname))
259
+
260
+ suffix = 0
261
+ while os.path.exists(download_path) and suffix < 1000:
262
+ suffix += 1
263
+ base, ext = os.path.splitext(fname)
264
+ new_fname = f"{base}__{suffix}{ext}"
265
+ download_path = os.path.abspath(os.path.join(self.downloads_folder, new_fname))
266
+
267
+ except NameError:
268
+ pass
269
+
270
+ # No suitable name, so make one
271
+ if fname is None:
272
+ extension = mimetypes.guess_extension(content_type)
273
+ if extension is None:
274
+ extension = ".download"
275
+ fname = str(uuid.uuid4()) + extension
276
+ download_path = os.path.abspath(os.path.join(self.downloads_folder, fname))
277
+
278
+ # Open a file for writing
279
+ with open(download_path, "wb") as fh:
280
+ for chunk in response.iter_content(chunk_size=512):
281
+ fh.write(chunk)
282
+
283
+ # Render it
284
+ local_uri = pathlib.Path(download_path).as_uri()
285
+ self.set_address(local_uri)
286
+
287
+
288
+ except UnsupportedFormatException as e:
289
+ print(e)
290
+ self.page_title = ("Download complete.",)
291
+ self._set_page_content(f"# Download complete\n\nSaved file to '{download_path}'")
292
+ except FileConversionException as e:
293
+ print(e)
294
+ self.page_title = ("Download complete.",)
295
+ self._set_page_content(f"# Download complete\n\nSaved file to '{download_path}'")
296
+ except FileNotFoundError:
297
+ self.page_title = "Error 404"
298
+ self._set_page_content(f"## Error 404\n\nFile not found: {download_path}")
299
+ except requests.exceptions.RequestException as request_exception:
300
+ try:
301
+ self.page_title = f"Error {response.status_code}"
302
+
303
+ # If the error was rendered in HTML we might as well render it
304
+ content_type = response.headers.get("content-type", "")
305
+ if content_type is not None and "text/html" in content_type.lower():
306
+ res = self._mdconvert.convert(response)
307
+ self.page_title = f"Error {response.status_code}"
308
+ self._set_page_content(f"## Error {response.status_code}\n\n{res.text_content}")
309
+ else:
310
+ text = ""
311
+ for chunk in response.iter_content(chunk_size=512, decode_unicode=True):
312
+ text += chunk
313
+ self.page_title = f"Error {response.status_code}"
314
+ self._set_page_content(f"## Error {response.status_code}\n\n{text}")
315
+ except NameError:
316
+ self.page_title = f"Error"
317
+ self._set_page_content(f"## Error\n\n{str(request_exception)}")
chat_with_bot (2).py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Provides terminal-based chat interface for RWKV model.
2
+ # Usage: python chat_with_bot.py C:\rwkv.cpp-169M.bin
3
+ # Prompts and code adapted from https://github.com/BlinkDL/ChatRWKV/blob/9ca4cdba90efaee25cfec21a0bae72cbd48d8acd/chat.py
4
+
5
+ import os
6
+ import argparse
7
+ import pathlib
8
+ import copy
9
+ import json
10
+ import time
11
+ import sampling
12
+ from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
13
+ from tokenizer_util import add_tokenizer_argument, get_tokenizer
14
+ from typing import List, Dict, Optional
15
+
16
+ # ======================================== Script settings ========================================
17
+
18
+ # English, Chinese, Japanese
19
+ LANGUAGE: str = 'English'
20
+ # QA: Question and Answer prompt to talk to an AI assistant.
21
+ # Chat: chat prompt (need a large model for adequate quality, 7B+).
22
+ PROMPT_TYPE: str = 'QA'
23
+
24
+ MAX_GENERATION_LENGTH: int = 250
25
+
26
+ # Sampling temperature. It could be a good idea to increase temperature when top_p is low.
27
+ TEMPERATURE: float = 0.8
28
+ # For better Q&A accuracy and less diversity, reduce top_p (to 0.5, 0.2, 0.1 etc.)
29
+ TOP_P: float = 0.5
30
+ # Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
31
+ PRESENCE_PENALTY: float = 0.2
32
+ # Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
33
+ FREQUENCY_PENALTY: float = 0.2
34
+
35
+ END_OF_LINE_TOKEN: int = 187
36
+ DOUBLE_END_OF_LINE_TOKEN: int = 535
37
+ END_OF_TEXT_TOKEN: int = 0
38
+
39
+ # =================================================================================================
40
+
41
+ parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
42
+ parser.add_argument('model_path', help='Path to RWKV model in ggml format')
43
+ add_tokenizer_argument(parser)
44
+ args = parser.parse_args()
45
+
46
+ script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
47
+
48
+ with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r', encoding='utf8') as json_file:
49
+ prompt_data = json.load(json_file)
50
+
51
+ user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt']
52
+
53
+ if init_prompt == '':
54
+ raise ValueError('Prompt must not be empty')
55
+
56
+ library = rwkv_cpp_shared_library.load_rwkv_shared_library()
57
+ print(f'System info: {library.rwkv_get_system_info_string()}')
58
+
59
+ print('Loading RWKV model')
60
+ model = rwkv_cpp_model.RWKVModel(library, args.model_path)
61
+
62
+ tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)
63
+
64
+ # =================================================================================================
65
+
66
+ processed_tokens: List[int] = []
67
+ logits: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None
68
+ state: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None
69
+
70
+ def process_tokens(_tokens: List[int], new_line_logit_bias: float = 0.0) -> None:
71
+ global processed_tokens, logits, state
72
+
73
+ logits, state = model.eval_sequence_in_chunks(_tokens, state, state, logits, use_numpy=True)
74
+
75
+ processed_tokens += _tokens
76
+
77
+ logits[END_OF_LINE_TOKEN] += new_line_logit_bias
78
+
79
+ state_by_thread: Dict[str, Dict] = {}
80
+
81
+ def save_thread_state(_thread: str) -> None:
82
+ state_by_thread[_thread] = {
83
+ 'tokens': copy.deepcopy(processed_tokens),
84
+ 'logits': copy.deepcopy(logits),
85
+ 'state': copy.deepcopy(state)
86
+ }
87
+
88
+ def load_thread_state(_thread: str) -> None:
89
+ global processed_tokens, logits, state
90
+
91
+ thread_state = state_by_thread[_thread]
92
+
93
+ processed_tokens = copy.deepcopy(thread_state['tokens'])
94
+ logits = copy.deepcopy(thread_state['logits'])
95
+ state = copy.deepcopy(thread_state['state'])
96
+
97
+ # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end.
98
+ # See https://github.com/BlinkDL/ChatRWKV/pull/110/files
99
+ def split_last_end_of_line(tokens: List[int]) -> List[int]:
100
+ if len(tokens) > 0 and tokens[-1] == DOUBLE_END_OF_LINE_TOKEN:
101
+ tokens = tokens[:-1] + [END_OF_LINE_TOKEN, END_OF_LINE_TOKEN]
102
+
103
+ return tokens
104
+
105
+ # =================================================================================================
106
+
107
+ processing_start: float = time.time()
108
+
109
+ prompt_tokens = tokenizer_encode(init_prompt)
110
+ prompt_token_count = len(prompt_tokens)
111
+ print(f'Processing {prompt_token_count} prompt tokens, may take a while')
112
+
113
+ process_tokens(split_last_end_of_line(prompt_tokens))
114
+
115
+ processing_duration: float = time.time() - processing_start
116
+
117
+ print(f'Processed in {int(processing_duration)} s, {int(processing_duration / prompt_token_count * 1000)} ms per token')
118
+
119
+ save_thread_state('chat_init')
120
+ save_thread_state('chat')
121
+
122
+ print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.')
123
+
124
+ while True:
125
+ # Read user input
126
+ user_input: str = input(f'> {user}{separator} ')
127
+ msg: str = user_input.replace('\\n', '\n').strip()
128
+
129
+ temperature: float = TEMPERATURE
130
+ top_p: float = TOP_P
131
+
132
+ if '-temp=' in msg:
133
+ temperature = float(msg.split('-temp=')[1].split(' ')[0])
134
+
135
+ msg = msg.replace('-temp='+f'{temperature:g}', '')
136
+
137
+ if temperature <= 0.2:
138
+ temperature = 0.2
139
+
140
+ if temperature >= 5:
141
+ temperature = 5
142
+
143
+ if '-top_p=' in msg:
144
+ top_p = float(msg.split('-top_p=')[1].split(' ')[0])
145
+
146
+ msg = msg.replace('-top_p='+f'{top_p:g}', '')
147
+
148
+ if top_p <= 0:
149
+ top_p = 0
150
+
151
+ msg = msg.strip()
152
+
153
+ # + reset --> reset chat
154
+ if msg == '+reset':
155
+ load_thread_state('chat_init')
156
+ save_thread_state('chat')
157
+ print(f'{bot}{separator} Chat reset.\n')
158
+ continue
159
+ elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':
160
+
161
+ # +gen YOUR PROMPT --> free single-round generation with any prompt. Requires Novel model.
162
+ if msg[:5].lower() == '+gen ':
163
+ new = '\n' + msg[5:].strip()
164
+ state = None
165
+ processed_tokens = []
166
+ process_tokens(tokenizer_encode(new))
167
+ save_thread_state('gen_0')
168
+
169
+ # +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
170
+ elif msg[:3].lower() == '+i ':
171
+ new = f'''
172
+ Below is an instruction that describes a task. Write a response that appropriately completes the request.
173
+
174
+ # Instruction:
175
+ {msg[3:].strip()}
176
+
177
+ # Response:
178
+ '''
179
+ state = None
180
+ processed_tokens = []
181
+ process_tokens(tokenizer_encode(new))
182
+ save_thread_state('gen_0')
183
+
184
+ # +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
185
+ elif msg[:4].lower() == '+qq ':
186
+ new = '\nQ: ' + msg[4:].strip() + '\nA:'
187
+ state = None
188
+ processed_tokens = []
189
+ process_tokens(tokenizer_encode(new))
190
+ save_thread_state('gen_0')
191
+
192
+ # +qa YOUR QUESTION --> answer an independent question (regardless of context).
193
+ elif msg[:4].lower() == '+qa ':
194
+ load_thread_state('chat_init')
195
+
196
+ real_msg = msg[4:].strip()
197
+ new = f'{user}{separator} {real_msg}\n\n{bot}{separator}'
198
+
199
+ process_tokens(tokenizer_encode(new))
200
+ save_thread_state('gen_0')
201
+
202
+ # +++ --> continue last free generation (only for +gen / +i)
203
+ elif msg.lower() == '+++':
204
+ try:
205
+ load_thread_state('gen_1')
206
+ save_thread_state('gen_0')
207
+ except Exception as e:
208
+ print(e)
209
+ continue
210
+
211
+ # ++ --> retry last free generation (only for +gen / +i)
212
+ elif msg.lower() == '++':
213
+ try:
214
+ load_thread_state('gen_0')
215
+ except Exception as e:
216
+ print(e)
217
+ continue
218
+ thread = 'gen_1'
219
+
220
+ else:
221
+ # + --> alternate chat reply
222
+ if msg.lower() == '+':
223
+ try:
224
+ load_thread_state('chat_pre')
225
+ except Exception as e:
226
+ print(e)
227
+ continue
228
+ # chat with bot
229
+ else:
230
+ load_thread_state('chat')
231
+ new = f'{user}{separator} {msg}\n\n{bot}{separator}'
232
+ process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999)
233
+ save_thread_state('chat_pre')
234
+
235
+ thread = 'chat'
236
+
237
+ # Print bot response
238
+ print(f'> {bot}{separator}', end='')
239
+
240
+ start_index: int = len(processed_tokens)
241
+ accumulated_tokens: List[int] = []
242
+ token_counts: Dict[int, int] = {}
243
+
244
+ for i in range(MAX_GENERATION_LENGTH):
245
+ for n in token_counts:
246
+ logits[n] -= PRESENCE_PENALTY + token_counts[n] * FREQUENCY_PENALTY
247
+
248
+ token: int = sampling.sample_logits(logits, temperature, top_p)
249
+
250
+ if token == END_OF_TEXT_TOKEN:
251
+ print()
252
+ break
253
+
254
+ if token not in token_counts:
255
+ token_counts[token] = 1
256
+ else:
257
+ token_counts[token] += 1
258
+
259
+ process_tokens([token])
260
+
261
+ # Avoid UTF-8 display issues
262
+ accumulated_tokens += [token]
263
+
264
+ decoded: str = tokenizer_decode(accumulated_tokens)
265
+
266
+ if '\uFFFD' not in decoded:
267
+ print(decoded, end='', flush=True)
268
+
269
+ accumulated_tokens = []
270
+
271
+ if thread == 'chat':
272
+ if '\n\n' in tokenizer_decode(processed_tokens[start_index:]):
273
+ break
274
+
275
+ if i == MAX_GENERATION_LENGTH - 1:
276
+ print()
277
+
278
+ save_thread_state(thread)
cookies (2).py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ COOKIES_LIST = [
2
+ {
3
+ "domain": ".youtube.com",
4
+ "expirationDate": 1718884961,
5
+ "hostOnly": False,
6
+ "httpOnly": False,
7
+ "name": "ST-xuwub9",
8
+ "path": "/",
9
+ "sameSite": None,
10
+ "secure": False,
11
+ "session": False,
12
+ "storeId": None,
13
+ "value": "session_logininfo=AFmmF2swRAIgf4gadACOuWOcipI1anW-dakEjtidNLkufnOC8uml7EECIDh2YisqWELDBJPTGUysCucJ3I0wjXxYjVHro1LHrdW0%3AQUQ3MjNmd2Jiajl3OWZYRnpFNnZlWWV5ZGJWZ0hpcmp4LVVPU280bk4zOS03Z0ozZG9fOFhWZ0dXaVo3NG1wTEg1b3hGaG10TFBlaFBnTlJfbER5bEp0aFhoNS1OLVhYNFRZT2F6ajgzOFpDbGhlUjZpMWRETlFFRjFfTTRiM0RnNTROSkdmMTFMVjFic1VuZ2trbGp4aktDa0JJUC1BWDh3"
14
+ },
15
+ {
16
+ "domain": ".youtube.com",
17
+ "expirationDate": 1753004444.745411,
18
+ "hostOnly": False,
19
+ "httpOnly": True,
20
+ "name": "__Secure-YEC",
21
+ "path": "/",
22
+ "sameSite": "lax",
23
+ "secure": True,
24
+ "session": False,
25
+ "storeId": None,
26
+ "value": "CgtRVnI5LW1zRHlQVSjbtNCzBjIhCgJGUhIbEhcSFRMLFBUWFwwYGRobHB0eHw4PIBAREiAk"
27
+ },
28
+ {
29
+ "domain": ".youtube.com",
30
+ "expirationDate": 1753434620.050824,
31
+ "hostOnly": False,
32
+ "httpOnly": True,
33
+ "name": "__Secure-3PSID",
34
+ "path": "/",
35
+ "sameSite": "no_restriction",
36
+ "secure": True,
37
+ "session": False,
38
+ "storeId": None,
39
+ "value": "g.a000kwibeLUu8Ea9Y-vLun7u3kU5VNJVuMAZl_jdfJaNm50JyDBB4ezJ_bdWu46a7YwObVn44wACgYKAakSARQSFQHGX2MicJcTzecTKH6bHzqU6TMbTxoVAUF8yKqQYK-MoI6Ql3vI2oYTB3E-0076"
40
+ },
41
+ {
42
+ "domain": ".youtube.com",
43
+ "expirationDate": 1750420959.974642,
44
+ "hostOnly": False,
45
+ "httpOnly": False,
46
+ "name": "SIDCC",
47
+ "path": "/",
48
+ "sameSite": None,
49
+ "secure": False,
50
+ "session": False,
51
+ "storeId": None,
52
+ "value": "AKEyXzWQZauHKOo8t87zoEcjaVNIYUX54ohoWXT-tX4aAhEuZzIIptxZAcNkHuG2oDXYL6t-lw"
53
+ },
54
+ {
55
+ "domain": ".youtube.com",
56
+ "expirationDate": 1753434620.050652,
57
+ "hostOnly": False,
58
+ "httpOnly": False,
59
+ "name": "SID",
60
+ "path": "/",
61
+ "sameSite": None,
62
+ "secure": False,
63
+ "session": False,
64
+ "storeId": None,
65
+ "value": "g.a000kwibeLUu8Ea9Y-vLun7u3kU5VNJVuMAZl_jdfJaNm50JyDBB6VHrZcC3gBAsFPbCQ0gF5AACgYKAYkSARQSFQHGX2Mi9kt0gHg5CxCYSkLQGHWaeBoVAUF8yKre_V6r3jZVak6JV4o2Q0FL0076"
66
+ },
67
+ {
68
+ "domain": ".youtube.com",
69
+ "expirationDate": 1750420958.397534,
70
+ "hostOnly": False,
71
+ "httpOnly": True,
72
+ "name": "__Secure-1PSIDTS",
73
+ "path": "/",
74
+ "sameSite": None,
75
+ "secure": True,
76
+ "session": False,
77
+ "storeId": None,
78
+ "value": "sidts-CjIB3EgAEkYL2L-GfrEzW5Dfy62S9oefGNLgst78S_986htCnGcfkxECch_9oz-qytSsZBAA"
79
+ },
80
+ {
81
+ "domain": ".youtube.com",
82
+ "expirationDate": 1753433494.44729,
83
+ "hostOnly": False,
84
+ "httpOnly": False,
85
+ "name": "_ga_M0180HEFCY",
86
+ "path": "/",
87
+ "sameSite": None,
88
+ "secure": False,
89
+ "session": False,
90
+ "storeId": None,
91
+ "value": "GS1.1.1718871908.1.0.1718873494.0.0.0"
92
+ },
93
+ {
94
+ "domain": ".youtube.com",
95
+ "expirationDate": 1753434620.050933,
96
+ "hostOnly": False,
97
+ "httpOnly": False,
98
+ "name": "SAPISID",
99
+ "path": "/",
100
+ "sameSite": None,
101
+ "secure": True,
102
+ "session": False,
103
+ "storeId": None,
104
+ "value": "mfeuiC-HraNJ-A03/ASXvCPNJSw7yTFgd6"
105
+ },
106
+ {
107
+ "domain": ".youtube.com",
108
+ "expirationDate": 1750420959.974764,
109
+ "hostOnly": False,
110
+ "httpOnly": True,
111
+ "name": "__Secure-1PSIDCC",
112
+ "path": "/",
113
+ "sameSite": None,
114
+ "secure": True,
115
+ "session": False,
116
+ "storeId": None,
117
+ "value": "AKEyXzWHDSoXGCZpZhPxRrnC7B1s8zGIUjeMVyvgtQfsm1fs92lXPtFEI_td9LBUyqVUe0xK"
118
+ },
119
+ {
120
+ "domain": ".youtube.com",
121
+ "expirationDate": 1753434620.050881,
122
+ "hostOnly": False,
123
+ "httpOnly": True,
124
+ "name": "SSID",
125
+ "path": "/",
126
+ "sameSite": None,
127
+ "secure": True,
128
+ "session": False,
129
+ "storeId": None,
130
+ "value": "AmlwXHnQvOQ10LVd-"
131
+ },
132
+ {
133
+ "domain": ".youtube.com",
134
+ "expirationDate": 1753434620.050959,
135
+ "hostOnly": False,
136
+ "httpOnly": False,
137
+ "name": "__Secure-1PAPISID",
138
+ "path": "/",
139
+ "sameSite": None,
140
+ "secure": True,
141
+ "session": False,
142
+ "storeId": None,
143
+ "value": "mfeuiC-HraNJ-A03/ASXvCPNJSw7yTFgd6"
144
+ },
145
+ {
146
+ "domain": ".youtube.com",
147
+ "expirationDate": 1753434620.050795,
148
+ "hostOnly": False,
149
+ "httpOnly": True,
150
+ "name": "__Secure-1PSID",
151
+ "path": "/",
152
+ "sameSite": None,
153
+ "secure": True,
154
+ "session": False,
155
+ "storeId": None,
156
+ "value": "g.a000kwibeLUu8Ea9Y-vLun7u3kU5VNJVuMAZl_jdfJaNm50JyDBBrlk7lRpKQGywAHEon7WGQAACgYKAQsSARQSFQHGX2MirAmnSRdZl6GPG6KLd4hOihoVAUF8yKoV17Tcj1a_OenIOkf2wBjO0076"
157
+ },
158
+ {
159
+ "domain": ".youtube.com",
160
+ "expirationDate": 1753434620.050993,
161
+ "hostOnly": False,
162
+ "httpOnly": False,
163
+ "name": "__Secure-3PAPISID",
164
+ "path": "/",
165
+ "sameSite": "no_restriction",
166
+ "secure": True,
167
+ "session": False,
168
+ "storeId": None,
169
+ "value": "mfeuiC-HraNJ-A03/ASXvCPNJSw7yTFgd6"
170
+ },
171
+ {
172
+ "domain": ".youtube.com",
173
+ "expirationDate": 1750420959.974815,
174
+ "hostOnly": False,
175
+ "httpOnly": True,
176
+ "name": "__Secure-3PSIDCC",
177
+ "path": "/",
178
+ "sameSite": "no_restriction",
179
+ "secure": True,
180
+ "session": False,
181
+ "storeId": None,
182
+ "value": "AKEyXzXM5UjKUEXwSHVmRAIo6hGHA4G63adj3EE1VdNriD0f38jZQbsUKiD4LQbA3BValmTFDg"
183
+ },
184
+ {
185
+ "domain": ".youtube.com",
186
+ "expirationDate": 1750420958.397647,
187
+ "hostOnly": False,
188
+ "httpOnly": True,
189
+ "name": "__Secure-3PSIDTS",
190
+ "path": "/",
191
+ "sameSite": "no_restriction",
192
+ "secure": True,
193
+ "session": False,
194
+ "storeId": None,
195
+ "value": "sidts-CjIB3EgAEkYL2L-GfrEzW5Dfy62S9oefGNLgst78S_986htCnGcfkxECch_9oz-qytSsZBAA"
196
+ },
197
+ {
198
+ "domain": ".youtube.com",
199
+ "expirationDate": 1753434620.050908,
200
+ "hostOnly": False,
201
+ "httpOnly": False,
202
+ "name": "APISID",
203
+ "path": "/",
204
+ "sameSite": None,
205
+ "secure": False,
206
+ "session": False,
207
+ "storeId": None,
208
+ "value": "IlQWLPjdNqziwCrV/ANG7Z4x5FF-IBxbZk"
209
+ },
210
+ {
211
+ "domain": ".youtube.com",
212
+ "expirationDate": 1753434620.050855,
213
+ "hostOnly": False,
214
+ "httpOnly": True,
215
+ "name": "HSID",
216
+ "path": "/",
217
+ "sameSite": None,
218
+ "secure": False,
219
+ "session": False,
220
+ "storeId": None,
221
+ "value": "AasA7hmRuTFv7vjoq"
222
+ },
223
+ {
224
+ "domain": ".youtube.com",
225
+ "expirationDate": 1753435873.577793,
226
+ "hostOnly": False,
227
+ "httpOnly": True,
228
+ "name": "LOGIN_INFO",
229
+ "path": "/",
230
+ "sameSite": "no_restriction",
231
+ "secure": True,
232
+ "session": False,
233
+ "storeId": None,
234
+ "value": "AFmmF2swRAIgf4gadACOuWOcipI1anW-dakEjtidNLkufnOC8uml7EECIDh2YisqWELDBJPTGUysCucJ3I0wjXxYjVHro1LHrdW0:QUQ3MjNmd2Jiajl3OWZYRnpFNnZlWWV5ZGJWZ0hpcmp4LVVPU280bk4zOS03Z0ozZG9fOFhWZ0dXaVo3NG1wTEg1b3hGaG10TFBlaFBnTlJfbER5bEp0aFhoNS1OLVhYNFRZT2F6ajgzOFpDbGhlUjZpMWRETlFFRjFfTTRiM0RnNTROSkdmMTFMVjFic1VuZ2trbGp4aktDa0JJUC1BWDh3"
235
+ },
236
+ {
237
+ "domain": ".youtube.com",
238
+ "expirationDate": 1753444956.555608,
239
+ "hostOnly": False,
240
+ "httpOnly": False,
241
+ "name": "PREF",
242
+ "path": "/",
243
+ "sameSite": None,
244
+ "secure": True,
245
+ "session": False,
246
+ "storeId": None,
247
+ "value": "f4=4000000&f6=40000000&tz=Europe.Paris&f5=30000&f7=100"
248
+ }
249
+ ]
250
+
251
+ COOKIES_LIST += [
252
+ {
253
+ "domain": ".www.researchgate.net",
254
+ "hostOnly": False,
255
+ "httpOnly": True,
256
+ "name": "isInstIp",
257
+ "path": "/",
258
+ "sameSite": None,
259
+ "secure": True,
260
+ "session": True,
261
+ "storeId": None,
262
+ "value": "False"
263
+ },
264
+ {
265
+ "domain": ".researchgate.net",
266
+ "expirationDate": 1734423981,
267
+ "hostOnly": False,
268
+ "httpOnly": False,
269
+ "name": "__eoi",
270
+ "path": "/",
271
+ "sameSite": None,
272
+ "secure": False,
273
+ "session": False,
274
+ "storeId": None,
275
+ "value": "ID=c26f752377373146:T=1718871981:RT=1718884914:S=AA-AfjZw-T_OOX2kW2LLaFzXImgc"
276
+ },
277
+ {
278
+ "domain": ".www.researchgate.net",
279
+ "expirationDate": 1753444909.646103,
280
+ "hostOnly": False,
281
+ "httpOnly": True,
282
+ "name": "ptc",
283
+ "path": "/",
284
+ "sameSite": None,
285
+ "secure": True,
286
+ "session": False,
287
+ "storeId": None,
288
+ "value": "RG1.8947708639250500550.1718872043"
289
+ },
290
+ {
291
+ "domain": ".researchgate.net",
292
+ "expirationDate": 1750507578,
293
+ "hostOnly": False,
294
+ "httpOnly": False,
295
+ "name": "euconsent-v2-didomi",
296
+ "path": "/",
297
+ "sameSite": "lax",
298
+ "secure": True,
299
+ "session": False,
300
+ "storeId": None,
301
+ "value": "CQAgmoAQAgmoAAHABBENA5EsAP_gAEPgAAYgJ2pB5G5UTWlBIG53YMskIAUFhFBoQEAgAACAAwIBSBIAIIwEAGAAIAgAICACAAIAIBIAIABAGAAAAAAAYIAAIAAIAAAQIAAKIAAAAAAAAgBQAAgIAgggEAAAgEBEABAAgAAAEIIAQNgACgAAACCAAAAAAAABAAAAAAAAQAAAAAAAYCQAAAJIAAAAACAIABAIAAAAAAAAAAAAAAAABBAAIJ2wPIAFAAXABQAFQALgAcAA8ACAAEgALwAZAA0ACIAEcAJgAUgAqgBcADEAGgAPQAfgBEACOAE4AMMAZYA0QBsgDkAHOAO4AfsBBwEIAItARwBHQC6gHUAO2Ae0A_4CHQEXgJ2AUOAo8BT4CpQFqALYAXmAwQBkgDLAGXANjAhCBG8CbAE3gJ1gTtAA.f_wACHwAAAAA"
302
+ },
303
+ {
304
+ "domain": ".researchgate.net",
305
+ "expirationDate": 1718885236,
306
+ "hostOnly": False,
307
+ "httpOnly": False,
308
+ "name": "_gat",
309
+ "path": "/",
310
+ "sameSite": None,
311
+ "secure": False,
312
+ "session": False,
313
+ "storeId": None,
314
+ "value": "1"
315
+ },
316
+ {
317
+ "domain": "www.researchgate.net",
318
+ "expirationDate": 1721477183,
319
+ "hostOnly": True,
320
+ "httpOnly": False,
321
+ "name": "_pbjs_userid_consent_data",
322
+ "path": "/",
323
+ "sameSite": "lax",
324
+ "secure": False,
325
+ "session": False,
326
+ "storeId": None,
327
+ "value": "3524755945110770"
328
+ },
329
+ {
330
+ "domain": ".researchgate.net",
331
+ "expirationDate": 1752567981,
332
+ "hostOnly": False,
333
+ "httpOnly": False,
334
+ "name": "__gads",
335
+ "path": "/",
336
+ "sameSite": None,
337
+ "secure": False,
338
+ "session": False,
339
+ "storeId": None,
340
+ "value": "ID=eca2adb88969c830:T=1718871981:RT=1718884914:S=ALNI_MY2qZchynrhWX6hWMlaI87Pcj9riQ"
341
+ },
342
+ {
343
+ "domain": ".researchgate.net",
344
+ "expirationDate": 1718886709.646173,
345
+ "hostOnly": False,
346
+ "httpOnly": True,
347
+ "name": "__cf_bm",
348
+ "path": "/",
349
+ "sameSite": "no_restriction",
350
+ "secure": True,
351
+ "session": False,
352
+ "storeId": None,
353
+ "value": "IkQ_J4ciBzKQduRvjqsfSmQu8UygDWbHeROO5JVccfo-1718884909-1.0.1.1-qvNGEdbfI0HfhFP6kwe7R7mkTqODNhFuKhs72lLly6K2BOPMG3kbahpQFGvPK0U8FUfkznkq65gngd1sWj7sDA"
354
+ },
355
+ {
356
+ "domain": ".researchgate.net",
357
+ "expirationDate": 1752567981,
358
+ "hostOnly": False,
359
+ "httpOnly": False,
360
+ "name": "__gpi",
361
+ "path": "/",
362
+ "sameSite": None,
363
+ "secure": False,
364
+ "session": False,
365
+ "storeId": None,
366
+ "value": "UID=00000e4e9aa2e6f2:T=1718871981:RT=1718884914:S=ALNI_MYFNrgzkKn7K6Bd2y8hC6GJCvDiSg"
367
+ },
368
+ {
369
+ "domain": ".researchgate.net",
370
+ "hostOnly": False,
371
+ "httpOnly": True,
372
+ "name": "_cfuvid",
373
+ "path": "/",
374
+ "sameSite": "no_restriction",
375
+ "secure": True,
376
+ "session": True,
377
+ "storeId": None,
378
+ "value": "_GPmGZkBymiH3UiqTqzakEpi98br3nfFUWC2_u_wqkc-1718884909785-0.0.1.1-604800000"
379
+ },
380
+ {
381
+ "domain": ".researchgate.net",
382
+ "expirationDate": 1753445177.271667,
383
+ "hostOnly": False,
384
+ "httpOnly": False,
385
+ "name": "_ga",
386
+ "path": "/",
387
+ "sameSite": None,
388
+ "secure": False,
389
+ "session": False,
390
+ "storeId": None,
391
+ "value": "GA1.1.1525244793.1718885177"
392
+ },
393
+ {
394
+ "domain": ".researchgate.net",
395
+ "expirationDate": 1753445177.271482,
396
+ "hostOnly": False,
397
+ "httpOnly": False,
398
+ "name": "_ga_4P31SJ70EJ",
399
+ "path": "/",
400
+ "sameSite": None,
401
+ "secure": False,
402
+ "session": False,
403
+ "storeId": None,
404
+ "value": "GS1.1.1718885177.1.0.1718885177.0.0.0"
405
+ },
406
+ {
407
+ "domain": ".researchgate.net",
408
+ "expirationDate": 1718971576,
409
+ "hostOnly": False,
410
+ "httpOnly": False,
411
+ "name": "_gid",
412
+ "path": "/",
413
+ "sameSite": None,
414
+ "secure": False,
415
+ "session": False,
416
+ "storeId": None,
417
+ "value": "GA1.2.854907463.1718885177"
418
+ },
419
+ {
420
+ "domain": ".www.researchgate.net",
421
+ "expirationDate": 1750407982.506505,
422
+ "hostOnly": False,
423
+ "httpOnly": True,
424
+ "name": "did",
425
+ "path": "/",
426
+ "sameSite": None,
427
+ "secure": True,
428
+ "session": False,
429
+ "storeId": None,
430
+ "value": "1dWLO3C6am8l667Q4VUlBo0O1LI49Qi2Vw21SJEXHavBDYT56DI9007W5rYGVFVH"
431
+ },
432
+ {
433
+ "domain": ".researchgate.net",
434
+ "expirationDate": 1750507578,
435
+ "hostOnly": False,
436
+ "httpOnly": False,
437
+ "name": "didomi_token",
438
+ "path": "/",
439
+ "sameSite": "lax",
440
+ "secure": True,
441
+ "session": False,
442
+ "storeId": None,
443
+ "value": "eyJ1c2VyX2lkIjoiMTkwMzU4YTUtNWU2My02Y2UzLWJlNzAtZGFjNzVmYjdiY2ExIiwiY3JlYXRlZCI6IjIwMjQtMDYtMjBUMTI6MDY6MTYuODA2WiIsInVwZGF0ZWQiOiIyMDI0LTA2LTIwVDEyOjA2OjE4Ljc4MVoiLCJ2ZW5kb3JzIjp7ImVuYWJsZWQiOlsidHdpdHRlciIsImdvb2dsZSIsImM6bGlua2VkaW4tbWFya2V0aW5nLXNvbHV0aW9ucyIsImM6b3duZXJpcSIsImM6b21uaXR1cmUtYWRvYmUtYW5hbHl0aWNzIiwiYzp0ZWNobm9yYXRpLW1lZGlhIiwiYzppbnRlcmNvbSIsImM6aW50ZW50LWlxIiwiYzppcHJvbSIsImM6bGlua2VkaW4iLCJjOmFtYXpvbmFkdi16Y1hGTEI2WCIsImM6bWVkaWFuZXQtY1V3YUtFNnoiLCJjOmluZGV4ZXhjaC1OWkNRTTY4UCIsImM6emVvdGFwZ21iLWQ3YndtdGp3IiwiYzp0cmlwbGVsaWYtZGRKSDM0clkiLCJjOnJ0YmhvdXNlLWI4Y2RIOHRNIiwiYzptZHByaW1pcy1lYU4yOVdjUCIsImM6bG9vcG1lbGktVGRhWXRCUHEiLCJjOm1hZ25pdGVpbi05d1RZTHFSRCIsImM6Ymlkc3dpdGNoLWQ2N0V3N1c5IiwiYzpvcmFjbGVhZHYtcUhlREptQUwiLCJjOmdvb2dsZWFuYS00VFhuSmlnUiIsImM6bG90YW1lc29sLURIaTdMUmpNIiwiYzpuZXh0bWlsbGUtR0pyZlg4VWMiLCJjOm5yaWNodGVjLXFVVlEyUlFxIiwiYzpicml0ZXBvb2wtQldWeVdHeVUiLCJjOnRhcGFkaW5jLXFxY2tVN1BXIiwiYzppZDV0ZWNobi16Tk1KNGR3ZiIsImM6bWljcm9zb2Z0IiwiYzpwZXJtdXRpdmUtSjdpaHJlTWsiLCJjOm9wZXJhc29mdC1CY1hjRFZKTSIsImM6cG9zdGhvZy1Cakp4RmRGOSJdfSwicHVycG9zZXMiOnsiZW5hYmxlZCI6WyJnZW9sb2NhdGlvbl9kYXRhIiwiZGV2aWNlX2NoYXJhY3RlcmlzdGljcyJdfSwidmVuZG9yc19saSI6eyJlbmFibGVkIjpbImdvb2dsZSIsImM6b3BlcmFzb2Z0LUJjWGNEVkpNIl19LCJ2ZXJzaW9uIjoyLCJhYyI6IkRIU0FvQUZrQWNnQTVnSHFnUUhBeGdCNndEMTRJR0FRTkFqMEJJd0NTY0VyQUtCd1YtZ3MxQmgwREc0R09nQUEuREhTQW9BRmtBY2dBNWdIcWdRSEF4Z0I2d0QxNElHQVFOQWowQkl3Q1NjRXJBS0J3Vi1nczFCaDBERzRHT2dBQSJ9"
444
+ },
445
+ {
446
+ "domain": ".www.researchgate.net",
447
+ "hostOnly": False,
448
+ "httpOnly": True,
449
+ "name": "hasPdpNext",
450
+ "path": "/",
451
+ "sameSite": None,
452
+ "secure": True,
453
+ "session": True,
454
+ "storeId": None,
455
+ "value": "False"
456
+ },
457
+ {
458
+ "domain": ".researchgate.net",
459
+ "expirationDate": 1750421183,
460
+ "hostOnly": False,
461
+ "httpOnly": False,
462
+ "name": "ph_phc_ma1XTQyee96N1GML6qUTgLQRiDifnRcE9STiHTZ0CfZ_posthog",
463
+ "path": "/",
464
+ "sameSite": "lax",
465
+ "secure": True,
466
+ "session": False,
467
+ "storeId": None,
468
+ "value": "%7B%22distinct_id%22%3A%220190358a-56a1-7313-83b0-d13dddeac787%22%2C%22%24sesid%22%3A%5B1718885183223%2C%220190358a-56a1-7313-83b0-d13b2b87778d%22%2C1718885176993%5D%2C%22%24session_is_sampled%22%3Atrue%7D"
469
+ },
470
+ {
471
+ "domain": ".www.researchgate.net",
472
+ "hostOnly": False,
473
+ "httpOnly": True,
474
+ "name": "sid",
475
+ "path": "/",
476
+ "sameSite": None,
477
+ "secure": True,
478
+ "session": True,
479
+ "storeId": None,
480
+ "value": "qmH5Lc4f0CUJ3zeaxORcV0S8I8V1MuCFZtcIQqPYtv1XPejrbSLAQRbT50PL40TqeKQ1XsQDWt9gtYVzuL80bRmPjw6jn3cQ0ikNqW40maHcQ3JL2Vfa8ZZf0j7p35eJ"
481
+ }
482
+ ]
483
+
484
+ COOKIES_LIST += [
485
+ {
486
+ "domain": "github.com",
487
+ "hostOnly": True,
488
+ "httpOnly": True,
489
+ "name": "_gh_sess",
490
+ "path": "/",
491
+ "sameSite": "lax",
492
+ "secure": True,
493
+ "session": True,
494
+ "storeId": None,
495
+ "value": "P%2Fmof1avuqwHaUQUIJR%2FZYn7jqbT7lgGuTGjp1BGAFIG5UpNDusEE3b8dRjz0eATE5xPdPjLYFqMs%2FI9AOalKX4YuYfSEEnxCMawU01099b4o9Xzzcv%2BmecrmO0Q8q%2Bdq1h8SIv6nvPP7HzlFesl8ysafb9b%2F0q6dTArKdSOurasza8UgLSYD08ofA50Pcm0IG7CTzF8ZCizrGgGTMi%2F%2B7L3E17jav5PM1Sf2vQKg15Gbg1QIOppJJHzlufgQoZigqFv%2BWznaws0Tt7Y2lSFCw%3D%3D--CJRhqMXJnwOaJgk4--DhUErlL4GdROikEjKD4O9g%3D%3D"
496
+ },
497
+ {
498
+ "domain": ".github.com",
499
+ "expirationDate": 1750408875.763785,
500
+ "hostOnly": False,
501
+ "httpOnly": False,
502
+ "name": "_octo",
503
+ "path": "/",
504
+ "sameSite": "lax",
505
+ "secure": True,
506
+ "session": False,
507
+ "storeId": None,
508
+ "value": "GH1.1.728652011.1718872875"
509
+ },
510
+ {
511
+ "domain": ".github.com",
512
+ "expirationDate": 1750408875.763926,
513
+ "hostOnly": False,
514
+ "httpOnly": True,
515
+ "name": "logged_in",
516
+ "path": "/",
517
+ "sameSite": "lax",
518
+ "secure": True,
519
+ "session": False,
520
+ "storeId": None,
521
+ "value": "no"
522
+ },
523
+ {
524
+ "domain": ".github.com",
525
+ "hostOnly": False,
526
+ "httpOnly": False,
527
+ "name": "preferred_color_mode",
528
+ "path": "/",
529
+ "sameSite": "lax",
530
+ "secure": True,
531
+ "session": True,
532
+ "storeId": None,
533
+ "value": "dark"
534
+ },
535
+ {
536
+ "domain": ".github.com",
537
+ "hostOnly": False,
538
+ "httpOnly": False,
539
+ "name": "tz",
540
+ "path": "/",
541
+ "sameSite": "lax",
542
+ "secure": True,
543
+ "session": True,
544
+ "storeId": None,
545
+ "value": "Europe%2FParis"
546
+ }
547
+ ]
548
+
549
+ COOKIES_LIST += [
550
+ {
551
+ "domain": ".web.archive.org",
552
+ "expirationDate": 1718886430,
553
+ "hostOnly": False,
554
+ "httpOnly": False,
555
+ "name": "_gat",
556
+ "path": "/web/20201123221659/http://orcid.org/",
557
+ "sameSite": None,
558
+ "secure": False,
559
+ "session": False,
560
+ "storeId": None,
561
+ "value": "1"
562
+ },
563
+ {
564
+ "domain": ".web.archive.org",
565
+ "expirationDate": 1718972770,
566
+ "hostOnly": False,
567
+ "httpOnly": False,
568
+ "name": "_gid",
569
+ "path": "/web/20201123221659/http://orcid.org/",
570
+ "sameSite": None,
571
+ "secure": False,
572
+ "session": False,
573
+ "storeId": None,
574
+ "value": "GA1.2.402246368.1606169825"
575
+ },
576
+ {
577
+ "domain": ".web.archive.org",
578
+ "expirationDate": 1753446370.315621,
579
+ "hostOnly": False,
580
+ "httpOnly": False,
581
+ "name": "_ga",
582
+ "path": "/web/20201123221659/http://orcid.org/",
583
+ "sameSite": None,
584
+ "secure": False,
585
+ "session": False,
586
+ "storeId": None,
587
+ "value": "GA1.2.1301409987.1606169825"
588
+ },
589
+ {
590
+ "domain": ".web.archive.org",
591
+ "expirationDate": 1750422367,
592
+ "hostOnly": False,
593
+ "httpOnly": False,
594
+ "name": "_hjid",
595
+ "path": "/web/20201123221659/http://orcid.org/",
596
+ "sameSite": "lax",
597
+ "secure": False,
598
+ "session": False,
599
+ "storeId": None,
600
+ "value": "07f80263-a631-4bf4-8ffd-8fc8912085e2"
601
+ },
602
+ {
603
+ "domain": ".web.archive.org",
604
+ "expirationDate": 1718888167,
605
+ "hostOnly": False,
606
+ "httpOnly": False,
607
+ "name": "_hjFirstSeen",
608
+ "path": "/web/20201123221659/http://orcid.org/",
609
+ "sameSite": "lax",
610
+ "secure": False,
611
+ "session": False,
612
+ "storeId": None,
613
+ "value": "1"
614
+ }
615
+ ]
616
+ COOKIES_LIST += [
617
+ {
618
+ "domain": "orcid.org",
619
+ "hostOnly": True,
620
+ "httpOnly": False,
621
+ "name": "AWSELBCORS",
622
+ "path": "/",
623
+ "sameSite": "no_restriction",
624
+ "secure": True,
625
+ "session": True,
626
+ "storeId": None,
627
+ "value": "CBD1D7FF1216388FA48838CBCA4774FD22800B8FB548A40EF92BB0994D5B77A8410307CDEAA69C52236663F2BF89B252C17BC0FCDF790FD59771BDDF6EA8CA4CFD29D8733F"
628
+ },
629
+ {
630
+ "domain": ".orcid.org",
631
+ "expirationDate": 1753452454.637671,
632
+ "hostOnly": False,
633
+ "httpOnly": False,
634
+ "name": "_ga_9R61FWK9H5",
635
+ "path": "/",
636
+ "sameSite": None,
637
+ "secure": False,
638
+ "session": False,
639
+ "storeId": None,
640
+ "value": "GS1.1.1718892454.1.0.1718892454.0.0.0"
641
+ },
642
+ {
643
+ "domain": ".orcid.org",
644
+ "expirationDate": 1753452454.63421,
645
+ "hostOnly": False,
646
+ "httpOnly": False,
647
+ "name": "_ga",
648
+ "path": "/",
649
+ "sameSite": None,
650
+ "secure": False,
651
+ "session": False,
652
+ "storeId": None,
653
+ "value": "GA1.1.2021310691.1718892455"
654
+ },
655
+ {
656
+ "domain": "orcid.org",
657
+ "hostOnly": True,
658
+ "httpOnly": False,
659
+ "name": "AWSELB",
660
+ "path": "/",
661
+ "sameSite": None,
662
+ "secure": False,
663
+ "session": True,
664
+ "storeId": None,
665
+ "value": "CBD1D7FF1216388FA48838CBCA4774FD22800B8FB548A40EF92BB0994D5B77A8410307CDEAA69C52236663F2BF89B252C17BC0FCDF790FD59771BDDF6EA8CA4CFD29D8733F"
666
+ },
667
+ {
668
+ "domain": ".orcid.org",
669
+ "expirationDate": 1750428454,
670
+ "hostOnly": False,
671
+ "httpOnly": False,
672
+ "name": "OptanonAlertBoxClosed",
673
+ "path": "/",
674
+ "sameSite": "lax",
675
+ "secure": False,
676
+ "session": False,
677
+ "storeId": None,
678
+ "value": "2024-06-20T14:07:34.583Z"
679
+ },
680
+ {
681
+ "domain": ".orcid.org",
682
+ "expirationDate": 1750428454,
683
+ "hostOnly": False,
684
+ "httpOnly": False,
685
+ "name": "OptanonConsent",
686
+ "path": "/",
687
+ "sameSite": "lax",
688
+ "secure": False,
689
+ "session": False,
690
+ "storeId": None,
691
+ "value": "isGpcEnabled=0&datestamp=Thu+Jun+20+2024+16%3A07%3A34+GMT%2B0200+(heure+d%E2%80%99%C3%A9t%C3%A9+d%E2%80%99Europe+centrale)&version=202310.2.0&browserGpcFlag=0&isIABGlobal=False&hosts=&landingPath=NotLandingPage&groups=C0001%3A1%2CC0003%3A1%2CC0002%3A1%2CC0004%3A1"
692
+ },
693
+ {
694
+ "domain": "orcid.org",
695
+ "hostOnly": True,
696
+ "httpOnly": False,
697
+ "name": "XSRF-TOKEN",
698
+ "path": "/",
699
+ "sameSite": None,
700
+ "secure": True,
701
+ "session": True,
702
+ "storeId": None,
703
+ "value": "6957be7a-bcb4-4d59-a522-ea9b6b210ed9"
704
+ }
705
+ ]
706
+ from requests.cookies import RequestsCookieJar
707
+
708
+ # Create a RequestsCookieJar instance
709
+ COOKIES = RequestsCookieJar()
710
+
711
+ # Add cookies to the jar
712
+ for cookie in COOKIES_LIST:
713
+ COOKIES.set(cookie['name'], cookie['value'], domain=cookie['domain'], path=cookie['path'])
llm_engine (2).py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import os, time, json, re, gc, subprocess
18
+ import gradio as gr
19
+ import torch
20
+ import numpy as np
21
+ import argparse
22
+ import time
23
+ import sampling
24
+ import copy
25
+ from datetime import datetime
26
+ from huggingface_hub import hf_hub_download
27
+ from pynvml import *
28
+ from tokenizer_util import add_tokenizer_argument, get_tokenizer
29
+ import rwkv_world_tokenizer
30
+ from huggingface_hub import snapshot_download, hf_hub_download
31
+ hf_hub_download(repo_id="JoPmt/RWKV-5-3B-V2-Quant", filename="rwkv-5-world-3b-v2-20231118-ctx16k.Q4_0.bin", local_dir='~/app/Downloads')
32
+ model_path='~/app/Downloads/rwkv-5-world-3b-v2-20231118-ctx16k.Q4_0.bin'
33
+ from copy import deepcopy
34
+ from enum import Enum
35
+ from typing import Dict, List
36
+ from huggingface_hub import InferenceClient
37
+ from transformers.agents import PythonInterpreterTool
38
+ from transformers import AutoTokenizer
39
+ tokenizer=AutoTokenizer.from_pretrained("NousResearch/Hermes-2-Pro-Llama-3-8B",revision="pr/13")
40
+ tools=[PythonInterpreterTool()]
41
+ os.system("apt-get update && apt-get install cmake gcc g++")
42
+ os.system("git clone --recursive https://github.com/JoPmt/rwkv.cpp.git && cd rwkv.cpp && mkdir build && cd build && cmake .. -DRWKV_CUBLAS=ON -DRWKV_BUILD_SHARED_LIBRARY=ON -DGGML_CUDA=ON -DRWKV_BUILD_PYTHON_MODULE=ON -DRWKV_BUILD_TOOLS=ON -DRWKV_BUILD_EXTRAS=ON && cmake --build . --config Release && make RWKV_CUBLAS=1 GGML_CUDA=1")
43
+ import rwkv_cpp_model
44
+ import rwkv_cpp_shared_library
45
+
46
+ def find_lib():
47
+ for root, dirs, files in os.walk("/"):
48
+ for file in files:
49
+ if file == "librwkv.so":
50
+ return os.path.join(root, file)
51
+ return None
52
+ library_path = find_lib()
53
+ rwkv_lib = rwkv_cpp_shared_library.RWKVSharedLibrary(library_path)
54
+ modal = rwkv_cpp_model.RWKVModel(rwkv_lib,model_path,thread_count=2)
55
+ print('Loading RWKV model')
56
+ tokenizer_decode, tokenizer_encode = get_tokenizer('auto', modal.n_vocab)
57
+ out_str = ''
58
+ prompt = out_str
59
+ token_count = 1200
60
+ temperature = 1.0
61
+ top_p = 0.7
62
+ presence_penalty = 0.1
63
+ count_penalty = 0.4
64
+ def generate_prompt(instruction, zput=""):
65
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
66
+ zput = zput.strip().replace('\r\n','\n').replace('\n\n','\n')
67
+ if zput:
68
+ return f"""Instruction: {instruction}
69
+ Input: {zput}
70
+ Response:"""
71
+ else:
72
+ return f"""User: hi
73
+ Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
74
+ User: {instruction}
75
+ Assistant:"""
76
+ class MessageRole(str, Enum):
77
+ USER = "user"
78
+ ASSISTANT = "assistant"
79
+ SYSTEM = "system"
80
+ TOOL_CALL = "tool-call"
81
+ TOOL_RESPONSE = "tool-response"
82
+ @classmethod
83
+ def roles(cls):
84
+ return [r.value for r in cls]
85
+ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
86
+ """
87
+ Subsequent messages with the same role will be concatenated to a single message.
88
+
89
+ Args:
90
+ message_list (`List[Dict[str, str]]`): List of chat messages.
91
+ """
92
+ final_message_list = []
93
+ message_list = deepcopy(message_list) # Avoid modifying the original list
94
+ for message in message_list:
95
+ if not set(message.keys()) == {"role", "content"}:
96
+ raise ValueError("Message should contain only 'role' and 'content' keys!")
97
+
98
+ role = message["role"]
99
+ if role not in MessageRole.roles():
100
+ raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
101
+
102
+ if role in role_conversions:
103
+ message["role"] = role_conversions[role]
104
+
105
+ if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
106
+ final_message_list[-1]["content"] = "\n=======\n" + message["content"]
107
+ else:
108
+ final_message_list.append(message)
109
+ return final_message_list
110
+ llama_role_conversions = {
111
+ MessageRole.TOOL_RESPONSE: MessageRole.USER,
112
+ MessageRole.TOOL_CALL: MessageRole.USER,
113
+ }
114
+ class HfEngine:
115
+ def __init__(self, model: str = "JoPmt/JoPmt"):
116
+ self.model = model
117
+ self.client = modal
118
+ def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str:
119
+ messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
120
+ print(messages)
121
+ pret=''
122
+ prut=''
123
+ for message in messages:
124
+ print(message['content'])
125
+ if message['role'].lower() == 'system':
126
+ pret+=''+message['content']+''
127
+ if message['role'].lower() == 'user':
128
+ prut+=''+message['content']+''
129
+ ##prompt = ins.format(question=''+pret+''+prut+'', system=pret)
130
+ prompt=tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True,)
131
+ print(prompt)
132
+ token_count=1200
133
+ temperature=1.0
134
+ top_p=0.7
135
+ presencePenalty = 0.1
136
+ countPenalty = 0.4
137
+ token_ban=[]
138
+ stop_token=[0]
139
+ ctx=pret
140
+ prompt=prut
141
+ all_tokens = []
142
+ out_last = 0
143
+ out_str = ''
144
+ occurrence = {}
145
+ state = None
146
+ ctx=generate_prompt(ctx,prompt)
147
+ prompt_tokens = tokenizer_encode(ctx)
148
+ prompt_token_count = len(prompt_tokens)
149
+ init_logits, init_state = modal.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True)
150
+ logits, state = init_logits.copy(), init_state.copy()
151
+ out_str = ''
152
+ occurrence = {}
153
+ bof=[]
154
+ for i in range(token_count):
155
+ for n in occurrence:
156
+ logits[n] -= (presencePenalty + occurrence[n] * countPenalty)
157
+ token = sampling.sample_logits(logits, temperature, top_p)
158
+
159
+ if token in stop_token:
160
+ break
161
+ all_tokens += [token]
162
+
163
+ for xxx in occurrence:
164
+ occurrence[xxx] *= 0.996
165
+
166
+ if token not in occurrence:
167
+ occurrence[token] = 1
168
+ else:
169
+ occurrence[token] += 1
170
+
171
+ tmp = tokenizer_decode(all_tokens[out_last:])
172
+ if '\ufffd' not in tmp:
173
+ out_str += tmp
174
+ out_last = i + 1
175
+ ##yield out_str.strip()
176
+ logits, state = modal.eval(token, state, state, logits, use_numpy=True)
177
+ del state
178
+ gc.collect()
179
+ return out_str.strip()
mdconvert (2).py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: E722
2
+ import json
3
+ import os
4
+ import requests
5
+ import re
6
+ import markdownify
7
+ import mimetypes
8
+ import html
9
+ import puremagic
10
+ import tempfile
11
+ import copy
12
+ import mammoth
13
+ import pptx
14
+ import pandas as pd
15
+ import traceback
16
+
17
+ from urllib.parse import urlparse, parse_qs
18
+ from bs4 import BeautifulSoup
19
+ from typing import Any, Dict, List, Optional, Union
20
+ import pdfminer
21
+ import pdfminer.high_level
22
+ from youtube_transcript_api import YouTubeTranscriptApi
23
+
24
+
25
+
26
+ class DocumentConverterResult:
27
+ """The result of converting a document to text."""
28
+
29
+ def __init__(self, title: Union[str, None] = None, text_content: str = ""):
30
+ self.title = title
31
+ self.text_content = text_content
32
+
33
+
34
+ class DocumentConverter:
35
+ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
36
+ raise NotImplementedError()
37
+
38
+
39
+ class PlainTextConverter(DocumentConverter):
40
+ """Anything with content type text/plain"""
41
+
42
+ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
43
+ extension = kwargs.get("file_extension", "")
44
+ if extension == "":
45
+ return None
46
+
47
+ content_type, encoding = mimetypes.guess_type("__placeholder" + extension)
48
+
49
+ text_content = ""
50
+ with open(local_path, "rt") as fh:
51
+ text_content = fh.read()
52
+
53
+ return DocumentConverterResult(
54
+ title=None,
55
+ text_content=text_content,
56
+ )
57
+
58
+
59
+ class HtmlConverter(DocumentConverter):
60
+ """Anything with content type text/html"""
61
+
62
+ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
63
+ # Bail if not html
64
+ extension = kwargs.get("file_extension", "")
65
+ if extension.lower() not in [".html", ".htm"]:
66
+ return None
67
+
68
+ result = None
69
+ with open(local_path, "rt") as fh:
70
+ result = self._convert(fh.read())
71
+
72
+ return result
73
+
74
+ def _convert(self, html_content) -> Union[None, DocumentConverterResult]:
75
+ """Helper function that converts and HTML string."""
76
+
77
+ # Parse the string
78
+ soup = BeautifulSoup(html_content, "html.parser")
79
+
80
+ # Remove javascript and style blocks
81
+ for script in soup(["script", "style"]):
82
+ script.extract()
83
+
84
+ # Print only the main content
85
+ body_elm = soup.find("body")
86
+ webpage_text = ""
87
+ if body_elm:
88
+ webpage_text = markdownify.MarkdownConverter().convert_soup(body_elm)
89
+ else:
90
+ webpage_text = markdownify.MarkdownConverter().convert_soup(soup)
91
+
92
+ return DocumentConverterResult(
93
+ title=None if soup.title is None else soup.title.string,
94
+ text_content=webpage_text,
95
+ )
96
+
97
+
98
+ class WikipediaConverter(DocumentConverter):
99
+ """Handle Wikipedia pages separately, focusing only on the main document content."""
100
+
101
+ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
102
+ # Bail if not Wikipedia
103
+ extension = kwargs.get("file_extension", "")
104
+ if extension.lower() not in [".html", ".htm"]:
105
+ return None
106
+ url = kwargs.get("url", "")
107
+ if not re.search(r"^https?:\/\/[a-zA-Z]{2,3}\.wikipedia.org\/", url):
108
+ return None
109
+
110
+ # Parse the file
111
+ soup = None
112
+ with open(local_path, "rt") as fh:
113
+ soup = BeautifulSoup(fh.read(), "html.parser")
114
+
115
+ # Remove javascript and style blocks
116
+ for script in soup(["script", "style"]):
117
+ script.extract()
118
+
119
+ # Print only the main content
120
+ body_elm = soup.find("div", {"id": "mw-content-text"})
121
+ title_elm = soup.find("span", {"class": "mw-page-title-main"})
122
+
123
+ webpage_text = ""
124
+ if body_elm:
125
+ # What's the title
126
+ main_title = soup.title.string
127
+ if title_elm and len(title_elm) > 0:
128
+ main_title = title_elm.string
129
+
130
+ # Convert the page
131
+ webpage_text = "# " + main_title + "\n\n" + markdownify.MarkdownConverter().convert_soup(body_elm)
132
+ else:
133
+ webpage_text = markdownify.MarkdownConverter().convert_soup(soup)
134
+
135
+ return DocumentConverterResult(
136
+ title=soup.title.string,
137
+ text_content=webpage_text,
138
+ )
139
+
140
+
141
+ class YouTubeConverter(DocumentConverter):
142
+ """Handle YouTube specially, focusing on the video title, description, and transcript."""
143
+
144
+ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
145
+ # Bail if not YouTube
146
+ extension = kwargs.get("file_extension", "")
147
+ if extension.lower() not in [".html", ".htm"]:
148
+ return None
149
+ url = kwargs.get("url", "")
150
+ if not url.startswith("https://www.youtube.com/watch?"):
151
+ return None
152
+
153
+ # Parse the file
154
+ soup = None
155
+ with open(local_path, "rt") as fh:
156
+ soup = BeautifulSoup(fh.read(), "html.parser")
157
+
158
+ # Read the meta tags
159
+ metadata = {"title": soup.title.string}
160
+ for meta in soup(["meta"]):
161
+ for a in meta.attrs:
162
+ if a in ["itemprop", "property", "name"]:
163
+ metadata[meta[a]] = meta.get("content", "")
164
+ break
165
+
166
+ # We can also try to read the full description. This is more prone to breaking, since it reaches into the page implementation
167
+ try:
168
+ for script in soup(["script"]):
169
+ content = script.text
170
+ if "ytInitialData" in content:
171
+ lines = re.split(r"\r?\n", content)
172
+ obj_start = lines[0].find("{")
173
+ obj_end = lines[0].rfind("}")
174
+ if obj_start >= 0 and obj_end >= 0:
175
+ data = json.loads(lines[0][obj_start : obj_end + 1])
176
+ attrdesc = self._findKey(data, "attributedDescriptionBodyText")
177
+ if attrdesc:
178
+ metadata["description"] = attrdesc["content"]
179
+ break
180
+ except:
181
+ pass
182
+
183
+ # Start preparing the page
184
+ webpage_text = "# YouTube\n"
185
+
186
+ title = self._get(metadata, ["title", "og:title", "name"])
187
+ if title:
188
+ webpage_text += f"\n## {title}\n"
189
+
190
+ stats = ""
191
+ views = self._get(metadata, ["interactionCount"])
192
+ if views:
193
+ stats += f"- **Views:** {views}\n"
194
+
195
+ keywords = self._get(metadata, ["keywords"])
196
+ if keywords:
197
+ stats += f"- **Keywords:** {keywords}\n"
198
+
199
+ runtime = self._get(metadata, ["duration"])
200
+ if runtime:
201
+ stats += f"- **Runtime:** {runtime}\n"
202
+
203
+ if len(stats) > 0:
204
+ webpage_text += f"\n### Video Metadata\n{stats}\n"
205
+
206
+ description = self._get(metadata, ["description", "og:description"])
207
+ if description:
208
+ webpage_text += f"\n### Description\n{description}\n"
209
+
210
+ transcript_text = ""
211
+ parsed_url = urlparse(url)
212
+ params = parse_qs(parsed_url.query)
213
+
214
+ video_id = params["v"][0]
215
+ # Must be a single transcript.
216
+ print("VIDDDD ID:", video_id)
217
+ transcript = YouTubeTranscriptApi.get_transcript(video_id)
218
+ transcript_text = " ".join([part["text"] for part in transcript])
219
+ # Alternative formatting:
220
+ # formatter = TextFormatter()
221
+ # formatter.format_transcript(transcript)
222
+ if transcript_text:
223
+ webpage_text += f"\n### Transcript\n{transcript_text}\n"
224
+
225
+ return DocumentConverterResult(
226
+ title=title if title else soup.title.string,
227
+ text_content=webpage_text,
228
+ )
229
+
230
+ def _get(self, json, keys, default=None):
231
+ for k in keys:
232
+ if k in json:
233
+ return json[k]
234
+ return default
235
+
236
+ def _findKey(self, json, key):
237
+ if isinstance(json, list):
238
+ for elm in json:
239
+ ret = self._findKey(elm, key)
240
+ if ret is not None:
241
+ return ret
242
+ elif isinstance(json, dict):
243
+ for k in json:
244
+ if k == key:
245
+ return json[k]
246
+ else:
247
+ ret = self._findKey(json[k], key)
248
+ if ret is not None:
249
+ return ret
250
+ return None
251
+
252
+
253
+ class PdfConverter(DocumentConverter):
254
+ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
255
+ # Bail if not a PDF
256
+ extension = kwargs.get("file_extension", "")
257
+ if extension.lower() != ".pdf":
258
+ return None
259
+
260
+ return DocumentConverterResult(
261
+ title=None,
262
+ text_content=pdfminer.high_level.extract_text(local_path),
263
+ )
264
+
265
+ from huggingface_hub import InferenceClient
266
+ class AudioConverter(DocumentConverter):
267
+ def __init__(self):
268
+ super().__init__()
269
+ self.client = InferenceClient("distil-whisper/distil-large-v3")
270
+
271
+ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
272
+ # Bail if not an audio file
273
+ extension = kwargs.get("file_extension", "")
274
+ if extension.lower() not in [".wav", ".mp3", ".flac", ".m4a"]:
275
+ return None
276
+ try:
277
+ result = self.client.automatic_speech_recognition(audio=local_path).text
278
+ except Exception as e:
279
+ print("Exception in decoding audio:", e)
280
+ from openai import OpenAI
281
+ oai_client = OpenAI()
282
+ from pathlib import Path
283
+ result = oai_client.audio.transcriptions.create(
284
+ model="whisper-1",
285
+ file=Path(local_path)
286
+ ).text
287
+
288
+ return DocumentConverterResult(
289
+ title=None,
290
+ text_content=result,
291
+ )
292
+
293
+
294
+ class DocxConverter(HtmlConverter):
295
+ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
296
+ # Bail if not a DOCX
297
+ extension = kwargs.get("file_extension", "")
298
+ if extension.lower() != ".docx":
299
+ return None
300
+
301
+ result = None
302
+ with open(local_path, "rb") as docx_file:
303
+ result = mammoth.convert_to_html(docx_file)
304
+ html_content = result.value
305
+ result = self._convert(html_content)
306
+
307
+ return result
308
+
309
+
310
+ class XlsxConverter(HtmlConverter):
311
+ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
312
+ # Bail if not a XLSX
313
+ extension = kwargs.get("file_extension", "")
314
+
315
+ if extension.lower() not in [".xlsx", ".xls"]:
316
+ return None
317
+
318
+ sheets = pd.read_excel(local_path, sheet_name=None)
319
+ md_content = ""
320
+ for s in sheets:
321
+ md_content += f"## {s}\n"
322
+ html_content = sheets[s].to_html(index=False)
323
+ md_content += self._convert(html_content).text_content.strip() + "\n\n"
324
+
325
+ return DocumentConverterResult(
326
+ title=None,
327
+ text_content=md_content.strip(),
328
+ )
329
+
330
+
331
+ import xml.etree.ElementTree as ET
332
+ class XmlConverter(DocumentConverter):
333
+ def convert(self, local_path, **kwargs) -> None | DocumentConverterResult:
334
+ # Parse the XML string
335
+ extension = kwargs.get("file_extension", "")
336
+
337
+ if extension.lower() not in [".xml"]:
338
+ return None
339
+
340
+ xml_string = ""
341
+ with open(local_path, "rt") as fh:
342
+ xml_string = fh.read()
343
+
344
+ def extract_table_from_html_like(xml_root):
345
+ table = xml_root.find('.//table')
346
+ if table is None:
347
+ raise ValueError("No table found in the XML")
348
+
349
+ headers = [th.text for th in table.find('thead').findall('th')]
350
+ rows = [[td.text for td in tr.findall('td')] for tr in table.find('tbody').findall('tr')]
351
+
352
+ # Create markdown table
353
+ markdown = '| ' + ' | '.join(headers) + ' |\n'
354
+ markdown += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n'
355
+ for row in rows:
356
+ markdown += '| ' + ' | '.join(row) + ' |\n'
357
+
358
+ def extract_table_from_wordml(xml_root, namespaces):
359
+ # Parse the XML content
360
+ root = xml_root
361
+ namespace = {'w': 'http://schemas.microsoft.com/office/word/2003/wordml'}
362
+
363
+ # Extract text content
364
+ body = root.find('w:body', namespace)
365
+ paragraphs = body.findall('.//w:p', namespace)
366
+ text_content = []
367
+ for para in paragraphs:
368
+ texts = para.findall('.//w:t', namespace)
369
+ for text in texts:
370
+ text_content.append(text.text)
371
+
372
+ return '\n'.join(text_content)
373
+
374
+ # Parse the XML string
375
+ root = ET.fromstring(xml_string)
376
+ namespaces = {'w': 'http://schemas.microsoft.com/office/word/2003/wordml'}
377
+
378
+ if root.tag.endswith('wordDocument'):
379
+ markdown = extract_table_from_wordml(root, namespaces)
380
+ else:
381
+ markdown = extract_table_from_html_like(root)
382
+
383
+ return DocumentConverterResult(
384
+ title=None,
385
+ text_content=markdown.strip(),
386
+ )
387
+
388
+ class PptxConverter(HtmlConverter):
389
+ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
390
+ # Bail if not a PPTX
391
+ extension = kwargs.get("file_extension", "")
392
+ if extension.lower() != ".pptx":
393
+ return None
394
+
395
+ md_content = ""
396
+
397
+ presentation = pptx.Presentation(local_path)
398
+ slide_num = 0
399
+ for slide in presentation.slides:
400
+ slide_num += 1
401
+
402
+ md_content += f"\n\n<!-- Slide number: {slide_num} -->\n"
403
+
404
+ title = slide.shapes.title
405
+ for shape in slide.shapes:
406
+ # Pictures
407
+ if self._is_picture(shape):
408
+ # https://github.com/scanny/python-pptx/pull/512#issuecomment-1713100069
409
+ alt_text = ""
410
+ try:
411
+ alt_text = shape._element._nvXxPr.cNvPr.attrib.get("descr", "")
412
+ except:
413
+ pass
414
+
415
+ # A placeholder name
416
+ filename = re.sub(r"\W", "", shape.name) + ".jpg"
417
+ # try:
418
+ # filename = shape.image.filename
419
+ # except:
420
+ # pass
421
+
422
+ md_content += "\n![" + (alt_text if alt_text else shape.name) + "](" + filename + ")\n"
423
+
424
+ # Tables
425
+ if self._is_table(shape):
426
+ html_table = "<html><body><table>"
427
+ first_row = True
428
+ for row in shape.table.rows:
429
+ html_table += "<tr>"
430
+ for cell in row.cells:
431
+ if first_row:
432
+ html_table += "<th>" + html.escape(cell.text) + "</th>"
433
+ else:
434
+ html_table += "<td>" + html.escape(cell.text) + "</td>"
435
+ html_table += "</tr>"
436
+ first_row = False
437
+ html_table += "</table></body></html>"
438
+ md_content += "\n" + self._convert(html_table).text_content.strip() + "\n"
439
+
440
+ # Text areas
441
+ elif shape.has_text_frame:
442
+ if shape == title:
443
+ md_content += "# " + shape.text.lstrip() + " "
444
+ else:
445
+ md_content += shape.text + " "
446
+
447
+ md_content = md_content.strip()
448
+
449
+ if slide.has_notes_slide:
450
+ md_content += "\n\n### Notes:\n"
451
+ notes_frame = slide.notes_slide.notes_text_frame
452
+ if notes_frame is not None:
453
+ md_content += notes_frame.text
454
+ md_content = md_content.strip()
455
+
456
+ return DocumentConverterResult(
457
+ title=None,
458
+ text_content=md_content.strip(),
459
+ )
460
+
461
+ def _is_picture(self, shape):
462
+ if shape.shape_type == pptx.enum.shapes.MSO_SHAPE_TYPE.PICTURE:
463
+ return True
464
+ if shape.shape_type == pptx.enum.shapes.MSO_SHAPE_TYPE.PLACEHOLDER:
465
+ if hasattr(shape, "image"):
466
+ return True
467
+ return False
468
+
469
+ def _is_table(self, shape):
470
+ if shape.shape_type == pptx.enum.shapes.MSO_SHAPE_TYPE.TABLE:
471
+ return True
472
+ return False
473
+
474
+ class FileConversionException(Exception):
475
+ pass
476
+
477
+ class UnsupportedFormatException(Exception):
478
+ pass
479
+
480
+ class MarkdownConverter:
481
+ """(In preview) An extremely simple text-based document reader, suitable for LLM use.
482
+ This reader will convert common file-types or webpages to Markdown."""
483
+
484
+ def __init__(
485
+ self,
486
+ requests_session: Optional[requests.Session] = None,
487
+ ):
488
+ if requests_session is None:
489
+ self._requests_session = requests.Session()
490
+ else:
491
+ self._requests_session = requests_session
492
+
493
+
494
+ self._page_converters: List[DocumentConverter] = []
495
+
496
+ # Register converters for successful browsing operations
497
+ # Later registrations are tried first / take higher priority than earlier registrations
498
+ # To this end, the most specific converters should appear below the most generic converters
499
+ self.register_page_converter(WikipediaConverter())
500
+ self.register_page_converter(XmlConverter())
501
+ self.register_page_converter(YouTubeConverter())
502
+ self.register_page_converter(DocxConverter())
503
+ self.register_page_converter(XlsxConverter())
504
+ self.register_page_converter(PptxConverter())
505
+ # self.register_page_converter(ImageConverter())
506
+ self.register_page_converter(PdfConverter())
507
+ self.register_page_converter(AudioConverter())
508
+ self.register_page_converter(HtmlConverter())
509
+ self.register_page_converter(PlainTextConverter())
510
+
511
+ def convert(self, source, **kwargs):
512
+ """
513
+ Args:
514
+ - source: can be a string representing a path or url, or a requests.response object
515
+ - extension: specifies the file extension to use when interpreting the file. If None, infer from source (path, uri, content-type, etc.)
516
+ """
517
+
518
+ # Local path or url
519
+ if isinstance(source, str):
520
+ if source.startswith("http://") or source.startswith("https://") or source.startswith("file://"):
521
+ return self.convert_url(source, **kwargs)
522
+ else:
523
+ return self.convert_local(source, **kwargs)
524
+ # Request response
525
+ elif isinstance(source, requests.Response):
526
+ return self.convert_response(source, **kwargs)
527
+
528
+ def convert_local(self, path, **kwargs):
529
+ # Prepare a list of extensions to try (in order of priority)
530
+ ext = kwargs.get("file_extension")
531
+ extensions = [ext] if ext is not None else []
532
+
533
+ # Get extension alternatives from the path and puremagic
534
+ base, ext = os.path.splitext(path)
535
+ self._append_ext(extensions, ext)
536
+ self._append_ext(extensions, self._guess_ext_magic(path))
537
+
538
+ # Convert
539
+ return self._convert(path, extensions, **kwargs)
540
+
541
+ def convert_url(self, url, **kwargs):
542
+ # Send a HTTP request to the URL
543
+ user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
544
+ response = self._requests_session.get(url, stream=True, headers={"User-Agent": user_agent})
545
+ response.raise_for_status()
546
+ return self.convert_response(response, **kwargs)
547
+
548
+ def convert_response(self, response, **kwargs):
549
+ # Prepare a list of extensions to try (in order of priority)
550
+ ext = kwargs.get("file_extension")
551
+ extensions = [ext] if ext is not None else []
552
+
553
+ # Guess from the mimetype
554
+ content_type = response.headers.get("content-type", "").split(";")[0]
555
+ self._append_ext(extensions, mimetypes.guess_extension(content_type))
556
+
557
+ # Read the content disposition if there is one
558
+ content_disposition = response.headers.get("content-disposition", "")
559
+ m = re.search(r"filename=([^;]+)", content_disposition)
560
+ if m:
561
+ base, ext = os.path.splitext(m.group(1).strip("\"'"))
562
+ self._append_ext(extensions, ext)
563
+
564
+ # Read from the extension from the path
565
+ base, ext = os.path.splitext(urlparse(response.url).path)
566
+ self._append_ext(extensions, ext)
567
+
568
+ # Save the file locally to a temporary file. It will be deleted before this method exits
569
+ handle, temp_path = tempfile.mkstemp()
570
+ fh = os.fdopen(handle, "wb")
571
+ result = None
572
+ try:
573
+ # Download the file
574
+ for chunk in response.iter_content(chunk_size=512):
575
+ fh.write(chunk)
576
+ fh.close()
577
+
578
+ # Use puremagic to check for more extension options
579
+ self._append_ext(extensions, self._guess_ext_magic(temp_path))
580
+
581
+ # Convert
582
+ result = self._convert(temp_path, extensions, url=response.url)
583
+ except Exception as e:
584
+ print(f"Error in converting: {e}")
585
+
586
+ # Clean up
587
+ finally:
588
+ try:
589
+ fh.close()
590
+ except:
591
+ pass
592
+ os.unlink(temp_path)
593
+
594
+ return result
595
+
596
+ def _convert(self, local_path, extensions, **kwargs):
597
+ error_trace = ""
598
+ for ext in extensions:
599
+ for converter in self._page_converters:
600
+ _kwargs = copy.deepcopy(kwargs)
601
+ _kwargs.update({"file_extension": ext})
602
+ # If we hit an error log it and keep trying
603
+ try:
604
+ res = converter.convert(local_path, **_kwargs)
605
+ if res is not None:
606
+ # Normalize the content
607
+ res.text_content = "\n".join([line.rstrip() for line in re.split(r"\r?\n", res.text_content)])
608
+ res.text_content = re.sub(r"\n{3,}", "\n\n", res.text_content)
609
+
610
+ # Todo
611
+ return res
612
+ except Exception as e:
613
+ error_trace = ("\n\n" + traceback.format_exc()).strip()
614
+
615
+
616
+ # If we got this far without success, report any exceptions
617
+ if len(error_trace) > 0:
618
+ raise FileConversionException(
619
+ f"Could not convert '{local_path}' to Markdown. File type was recognized as {extensions}. While converting the file, the following error was encountered:\n\n{error_trace}"
620
+ )
621
+
622
+ # Nothing can handle it!
623
+ # raise UnsupportedFormatException(
624
+ # f"Could not convert '{local_path}' to Markdown. The formats {extensions} are not supported."
625
+ # )
626
+ res = PlainTextConverter().convert(local_path, **kwargs)
627
+ return res
628
+
629
+ def _append_ext(self, extensions, ext):
630
+ """Append a unique non-None, non-empty extension to a list of extensions."""
631
+ if ext is None:
632
+ return
633
+ ext = ext.strip()
634
+ if ext == "":
635
+ return
636
+ # if ext not in extensions:
637
+ if True:
638
+ extensions.append(ext)
639
+
640
+ def _guess_ext_magic(self, path):
641
+ """Use puremagic (a Python implementation of libmagic) to guess a file's extension based on the first few bytes."""
642
+ # Use puremagic to guess
643
+ try:
644
+ guesses = puremagic.magic_file(path)
645
+ if len(guesses) > 0:
646
+ ext = guesses[0].extension.strip()
647
+ if len(ext) > 0:
648
+ return ext
649
+ except FileNotFoundError:
650
+ pass
651
+ except IsADirectoryError:
652
+ pass
653
+ except PermissionError:
654
+ pass
655
+ return None
656
+
657
+ def register_page_converter(self, converter: DocumentConverter) -> None:
658
+ """Register a page text converter."""
659
+ self._page_converters.append(converter)
requirements (93).txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.43.0
2
+ torch
3
+ gradio
4
+ huggingface_hub
5
+ beautifulsoup4
6
+ requests
7
+ gradio_tools
8
+ accelerate
9
+ langchain
10
+ sentence-transformers
11
+ faiss-cpu
12
+ langchain_community
13
+ langchain-huggingface
14
+ pypdf
15
+ markdownify
16
+ urllib3
17
+ pathvalidate
18
+ pdfminer.six
19
+ pdfminer
20
+ mammoth
21
+ python-pptx
22
+ pandas
23
+ puremagic
24
+ youtube_transcript_api
25
+ google-search-results
26
+ duckduckgo_search
27
+ cmake
28
+ numpy
29
+ pynvml
30
+ argparse
31
+ typing
32
+ tqdm
rwkv_cpp_model (2).py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import multiprocessing
3
+
4
+ # Pre-import PyTorch, if available.
5
+ # This fixes "OSError: [WinError 127] The specified procedure could not be found".
6
+ try:
7
+ import torch
8
+ except ModuleNotFoundError:
9
+ pass
10
+
11
+ # I'm sure this is not strictly correct, but let's keep this crutch for now.
12
+ try:
13
+ import rwkv_cpp_shared_library
14
+ except ModuleNotFoundError:
15
+ from . import rwkv_cpp_shared_library
16
+
17
+ from typing import TypeVar, Optional, Tuple, List
18
+
19
+ # A value of this type is either a numpy's ndarray or a PyTorch's Tensor.
20
+ NumpyArrayOrPyTorchTensor: TypeVar = TypeVar('NumpyArrayOrPyTorchTensor')
21
+
22
+ class RWKVModel:
23
+ """
24
+ An RWKV model managed by rwkv.cpp library.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
30
+ model_path: str,
31
+ thread_count: int = max(1, multiprocessing.cpu_count() // 2),
32
+ gpu_layer_count: int = 0,
33
+ **kwargs
34
+ ) -> None:
35
+ """
36
+ Loads the model and prepares it for inference.
37
+ In case of any error, this method will throw an exception.
38
+
39
+ Parameters
40
+ ----------
41
+ shared_library : RWKVSharedLibrary
42
+ rwkv.cpp shared library.
43
+ model_path : str
44
+ Path to RWKV model file in ggml format.
45
+ thread_count : int
46
+ Thread count to use. If not set, defaults to CPU count / 2.
47
+ gpu_layer_count : int
48
+ Count of layers to offload onto the GPU, must be >= 0.
49
+ See documentation of `gpu_offload_layers` for details about layer offloading.
50
+ """
51
+
52
+ if 'gpu_layers_count' in kwargs:
53
+ gpu_layer_count = kwargs['gpu_layers_count']
54
+
55
+ if not os.path.isfile(model_path):
56
+ raise ValueError(f'{model_path} is not a file')
57
+
58
+ if not (thread_count > 0):
59
+ raise ValueError('Thread count must be > 0')
60
+
61
+ if not (gpu_layer_count >= 0):
62
+ raise ValueError('GPU layer count must be >= 0')
63
+
64
+ self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library
65
+
66
+ self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count)
67
+
68
+ if gpu_layer_count > 0:
69
+ self.gpu_offload_layers(gpu_layer_count)
70
+
71
+ self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx)
72
+ self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
73
+
74
+ self._valid: bool = True
75
+
76
+ def gpu_offload_layers(self, layer_count: int) -> bool:
77
+ """
78
+ Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
79
+ For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
80
+ - pass `model.n_layer` to offload all layers except model head
81
+ - pass `model.n_layer + 1` to offload all layers, including model head
82
+
83
+ Returns true if at least one layer was offloaded.
84
+ If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
85
+
86
+ Parameters
87
+ ----------
88
+ layer_count : int
89
+ Count of layers to offload onto the GPU, must be >= 0.
90
+ """
91
+
92
+ if not (layer_count >= 0):
93
+ raise ValueError('Layer count must be >= 0')
94
+
95
+ return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)
96
+
97
+ @property
98
+ def n_vocab(self) -> int:
99
+ return self._library.rwkv_get_n_vocab(self._ctx)
100
+
101
+ @property
102
+ def n_embed(self) -> int:
103
+ return self._library.rwkv_get_n_embed(self._ctx)
104
+
105
+ @property
106
+ def n_layer(self) -> int:
107
+ return self._library.rwkv_get_n_layer(self._ctx)
108
+
109
+ def eval(
110
+ self,
111
+ token: int,
112
+ state_in: Optional[NumpyArrayOrPyTorchTensor],
113
+ state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
114
+ logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
115
+ use_numpy: bool = False
116
+ ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
117
+ """
118
+ Evaluates the model for a single token.
119
+ In case of any error, this method will throw an exception.
120
+
121
+ Parameters
122
+ ----------
123
+ token : int
124
+ Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
125
+ state_in : Optional[NumpyArrayOrTorchTensor]
126
+ State from previous call of this method. If this is a first pass, set it to None.
127
+ state_out : Optional[NumpyArrayOrTorchTensor]
128
+ Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
129
+ logits_out : Optional[NumpyArrayOrTorchTensor]
130
+ Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
131
+ use_numpy : bool
132
+ If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
133
+ This parameter is ignored if any tensor parameter is not None; in such case,
134
+ type of returned tensors will match the type of received tensors.
135
+
136
+ Returns
137
+ -------
138
+ logits, state
139
+ Logits vector of shape (n_vocab); state for the next step.
140
+ """
141
+
142
+ if not self._valid:
143
+ raise ValueError('Model was freed')
144
+
145
+ use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
146
+
147
+ if state_in is not None:
148
+ self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
149
+
150
+ state_in_ptr = self._get_data_ptr(state_in)
151
+ else:
152
+ state_in_ptr = 0
153
+
154
+ if state_out is not None:
155
+ self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
156
+ else:
157
+ state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
158
+
159
+ if logits_out is not None:
160
+ self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
161
+ else:
162
+ logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
163
+
164
+ self._library.rwkv_eval(
165
+ self._ctx,
166
+ token,
167
+ state_in_ptr,
168
+ self._get_data_ptr(state_out),
169
+ self._get_data_ptr(logits_out)
170
+ )
171
+
172
+ return logits_out, state_out
173
+
174
+ def eval_sequence(
175
+ self,
176
+ tokens: List[int],
177
+ state_in: Optional[NumpyArrayOrPyTorchTensor],
178
+ state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
179
+ logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
180
+ use_numpy: bool = False
181
+ ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
182
+ """
183
+ Evaluates the model for a sequence of tokens.
184
+
185
+ NOTE ON GGML NODE LIMIT
186
+
187
+ ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
188
+ this limit when using large models and/or large sequence lengths.
189
+ Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
190
+
191
+ If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
192
+ To get rid of the assertion failure, reduce the model size and/or sequence length.
193
+
194
+ In case of any error, this method will throw an exception.
195
+
196
+ Parameters
197
+ ----------
198
+ tokens : List[int]
199
+ Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
200
+ state_in : Optional[NumpyArrayOrTorchTensor]
201
+ State from previous call of this method. If this is a first pass, set it to None.
202
+ state_out : Optional[NumpyArrayOrTorchTensor]
203
+ Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
204
+ logits_out : Optional[NumpyArrayOrTorchTensor]
205
+ Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
206
+ use_numpy : bool
207
+ If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
208
+ This parameter is ignored if any tensor parameter is not None; in such case,
209
+ type of returned tensors will match the type of received tensors.
210
+
211
+ Returns
212
+ -------
213
+ logits, state
214
+ Logits vector of shape (n_vocab); state for the next step.
215
+ """
216
+
217
+ if not self._valid:
218
+ raise ValueError('Model was freed')
219
+
220
+ use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
221
+
222
+ if state_in is not None:
223
+ self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
224
+
225
+ state_in_ptr = self._get_data_ptr(state_in)
226
+ else:
227
+ state_in_ptr = 0
228
+
229
+ if state_out is not None:
230
+ self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
231
+ else:
232
+ state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
233
+
234
+ if logits_out is not None:
235
+ self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
236
+ else:
237
+ logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
238
+
239
+ self._library.rwkv_eval_sequence(
240
+ self._ctx,
241
+ tokens,
242
+ state_in_ptr,
243
+ self._get_data_ptr(state_out),
244
+ self._get_data_ptr(logits_out)
245
+ )
246
+
247
+ return logits_out, state_out
248
+
249
+ def eval_sequence_in_chunks(
250
+ self,
251
+ tokens: List[int],
252
+ state_in: Optional[NumpyArrayOrPyTorchTensor],
253
+ state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
254
+ logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
255
+ chunk_size: int = 16,
256
+ use_numpy: bool = False
257
+ ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
258
+ """
259
+ Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
260
+ This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
261
+ It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance.
262
+
263
+ Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
264
+ A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
265
+ and choose one that works the best in your use case.
266
+
267
+ In case of any error, this method will throw an exception.
268
+
269
+ Parameters
270
+ ----------
271
+ tokens : List[int]
272
+ Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
273
+ chunk_size : int
274
+ Size of each chunk in tokens, must be positive.
275
+ state_in : Optional[NumpyArrayOrTorchTensor]
276
+ State from previous call of this method. If this is a first pass, set it to None.
277
+ state_out : Optional[NumpyArrayOrTorchTensor]
278
+ Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
279
+ logits_out : Optional[NumpyArrayOrTorchTensor]
280
+ Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
281
+ use_numpy : bool
282
+ If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
283
+ This parameter is ignored if any tensor parameter is not None; in such case,
284
+ type of returned tensors will match the type of received tensors.
285
+
286
+ Returns
287
+ -------
288
+ logits, state
289
+ Logits vector of shape (n_vocab); state for the next step.
290
+ """
291
+
292
+ if not self._valid:
293
+ raise ValueError('Model was freed')
294
+
295
+ use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
296
+
297
+ if state_in is not None:
298
+ self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
299
+
300
+ state_in_ptr = self._get_data_ptr(state_in)
301
+ else:
302
+ state_in_ptr = 0
303
+
304
+ if state_out is not None:
305
+ self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
306
+ else:
307
+ state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
308
+
309
+ if logits_out is not None:
310
+ self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
311
+ else:
312
+ logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
313
+
314
+ self._library.rwkv_eval_sequence_in_chunks(
315
+ self._ctx,
316
+ tokens,
317
+ chunk_size,
318
+ state_in_ptr,
319
+ self._get_data_ptr(state_out),
320
+ self._get_data_ptr(logits_out)
321
+ )
322
+
323
+ return logits_out, state_out
324
+
325
+ def free(self) -> None:
326
+ """
327
+ Frees all allocated resources.
328
+ In case of any error, this method will throw an exception.
329
+ The object must not be used anymore after calling this method.
330
+ """
331
+
332
+ if not self._valid:
333
+ raise ValueError('Already freed')
334
+
335
+ self._valid = False
336
+
337
+ self._library.rwkv_free(self._ctx)
338
+
339
+ def __del__(self) -> None:
340
+ # Free the context on GC in case user forgot to call free() explicitly.
341
+ if hasattr(self, '_valid') and self._valid:
342
+ self.free()
343
+
344
+ def _is_pytorch_tensor(self, tensor: NumpyArrayOrPyTorchTensor) -> bool:
345
+ return hasattr(tensor, '__module__') and tensor.__module__ == 'torch'
346
+
347
+ def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]], use_numpy_by_default: bool) -> bool:
348
+ for tensor in tensors:
349
+ if tensor is not None:
350
+ return False if self._is_pytorch_tensor(tensor) else True
351
+
352
+ return use_numpy_by_default
353
+
354
+ def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
355
+ if self._is_pytorch_tensor(tensor):
356
+ tensor: torch.Tensor = tensor
357
+
358
+ if tensor.device != torch.device('cpu'):
359
+ raise ValueError(f'{name} is not on CPU')
360
+ if tensor.dtype != torch.float32:
361
+ raise ValueError(f'{name} is not of type float32')
362
+ if tensor.shape != (size,):
363
+ raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
364
+ if not tensor.is_contiguous():
365
+ raise ValueError(f'{name} is not contiguous')
366
+ else:
367
+ import numpy as np
368
+ tensor: np.ndarray = tensor
369
+
370
+ if tensor.dtype != np.float32:
371
+ raise ValueError(f'{name} is not of type float32')
372
+ if tensor.shape != (size,):
373
+ raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
374
+ if not tensor.data.contiguous:
375
+ raise ValueError(f'{name} is not contiguous')
376
+
377
+ def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
378
+ if self._is_pytorch_tensor(tensor):
379
+ return tensor.data_ptr()
380
+ else:
381
+ return tensor.ctypes.data
382
+
383
+ def _zeros_float32(self, element_count: int, use_numpy: bool) -> NumpyArrayOrPyTorchTensor:
384
+ if use_numpy:
385
+ import numpy as np
386
+ return np.zeros(element_count, dtype=np.float32)
387
+ else:
388
+ return torch.zeros(element_count, dtype=torch.float32, device='cpu')
rwkv_cpp_shared_library (2).py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import ctypes
4
+ import pathlib
5
+ import platform
6
+ from typing import Optional, List, Tuple, Callable
7
+
8
+ QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = (
9
+ 'Q4_0',
10
+ 'Q4_1',
11
+ 'Q5_0',
12
+ 'Q5_1',
13
+ 'Q8_0'
14
+ )
15
+
16
+ P_FLOAT = ctypes.POINTER(ctypes.c_float)
17
+ P_INT = ctypes.POINTER(ctypes.c_int32)
18
+
19
+ class RWKVContext:
20
+
21
+ def __init__(self, ptr: ctypes.pointer) -> None:
22
+ self.ptr: ctypes.pointer = ptr
23
+
24
+ class RWKVSharedLibrary:
25
+ """
26
+ Python wrapper around rwkv.cpp shared library.
27
+ """
28
+
29
+ def __init__(self, shared_library_path: str) -> None:
30
+ """
31
+ Loads the shared library from specified file.
32
+ In case of any error, this method will throw an exception.
33
+
34
+ Parameters
35
+ ----------
36
+ shared_library_path : str
37
+ Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
38
+ """
39
+ # When Python is greater than 3.8, we need to reprocess the custom dll
40
+ # according to the documentation to prevent loading failure errors.
41
+ # https://docs.python.org/3/whatsnew/3.8.html#ctypes
42
+ if platform.system().lower() == 'windows':
43
+ self.library = ctypes.CDLL(shared_library_path, winmode=0)
44
+ else:
45
+ self.library = ctypes.cdll.LoadLibrary(shared_library_path)
46
+
47
+ self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
48
+ self.library.rwkv_init_from_file.restype = ctypes.c_void_p
49
+
50
+ self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
51
+ self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
52
+
53
+ self.library.rwkv_eval.argtypes = [
54
+ ctypes.c_void_p, # ctx
55
+ ctypes.c_int32, # token
56
+ P_FLOAT, # state_in
57
+ P_FLOAT, # state_out
58
+ P_FLOAT # logits_out
59
+ ]
60
+ self.library.rwkv_eval.restype = ctypes.c_bool
61
+
62
+ self.library.rwkv_eval_sequence.argtypes = [
63
+ ctypes.c_void_p, # ctx
64
+ P_INT, # tokens
65
+ ctypes.c_size_t, # token count
66
+ P_FLOAT, # state_in
67
+ P_FLOAT, # state_out
68
+ P_FLOAT # logits_out
69
+ ]
70
+ self.library.rwkv_eval_sequence.restype = ctypes.c_bool
71
+
72
+ self.library.rwkv_eval_sequence_in_chunks.argtypes = [
73
+ ctypes.c_void_p, # ctx
74
+ P_INT, # tokens
75
+ ctypes.c_size_t, # token count
76
+ ctypes.c_size_t, # chunk size
77
+ P_FLOAT, # state_in
78
+ P_FLOAT, # state_out
79
+ P_FLOAT # logits_out
80
+ ]
81
+ self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool
82
+
83
+ self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
84
+ self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t
85
+
86
+ self.library.rwkv_get_n_embed.argtypes = [ctypes.c_void_p]
87
+ self.library.rwkv_get_n_embed.restype = ctypes.c_size_t
88
+
89
+ self.library.rwkv_get_n_layer.argtypes = [ctypes.c_void_p]
90
+ self.library.rwkv_get_n_layer.restype = ctypes.c_size_t
91
+
92
+ self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
93
+ self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32
94
+
95
+ self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
96
+ self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32
97
+
98
+ self.library.rwkv_free.argtypes = [ctypes.c_void_p]
99
+ self.library.rwkv_free.restype = None
100
+
101
+ self.library.rwkv_free.argtypes = [ctypes.c_void_p]
102
+ self.library.rwkv_free.restype = None
103
+
104
+ self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p]
105
+ self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
106
+
107
+ self.library.rwkv_get_system_info_string.argtypes = []
108
+ self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
109
+
110
+ self.nullptr = ctypes.cast(0, ctypes.c_void_p)
111
+
112
+ def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
113
+ """
114
+ Loads the model from a file and prepares it for inference.
115
+ Throws an exception in case of any error. Error messages would be printed to stderr.
116
+
117
+ Parameters
118
+ ----------
119
+ model_file_path : str
120
+ Path to model file in ggml format.
121
+ thread_count : int
122
+ Count of threads to use, must be positive.
123
+ """
124
+
125
+ ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
126
+
127
+ if ptr is None:
128
+ raise ValueError('rwkv_init_from_file failed, check stderr')
129
+
130
+ return RWKVContext(ptr)
131
+
132
+ def rwkv_gpu_offload_layers(self, ctx: RWKVContext, layer_count: int) -> bool:
133
+ """
134
+ Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
135
+ For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
136
+ - pass `rwkv_get_n_layer(ctx)` to offload all layers except model head
137
+ - pass `rwkv_get_n_layer(ctx) + 1` to offload all layers, including model head
138
+ Returns true if at least one layer was offloaded.
139
+ If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
140
+
141
+ Parameters
142
+ ----------
143
+ ctx : RWKVContext
144
+ RWKV context obtained from rwkv_init_from_file.
145
+ layer_count : int
146
+ Count of layers to offload onto the GPU, must be >= 0.
147
+ """
148
+
149
+ if not (layer_count >= 0):
150
+ raise ValueError('Layer count must be >= 0')
151
+
152
+ return self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(layer_count))
153
+
154
+ def rwkv_eval(
155
+ self,
156
+ ctx: RWKVContext,
157
+ token: int,
158
+ state_in_address: Optional[int],
159
+ state_out_address: int,
160
+ logits_out_address: int
161
+ ) -> None:
162
+ """
163
+ Evaluates the model for a single token.
164
+ Throws an exception in case of any error. Error messages would be printed to stderr.
165
+ Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
166
+
167
+ Parameters
168
+ ----------
169
+ ctx : RWKVContext
170
+ RWKV context obtained from rwkv_init_from_file.
171
+ token : int
172
+ Next token index, in range 0 <= token < n_vocab.
173
+ state_in_address : int
174
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
175
+ state_out_address : int
176
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
177
+ logits_out_address : int
178
+ Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
179
+ """
180
+
181
+ if not self.library.rwkv_eval(
182
+ ctx.ptr,
183
+ ctypes.c_int32(token),
184
+ ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
185
+ ctypes.cast(state_out_address, P_FLOAT),
186
+ ctypes.cast(logits_out_address, P_FLOAT)
187
+ ):
188
+ raise ValueError('rwkv_eval failed, check stderr')
189
+
190
+ def rwkv_eval_sequence(
191
+ self,
192
+ ctx: RWKVContext,
193
+ tokens: List[int],
194
+ state_in_address: Optional[int],
195
+ state_out_address: int,
196
+ logits_out_address: int
197
+ ) -> None:
198
+ """
199
+ Evaluates the model for a sequence of tokens.
200
+ Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so.
201
+ Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
202
+
203
+ NOTE ON GGML NODE LIMIT
204
+
205
+ ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
206
+ this limit when using large models and/or large sequence lengths.
207
+ Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
208
+
209
+ If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
210
+ To get rid of the assertion failure, reduce the model size and/or sequence length.
211
+
212
+ Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
213
+ Throws an exception in case of any error. Error messages would be printed to stderr.
214
+
215
+ Parameters
216
+ ----------
217
+ ctx : RWKVContext
218
+ RWKV context obtained from rwkv_init_from_file.
219
+ tokens : List[int]
220
+ Next token indices, in range 0 <= token < n_vocab.
221
+ state_in_address : int
222
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
223
+ state_out_address : int
224
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
225
+ logits_out_address : int
226
+ Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
227
+ """
228
+
229
+ if not self.library.rwkv_eval_sequence(
230
+ ctx.ptr,
231
+ ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
232
+ ctypes.c_size_t(len(tokens)),
233
+ ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
234
+ ctypes.cast(state_out_address, P_FLOAT),
235
+ ctypes.cast(logits_out_address, P_FLOAT)
236
+ ):
237
+ raise ValueError('rwkv_eval_sequence failed, check stderr')
238
+
239
+ def rwkv_eval_sequence_in_chunks(
240
+ self,
241
+ ctx: RWKVContext,
242
+ tokens: List[int],
243
+ chunk_size: int,
244
+ state_in_address: Optional[int],
245
+ state_out_address: int,
246
+ logits_out_address: int
247
+ ) -> None:
248
+ """
249
+ Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
250
+ This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
251
+ It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance.
252
+
253
+ Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
254
+ A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
255
+ and choose one that works the best in your use case.
256
+
257
+ Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
258
+ Throws an exception in case of any error. Error messages would be printed to stderr.
259
+
260
+ Parameters
261
+ ----------
262
+ ctx : RWKVContext
263
+ RWKV context obtained from rwkv_init_from_file.
264
+ tokens : List[int]
265
+ Next token indices, in range 0 <= token < n_vocab.
266
+ chunk_size : int
267
+ Size of each chunk in tokens, must be positive.
268
+ state_in_address : int
269
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
270
+ state_out_address : int
271
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
272
+ logits_out_address : int
273
+ Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
274
+ """
275
+
276
+ if not self.library.rwkv_eval_sequence_in_chunks(
277
+ ctx.ptr,
278
+ ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
279
+ ctypes.c_size_t(len(tokens)),
280
+ ctypes.c_size_t(chunk_size),
281
+ ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
282
+ ctypes.cast(state_out_address, P_FLOAT),
283
+ ctypes.cast(logits_out_address, P_FLOAT)
284
+ ):
285
+ raise ValueError('rwkv_eval_sequence_in_chunks failed, check stderr')
286
+
287
+ def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
288
+ """
289
+ Returns the number of tokens in the given model's vocabulary.
290
+ Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
291
+
292
+ Parameters
293
+ ----------
294
+ ctx : RWKVContext
295
+ RWKV context obtained from rwkv_init_from_file.
296
+ """
297
+
298
+ return self.library.rwkv_get_n_vocab(ctx.ptr)
299
+
300
+ def rwkv_get_n_embed(self, ctx: RWKVContext) -> int:
301
+ """
302
+ Returns the number of elements in the given model's embedding.
303
+ Useful for reading individual fields of a model's hidden state.
304
+
305
+ Parameters
306
+ ----------
307
+ ctx : RWKVContext
308
+ RWKV context obtained from rwkv_init_from_file.
309
+ """
310
+
311
+ return self.library.rwkv_get_n_embed(ctx.ptr)
312
+
313
+ def rwkv_get_n_layer(self, ctx: RWKVContext) -> int:
314
+ """
315
+ Returns the number of layers in the given model.
316
+ A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model.
317
+ Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`.
318
+ Useful for always offloading the entire model to GPU.
319
+
320
+ Parameters
321
+ ----------
322
+ ctx : RWKVContext
323
+ RWKV context obtained from rwkv_init_from_file.
324
+ """
325
+
326
+ return self.library.rwkv_get_n_layer(ctx.ptr)
327
+
328
+ def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int:
329
+ """
330
+ Returns count of FP32 elements in state buffer.
331
+
332
+ Parameters
333
+ ----------
334
+ ctx : RWKVContext
335
+ RWKV context obtained from rwkv_init_from_file.
336
+ """
337
+
338
+ return self.library.rwkv_get_state_buffer_element_count(ctx.ptr)
339
+
340
+ def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int:
341
+ """
342
+ Returns count of FP32 elements in logits buffer.
343
+
344
+ Parameters
345
+ ----------
346
+ ctx : RWKVContext
347
+ RWKV context obtained from rwkv_init_from_file.
348
+ """
349
+
350
+ return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr)
351
+
352
+ def rwkv_free(self, ctx: RWKVContext) -> None:
353
+ """
354
+ Frees all allocated memory and the context.
355
+
356
+ Parameters
357
+ ----------
358
+ ctx : RWKVContext
359
+ RWKV context obtained from rwkv_init_from_file.
360
+ """
361
+
362
+ self.library.rwkv_free(ctx.ptr)
363
+
364
+ ctx.ptr = self.nullptr
365
+
366
+ def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None:
367
+ """
368
+ Quantizes FP32 or FP16 model to one of INT4 formats.
369
+ Throws an exception in case of any error. Error messages would be printed to stderr.
370
+
371
+ Parameters
372
+ ----------
373
+ model_file_path_in : str
374
+ Path to model file in ggml format, must be either FP32 or FP16.
375
+ model_file_path_out : str
376
+ Quantized model will be written here.
377
+ format_name : str
378
+ One of QUANTIZED_FORMAT_NAMES.
379
+ """
380
+
381
+ if format_name not in QUANTIZED_FORMAT_NAMES:
382
+ raise ValueError(f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}')
383
+
384
+ if not self.library.rwkv_quantize_model_file(
385
+ model_file_path_in.encode('utf-8'),
386
+ model_file_path_out.encode('utf-8'),
387
+ format_name.encode('utf-8')
388
+ ):
389
+ raise ValueError('rwkv_quantize_model_file failed, check stderr')
390
+
391
+ def rwkv_get_system_info_string(self) -> str:
392
+ """
393
+ Returns system information string.
394
+ """
395
+
396
+ return self.library.rwkv_get_system_info_string().decode('utf-8')
397
+
398
+ def load_rwkv_shared_library() -> RWKVSharedLibrary:
399
+ """
400
+ Attempts to find rwkv.cpp shared library and load it.
401
+ To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly.
402
+ """
403
+
404
+ file_name: str
405
+
406
+ if 'win32' in sys.platform or 'cygwin' in sys.platform:
407
+ file_name = 'rwkv.dll'
408
+ elif 'darwin' in sys.platform:
409
+ file_name = 'librwkv.dylib'
410
+ else:
411
+ file_name = 'librwkv.so'
412
+
413
+ # Possible sub-paths to the library relative to the repo dir.
414
+ child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [
415
+ # No lookup for Debug config here.
416
+ # I assume that if a user wants to debug the library,
417
+ # they will be able to find the library and set the exact path explicitly.
418
+ lambda p: p / 'bin' / 'Release' / file_name,
419
+ lambda p: p / 'bin' / file_name,
420
+ # Some people prefer to build in the "build" subdirectory.
421
+ lambda p: p / 'build' / 'bin' / 'Release' / file_name,
422
+ lambda p: p / 'build' / 'bin' / file_name,
423
+ lambda p: p / 'build' / file_name,
424
+ # Fallback.
425
+ lambda p: p / file_name
426
+ ]
427
+
428
+ working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd()))
429
+
430
+ parent_paths: List[pathlib.Path] = [
431
+ # Possible repo dirs relative to the working dir.
432
+ # ./python/rwkv_cpp
433
+ working_dir.parent.parent,
434
+ # ./python
435
+ working_dir.parent,
436
+ # .
437
+ working_dir,
438
+ # Repo dir relative to this Python file.
439
+ pathlib.Path(os.path.abspath(__file__)).parent.parent.parent
440
+ ]
441
+
442
+ for parent_path in parent_paths:
443
+ for child_path in child_paths:
444
+ full_path: pathlib.Path = child_path(parent_path)
445
+
446
+ if os.path.isfile(full_path):
447
+ return RWKVSharedLibrary(str(full_path))
448
+
449
+ raise ValueError(f'Failed to find {file_name} automatically; '
450
+ f'you need to find the library and create RWKVSharedLibrary specifying the path to it')
rwkv_world_tokenizer (2).py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ from typing import List, Set, Tuple, Callable
4
+
5
+ # Taken from https://github.com/BlinkDL/ChatRWKV/tree/main/tokenizer/rwkv_tokenizer.py
6
+
7
+ class Trie:
8
+ __slots__ = ('ch', 'to', 'values', 'front')
9
+
10
+ def __init__(self, front=None, ch=None) -> None:
11
+ self.ch = ch
12
+ self.to: List = [None for _ in range(256)]
13
+ self.values: Set = set()
14
+ self.front = front
15
+
16
+ def add(self, key: bytes, idx: int = 0, val=None) -> 'Trie':
17
+ if idx == len(key):
18
+ if val is None:
19
+ val = key
20
+
21
+ self.values.add(val)
22
+
23
+ return self
24
+
25
+ ch = key[idx]
26
+
27
+ if self.to[ch] is None:
28
+ self.to[ch] = Trie(front=self, ch=ch)
29
+
30
+ return self.to[ch].add(key, idx=idx + 1, val=val)
31
+
32
+ def find_longest(self, key: bytes, idx: int = 0) -> Tuple[int, 'Trie', set]:
33
+ u: Trie = self
34
+ ch: int = key[idx]
35
+ ret = None
36
+
37
+ while u.to[ch] is not None:
38
+ u = u.to[ch]
39
+ idx += 1
40
+
41
+ if u.values:
42
+ ret = idx, u, u.values
43
+
44
+ if idx == len(key):
45
+ break
46
+
47
+ ch = key[idx]
48
+
49
+ if ret is None:
50
+ raise ValueError('Entry not found')
51
+
52
+ return ret
53
+
54
+ def __repr__(self) -> str:
55
+ fr = self
56
+ ret = []
57
+
58
+ while fr is not None:
59
+ if fr.ch is not None:
60
+ ret.append(fr.ch)
61
+
62
+ fr = fr.front
63
+
64
+ return '<TRIE %s %s>' % (ret[::-1], self.values)
65
+
66
+ class WorldTokenizer:
67
+
68
+ def __init__(self, file_path) -> None:
69
+ self.index_to_token = {}
70
+
71
+ with open(file_path, 'r', encoding='utf-8') as f:
72
+ lines = f.readlines()
73
+
74
+ for line in lines:
75
+ idx = int(line[:line.index(' ')])
76
+ x = eval(line[line.index(' '):line.rindex(' ')])
77
+ x = x.encode('utf-8') if isinstance(x, str) else x
78
+ assert isinstance(x, bytes)
79
+ assert len(x) == int(line[line.rindex(' '):])
80
+ self.index_to_token[idx] = x
81
+
82
+ self.token_to_index = {}
83
+
84
+ for k, v in self.index_to_token.items():
85
+ self.token_to_index[v] = int(k)
86
+
87
+ self.root = Trie()
88
+
89
+ for t, i in self.token_to_index.items():
90
+ _ = self.root.add(t, val=(t, i))
91
+
92
+ def encode_bytes(self, src: bytes) -> List[int]:
93
+ idx: int = 0
94
+ tokens: List[int] = []
95
+
96
+ while idx < len(src):
97
+ _idx: int = idx
98
+ idx, _, values = self.root.find_longest(src, idx)
99
+ assert (idx != _idx)
100
+ _, token = next(iter(values))
101
+ tokens.append(token)
102
+
103
+ return tokens
104
+
105
+ def decode_bytes(self, tokens: List[int]) -> bytes:
106
+ return b''.join(map(lambda i: self.index_to_token[i], tokens))
107
+
108
+ def encode(self, src: str) -> List[int]:
109
+ return self.encode_bytes(src.encode('utf-8'))
110
+
111
+ def decode(self, tokens: List[int]) -> str:
112
+ # 'replace' error handling mode will insert \uFFFD characters in place of malformed/partial UTF-8 sequences.
113
+ # Downstream code needs to detect \uFFFD and attempt to postpone decoding until more tokens arrive and UTF-8 sequences are complete.
114
+ return self.decode_bytes(tokens).decode('utf-8', errors='replace')
115
+
116
+ def get_world_tokenizer_v20230424() -> Tuple[
117
+ Callable[[List[int]], str],
118
+ Callable[[str], List[int]]
119
+ ]:
120
+ """
121
+ Loads the default World tokenizer, commonly used in RWKV v4 World models.
122
+ Returns a tuple of `decode` and `encode` functions.
123
+ """
124
+ parent: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
125
+ tokenizer: WorldTokenizer = WorldTokenizer(parent / 'rwkv_vocab_v20230424.txt')
126
+ return tokenizer.decode, tokenizer.encode
sampling (2).py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Dict
3
+
4
+ # https://stackoverflow.com/a/50425683
5
+ def softmax(x: np.ndarray, axis: int):
6
+ x -= x.max(axis=axis, keepdims=True)
7
+ e: np.ndarray = np.exp(x)
8
+ return e / e.sum(axis=axis, keepdims=True)
9
+
10
+ def sample_logits(out, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int:
11
+ if hasattr(out, '__module__') and out.__module__ == 'torch':
12
+ out = out.cpu().numpy()
13
+
14
+ probs: np.ndarray = softmax(out, axis=-1)
15
+
16
+ return sample_probs(probs, temperature, top_p, logit_bias)
17
+
18
+ def sample_probs(probs: np.ndarray, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int:
19
+ if not (0.0 <= temperature):
20
+ raise ValueError('temperature')
21
+ if not (0.0 <= top_p <= 1.0):
22
+ raise ValueError('top_p')
23
+
24
+ if top_p == 0.0:
25
+ top_p = 1.0
26
+
27
+ if logit_bias is not None and len(logit_bias) > 0:
28
+ logits: np.ndarray = np.log(probs)
29
+
30
+ ids, values = zip(*logit_bias.items())
31
+ logits[list(ids)] += values
32
+
33
+ # Makes calculation more numerically stable, does not change the result
34
+ logits -= logits.max(axis=-1, keepdims=True)
35
+
36
+ probs = np.exp(logits) / np.sum(np.exp(logits))
37
+
38
+ if temperature == 0.0:
39
+ return np.argmax(probs).item()
40
+
41
+ if top_p < 1.0:
42
+ sorted_probs = np.sort(probs)[::-1]
43
+ cumulative_probs = np.cumsum(sorted_probs)
44
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
45
+ probs[probs < cutoff] = 0
46
+
47
+ if temperature != 1.0:
48
+ probs = np.power(probs, 1.0 / temperature)
49
+
50
+ probs = probs / np.sum(probs)
51
+
52
+ return np.random.choice(a=len(probs), p=probs)
tokenizer_util (2).py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import rwkv_world_tokenizer
4
+ from typing import List, Tuple, Callable
5
+
6
+ def add_tokenizer_argument(parser) -> None:
7
+ parser.add_argument(
8
+ 'tokenizer',
9
+ help='Tokenizer to use; supported tokenizers: auto (guess from n_vocab), 20B, world',
10
+ nargs='?',
11
+ type=str,
12
+ default='auto'
13
+ )
14
+
15
+ def get_tokenizer(tokenizer_name: str, n_vocab: int) -> Tuple[
16
+ Callable[[List[int]], str],
17
+ Callable[[str], List[int]]
18
+ ]:
19
+ if tokenizer_name == 'auto':
20
+ if n_vocab == 50277:
21
+ tokenizer_name = '20B'
22
+ elif n_vocab == 65536:
23
+ tokenizer_name = 'world'
24
+ else:
25
+ raise ValueError(f'Can not guess the tokenizer from n_vocab value of {n_vocab}')
26
+
27
+ parent: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
28
+
29
+ if tokenizer_name == 'world':
30
+ print('Loading World v20230424 tokenizer')
31
+ return rwkv_world_tokenizer.get_world_tokenizer_v20230424()
32
+ elif tokenizer_name == '20B':
33
+ print('Loading 20B tokenizer')
34
+ import tokenizers
35
+ tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(parent / '20B_tokenizer.json'))
36
+ return tokenizer.decode, lambda x: tokenizer.encode(x).ids
37
+ else:
38
+ raise ValueError(f'Unknown tokenizer {tokenizer_name}')