king17pvp's picture
Upload folder using huggingface_hub
995aaf9 verified
raw
history blame
6.59 kB
import requests
import nltk
import random
import json
import os
import pickle
import re
nltk.download('punkt')
hf_tokens = []
filepath = __file__.replace("\\", "/").replace("utils.py", "")
with open(filepath + "data/hf_tokens.pkl", "rb") as f:
hf_tokens = pickle.load(f)
MAX_TOKEN_LENGTH = 4096
MAX_CHUNK_SIZE = 16000
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
def prompt_template(prompt, sys_prompt = ""):
return_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<system_prompt><|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<user_prompt><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'.replace('<user_prompt>', prompt).replace('<system_prompt>', sys_prompt)
return return_prompt
def query(payload: dict, hf_token: str):
headers = {"Authorization": f"Bearer {hf_token}"}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
def gen_prompt(prompt: str, sys_prompt:str = ""):
input_prompt = prompt_template(prompt, sys_prompt)
selected_token = ''
for token in hf_tokens:
test_output = query({
"inputs": prompt_template("Who are you?"),
"parameters": {"max_new_tokens": 100}
}, token)
if 'error' not in test_output:
selected_token = token
break
output = query({
"inputs": input_prompt,
"parameters": {"max_new_tokens": 512},
}, selected_token)
return output[0]['generated_text'][len(input_prompt):]
class Node:
def __init__(self, summary=None):
self.summary = summary
self.children = []
self.parent = None
def add_child(self, child_node):
child_node.parent = self
self.children.append(child_node)
class MemWalker:
def __init__(self, segments):
self.segments = segments
self.root = 0
def build_memory_tree(self):
# Step 1: Create leaf nodes for each segment
leaves = [Node(summarize(seg, 0)) for seg in self.segments]
# Step 2: Build tree recursively
while len(leaves) > 1:
new_leaves = []
for i in range(0, len(leaves), 2):
if i + 1 < len(leaves):
combined_summary = summarize(leaves[i].summary + ", " + leaves[i + 1].summary, 1)
parent_node = Node(combined_summary)
parent_node.add_child(leaves[i])
parent_node.add_child(leaves[i + 1])
else:
parent_node = leaves[i]
new_leaves.append(parent_node)
leaves = new_leaves
self.root = leaves[0]
# Placeholder functions for LLM operations
def summarize(text, sum_type: int = 1):
assert sum_type in [0, 1], "Lmao sum type should be either 0 or 1"
if sum_type == 0:
USER_PROMPT = "Write a concise summary of the meeting transcript in maximum 5 sentences:" + "\n\n" + text
else:
USER_PROMPT = "Compress the following summaries into a much shorter summary: " + "\n\n" + text
SYS_PROMPT = "Act as a professional technical meeting minutes writer."
tmp = gen_prompt(USER_PROMPT, SYS_PROMPT)
if len(tmp.split("\n\n")) == 1:
return tmp
else:
return tmp.split("\n\n")[1]
#return output[0]['generated_text'][len(input_prompt):]
def split_chunk(transcript: str):
sentences = nltk.sent_tokenize(transcript)
idx = 0
chunk = []
current_chunk = ""
while idx < len(sentences):
if len(current_chunk + sentences[idx]) < MAX_CHUNK_SIZE:
current_chunk += sentences[idx] + " "
else:
chunk.append(current_chunk)
current_chunk = ''
for i in range(10, -1, -1):
current_chunk += sentences[idx - i] + " "
idx += 1
chunk.append(current_chunk)
return chunk
def summarize_three_ways(chunks: list[str]):
SYS_PROMPT = "Act as a professional technical meeting minutes writer."
PROMPT_TEMPLATE = "Write a concise summary of the meeting transcript in maximum 5 sentences:" + "\n\n" + "{text}"
REFINE_TEMPLATE = (
"Your job is to produce a final summary\n"
"We have provided an existing summary up to a certain point: {existing_answer}\n"
"We have the opportunity to refine the existing summary"
"(only if needed) with some more context below.\n"
"------------\n"
"{text}\n"
"------------\n"
f"Given the new context, refine the original summary in English within 5 sentences. If the context isn't useful, return the original summary."
)
step = 0
prev_sum = ""
partial_sum = []
return_dict = {}
for chunk in chunks:
if step == 0:
CUR_PROMPT = PROMPT_TEMPLATE.replace("{text}", chunk)
cur_sum = gen_prompt(CUR_PROMPT , SYS_PROMPT)
else:
CUR_PROMPT = REFINE_TEMPLATE.replace("{existing_answer}", partial_sum[-1])
CUR_PROMPT = CUR_PROMPT.replace("{text}", chunk)
cur_sum = gen_prompt(CUR_PROMPT, SYS_PROMPT)
if len(cur_sum.split("\n\n")) > 1:
cur_sum = cur_sum.split("\n\n")[1]
#print(cur_sum)
partial_sum.append(cur_sum)
step += 1
#print(partial_sum)
CUR_PROMPT = "Rewrite the following text by maintaining coherency: " + "\n\n"
CUR_PROMPT += ' '.join(partial_sum)
tmp = gen_prompt(CUR_PROMPT, SYS_PROMPT)
final_sum = ''
if len(tmp.split("\n\n")) == 1:
final_sum = tmp
else:
final_sum = tmp.split("\n\n")[1]
return_dict['truncated'] = partial_sum[0]
return_dict['accumulate'] = partial_sum[-1]
return_dict['rewrite'] = final_sum
return return_dict
def get_example()->list[str]:
data = []
with open(filepath + "data/test.json", "r") as f:
for line in f:
data.append(json.loads(line))
#random_idx = random.sample(list(range(len(data))), 6)
random_idx = [1, 2, 9, 13]
return ['\n'.join(nltk.sent_tokenize(data[i]['transcript'])) for i in random_idx]
if __name__ == "__main__":
'''data = []
with open(filepath + "data/test.json", "r") as f:
for line in f:
data.append(json.loads(line))
tmp = data[:100]
for j, i in enumerate(tmp):
print(j, len(i['transcript']))'''