Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| from pdfminer.high_level import extract_text | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace | |
| from langchain_core.runnables import RunnablePassthrough, Runnable | |
| from io import BytesIO | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.documents import Document | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from huggingface_hub import InferenceClient | |
| import time | |
| from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type | |
| import logging | |
| import os | |
| lo = "hf_JyAJApaXhIrONPFSIo" | |
| ve = "wbnJbrXViYurrsvP" | |
| half = lo+ve | |
| HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", half) | |
| client = InferenceClient( | |
| model="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| token=HF_TOKEN | |
| ) | |
| class HuggingFaceInferenceClientRunnable(Runnable): | |
| def __init__(self, client, max_tokens=512, temperature=0.7, top_p=0.95): | |
| self.client = client | |
| self.max_tokens = max_tokens | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| def invoke(self, input, config=None): | |
| # Extract the prompt from the input (ChatPromptTemplate output) | |
| prompt = input.to_messages()[0].content | |
| messages = [{"role": "user", "content": prompt}] | |
| # Call the InferenceClient with streaming | |
| response = "" | |
| for parts in self.client.chat_completion( | |
| messages, | |
| max_tokens=self.max_tokens, | |
| stream=True, | |
| temperature=self.temperature, | |
| top_p=self.top_p | |
| ): | |
| # Handle streaming response parts | |
| for part in parts.choices: | |
| token = part.delta.content | |
| if token: | |
| response += token | |
| return response | |
| def update_params(self, max_tokens, temperature, top_p): | |
| self.max_tokens = max_tokens | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| def extract_pdf_text(url: str) -> str: | |
| response = requests.get(url) | |
| pdf_file = BytesIO(response.content) | |
| text = extract_text(pdf_file) | |
| return text | |
| pdf_url = "https://arxiv.org/pdf/2408.09869" | |
| text = extract_pdf_text(pdf_url) | |
| docs_list = [Document(page_content=text)] | |
| text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=7500, chunk_overlap=100) | |
| docs_splits = text_splitter.split_documents(docs_list) | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| vectorstore = Chroma.from_documents( | |
| documents=docs_splits, | |
| collection_name="rag-chroma", | |
| embedding=embeddings, | |
| ) | |
| retriever = vectorstore.as_retriever() | |
| llm = HuggingFaceInferenceClientRunnable(client) | |
| # After RAG chain | |
| after_rag_template = """You are a {role}. Summarize the following content for yourself and speak in terms of first person. | |
| Only include content relevant to that role like a resume summary. | |
| Context: | |
| {context} | |
| Question: Give a one paragraph summary of the key skills a {role} can have from this document. | |
| """ | |
| after_rag_prompt = ChatPromptTemplate.from_template(after_rag_template) | |
| def format_query(input_dict): | |
| return f"Give a one paragraph summary of the key skills a {input_dict['role']} can have from this document." | |
| after_rag_chain = ( | |
| { | |
| "context": format_query | retriever, | |
| "role": lambda x: x["role"], | |
| } | |
| | after_rag_prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| def process_query(role, system_message, max_tokens, temperature, top_p): | |
| llm.update_params(max_tokens, temperature, top_p) | |
| # After RAG | |
| after_rag_result = after_rag_chain.invoke({"role": role}) | |
| return f"**RAG Summary**\n{after_rag_result}" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Zephyr Chatbot Controls") | |
| role_dropdown = gr.Dropdown(choices=["SDE", "BA"], label="Select Role", value="SDE") | |
| system_message = gr.Textbox(value="You are a friendly chatbot.", label="System message") | |
| max_tokens = gr.Slider(1, 2048, value=512, label="Max tokens") | |
| temperature = gr.Slider(0.1, 4.0, value=0.7, label="Temperature", step=0.1) | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p", step=0.05) | |
| output = gr.Textbox(label="Output", lines=20) | |
| submit_btn = gr.Button("Submit") | |
| clear_btn = gr.Button("Clear") | |
| submit_btn.click( | |
| fn=process_query, | |
| inputs=[role_dropdown, system_message, max_tokens, temperature, top_p], | |
| outputs=output | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", gr.Info("Chat cleared!")), | |
| outputs=[output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |