Roger Condori commited on
Commit
17149e5
1 Parent(s): 0c04d6f

upload backend

Browse files
conversadocs/bones.py CHANGED
@@ -7,7 +7,7 @@ from langchain.memory import ConversationBufferMemory
7
  from langchain.chat_models import ChatOpenAI
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain import HuggingFaceHub
10
- from langchain.llms import LlamaCpp
11
  from huggingface_hub import hf_hub_download
12
  import param
13
  import os
@@ -23,37 +23,28 @@ from langchain.document_loaders import (
23
  UnstructuredWordDocumentLoader,
24
  PyPDFLoader,
25
  )
 
 
 
26
 
27
  #YOUR_HF_TOKEN = os.getenv("My_hf_token")
28
- llm_api=HuggingFaceHub(
29
- huggingfacehub_api_token=os.getenv("My_hf_token"),
30
- repo_id="tiiuae/falcon-7b-instruct",
31
- model_kwargs={
32
- "temperature":0.2,
33
- "max_new_tokens":500,
34
- "top_k":50,
35
- "top_p":0.95,
36
- "repetition_penalty":1.2,
37
- },), #ChatOpenAI(model_name=llm_name, temperature=0)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  #alter
41
  def load_db(files):
42
- EXTENSIONS = {
43
- ".txt": (TextLoader, {"encoding": "utf8"}),
44
- ".pdf": (PyPDFLoader, {}),
45
- ".doc": (UnstructuredWordDocumentLoader, {}),
46
- ".docx": (UnstructuredWordDocumentLoader, {}),
47
- ".enex": (EverNoteLoader, {}),
48
- ".epub": (UnstructuredEPubLoader, {}),
49
- ".html": (UnstructuredHTMLLoader, {}),
50
- ".md": (UnstructuredMarkdownLoader, {}),
51
- ".odt": (UnstructuredODTLoader, {}),
52
- ".ppt": (UnstructuredPowerPointLoader, {}),
53
- ".pptx": (UnstructuredPowerPointLoader, {}),
54
- }
55
-
56
-
57
 
58
  # select extensions loader
59
  documents = []
@@ -102,14 +93,14 @@ class DocChat(param.Parameterized):
102
  answer = param.String("")
103
  db_query = param.String("")
104
  db_response = param.List([])
105
- llm = llm_api[0]
106
  k_value = param.Integer(3)
107
-
108
 
109
  def __init__(self, **params):
110
  super(DocChat, self).__init__( **params)
111
- self.loaded_file = "demo_docs/demo.txt"
112
  self.db = load_db(self.loaded_file)
 
113
  self.qa = q_a(self.db, "stuff", self.k_value, self.llm)
114
 
115
 
@@ -141,7 +132,8 @@ class DocChat(param.Parameterized):
141
  try:
142
  result = self.qa({"question": query, "chat_history": self.chat_history})
143
  except:
144
- self.default_falcon_model()
 
145
  self.qa = q_a(self.db, "stuff", k_max, self.llm)
146
  result = self.qa({"question": query, "chat_history": self.chat_history})
147
 
@@ -151,17 +143,48 @@ class DocChat(param.Parameterized):
151
  self.answer = result['answer']
152
  return self.answer
153
 
