quantaji's picture
add post-process file
21a2300
raw
history blame
7.58 kB
import tiktoken
from typing import Dict, Tuple, List
def slide_generation(res, num_tokens_limit=1800):
new_res = [res[0]]
for i in range(1, len(res)):
if not res[i]:
continue
prev_cnt = get_num_tokens(new_res[-1])
curr_cnt = get_num_tokens(res[i])
if prev_cnt + curr_cnt < num_tokens_limit:
new_res[-1] += res[i]
else:
new_res.append(res[i])
return new_res
def slide_generation_ver2(res, num_tokens_limit=1800):
text = "\n".join(res).split("[PE]")
text = [(t.strip() + "\n[PE]\n") if t else "" for t in text]
return slide_generation(text, num_tokens_limit=num_tokens_limit)
def parse_prompt(file: str, data: List[str] = None):
roles = []
contents = []
file = open(file, "r")
for line in file.readlines():
# if line is empty or a comment, skip
if "#" in line or not line.strip():
continue
if "[user]" in line:
roles.append("user")
contents.append([])
continue
elif "[assistant]" in line:
roles.append("assistant")
contents.append([])
continue
elif "[system]" in line:
roles.append("system")
contents.append([])
continue
if line.strip():
assert roles, "No role specified"
contents[-1].append(line.strip())
# checking roles
assert roles[0] in ["user", "system"], "First role must be user or system"
for i in range(1, len(roles)):
assert roles[i] in ["user", "assistant"], "Roles must be user or assistant"
assert roles[i] != roles[i - 1], "Roles must alternate between user and assistant"
contents_str = []
for content in contents:
contents_str.append(" ".join(content))
curr_idx = 0
for i in range(len(contents_str)):
tag = f"[data_tag_{curr_idx}]"
# replace \n with newline
contents_str[i] = contents_str[i].replace("\\n", "\n")
if tag in contents_str[i]:
contents_str[i] = contents_str[i].replace(tag, data[curr_idx])
curr_idx += 1
assert curr_idx == len(data), "Not all data tags were replaced"
messages = []
for i in range(len(roles)):
messages.append({"role": roles[i], "content": contents_str[i]})
return messages
def clean_slides(slide):
slide_list = slide.split('\n')
clean_slide_list = []
for line in slide_list:
if line[:3] == '[F]' or line[:3] == '[T]' or line[:6] == '[T][T]' or line[:4] == '[PB]' or line[:4] == '[PE]':
clean_slide_list.append(line)
return '\n'.join(clean_slide_list)
def generate_latex_slide(slide, output_path=None):
# Initialize the Beamer document
latex_code = "\\documentclass{beamer} \n\\begin{document}"
# Split the slide string into pages
pages = slide.split('[PB]')[1:]
# Iterate through each page
for i, page in enumerate(pages):
tmp_list = [None, None] # [title, content]
page = page.strip()
print(i, page)
# Extract the page title and content
title_end_index = page.index("\n") + 1
title = page[:title_end_index].strip()
content_end_index = page.index("[PE]")
content = page[title_end_index:content_end_index].strip()
# Start a new frame with the page title
if title:
tmp_list[0] = f"\n\\begin{{frame}}{{{title}}}\n\n"
# Split the content into list items
items = content.split('\n')
p = []
for item in items:
if not item:
break
# print(item)
if '[T][T]' in item:
assert len(p) > 0, "Subpoint cannot be the first item in a page"
subpoints = item.split('[T][T]')[1]
p[-1].append(subpoints)
else:
if '[T]' in item:
point = item.split('[T]')[1]
else:
point = item
p.append([point])
if p:
# Add each item as a Beamer itemize element
tmp_list[1] = "\\begin{itemize}\n"
for point in p:
if not point:
break
tmp_list[1] += f"\\item {point[0]}\n"
if len(point) > 1:
tmp_list[1] += "\\begin{itemize}\n"
for subpoint in point[1:]:
tmp_list[1] += f"\\item {subpoint}\n"
tmp_list[1] += "\\end{itemize}\n"
tmp_list[1] += "\\end{itemize}\n"
if tmp_list[0] is None and tmp_list[1] is None:
# The page is empty, so skip it
if i == len(pages) - 1:
# This is the last page, so end the document instead of the frame
latex_code += "\n\\end{document}"
break
tmp_list[1] += "\n\\end{frame}\n"
# End the frame
if i == len(pages) - 1:
# This is the last page, so end the document instead of the frame
tmp_list[1] += "\n\\end{document}"
latex_code += "".join(tmp_list)
latex_code = latex_code.replace('_', '\_').replace('&', '\&').replace('^', '\^').replace('$', '\$')
if output_path:
with open(output_path, 'w') as f:
f.write(latex_code)
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"):
"""
Returns the number of tokens required to encode the given messages.
source: https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/chatgpt?pivots=programming-language-chat-completions#managing-conversations
"""
encoding = tiktoken.encoding_for_model(model)
num_tokens = 0
for message in messages:
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
def get_num_tokens(message, model="gpt-3.5-turbo-0301"):
encoding = tiktoken.encoding_for_model(model)
num_tokens = 0
num_tokens += len(encoding.encode(message))
return num_tokens
def get_paper_text_in_chunks(example, chunk_size=4000):
paper_length = len(example['paper']['text'])
title = '[TB] ' + example['title'] + ' [TE] '
abstract = '[AB] ' + example['paper']['abstract'] + ' [AE] '
sections = [' [SB] ' + head['n'] + ' ' + head['section'] + ' [SC] ' + ' '.join([example['paper']['text'][idx]['string'] for idx in range(head['start'], min(head['end'] + 1, paper_length))]) + ' [SE] ' for head in example['paper']['headers']]
figures = [' [FB] ' + fig['caption'] + ' [FE] ' for fig in example['paper']['figures']]
chunks = []
temp_chunk = title + abstract
temp_chunk_length = get_num_tokens(temp_chunk)
for s in sections + figures:
assert get_num_tokens(s) < chunk_size, "Section or figure is too long to fit in a chunk"
if temp_chunk_length + get_num_tokens(s) > chunk_size:
chunks.append(temp_chunk)
temp_chunk = s
temp_chunk_length = get_num_tokens(s)
else:
temp_chunk += s
temp_chunk_length += get_num_tokens(s)
if temp_chunk_length > 0:
chunks.append(temp_chunk)
return chunks