|
import requests |
|
import os |
|
import gradio as gr |
|
import time |
|
import heapq |
|
import re |
|
from utils import package_installer |
|
|
|
package_installer('sentence_transformers') |
|
package_installer('nbconvert') |
|
package_installer('inflect') |
|
|
|
|
|
|
|
from sentence_transformers import SentenceTransformer, util |
|
import nbconvert |
|
import nbformat |
|
from bs4 import BeautifulSoup |
|
from inflect import engine |
|
|
|
inflect_engine = engine() |
|
|
|
def convert_ipynb_to_html(input_file): |
|
|
|
notebook = nbformat.read(input_file, as_version=4) |
|
|
|
|
|
html_exporter = nbconvert.HTMLExporter() |
|
(body, resources) = html_exporter.from_notebook_node(notebook) |
|
|
|
|
|
|
|
|
|
return body |
|
|
|
|
|
API_TOKEN = os.environ['HF_TOKEN'] |
|
|
|
API_URL = "https://api-inference.huggingface.co/models/microsoft/Phi-3-mini-4k-instruct" |
|
headers = {"Authorization": f"Bearer {API_TOKEN}"} |
|
|
|
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
def user_prompt_template(user_msg:str): |
|
return f"<|user|>\n{user_msg}<|end|>\n<|assistant|>" |
|
|
|
def assistant_response_template(assistant_msg:str): |
|
return f"{assistant_msg}<|end|>\n" |
|
|
|
|
|
def query(payload): |
|
response = requests.post(API_URL, headers=headers, json=payload,timeout=120) |
|
return response.json() |
|
|
|
|
|
|
|
def chat(message,history): |
|
|
|
formatted_user_msg = user_prompt_template(message['text']) |
|
|
|
hist_cop = history.copy() |
|
for item in history: |
|
if None in item and type(item[0]) == tuple: |
|
hist_cop.remove(item) |
|
|
|
try: |
|
html_code = convert_ipynb_to_html(message['files'][0]) |
|
soup = BeautifulSoup(html_code, 'html.parser') |
|
text = soup.get_text() |
|
|
|
code_data = text.split('\n') |
|
string = '' |
|
cells_list = [] |
|
for item in code_data: |
|
|
|
if len(item) > 0: |
|
string += item + '\n' |
|
continue |
|
|
|
if len(item) == 0 and len(string) > 0: |
|
cells_list.append(string) |
|
string = '' |
|
|
|
cells_list_copy = cells_list.copy() |
|
for item in cells_list: |
|
if item == 'Notebook\n' or item == 'In\xa0[\xa0]:\n': |
|
cells_list_copy.remove(item) |
|
|
|
indexed_cells_list = [] |
|
index_comments = [] |
|
for i in range(len(cells_list_copy)): |
|
itxt = cells_list_copy[i] |
|
cell_addresses = f'# Cell Number: {i+1}\n' + f'# Cell Number: {inflect_engine.number_to_words(i+1)}\n' |
|
if i+1 % 10 == 1: |
|
indexed_cells_list.append(f'# {i+1}st cell\n'+ cell_addresses + itxt) |
|
index_comments.append(f'# {i+1}st cell\n'+ cell_addresses) |
|
elif i+1 % 10 == 2: |
|
indexed_cells_list.append(f'# {i+1}nd cell\n' + cell_addresses + itxt) |
|
index_comments.append(f'# {i+1}nd cell\n' + cell_addresses) |
|
elif i+1 % 10 == 3: |
|
indexed_cells_list.append(f'# {i+1}rd cell\n' + cell_addresses + itxt) |
|
index_comments.append(f'# {i+1}rd cell\n' + cell_addresses) |
|
else: |
|
indexed_cells_list.append(f'# {i+1}th cell\n' + cell_addresses + itxt) |
|
index_comments.append(f'# {i+1}th cell\n' + cell_addresses) |
|
|
|
|
|
|
|
|
|
|
|
except: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
emb_cells = embedding_model.encode(index_comments,convert_to_tensor=True) |
|
|
|
emb_formatted_user_msg = embedding_model.encode(formatted_user_msg,convert_to_tensor=True) |
|
cosine_sim_0 = util.cos_sim(emb_formatted_user_msg,emb_cells) |
|
|
|
top_5_cells_scores = heapq.nlargest(5,cosine_sim_0[0]) |
|
top_5_cells = [indexed_cells_list[index] for index in sorted(list(cosine_sim_0[0]).index(score) for score in top_5_cells_scores)] |
|
|
|
|
|
top_2_chats = None |
|
if hist_cop: |
|
chat_history = [user_prompt_template(item[0]) + assistant_response_template(item[1]) for item in hist_cop] |
|
|
|
emb_chat_history = embedding_model.encode(chat_history,convert_to_tensor=True) |
|
cosine_similarity_scores = util.cos_sim(emb_formatted_user_msg,emb_chat_history) |
|
top_2_scores = heapq.nlargest(2,cosine_similarity_scores[0]) |
|
top_2_chats = [chat_history[i] for i in sorted(list(cosine_similarity_scores[0]).index(val) for val in top_2_scores)] |
|
|
|
similar_chat_history = '' |
|
if top_2_chats: |
|
for chats in top_2_chats: |
|
|
|
similar_chat_history += chats |
|
|
|
|
|
|
|
top_5_cells_string = '\n'.join(top_5_cells) |
|
context_plus_message = top_5_cells_string + message['text'] |
|
formatted_context_plus_message = user_prompt_template(context_plus_message) |
|
user_input = similar_chat_history + formatted_context_plus_message |
|
|
|
|
|
|
|
|
|
|
|
inp_dict = {"inputs":user_input, |
|
"parameters": {"max_new_tokens":750,"temperature":0.01}} |
|
output = query(inp_dict) |
|
|
|
try: |
|
output_text = output[0]['generated_text'] |
|
formatted_assistant_msg = output_text.replace(user_input,'').strip().removesuffix('<|end|>') |
|
except: |
|
if type(output) == dict: |
|
formatted_assistant_msg = f"Error has occured, type of output is {type(output)} and items of output are: {output.items()}" |
|
else: |
|
formatted_assistant_msg = f"Error has occured, type of output is {type(output)} and length of output is: {len(output)}" |
|
|
|
|
|
print(user_input) |
|
print() |
|
print(indexed_cells_list) |
|
return formatted_assistant_msg |
|
|
|
demo = gr.ChatInterface(chat, multimodal=True) |
|
|
|
if __name__ == '__main__': |
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|