jbdel commited on
Commit
8c5c31d
1 Parent(s): 2886238

chat_with_paper_update

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +13 -10
  3. paper_chat_tab.py +240 -193
README.md CHANGED
@@ -5,7 +5,7 @@ emoji: ⚡
5
  colorFrom: red
6
  colorTo: purple
7
  sdk: gradio
8
- sdk_version: 5.6.0
9
  app_file: app.py
10
  pinned: false
11
  header: mini
 
5
  colorFrom: red
6
  colorTo: purple
7
  sdk: gradio
8
+ sdk_version: 5.8.0
9
  app_file: app.py
10
  pinned: false
11
  header: mini
app.py CHANGED
@@ -82,6 +82,12 @@ with gr.Blocks(css_paths="style.css") as demo:
82
  link="https://huggingface.co/datasets/huggingface/paper-central-data")
83
 
84
  with gr.Tabs() as tabs:
 
 
 
 
 
 
85
  with gr.Tab("Paper-central", id="tab-paper-central"):
86
  # Create a row for navigation buttons and calendar
87
  with gr.Row():
@@ -178,6 +184,8 @@ with gr.Blocks(css_paths="style.css") as demo:
178
  wrap=True,
179
  )
180
 
 
 
181
  with gr.Tab("Edit papers", id="tab-pr"):
182
  pr_paper_central_tab(paper_central_df.df_raw)
183
 
@@ -187,19 +195,13 @@ with gr.Blocks(css_paths="style.css") as demo:
187
  with gr.Tab("Contributors"):
188
  author_resource_leaderboard_tab()
189
 
190
- with gr.Tab("Chat With Paper", id="tab-chat-with-paper", visible=False) as tab_chat_paper:
191
- gr.Markdown("## Chat with Paper")
192
- arxiv_id = gr.State(value=None)
193
- paper_from = gr.State(value=None)
194
- paper_chat_tab(arxiv_id, paper_from)
195
-
196
 
197
  # chat with paper
198
  def get_selected(evt: gr.SelectData, dataframe_origin):
199
 
200
  paper_id = gr.update(value=None)
201
  paper_from = gr.update(value=None)
202
- tab_chat_paper = gr.update(visible=False)
203
  selected_tab = gr.Tabs()
204
 
205
  try:
@@ -516,7 +518,7 @@ with gr.Blocks(css_paths="style.css") as demo:
516
  selected_tab = gr.Tabs()
517
  paper_id = gr.update(value=None)
518
  paper_from = gr.update(value=None)
519
- tab_chat_paper = gr.update(visible=False)
520
 
521
  if request:
522
  # print("Request headers dictionary:", dict(request.headers))
@@ -568,7 +570,8 @@ with gr.Blocks(css_paths="style.css") as demo:
568
  api_name="update_data",
569
  ).then(
570
  fn=echo,
571
- outputs=[calendar, date_range_radio, conference_options, hf_options, tabs, arxiv_id, paper_from, tab_chat_paper],
 
572
  api_name=False,
573
  ).then(
574
  # New then to handle LoginButton and HTML components
@@ -583,7 +586,7 @@ def main():
583
  """
584
  Launches the Gradio app.
585
  """
586
- demo.launch(ssr_mode=False)
587
 
588
 
589
  # Run the main function when the script is executed
 
82
  link="https://huggingface.co/datasets/huggingface/paper-central-data")
83
 
84
  with gr.Tabs() as tabs:
85
+ with gr.Tab("Chat With Paper", id="tab-chat-with-paper", visible=True) as tab_chat_paper:
86
+ gr.Markdown("## Chat with Paper")
87
+ arxiv_id = gr.State(value=None)
88
+ paper_from = gr.State(value=None)
89
+ paper_chat_tab(arxiv_id, paper_from, paper_central_df)
90
+
91
  with gr.Tab("Paper-central", id="tab-paper-central"):
92
  # Create a row for navigation buttons and calendar
93
  with gr.Row():
 
184
  wrap=True,
185
  )
186
 
187
+
188
+
189
  with gr.Tab("Edit papers", id="tab-pr"):
190
  pr_paper_central_tab(paper_central_df.df_raw)
191
 
 
195
  with gr.Tab("Contributors"):
196
  author_resource_leaderboard_tab()
197
 
 
 
 
 
 
 
198
 
199
  # chat with paper
200
  def get_selected(evt: gr.SelectData, dataframe_origin):
201
 
202
  paper_id = gr.update(value=None)
203
  paper_from = gr.update(value=None)
204
+ tab_chat_paper = gr.update(visible=True)
205
  selected_tab = gr.Tabs()
206
 
207
  try:
 
518
  selected_tab = gr.Tabs()
519
  paper_id = gr.update(value=None)
520
  paper_from = gr.update(value=None)
521
+ tab_chat_paper = gr.update(visible=True)
522
 
523
  if request:
524
  # print("Request headers dictionary:", dict(request.headers))
 
570
  api_name="update_data",
571
  ).then(
572
  fn=echo,
573
+ outputs=[calendar, date_range_radio, conference_options, hf_options, tabs, arxiv_id, paper_from,
574
+ tab_chat_paper],
575
  api_name=False,
576
  ).then(
577
  # New then to handle LoginButton and HTML components
 
586
  """
587
  Launches the Gradio app.
588
  """
589
+ demo.launch(share=True)
590
 
591
 
592
  # Run the main function when the script is executed
paper_chat_tab.py CHANGED
@@ -7,9 +7,10 @@ import requests
7
  from io import BytesIO
8
  from transformers import AutoTokenizer
9
  import json
10
-
11
  import os
12
  from openai import OpenAI
 
13
 
14
  # Cache for tokenizers to avoid reloading
15
  tokenizer_cache = {}
@@ -23,7 +24,6 @@ PROVIDERS = {
23
  "api_key_env_var": "SAMBANOVA_API_KEY",
24
  "models": [
25
  "Meta-Llama-3.1-70B-Instruct",
26
- # Add more models if needed
27
  ],
28
  "type": "tuples",
29
  "max_total_tokens": "50000",
@@ -43,12 +43,12 @@ PROVIDERS = {
43
  }
44
 
45
 
46
- # Function to fetch paper information from OpenReview
47
  def fetch_paper_info_neurips(paper_id):
48
  url = f"https://openreview.net/forum?id={paper_id}"
49
  response = requests.get(url)
50
  if response.status_code != 200:
51
- return None
52
 
53
  html_content = response.content
54
  soup = BeautifulSoup(html_content, 'html.parser')
@@ -73,66 +73,104 @@ def fetch_paper_info_neurips(paper_id):
73
  else:
74
  abstract = 'Abstract not found'
75
 
76
- # Construct preamble in Markdown
77
- preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n"
 
78
 
79
- return preamble
80
 
81
-
82
- def fetch_paper_content_arxiv(paper_id):
83
  try:
84
- # Construct the URL for the arXiv PDF
85
- url = f"https://arxiv.org/pdf/{paper_id}.pdf"
86
-
87
- # Fetch the PDF
88
  response = requests.get(url)
89
- response.raise_for_status() # Raise an exception for HTTP errors
90
-
91
- # Read the PDF content
92
  pdf_content = BytesIO(response.content)
93
  reader = PdfReader(pdf_content)
94
-
95
- # Extract text from the PDF
96
  text = ""
97
  for page in reader.pages:
98
  text += page.extract_text()
99
-
100
- return text # Return full text; truncation will be handled later
101
- except Exception as e:
102
- print(f"Error fetching paper content: {e}")
103
  return None
104
 
105
 
106
- def fetch_paper_content(paper_id):
107
  try:
108
- # Construct the URL
109
- url = f"https://openreview.net/pdf?id={paper_id}"
110
-
111
- # Fetch the PDF
112
  response = requests.get(url)
113
- response.raise_for_status() # Raise an exception for HTTP errors
114
-
115
- # Read the PDF content
116
  pdf_content = BytesIO(response.content)
117
  reader = PdfReader(pdf_content)
118
-
119
- # Extract text from the PDF
120
  text = ""
121
  for page in reader.pages:
122
  text += page.extract_text()
123
-
124
- return text # Return full text; truncation will be handled later
125
-
126
  except Exception as e:
127
- print(f"An error occurred: {e}")
128
  return None
129
 
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type,
132
  provider_max_total_tokens):
133
  # Define the function to handle the chat
134
- print("the type is", default_type.value)
135
-
136
  def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value,
137
  max_total_tokens):
138
  provider_info = PROVIDERS[provider_name_value]
@@ -141,11 +179,9 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
141
  models = provider_info['models']
142
  max_total_tokens = int(max_total_tokens)
143
 
144
- # Load tokenizer and cache it
145
  tokenizer_key = f"{provider_name_value}_{model_name_value}"
146
  if tokenizer_key not in tokenizer_cache:
147
- # Load the tokenizer; adjust the model path based on the provider and model
148
- # This is a placeholder; you need to provide the correct tokenizer path
149
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
150
  token=os.environ.get("HF_TOKEN"))
151
  tokenizer_cache[tokenizer_key] = tokenizer
@@ -189,32 +225,28 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
189
 
190
  # Check if total tokens exceed the maximum allowed tokens
191
  if total_tokens > max_total_tokens:
192
- # Attempt to truncate the context first
193
  available_tokens = max_total_tokens - (total_tokens - context_token_length)
194
  if available_tokens > 0:
195
- # Truncate the context to fit the available tokens
196
  truncated_context_tokens = context_tokens[:available_tokens]
197
  context = tokenizer.decode(truncated_context_tokens)
198
  context_token_length = available_tokens
199
  total_tokens = total_tokens - len(context_tokens) + context_token_length
200
  else:
201
- # Not enough space for context; remove it
202
  context = ""
203
  total_tokens -= context_token_length
204
  context_token_length = 0
205
 
206
- # If total tokens still exceed the limit, truncate the message history
207
  while total_tokens > max_total_tokens and len(messages) > 1:
208
- # Remove the oldest message
209
  removed_message = messages.pop(0)
210
  removed_tokens = message_tokens_list.pop(0)
211
  total_tokens -= removed_tokens
212
 
213
- # Rebuild the final messages list including the (possibly truncated) context
214
  final_messages = []
215
  if context:
216
- final_messages.append(
217
- {"role": "system", "content": f"{context}"})
218
  final_messages.extend(messages)
219
 
220
  # Use the provider's API key
@@ -222,14 +254,13 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
222
  if not api_key:
223
  raise ValueError("API token is not provided.")
224
 
225
- # Initialize the OpenAI client with the provider's endpoint
226
  client = OpenAI(
227
  base_url=endpoint,
228
  api_key=api_key,
229
  )
230
 
231
  try:
