| | import os |
| | import torch |
| | import pandas as pd |
| | import transformers |
| | from pynvml import * |
| | import torch |
| | from langchain import hub |
| | from model_ret import zephyr_model,llama_model,mistral_model,phi_model,flant5_model |
| | from langchain_core.output_parsers import StrOutputParser |
| | from langchain_core.runnables import RunnablePassthrough |
| | from create_retriever import ensemble_retriever |
| | |
| | hf_model_map = { |
| | "Zephyr": "HuggingFaceH4/zephyr-7b-beta", |
| | "Llama": "NousResearch/Meta-Llama-3-8B", |
| | "Mistral": "unsloth/mistral-7b-instruct-v0.3", |
| | "Phi": "microsoft/Phi-3-mini-4k-instruct", |
| | "Flant5": "google/flan-t5-base" |
| | } |
| |
|
| | |
| | class model_chain: |
| | model_name = "" |
| |
|
| | def __init__(self, |
| | model_name_local, |
| | model_name_online="Llama", |
| | use_local=True, |
| | embedding_name="BAAI/bge-base-en-v1.5", |
| | splitter_type_dropdown="character", |
| | chunk_size_slider=512, |
| | chunk_overlap_slider=30, |
| | separator_textbox="\n", |
| | max_tokens_slider=2048) -> None: |
| | if use_local: |
| | quantization, self.model_name = model_name_local.split("_")[0], model_name_local.split("_")[1] |
| | model_name_temp = model_name_local |
| | else: |
| | self.model_name = model_name_online |
| | model_name_temp = hf_model_map[model_name_online] |
| |
|
| | if self.model_name == "Zephyr": |
| | self.llm = zephyr_model(model_name_temp, quantization, use_local=use_local) |
| | elif self.model_name == "Llama": |
| | self.llm = llama_model(model_name_temp, quantization, use_local=use_local) |
| | elif self.model_name == "Mistral": |
| | self.llm = mistral_model(model_name_temp, quantization, use_local=use_local) |
| | elif self.model_name == "Phi": |
| | self.llm = phi_model(model_name_temp, quantization, use_local=use_local) |
| | elif self.model_name == "Flant5": |
| | self.tokenizer, self.model, self.llm = flant5_model(model_name_temp, use_local=use_local) |
| |
|
| | |
| | self.retriever = ensemble_retriever(embedding_name, |
| | splitter_type=splitter_type_dropdown, |
| | chunk_size=chunk_size_slider, |
| | chunk_overlap=chunk_overlap_slider, |
| | separator=separator_textbox, |
| | max_tokens=max_tokens_slider) |
| |
|
| | |
| | prompt = hub.pull("rlm/rag-prompt") |
| | self.rag_chain = ( |
| | {"context": self.retriever | self.format_docs, "question": RunnablePassthrough()} |
| | | prompt |
| | | self.llm |
| | | StrOutputParser() |
| | ) |
| |
|
| | |
| | def format_docs(self, docs): |
| | return "\n\n".join(doc.page_content for doc in docs) |
| |
|
| | |
| | def rag_chain_ret(self): |
| | return self.rag_chain |
| |
|
| | |
| | def ans_ret(self, inp, rag_chain): |
| | if self.model_name == 'Flant5': |
| | my_question = "What is KUET?" |
| | data = self.retriever.invoke(inp) |
| | context = "" |
| | for x in data[:2]: |
| | context += (x.page_content) + "\n" |
| | inputs = f"""Please answer to this question using this context:\n{context}\n{my_question}""" |
| | inputs = self.tokenizer(inputs, return_tensors="pt") |
| | outputs = self.model.generate(**inputs) |
| | answer = self.tokenizer.decode(outputs[0]) |
| | from textwrap import fill |
| | ans = fill(answer, width=100) |
| | return ans |
| |
|
| | ans = rag_chain.invoke(inp) |
| | ans = ans.split("Answer:")[1] |
| | return ans |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|