Spaces:
Sleeping
Sleeping
| import os | |
| from langchain_community.document_loaders import PyMuPDFLoader | |
| from langchain_core.documents import Document | |
| from langchain_community.embeddings.sentence_transformer import ( | |
| SentenceTransformerEmbeddings, | |
| ) | |
| from langchain.schema import StrOutputParser | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain import hub | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_groq import ChatGroq | |
| from langchain_openai import ChatOpenAI | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_anthropic import ChatAnthropic | |
| from dotenv import load_dotenv | |
| from langchain_core.output_parsers import XMLOutputParser | |
| from langchain.prompts import ChatPromptTemplate | |
| load_dotenv() | |
| # suppress grpc and glog logs for gemini | |
| os.environ["GRPC_VERBOSITY"] = "ERROR" | |
| os.environ["GLOG_minloglevel"] = "2" | |
| # RAG parameters | |
| CHUNK_SIZE = 1024 | |
| CHUNK_OVERLAP = CHUNK_SIZE // 8 | |
| K = 10 | |
| FETCH_K = 20 | |
| llm_model_translation = { | |
| "LLaMA 3": "llama3-70b-8192", | |
| "OpenAI GPT 4o Mini": "gpt-4o-mini", | |
| "OpenAI GPT 4o": "gpt-4o", | |
| "OpenAI GPT 4": "gpt-4-turbo", | |
| "Gemini 1.5 Pro": "gemini-1.5-pro", | |
| "Claude Sonnet 3.5": "claude-3-5-sonnet-20240620", | |
| } | |
| llm_classes = { | |
| "llama3-70b-8192": ChatGroq, | |
| "gpt-4o-mini": ChatOpenAI, | |
| "gpt-4o": ChatOpenAI, | |
| "gpt-4-turbo": ChatOpenAI, | |
| "gemini-1.5-pro": ChatGoogleGenerativeAI, | |
| "claude-3-5-sonnet-20240620": ChatAnthropic, | |
| } | |
| xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, \ | |
| fulfill all the requirements of the prompt and provide citations. If a part of the generated text does \ | |
| not use any of the sources, don't put a citation for that part. Otherwise, list all sources used for that part of the text. | |
| At the end of each relevant part, add a citation in square brackets, numbered sequentially starting from [0], regardless of the source's original ID. | |
| Remember, you must return both the requested text and citations. A citation consists of a VERBATIM quote that \ | |
| justifies the text and a sequential number (starting from 0) for the quote's article. Return a citation for every quote across all articles \ | |
| that justify the text. Use the following format for your final output: | |
| <cited_text> | |
| <text></text> | |
| <citations> | |
| <citation><source_id></source_id><source></source><quote></quote></citation> | |
| <citation><source_id></source_id><source></source><quote></quote></citation> | |
| ... | |
| </citations> | |
| </cited_text> | |
| Here are the sources:{context}""" | |
| xml_prompt = ChatPromptTemplate.from_messages( | |
| [("system", xml_system), ("human", "{input}")] | |
| ) | |
| def format_docs_xml(docs: list[Document]) -> str: | |
| formatted = [] | |
| for i, doc in enumerate(docs): | |
| doc_str = f"""\ | |
| <source> | |
| <source>{doc.metadata['source']}</source> | |
| <title>{doc.metadata['title']}</title> | |
| <article_snippet>{doc.page_content}</article_snippet> | |
| </source>""" | |
| formatted.append(doc_str) | |
| return "\n\n<sources>" + "\n".join(formatted) + "</sources>" | |
| def citations_to_html(citations_data): | |
| if citations_data: | |
| html_output = "<ul>" | |
| for index, citation in enumerate(citations_data): | |
| source_id = citation['citation'][0]['source_id'] | |
| source = citation['citation'][1]['source'] | |
| quote = citation['citation'][2]['quote'] | |
| html_output += f""" | |
| <li> | |
| [{index}] - "{source}" <br> | |
| "{quote}" | |
| </li> | |
| """ | |
| html_output += "</ul>" | |
| return html_output | |
| return "" | |
| def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048): | |
| model_name = llm_model_translation.get(model) | |
| llm_class = llm_classes.get(model_name) | |
| if not llm_class: | |
| raise ValueError(f"Model {model} not supported.") | |
| try: | |
| llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length) | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| llm = None | |
| return llm | |
| def create_db_with_langchain(path: list[str], url_content: dict): | |
| all_docs = [] | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) | |
| embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2") | |
| if path: | |
| for file in path: | |
| loader = PyMuPDFLoader(file) | |
| data = loader.load() | |
| # split it into chunks | |
| docs = text_splitter.split_documents(data) | |
| all_docs.extend(docs) | |
| if url_content: | |
| for url, content in url_content.items(): | |
| doc = Document(page_content=content, metadata={"source": url}) | |
| # split it into chunks | |
| docs = text_splitter.split_documents([doc]) | |
| all_docs.extend(docs) | |
| # print docs | |
| for idx, doc in enumerate(all_docs): | |
| print(f"Doc: {idx} | Length = {len(doc.page_content)}") | |
| assert len(all_docs) > 0, "No PDFs or scrapped data provided" | |
| db = Chroma.from_documents(all_docs, embedding_function) | |
| return db | |
| def generate_rag( | |
| prompt: str, | |
| topic: str, | |
| model: str, | |
| url_content: dict, | |
| path: list[str], | |
| temperature: float = 1.0, | |
| max_length: int = 2048, | |
| api_key: str = "", | |
| sys_message="", | |
| ): | |
| llm = load_llm(model, api_key, temperature, max_length) | |
| if llm is None: | |
| print("Failed to load LLM. Aborting operation.") | |
| return None | |
| db = create_db_with_langchain(path, url_content) | |
| retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K}) | |
| rag_prompt = hub.pull("rlm/rag-prompt") | |
| def format_docs(docs): | |
| if all(isinstance(doc, Document) for doc in docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| else: | |
| raise TypeError("All items in docs must be instances of Document.") | |
| docs = retriever.get_relevant_documents(topic) | |
| # formatted_docs = format_docs(docs) | |
| # rag_chain = ( | |
| # {"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser() | |
| # ) | |
| # return rag_chain.invoke(prompt) | |
| formatted_docs = format_docs_xml(docs) | |
| rag_chain = ( | |
| RunnablePassthrough.assign(context=lambda _: formatted_docs) | |
| | xml_prompt | |
| | llm | |
| | XMLOutputParser() | |
| ) | |
| result = rag_chain.invoke({"input": prompt}) | |
| from pprint import pprint | |
| pprint(result) | |
| return result['cited_text'][0]['text'], citations_to_html(result['cited_text'][1]['citations']) | |
| def generate_base( | |
| prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message="" | |
| ): | |
| llm = load_llm(model, api_key, temperature, max_length) | |
| if llm is None: | |
| print("Failed to load LLM. Aborting operation.") | |
| return None | |
| try: | |
| output = llm.invoke(prompt).content | |
| return output | |
| except Exception as e: | |
| print(f"An error occurred while running the model: {e}") | |
| return None | |
| def generate( | |
| prompt: str, | |
| topic: str, | |
| model: str, | |
| url_content: dict, | |
| path: list[str], | |
| temperature: float = 1.0, | |
| max_length: int = 2048, | |
| api_key: str = "", | |
| sys_message="", | |
| ): | |
| if path or url_content: | |
| return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message) | |
| else: | |
| return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message) |