232
- # Create the chat completion
233
  completion = client.chat.completions.create(
234
  model=model_name_value,
235
  messages=final_messages,
@@ -241,29 +272,13 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
241
  response_text += delta
242
  yield response_text
243
  except json.JSONDecodeError as e:
244
- print("Failed to decode JSON during the completion creation process.")
245
- print(f"Error Message: {e.msg}")
246
- print(f"Error Position: Line {e.lineno}, Column {e.colno} (Character {e.pos})")
247
- print(f"Problematic JSON Data: {e.doc}")
248
- yield f"{e.doc}"
249
  except openai.OpenAIError as openai_err:
250
- # Handle other OpenAI-related errors
251
- print(f"An OpenAI error occurred: {openai_err}")
252
- yield f"{openai_err}"
253
  except Exception as ex:
254
- # Handle any other exceptions
255
- print(f"An unexpected error occurred: {ex}")
256
- yield f"{ex}"
257
-
258
- # Create the Chatbot separately to access it later
259
- chatbot = gr.Chatbot(
260
- label="Chatbot",
261
- scale=1,
262
- height=400,
263
- autoscroll=True,
264
- )
265
 
266
- # Create the ChatInterface
267
  chat_interface = gr.ChatInterface(
268
  fn=get_fn,
269
  chatbot=chatbot,
@@ -273,142 +288,164 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
273
  return chat_interface, chatbot
274
 
275
 
276
- def paper_chat_tab(paper_id, paper_from):
277
- with gr.Column():
278
-
279
- # Preamble message to hint the user
280
- gr.Markdown("**Note:** Providing your own API token can help you avoid rate limits.")
 
281
 
282
- # Input for API token
283
- provider_names = list(PROVIDERS.keys())
284
- default_provider = provider_names[0]
 
 
 
 
285
 
286
- default_type = gr.State(value=PROVIDERS[default_provider]["type"])
287
- default_max_total_tokens = gr.State(value=PROVIDERS[default_provider]["max_total_tokens"])
288
 
289
- provider_dropdown = gr.Dropdown(
290
- label="Select Provider",
291
- choices=provider_names,
292
- value=default_provider
293
- )
294
 
295
- hf_token_input = gr.Textbox(
296
- label=f"Enter your {default_provider} API token (optional)",
297
- type="password",
298
- placeholder=f"Enter your {default_provider} API token to avoid rate limits"
299
- )
300
-
301
- # Dropdown for selecting the model
302
- model_dropdown = gr.Dropdown(
303
- label="Select Model",
304
- choices=PROVIDERS[default_provider]['models'],
305
- value=PROVIDERS[default_provider]['models'][0]
306
- )
307
 
308
- # Placeholder for the provider logo
309
- logo_html = gr.HTML(
310
- value=f'<img src="{PROVIDERS[default_provider]["logo"]}" width="100px" />'
311
- )
312
 
313
- # Note about the provider
314
- note_markdown = gr.Markdown(f"**Note:** This model is supported by {default_provider}.")
 
 
 
315
 
316
- # State to store the paper content
317
- paper_content = gr.State()
318
 
319
- # Textbox to display the paper title and authors
320
- content = gr.Markdown(value="")
 
 
 
321
 
322
- # Create the chat interface and get the chatbot component
323
- chat_interface, chatbot = create_chat_interface(provider_dropdown, model_dropdown, paper_content,
324
- hf_token_input,
325
- default_type, default_max_total_tokens)
 
326
 
327
- # Function to update models and logo when provider changes
328
- def update_provider(selected_provider):
329
- provider_info = PROVIDERS[selected_provider]
330
- models = provider_info['models']
331
- logo_url = provider_info['logo']
332
- chatbot_message_type = provider_info['type']
333
- max_total_tokens = provider_info['max_total_tokens']
334
 
335
- # Update the models dropdown
336
- model_dropdown_choices = gr.update(choices=models, value=models[0])
 
337
 
338
- # Update the logo image
339
- logo_html_content = f'<img src="{logo_url}" width="100px" />'
340
- logo_html_update = gr.update(value=logo_html_content)
341
 
342
- # Update the note markdown
343
- note_markdown_update = gr.update(value=f"**Note:** This model is supported by {selected_provider}.")
344
 
345
- # Update the hf_token_input label and placeholder
346
- hf_token_input_update = gr.update(
347
- label=f"Enter your {selected_provider} API token (optional)",
348
- placeholder=f"Enter your {selected_provider} API token to avoid rate limits"
349
- )
350
 
351
- # Reset the chatbot history
352
- chatbot_reset = [] # This resets the chatbot conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
- return model_dropdown_choices, logo_html_update, note_markdown_update, hf_token_input_update, chatbot_message_type, max_total_tokens, chatbot_reset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
- provider_dropdown.change(
357
- fn=update_provider,
358
- inputs=provider_dropdown,
359
- outputs=[model_dropdown, logo_html, note_markdown, hf_token_input, default_type, default_max_total_tokens,
360
- chatbot],
361
- queue=False
362
- )
363
 
364
- # Function to update the paper info
365
- def update_paper_info(paper_id_value, paper_from_value, selected_model):
366
- if paper_from_value == "neurips":
367
- preamble = fetch_paper_info_neurips(paper_id_value)
368
- text = fetch_paper_content(paper_id_value)
369
- if preamble is None:
370
- preamble = "Paper not found or could not retrieve paper information."
371
- if text is None:
372
- return preamble, None, []
373
- return preamble, text, []
374
- elif paper_from_value == "paper_page":
375
- # Fetch the paper information from Hugging Face API
376
- url = f"https://huggingface.co/api/papers/{paper_id_value}?field=comments"
377
- response = requests.get(url)
378
- if response.status_code != 200:
379
- return "Paper not found or could not retrieve paper information.", None, []
380
- paper_info = response.json()
381
-
382
- # Extract required information
383
- title = paper_info.get('title', 'No Title')
384
- link = f"https://huggingface.co/papers/{paper_id_value}"
385
- authors_list = [author.get('name', 'Unknown') for author in paper_info.get('authors', [])]
386
- authors = ', '.join(authors_list)
387
- summary = paper_info.get('summary', 'No Summary')
388
- num_comments = len(paper_info.get('comments', []))
389
- num_upvotes = paper_info.get('upvotes', 0)
390
-
391
- # Format the preamble
392
- preamble = f"🤗 [paper-page]({link})<br/>"
393
- preamble += f"**Title:** {title}<br/>"
394
- preamble += f"**Authors:** {authors}<br/>"
395
- preamble += f"**Summary:**<br/>>\n{summary}<br/>"
396
- preamble += f"👍{num_comments} 💬{num_upvotes} <br/>"
397
-
398
- # Fetch the paper content
399
- text = fetch_paper_content_arxiv(paper_id_value)
400
- if text is None:
401
- text = "Paper content could not be retrieved."
402
- return preamble, text, []
403
- else:
404
- return "", "", []
405
 
406
- # Update paper content when paper ID changes
407
- paper_id.change(
408
- fn=update_paper_info,
409
- inputs=[paper_id, paper_from, model_dropdown],
410
- outputs=[content, paper_content, chatbot]
411
- )
412
 
413
 
414
  def main():
@@ -416,10 +453,7 @@ def main():
416
  Launches the Gradio app.
417
  """
418
  with gr.Blocks(css_paths="style.css") as demo:
419
- # Create an input for paper_id
420
  paper_id = gr.Textbox(label="Paper ID", value="")
421
-
422
- # Create an input for paper_from (e.g., 'neurips' or 'paper_page')
423
  paper_from = gr.Radio(
424
  label="Paper Source",
425
  choices=["neurips", "paper_page"],
@@ -427,11 +461,24 @@ def main():
427
  )
428
 
429
  # Build the paper chat tab
430
- paper_chat_tab(paper_id, paper_from)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
  demo.launch(ssr_mode=False)
433
 
434
 
435
- # Run the main function when the script is executed
436
  if __name__ == "__main__":
437
  main()
 
7
  from io import BytesIO
8
  from transformers import AutoTokenizer
9
  import json
10
+ from datetime import datetime
11
  import os
12
  from openai import OpenAI
13
+ import re
14
 
15
  # Cache for tokenizers to avoid reloading
16
  tokenizer_cache = {}
 
24
  "api_key_env_var": "SAMBANOVA_API_KEY",
25
  "models": [
26
  "Meta-Llama-3.1-70B-Instruct",
 
27
  ],
28
  "type": "tuples",
29
  "max_total_tokens": "50000",
 
43
  }
44
 
45
 
46
+ # Functions for paper fetching
47
  def fetch_paper_info_neurips(paper_id):
48
  url = f"https://openreview.net/forum?id={paper_id}"
49
  response = requests.get(url)
50
  if response.status_code != 200:
51
+ return None, None, None
52
 
53
  html_content = response.content
54
  soup = BeautifulSoup(html_content, 'html.parser')
 
73
  else:
74
  abstract = 'Abstract not found'
75
 
76
+ # Construct preamble
77
+ link = f"https://openreview.net/forum?id={paper_id}"
78
+ return title, author_list, f"**Abstract:** {abstract}\n\n[View on OpenReview]({link})"
79
 
 
80
 
81
+ def fetch_paper_content_neurips(paper_id):
 
82
  try:
83
+ url = f"https://openreview.net/pdf?id={paper_id}"
 
 
 
84
  response = requests.get(url)
85
+ response.raise_for_status()
 
 
86
  pdf_content = BytesIO(response.content)
87
  reader = PdfReader(pdf_content)
 
 
88
  text = ""
89
  for page in reader.pages:
90
  text += page.extract_text()
91
+ return text
92
+ except:
 
 
93
  return None
94
 
95
 
96
+ def fetch_paper_content_arxiv(paper_id):
97
  try:
98
+ url = f"https://arxiv.org/pdf/{paper_id}.pdf"
 
 
 
99
  response = requests.get(url)
100
+ response.raise_for_status()
 
 
101
  pdf_content = BytesIO(response.content)
102
  reader = PdfReader(pdf_content)
 
 
103
  text = ""
104
  for page in reader.pages:
105
  text += page.extract_text()
106
+ return text
 
 
107
  except Exception as e:
108
+ print(f"Error fetching paper content: {e}")
109
  return None
110
 
111
 
112
+ def fetch_paper_info_paperpage(paper_id_value):
113
+ # Extract paper_id from paper_page link or input
114
+ def extract_paper_id(input_string):
115
+ # Already in correct form?
116
+ if re.fullmatch(r'\d+\.\d+', input_string.strip()):
117
+ return input_string.strip()
118
+ # If URL
119
+ match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string)
120
+ if match:
121
+ return match.group(1)
122
+ return input_string.strip()
123
+
124
+ paper_id_value = extract_paper_id(paper_id_value)
125
+ url = f"https://huggingface.co/api/papers/{paper_id_value}?field=comments"
126
+ response = requests.get(url)
127
+ if response.status_code != 200:
128
+ return None, None, None
129
+ paper_info = response.json()
130
+ title = paper_info.get('title', 'No Title')
131
+ authors_list = [author.get('name', 'Unknown') for author in paper_info.get('authors', [])]
132
+ authors = ', '.join(authors_list)
133
+ summary = paper_info.get('summary', 'No Summary')
134
+ num_comments = len(paper_info.get('comments', []))
135
+ num_upvotes = paper_info.get('upvotes', 0)
136
+ link = f"https://huggingface.co/papers/{paper_id_value}"
137
+
138
+ details = f"{summary}<br/>👍{num_comments} 💬{num_upvotes}<br/> <a href='{link}' " \
139
+ f"target='_blank'>View on 🤗 hugging face</a>"
140
+ return title, authors, details
141
+
142
+
143
+ def fetch_paper_content_paperpage(paper_id_value):
144
+ # Extract paper_id
145
+ def extract_paper_id(input_string):
146
+ if re.fullmatch(r'\d+\.\d+', input_string.strip()):
147
+ return input_string.strip()
148
+ match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string)
149
+ if match:
150
+ return match.group(1)
151
+ return input_string.strip()
152
+
153
+ paper_id_value = extract_paper_id(paper_id_value)
154
+ text = fetch_paper_content_arxiv(paper_id_value)
155
+ return text
156
+
157
+
158
+ # Dictionary for paper sources
159
+ PAPER_SOURCES = {
160
+ "neurips": {
161
+ "fetch_info": fetch_paper_info_neurips,
162
+ "fetch_pdf": fetch_paper_content_neurips
163
+ },
164
+ "paper_page": {
165
+ "fetch_info": fetch_paper_info_paperpage,
166
+ "fetch_pdf": fetch_paper_content_paperpage
167
+ }
168
+ }
169
+
170
+
171
  def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type,
172
  provider_max_total_tokens):
173
  # Define the function to handle the chat
 
 
174
  def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value,
175
  max_total_tokens):
176
  provider_info = PROVIDERS[provider_name_value]
 
179
  models = provider_info['models']
180
  max_total_tokens = int(max_total_tokens)
181
 
182
+ # Load tokenizer
183
  tokenizer_key = f"{provider_name_value}_{model_name_value}"
184
  if tokenizer_key not in tokenizer_cache:
 
 
185
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
186
  token=os.environ.get("HF_TOKEN"))
187
  tokenizer_cache[tokenizer_key] = tokenizer
 
225
 
226
  # Check if total tokens exceed the maximum allowed tokens
227
  if total_tokens > max_total_tokens:
228
+ # Attempt to truncate context
229
  available_tokens = max_total_tokens - (total_tokens - context_token_length)
230
  if available_tokens > 0:
 
231
  truncated_context_tokens = context_tokens[:available_tokens]
232
  context = tokenizer.decode(truncated_context_tokens)
233
  context_token_length = available_tokens
234
  total_tokens = total_tokens - len(context_tokens) + context_token_length
235
  else:
 
236
  context = ""
237
  total_tokens -= context_token_length
238
  context_token_length = 0
239
 
240
+ # Truncate message history if needed
241
  while total_tokens > max_total_tokens and len(messages) > 1:
 
242
  removed_message = messages.pop(0)
243
  removed_tokens = message_tokens_list.pop(0)
244
  total_tokens -= removed_tokens
245
 
246
+ # Rebuild the final messages
247
  final_messages = []
248
  if context:
249
+ final_messages.append({"role": "system", "content": f"{context}"})
 
250
  final_messages.extend(messages)
251
 
252
  # Use the provider's API key
 
254
  if not api_key:
255
  raise ValueError("API token is not provided.")
256
 
257
+ # Initialize the OpenAI client
258
  client = OpenAI(
259
  base_url=endpoint,
260
  api_key=api_key,
261
  )
262
 
263
  try:
 
264
  completion = client.chat.completions.create(
265
  model=model_name_value,
266
  messages=final_messages,
 
272
  response_text += delta
273
  yield response_text
274
  except json.JSONDecodeError as e:
275
+ yield f"JSON decoding error: {e.msg}"
 
 
 
 
276
  except openai.OpenAIError as openai_err:
277
+ yield f"OpenAI error: {openai_err}"
 
 
278
  except Exception as ex:
279
+ yield f"Unexpected error: {ex}"
 
 
 
 
 
 
 
 
 
 
280
 
281
+ chatbot = gr.Chatbot(label="Chatbot", scale=1, height=400, autoscroll=True)
282
  chat_interface = gr.ChatInterface(
283
  fn=get_fn,
284
  chatbot=chatbot,
 
288
  return chat_interface, chatbot
289
 
290
 
291
+ def paper_chat_tab(paper_id, paper_from, paper_central_df):
292
+ with gr.Row():
293
+ # Left column: Paper selection and display
294
+ with gr.Column(scale=1):
295
+ gr.Markdown("### Select a Paper")
296
+ todays_date = datetime.today().strftime('%Y-%m-%d')
297
 
298
+ # Filter papers for today's date and having a paper_page
299
+ selectable_papers = paper_central_df.df_prettified
300
+ selectable_papers = selectable_papers[
301
+ selectable_papers['paper_page'].notna() &
302
+ (selectable_papers['paper_page'] != "") &
303
+ (selectable_papers['date'] == todays_date)
304
+ ]
305
 
306
+ paper_choices = [(row['title'], row['paper_page']) for _, row in selectable_papers.iterrows()]
307
+ paper_choices = sorted(paper_choices, key=lambda x: x[0])
308
 
309
+ if not paper_choices:
310
+ paper_choices = [("No available papers for today", "")]
 
 
 
311
 
312
+ paper_select = gr.Dropdown(
313
+ label="Select a paper to chat with:",
314
+ choices=[p[0] for p in paper_choices],
315
+ value=paper_choices[0][0] if paper_choices else None
316
+ )
317
+ select_paper_button = gr.Button("Load this paper")
 
 
 
 
 
 
318
 
319
+ # Paper info display - styled card
320
+ content = gr.HTML(value="", elem_id="paper_info_card")
 
 
321
 
322
+ # Right column: Provider and model selection + chat
323
+ with gr.Column(scale=1, visible=False) as provider_section:
324
+ gr.Markdown("### LLM Provider and Model")
325
+ provider_names = list(PROVIDERS.keys())
326
+ default_provider = provider_names[0]
327
 
328
+ default_type = gr.State(value=PROVIDERS[default_provider]["type"])
329
+ default_max_total_tokens = gr.State(value=PROVIDERS[default_provider]["max_total_tokens"])
330
 
331
+ provider_dropdown = gr.Dropdown(
332
+ label="Select Provider",
333
+ choices=provider_names,
334
+ value=default_provider
335
+ )
336
 
337
+ hf_token_input = gr.Textbox(
338
+ label=f"Enter your {default_provider} API token (optional)",
339
+ type="password",
340
+ placeholder=f"Enter your {default_provider} API token to avoid rate limits"
341
+ )
342
 
343
+ model_dropdown = gr.Dropdown(
344
+ label="Select Model",
345
+ choices=PROVIDERS[default_provider]['models'],
346
+ value=PROVIDERS[default_provider]['models'][0]
347
+ )
 
 
348
 
349
+ logo_html = gr.HTML(
350
+ value=f'<img src="{PROVIDERS[default_provider]["logo"]}" width="100px" />'
351
+ )
352
 
353
+ note_markdown = gr.Markdown(f"**Note:** This model is supported by {default_provider}.")
 
 
354
 
355
+ paper_content = gr.State()
 
356
 
357
+ # Create chat interface
358
+ chat_interface, chatbot = create_chat_interface(provider_dropdown, model_dropdown, paper_content,
359
+ hf_token_input, default_type, default_max_total_tokens)
 
 
360
 
361
+ def update_provider(selected_provider):
362
+ provider_info = PROVIDERS[selected_provider]
363
+ models = provider_info['models']
364
+ logo_url = provider_info['logo']
365
+ chatbot_message_type = provider_info['type']
366
+ max_total_tokens = provider_info['max_total_tokens']
367
+
368
+ model_dropdown_choices = gr.update(choices=models, value=models[0])
369
+ logo_html_content = f'<img src="{logo_url}" width="100px" />'
370
+ logo_html_update = gr.update(value=logo_html_content)
371
+ note_markdown_update = gr.update(value=f"**Note:** This model is supported by {selected_provider}.")
372
+ hf_token_input_update = gr.update(
373
+ label=f"Enter your {selected_provider} API token (optional)",
374
+ placeholder=f"Enter your {selected_provider} API token to avoid rate limits"
375
+ )
376
+ chatbot_reset = []
377
+ return model_dropdown_choices, logo_html_update, note_markdown_update, hf_token_input_update, chatbot_message_type, max_total_tokens, chatbot_reset
378
+
379
+ provider_dropdown.change(
380
+ fn=update_provider,
381
+ inputs=provider_dropdown,
382
+ outputs=[model_dropdown, logo_html, note_markdown, hf_token_input, default_type, default_max_total_tokens,
383
+ chatbot],
384
+ queue=False
385
+ )
386
 
387
+ def update_paper_info(paper_id_value, paper_from_value, selected_model, old_content):
388
+ # Use PAPER_SOURCES to fetch info
389
+ source_info = PAPER_SOURCES.get(paper_from_value, {})
390
+ fetch_info_fn = source_info.get("fetch_info")
391
+ fetch_pdf_fn = source_info.get("fetch_pdf")
392
+
393
+ if not fetch_info_fn or not fetch_pdf_fn:
394
+ return gr.update(value="<div>No information available.</div>"), None, []
395
+
396
+ title, authors, details = fetch_info_fn(paper_id_value)
397
+ if title is None and authors is None and details is None:
398
+ return gr.update(value="<div>No information could be retrieved.</div>"), None, []
399
+
400
+ text = fetch_pdf_fn(paper_id_value)
401
+ if text is None:
402
+ text = "Paper content could not be retrieved."
403
+
404
+ # Create a styled card for the paper info
405
+ card_html = f"""
406
+ <div style="border:1px solid #ccc; border-radius:6px; background:#f9f9f9; padding:15px; margin-bottom:10px;">
407
+ <center><h3 style="margin-top:0; text-decoration:underline;">You are talking with:</h3></center>
408
+ <h3>{title}</h3>
409
+ <p><strong>Authors:</strong> {authors}</p>
410
+ <p>{details}</p>
411
+ </div>
412
+ """
413
+
414
+ return gr.update(value=card_html), text, []
415
+
416
+ def select_paper(paper_title):
417
+ # Find the corresponding paper_page from the title
418
+ for t, ppage in paper_choices:
419
+ if t == paper_title:
420
+ return ppage, "paper_page"
421
+ return "", ""
422
+
423
+ select_paper_button.click(
424
+ fn=select_paper,
425
+ inputs=[paper_select],
426
+ outputs=[paper_id, paper_from]
427
+ )
428
 
429
+ # After updating paper_id, we update paper info
430
+ paper_id.change(
431
+ fn=update_paper_info,
432
+ inputs=[paper_id, paper_from, model_dropdown, content],
433
+ outputs=[content, paper_content, chatbot]
434
+ )
 
435
 
436
+ # Function to toggle visibility of the right column based on paper_id
437
+ def toggle_provider_visibility(paper_id_value):
438
+ if paper_id_value and paper_id_value.strip():
439
+ return gr.update(visible=True)
440
+ else:
441
+ return gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
+ # Chain a then call to toggle visibility of the provider_section after paper info update
444
+ paper_id.change(
445
+ fn=toggle_provider_visibility,
446
+ inputs=[paper_id],
447
+ outputs=[provider_section]
448
+ )
449
 
450
 
451
  def main():
 
453
  Launches the Gradio app.
454
  """
455
  with gr.Blocks(css_paths="style.css") as demo:
 
456
  paper_id = gr.Textbox(label="Paper ID", value="")
 
 
457
  paper_from = gr.Radio(
458
  label="Paper Source",
459
  choices=["neurips", "paper_page"],
 
461
  )
462
 
463
  # Build the paper chat tab
464
+ dummy_calendar = gr.State(datetime.now().strftime("%Y-%m-%d"))
465
+
466
+ class MockPaperCentral:
467
+ def __init__(self):
468
+ import pandas as pd
469
+ data = {
470
+ 'date': [datetime.today().strftime('%Y-%m-%d')],
471
+ 'paper_page': ['1234.56789'],
472
+ 'title': ['An Example Paper']
473
+ }
474
+ self.df_prettified = pd.DataFrame(data)
475
+
476
+ paper_central_df = MockPaperCentral()
477
+
478
+ paper_chat_tab(paper_id, paper_from, paper_central_df)
479
 
480
  demo.launch(ssr_mode=False)
481
 
482
 
 
483
  if __name__ == "__main__":
484
  main()