154
- def change_llm(self, repo_, file_, max_tokens=16, temperature=0.2, top_p=0.95, top_k=50, repeat_penalty=1.2, k=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  if torch.cuda.is_available():
157
  try:
158
  model_path = hf_hub_download(repo_id=repo_, filename=file_)
 
 
 
 
 
 
159
 
160
  self.llm = LlamaCpp(
161
  model_path=model_path,
162
- n_ctx=1000,
163
  n_batch=512,
164
- n_gpu_layers=35,
165
  max_tokens=max_tokens,
166
  verbose=False,
167
  temperature=temperature,
@@ -173,14 +196,20 @@ class DocChat(param.Parameterized):
173
  self.k_value = k
174
  return f"Loaded {file_} [GPU INFERENCE]"
175
  except:
176
- return "No valid model"
 
177
  else:
178
  try:
179
  model_path = hf_hub_download(repo_id=repo_, filename=file_)
 
 
 
 
 
180
 
181
  self.llm = LlamaCpp(
182
  model_path=model_path,
183
- n_ctx=1000,
184
  n_batch=8,
185
  max_tokens=max_tokens,
186
  verbose=False,
@@ -193,10 +222,20 @@ class DocChat(param.Parameterized):
193
  self.k_value = k
194
  return f"Loaded {file_} [CPU INFERENCE SLOW]"
195
  except:
196
- return "No valid model"
 
197
 
198
- def default_falcon_model(self):
199
- self.llm = llm_api[0]
 
 
 
 
 
 
 
 
 
200
  self.qa = q_a(self.db, "stuff", self.k_value, self.llm)
201
  return "Loaded model Falcon 7B-instruct [API FAST INFERENCE]"
202
 
@@ -204,7 +243,7 @@ class DocChat(param.Parameterized):
204
  self.llm = ChatOpenAI(temperature=0, openai_api_key=API_KEY, model_name='gpt-3.5-turbo')
205
  self.qa = q_a(self.db, "stuff", self.k_value, self.llm)
206
  API_KEY = ""
207
- return "Loaded model OpenAI gpt-3.5-turbo [API FAST INFERENCE] | If there is no response from the API, Falcon 7B-instruct will be used."
208
 
209
  @param.depends('db_query ', )
210
  def get_lquest(self):
 
7
  from langchain.chat_models import ChatOpenAI
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain import HuggingFaceHub
10
+ from conversadocs.llamacppmodels import LlamaCpp #from langchain.llms import LlamaCpp
11
  from huggingface_hub import hf_hub_download
12
  import param
13
  import os
 
23
  UnstructuredWordDocumentLoader,
24
  PyPDFLoader,
25
  )
26
+ import gc
27
+ gc.collect()
28
+ torch.cuda.empty_cache()
29
 
30
  #YOUR_HF_TOKEN = os.getenv("My_hf_token")
 
 
 
 
 
 
 
 
 
 
31
 
32
+ EXTENSIONS = {
33
+ ".txt": (TextLoader, {"encoding": "utf8"}),
34
+ ".pdf": (PyPDFLoader, {}),
35
+ ".doc": (UnstructuredWordDocumentLoader, {}),
36
+ ".docx": (UnstructuredWordDocumentLoader, {}),
37
+ ".enex": (EverNoteLoader, {}),
38
+ ".epub": (UnstructuredEPubLoader, {}),
39
+ ".html": (UnstructuredHTMLLoader, {}),
40
+ ".md": (UnstructuredMarkdownLoader, {}),
41
+ ".odt": (UnstructuredODTLoader, {}),
42
+ ".ppt": (UnstructuredPowerPointLoader, {}),
43
+ ".pptx": (UnstructuredPowerPointLoader, {}),
44
+ }
45
 
46
  #alter
47
  def load_db(files):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # select extensions loader
50
  documents = []
 
93
  answer = param.String("")
94
  db_query = param.String("")
95
  db_response = param.List([])
 
96
  k_value = param.Integer(3)
97
+ llm = None
98
 
99
  def __init__(self, **params):
100
  super(DocChat, self).__init__( **params)
101
+ self.loaded_file = ["demo_docs/demo.txt"]
102
  self.db = load_db(self.loaded_file)
103
+ self.change_llm("TheBloke/Llama-2-7B-Chat-GGML", "llama-2-7b-chat.ggmlv3.q5_1.bin", max_tokens=256, temperature=0.2, top_p=0.95, top_k=50, repeat_penalty=1.2, k=3)
104
  self.qa = q_a(self.db, "stuff", self.k_value, self.llm)
105
 
106
 
 
132
  try:
133
  result = self.qa({"question": query, "chat_history": self.chat_history})
134
  except:
135
+ print("Error not get response from model, reloaded default llama-2 7B config")
136
+ self.change_llm("TheBloke/Llama-2-7B-Chat-GGML", "llama-2-7b-chat.ggmlv3.q5_1.bin", max_tokens=256, temperature=0.2, top_p=0.95, top_k=50, repeat_penalty=1.2, k=3)
137
  self.qa = q_a(self.db, "stuff", k_max, self.llm)
138
  result = self.qa({"question": query, "chat_history": self.chat_history})
139
 
 
143
  self.answer = result['answer']
144
  return self.answer
145
 
146
+ def summarize(self, chunk_size=2000, chunk_overlap=100):
147
+ # load docs
148
+ documents = []
149
+ for file in self.loaded_file:
150
+ ext = "." + file.rsplit(".", 1)[-1]
151
+ if ext in EXTENSIONS:
152
+ loader_class, loader_args = EXTENSIONS[ext]
153
+ loader = loader_class(file, **loader_args)
154
+ documents.extend(loader.load_and_split())
155
+
156
+ if documents == []:
157
+ return "Error in summarization"
158
+
159
+ # split documents
160
+ text_splitter = RecursiveCharacterTextSplitter(
161
+ chunk_size=chunk_size,
162
+ chunk_overlap=chunk_overlap,
163
+ separators=["\n\n", "\n", "(?<=\. )", " ", ""]
164
+ )
165
+ docs = text_splitter.split_documents(documents)
166
+ # summarize
167
+ from langchain.chains.summarize import load_summarize_chain
168
+ chain = load_summarize_chain(self.llm, chain_type='map_reduce', verbose=True)
169
+ return chain.run(docs)
170
+
171
+ def change_llm(self, repo_, file_, max_tokens=256, temperature=0.2, top_p=0.95, top_k=50, repeat_penalty=1.2, k=3):
172
 
173
  if torch.cuda.is_available():
174
  try:
175
  model_path = hf_hub_download(repo_id=repo_, filename=file_)
176
+
177
+ self.qa = None
178
+ self.llm = None
179
+ gc.collect()
180
+ torch.cuda.empty_cache()
181
+ gpu_llm_layers = 35 if not '70B' in repo_.upper() else 25 # fix for 70B
182
 
183
  self.llm = LlamaCpp(
184
  model_path=model_path,
185
+ n_ctx=4096,
186
  n_batch=512,
187
+ n_gpu_layers=gpu_llm_layers,
188
  max_tokens=max_tokens,
189
  verbose=False,
190
  temperature=temperature,
 
196
  self.k_value = k
197
  return f"Loaded {file_} [GPU INFERENCE]"
198
  except:
199
+ self.change_llm("TheBloke/Llama-2-7B-Chat-GGML", "llama-2-7b-chat.ggmlv3.q5_1.bin", max_tokens=256, temperature=0.2, top_p=0.95, top_k=50, repeat_penalty=1.2, k=3)
200
+ return "No valid model | Reloaded Reloaded default llama-2 7B config"
201
  else:
202
  try:
203
  model_path = hf_hub_download(repo_id=repo_, filename=file_)
204
+
205
+ self.qa = None
206
+ self.llm = None
207
+ gc.collect()
208
+ torch.cuda.empty_cache()
209
 
210
  self.llm = LlamaCpp(
211
  model_path=model_path,
212
+ n_ctx=2048,
213
  n_batch=8,
214
  max_tokens=max_tokens,
215
  verbose=False,
 
222
  self.k_value = k
223
  return f"Loaded {file_} [CPU INFERENCE SLOW]"
224
  except:
225
+ self.change_llm("TheBloke/Llama-2-7B-Chat-GGML", "llama-2-7b-chat.ggmlv3.q5_1.bin", max_tokens=256, temperature=0.2, top_p=0.95, top_k=50, repeat_penalty=1.2, k=3)
226
+ return "No valid model | Reloaded default llama-2 7B config"
227
 
228
+ def default_falcon_model(self, HF_TOKEN):
229
+ self.llm = llm_api=HuggingFaceHub(
230
+ huggingfacehub_api_token=HF_TOKEN,
231
+ repo_id="tiiuae/falcon-7b-instruct",
232
+ model_kwargs={
233
+ "temperature":0.2,
234
+ "max_new_tokens":500,
235
+ "top_k":50,
236
+ "top_p":0.95,
237
+ "repetition_penalty":1.2,
238
+ },)
239
  self.qa = q_a(self.db, "stuff", self.k_value, self.llm)
240
  return "Loaded model Falcon 7B-instruct [API FAST INFERENCE]"
241
 
 
243
  self.llm = ChatOpenAI(temperature=0, openai_api_key=API_KEY, model_name='gpt-3.5-turbo')
244
  self.qa = q_a(self.db, "stuff", self.k_value, self.llm)
245
  API_KEY = ""
246
+ return "Loaded model OpenAI gpt-3.5-turbo [API FAST INFERENCE]"
247
 
248
  @param.depends('db_query ', )
249
  def get_lquest(self):
conversadocs/llamacppmodels.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, Iterator, List, Optional
3
+
4
+ from pydantic import Field, root_validator
5
+
6
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
7
+ from langchain.llms.base import LLM
8
+ from langchain.schema.output import GenerationChunk
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class LlamaCpp(LLM):
14
+ """llama.cpp model.
15
+
16
+ To use, you should have the llama-cpp-python library installed, and provide the
17
+ path to the Llama model as a named parameter to the constructor.
18
+ Check out: https://github.com/abetlen/llama-cpp-python
19
+
20
+ Example:
21
+ .. code-block:: python
22
+
23
+ from langchain.llms import LlamaCpp
24
+ llm = LlamaCpp(model_path="/path/to/llama/model")
25
+ """
26
+
27
+ client: Any #: :meta private:
28
+ model_path: str
29
+ """The path to the Llama model file."""
30
+
31
+ lora_base: Optional[str] = None
32
+ """The path to the Llama LoRA base model."""
33
+
34
+ lora_path: Optional[str] = None
35
+ """The path to the Llama LoRA. If None, no LoRa is loaded."""
36
+
37
+ n_ctx: int = Field(512, alias="n_ctx")
38
+ """Token context window."""
39
+
40
+ n_parts: int = Field(-1, alias="n_parts")
41
+ """Number of parts to split the model into.
42
+ If -1, the number of parts is automatically determined."""
43
+
44
+ seed: int = Field(-1, alias="seed")
45
+ """Seed. If -1, a random seed is used."""
46
+
47
+ f16_kv: bool = Field(True, alias="f16_kv")
48
+ """Use half-precision for key/value cache."""
49
+
50
+ logits_all: bool = Field(False, alias="logits_all")
51
+ """Return logits for all tokens, not just the last token."""
52
+
53
+ vocab_only: bool = Field(False, alias="vocab_only")
54
+ """Only load the vocabulary, no weights."""
55
+
56
+ use_mlock: bool = Field(False, alias="use_mlock")
57
+ """Force system to keep model in RAM."""
58
+
59
+ n_threads: Optional[int] = Field(None, alias="n_threads")
60
+ """Number of threads to use.
61
+ If None, the number of threads is automatically determined."""
62
+
63
+ n_batch: Optional[int] = Field(8, alias="n_batch")
64
+ """Number of tokens to process in parallel.
65
+ Should be a number between 1 and n_ctx."""
66
+
67
+ n_gpu_layers: Optional[int] = Field(None, alias="n_gpu_layers")
68
+ """Number of layers to be loaded into gpu memory. Default None."""
69
+
70
+ suffix: Optional[str] = Field(None)
71
+ """A suffix to append to the generated text. If None, no suffix is appended."""
72
+
73
+ max_tokens: Optional[int] = 256
74
+ """The maximum number of tokens to generate."""
75
+
76
+ temperature: Optional[float] = 0.8
77
+ """The temperature to use for sampling."""
78
+
79
+ top_p: Optional[float] = 0.95
80
+ """The top-p value to use for sampling."""
81
+
82
+ logprobs: Optional[int] = Field(None)
83
+ """The number of logprobs to return. If None, no logprobs are returned."""
84
+
85
+ echo: Optional[bool] = False
86
+ """Whether to echo the prompt."""
87
+
88
+ stop: Optional[List[str]] = []
89
+ """A list of strings to stop generation when encountered."""
90
+
91
+ repeat_penalty: Optional[float] = 1.1
92
+ """The penalty to apply to repeated tokens."""
93
+
94
+ top_k: Optional[int] = 40
95
+ """The top-k value to use for sampling."""
96
+
97
+ last_n_tokens_size: Optional[int] = 64
98
+ """The number of tokens to look back when applying the repeat_penalty."""
99
+
100
+ use_mmap: Optional[bool] = True
101
+ """Whether to keep the model loaded in RAM"""
102
+
103
+ rope_freq_scale: float = 1.0
104
+ """Scale factor for rope sampling."""
105
+
106
+ rope_freq_base: float = 10000.0
107
+ """Base frequency for rope sampling."""
108
+
109
+ streaming: bool = True
110
+ """Whether to stream the results, token by token."""
111
+
112
+ verbose: bool = True
113
+ """Print verbose output to stderr."""
114
+
115
+ n_gqa: Optional[int] = None
116
+
117
+ @root_validator()
118
+ def validate_environment(cls, values: Dict) -> Dict:
119
+ """Validate that llama-cpp-python library is installed."""
120
+
121
+
122
+ model_path = values["model_path"]
123
+ model_param_names = [
124
+ "n_gqa",
125
+ "rope_freq_scale",
126
+ "rope_freq_base",
127
+ "lora_path",
128
+ "lora_base",
129
+ "n_ctx",
130
+ "n_parts",
131
+ "seed",
132
+ "f16_kv",
133
+ "logits_all",
134
+ "vocab_only",
135
+ "use_mlock",
136
+ "n_threads",
137
+ "n_batch",
138
+ "use_mmap",
139
+ "last_n_tokens_size",
140
+ "verbose",
141
+ ]
142
+ model_params = {k: values[k] for k in model_param_names}
143
+
144
+ model_params['n_gqa'] = 8 if '70B' in model_path.upper() else None # (TEMPORARY) must be 8 for llama2 70b
145
+ # For backwards compatibility, only include if non-null.
146
+ if values["n_gpu_layers"] is not None:
147
+ model_params["n_gpu_layers"] = values["n_gpu_layers"]
148
+
149
+ try:
150
+ from llama_cpp import Llama
151
+
152
+ values["client"] = Llama(model_path, **model_params)
153
+ except ImportError:
154
+ raise ImportError(
155
+ "Could not import llama-cpp-python library. "
156
+ "Please install the llama-cpp-python library to "
157
+ "use this embedding model: pip install llama-cpp-python"
158
+ )
159
+ except Exception as e:
160
+ raise ValueError(
161
+ f"Could not load Llama model from path: {model_path}. "
162
+ f"Received error {e}"
163
+ )
164
+
165
+ return values
166
+
167
+ @property
168
+ def _default_params(self) -> Dict[str, Any]:
169
+ """Get the default parameters for calling llama_cpp."""
170
+ return {
171
+ "suffix": self.suffix,
172
+ "max_tokens": self.max_tokens,
173
+ "temperature": self.temperature,
174
+ "top_p": self.top_p,
175
+ "logprobs": self.logprobs,
176
+ "echo": self.echo,
177
+ "stop_sequences": self.stop, # key here is convention among LLM classes
178
+ "repeat_penalty": self.repeat_penalty,
179
+ "top_k": self.top_k,
180
+ }
181
+
182
+ @property
183
+ def _identifying_params(self) -> Dict[str, Any]:
184
+ """Get the identifying parameters."""
185
+ return {**{"model_path": self.model_path}, **self._default_params}
186
+
187
+ @property
188
+ def _llm_type(self) -> str:
189
+ """Return type of llm."""
190
+ return "llamacpp"
191
+
192
+ def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
193
+ """
194
+ Performs sanity check, preparing parameters in format needed by llama_cpp.
195
+
196
+ Args:
197
+ stop (Optional[List[str]]): List of stop sequences for llama_cpp.
198
+
199
+ Returns:
200
+ Dictionary containing the combined parameters.
201
+ """
202
+
203
+ # Raise error if stop sequences are in both input and default params
204
+ if self.stop and stop is not None:
205
+ raise ValueError("`stop` found in both the input and default params.")
206
+
207
+ params = self._default_params
208
+
209
+ # llama_cpp expects the "stop" key not this, so we remove it:
210
+ params.pop("stop_sequences")
211
+
212
+ # then sets it as configured, or default to an empty list:
213
+ params["stop"] = self.stop or stop or []
214
+
215
+ return params
216
+
217
+ def _call(
218
+ self,
219
+ prompt: str,
220
+ stop: Optional[List[str]] = None,
221
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
222
+ **kwargs: Any,
223
+ ) -> str:
224
+ """Call the Llama model and return the output.
225
+
226
+ Args:
227
+ prompt: The prompt to use for generation.
228
+ stop: A list of strings to stop generation when encountered.
229
+
230
+ Returns:
231
+ The generated text.
232
+
233
+ Example:
234
+ .. code-block:: python
235
+
236
+ from langchain.llms import LlamaCpp
237
+ llm = LlamaCpp(model_path="/path/to/local/llama/model.bin")
238
+ llm("This is a prompt.")
239
+ """
240
+ if self.streaming:
241
+ # If streaming is enabled, we use the stream
242
+ # method that yields as they are generated
243
+ # and return the combined strings from the first choices's text:
244
+ combined_text_output = ""
245
+ for chunk in self._stream(
246
+ prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
247
+ ):
248
+ combined_text_output += chunk.text
249
+ return combined_text_output
250
+ else:
251
+ params = self._get_parameters(stop)
252
+ params = {**params, **kwargs}
253
+ result = self.client(prompt=prompt, **params)
254
+ return result["choices"][0]["text"]
255
+
256
+ def _stream(
257
+ self,
258
+ prompt: str,
259
+ stop: Optional[List[str]] = None,
260
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
261
+ **kwargs: Any,
262
+ ) -> Iterator[GenerationChunk]:
263
+ """Yields results objects as they are generated in real time.
264
+
265
+ It also calls the callback manager's on_llm_new_token event with
266
+ similar parameters to the OpenAI LLM class method of the same name.
267
+
268
+ Args:
269
+ prompt: The prompts to pass into the model.
270
+ stop: Optional list of stop words to use when generating.
271
+
272
+ Returns:
273
+ A generator representing the stream of tokens being generated.
274
+
275
+ Yields:
276
+ A dictionary like objects containing a string token and metadata.
277
+ See llama-cpp-python docs and below for more.
278
+
279
+ Example:
280
+ .. code-block:: python
281
+
282
+ from langchain.llms import LlamaCpp
283
+ llm = LlamaCpp(
284
+ model_path="/path/to/local/model.bin",
285
+ temperature = 0.5
286
+ )
287
+ for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
288
+ stop=["'","\n"]):
289
+ result = chunk["choices"][0]
290
+ print(result["text"], end='', flush=True)
291
+
292
+ """
293
+ params = {**self._get_parameters(stop), **kwargs}
294
+ result = self.client(prompt=prompt, stream=True, **params)
295
+ for part in result:
296
+ logprobs = part["choices"][0].get("logprobs", None)
297
+ chunk = GenerationChunk(
298
+ text=part["choices"][0]["text"],
299
+ generation_info={"logprobs": logprobs},
300
+ )
301
+ yield chunk
302
+ if run_manager:
303
+ run_manager.on_llm_new_token(
304
+ token=chunk.text, verbose=self.verbose, log_probs=logprobs
305
+ )
306
+
307
+ def get_num_tokens(self, text: str) -> int:
308
+ tokenized_text = self.client.tokenize(text.encode("utf-8"))
309
+ return len(tokenized_text)
conversadocs/llm_chess.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import chess
3
+ import chess.pgn
4
+
5
+ # credit code https://github.com/notnil/chess-gpt
6
+ def get_legal_moves(board):
7
+ """Returns a list of legal moves in UCI notation."""
8
+ return list(map(board.san, board.legal_moves))
9
+
10
+ def init_game() -> tuple[chess.pgn.Game, chess.Board]:
11
+ """Initializes a new game."""
12
+ board = chess.Board()
13
+ game = chess.pgn.Game()
14
+ game.headers["White"] = "User"
15
+ game.headers["Black"] = "Chess-engine"
16
+ del game.headers["Event"]
17
+ del game.headers["Date"]
18
+ del game.headers["Site"]
19
+ del game.headers["Round"]
20
+ del game.headers["Result"]
21
+ game.setup(board)
22
+ return game, board
23
+
24
+ def generate_prompt(game: chess.pgn.Game, board: chess.Board) -> str:
25
+
26
+ moves = get_legal_moves(board)
27
+ moves_str = ",".join(moves)
28
+ return f"""
29
+ The task is play a Chess game:
30
+ You are the Chess-engine playing a chess match against the user as black and trying to win.
31
+
32
+ The current FEN notation is:
33
+ {board.fen()}
34
+
35
+ The next valid moves are:
36
+ {moves_str}
37
+
38
+ Continue the game.
39
+ {str(game)[:-2]}"""
40
+
41
+ def get_move(content, moves):
42
+ lines = content.splitlines()
43
+ for line in lines:
44
+ for lm in moves:
45
+ if lm in line:
46
+ return lm
47
+
48
+ class ChessGame:
49
+ def __init__(self, docschatllm):
50
+ self.docschatllm = docschatllm
51
+
52
+ def start_game(self):
53
+ self.game, self.board = init_game()
54
+ self.game_cp, _ = init_game()
55
+ self.node = self.game
56
+ self.node_copy = self.game_cp
57
+
58
+ svg_board = chess.svg.board(self.board, size=350)
59
+ return svg_board, "Valid moves: "+",".join(get_legal_moves(self.board)) # display(self.board)
60
+
61
+ def user_move(self, move_input):
62
+ try:
63
+ self.board.push_san(move_input)
64
+ except ValueError:
65
+ print("Invalid move")
66
+ svg_board = chess.svg.board(self.board, size=350)
67
+ return svg_board, "Valid moves: "+",".join(get_legal_moves(self.board)), 'Invalid move'
68
+ self.node = self.node.add_variation(self.board.move_stack[-1])
69
+ self.node_copy = self.node_copy.add_variation(self.board.move_stack[-1])
70
+
71
+ if self.board.is_game_over():
72
+ svg_board = chess.svg.board(self.board, size=350)
73
+ return svg_board, ",".join(get_legal_moves(self.board)), 'GAME OVER'
74
+
75
+ prompt = generate_prompt(self.game, self.board)
76
+ print("Prompt: \n"+prompt)
77
+ print("#############")
78
+ for i in range(10): #tries
79
+ if i == 9:
80
+ svg_board = chess.svg.board(self.board, size=350)
81
+ return svg_board, ",".join(get_legal_moves(self.board)), "The model can't do a valid move"
82
+ try:
83
+ """Returns the move from the prompt."""
84
+ content = self.docschatllm.llm.predict(prompt) ### from selected model ###
85
+ #print(moves)
86
+ print("Response: \n"+content)
87
+ print("#############")
88
+
89
+ moves = get_legal_moves(self.board)
90
+ move = get_move(content, moves)
91
+ print(move)
92
+ print("#############")
93
+ self.board.push_san(move)
94
+ break
95
+ except:
96
+ prompt = prompt[1:]
97
+ print("attempt a move.")
98
+ self.node = self.node.add_variation(self.board.move_stack[-1])
99
+ self.node_copy = self.node_copy.add_variation(self.board.move_stack[-1])
100
+ svg_board = chess.svg.board(self.board, size=350)
101
+ return svg_board, "Valid moves: "+",".join(get_legal_moves(self.board)), ''