| | import gradio as gr |
| | import pandas as pd |
| | from tqdm import tqdm |
| | from langchain.docstore.document import Document as LangchainDocument |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| | from langchain_community.vectorstores import FAISS |
| | from langchain_community.embeddings import HuggingFaceEmbeddings |
| | from langchain_community.vectorstores.utils import DistanceStrategy |
| | import torch |
| | import matplotlib.pyplot as plt |
| | from typing import Optional, List |
| | from tqdm import tqdm |
| | from langchain_community.vectorstores import FAISS |
| |
|
| | |
| |
|
| |
|
| | |
| | pd.set_option("display.max_colwidth", None) |
| |
|
| | |
| | with open("iplteams_info.txt", "r") as fp1: |
| | content1 = fp1.read() |
| |
|
| | |
| | with open("match_summaries_sentences.txt", "r") as fp2: |
| | content2 = fp2.read() |
| |
|
| | |
| | with open("formatted_playersinfo.txt", "r") as fp3: |
| | content3 = fp3.read() |
| |
|
| | |
| | combined_content = content1 + "\n\n\n" + content2 + "\n\n\n" + content3 |
| |
|
| | |
| | s = combined_content.split("\n\n\n") |
| |
|
| | |
| | print(s[0]) |
| | print(len(s)) |
| |
|
| | |
| | RAW_KNOWLEDGE_BASE = [ |
| | LangchainDocument(page_content=doc) |
| | for doc in tqdm(s) |
| | ] |
| |
|
| | |
| | MARKDOWN_SEPARATORS = [ |
| | "\n#{1,6}", |
| | "```\n", |
| | "\n\\*\\*\\*+\n", |
| | "\n---+\n", |
| | "\n__+\n", |
| | "\n\n", |
| | "\n", |
| | " ", |
| | "" |
| | ] |
| |
|
| | text_splitter = RecursiveCharacterTextSplitter( |
| | chunk_size=1000, |
| | chunk_overlap=100, |
| | add_start_index=True, |
| | strip_whitespace=True, |
| | separators=MARKDOWN_SEPARATORS, |
| | ) |
| |
|
| | docs_processed = [] |
| | for doc in RAW_KNOWLEDGE_BASE: |
| | docs_processed += text_splitter.split_documents([doc]) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-small") |
| | lengths = [len(tokenizer.encode(doc.page_content)) for doc in tqdm(docs_processed)] |
| |
|
| | fig = pd.Series(lengths).hist() |
| | fig.set_title("Histogram of Document Lengths") |
| | plt.title("Distribution") |
| | plt.show() |
| |
|
| | EMBEDDING_MODEL_NAME = "thenlper/gte-small" |
| |
|
| | def split_documents( |
| | chunk_size: int, |
| | knowledge_base: list[LangchainDocument], |
| | tokenizer_name: Optional[str] = EMBEDDING_MODEL_NAME, |
| | ) -> List[LangchainDocument]: |
| | text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( |
| | AutoTokenizer.from_pretrained(tokenizer_name), |
| | chunk_size=chunk_size, |
| | chunk_overlap=int(chunk_size / 10), |
| | add_start_index=True, |
| | strip_whitespace=True, |
| | separators=MARKDOWN_SEPARATORS, |
| | ) |
| | docs_processed = [] |
| | for doc in knowledge_base: |
| | docs_processed += text_splitter.split_documents([doc]) |
| |
|
| | unique_texts = {} |
| | docs_processed_unique = [] |
| | for doc in docs_processed: |
| | if doc.page_content not in unique_texts: |
| | unique_texts[doc.page_content] = True |
| | docs_processed_unique.append(doc) |
| | return docs_processed_unique |
| |
|
| | docs_processed = split_documents(512, RAW_KNOWLEDGE_BASE, tokenizer_name=EMBEDDING_MODEL_NAME) |
| | print(len(docs_processed)) |
| | print(docs_processed[0:3]) |
| |
|
| | print(torch.cuda.is_available()) |
| |
|
| | embedding_model = HuggingFaceEmbeddings( |
| | model_name=EMBEDDING_MODEL_NAME, |
| | multi_process=True, |
| | model_kwargs={"device": "cuda"}, |
| | encode_kwargs={"normalize_embeddings": True}, |
| | ) |
| |
|
| | KNOWLEDGE_VECTOR_DATABASE = FAISS.from_documents( |
| | docs_processed, |
| | embedding_model, |
| | distance_strategy=DistanceStrategy.COSINE, |
| | ) |
| |
|
| | torch.random.manual_seed(0) |
| |
|
| | model = AutoModelForCausalLM.from_pretrained( |
| | "microsoft/Phi-3-mini-128k-instruct", |
| | device_map="cuda", |
| | torch_dtype="auto", |
| | trust_remote_code=True, |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct") |
| |
|
| | pipe = pipeline( |
| | "text-generation", |
| | model=model, |
| | tokenizer=tokenizer, |
| | ) |
| |
|
| | generation_args = { |
| | "max_new_tokens": 500, |
| | "return_full_text": False, |
| | "temperature": 0.0, |
| | "do_sample": False, |
| | } |
| |
|
| | prompt_chat=[ |
| | { |
| | "role":"system", |
| | "content":"""Using the information contained in the context, |
| | Give a comprehensive answer to the question. |
| | Respond only to the question asked , response should be concise and relevant to the question. |
| | provide the number of the source document when relevant. |
| | If the answer cannot be deduced from the context, do not give an answer""", |
| | }, |
| | { |
| | "role":"user", |
| | "content":"""Context: |
| | {context} |
| | --- |
| | Now here is the Question you need to answer. |
| | Question:{question} |
| | """, |
| | }, |
| | ] |
| |
|
| | RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template( |
| | prompt_chat, tokenize=False, add_generation_prompt=True, |
| | ) |
| | print(RAG_PROMPT_TEMPLATE) |
| |
|
| | u_query = "give the match summary of royal challengers bengaluru and mumbai indians in 2024" |
| | retrieved_docs = KNOWLEDGE_VECTOR_DATABASE.similarity_search(query=u_query, k=3) |
| |
|
| | context = retrieved_docs[0].page_content |
| | final_prompt = RAG_PROMPT_TEMPLATE.format( |
| | question=u_query, context=context |
| | ) |
| |
|
| | output = pipe(final_prompt, **generation_args) |
| | print("YOUR QUESTION:\n", u_query, "\n") |
| | print("MICROSOFT 128K ANSWER: \n", output[0]['generated_text']) |
| |
|
| | def handle_query(question): |
| | retrieved_docs = KNOWLEDGE_VECTOR_DATABASE.similarity_search(query=question, k=3) |
| | context = retrieved_docs[0].page_content |
| | final_prompt = RAG_PROMPT_TEMPLATE.format( |
| | question=question, context=context |
| | ) |
| | output = pipe(final_prompt, **generation_args) |
| | return output[0]['generated_text'] |
| |
|
| | interface = gr.Interface( |
| | fn=handle_query, |
| | inputs="text", |
| | outputs="text", |
| | title="IPL Match Summary Generator", |
| | description="Get the match summary of IPL teams based on your query.", |
| | ) |
| |
|
| | interface.launch(sharing=True) |
| |
|