Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from mistralai import Mistral | |
| from langchain_community.tools import TavilySearchResults, JinaSearch | |
| import concurrent.futures | |
| import json | |
| import os | |
| import arxiv | |
| from PIL import Image | |
| import io | |
| import base64 | |
| from langchain.chains import MapReduceDocumentsChain, ReduceDocumentsChain | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain_mistralai import ChatMistralAI | |
| from langchain.chains.combine_documents.stuff import StuffDocumentsChain | |
| from langchain.chains.llm import LLMChain | |
| from langchain_core.prompts import PromptTemplate | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("mistral-community/pixtral-12b") | |
| def count_tokens_in_text(text): | |
| tokens = tokenizer(text, return_tensors="pt", truncation=False, add_special_tokens=True) | |
| return len(tokens["input_ids"][0]) | |
| # Set environment variables for Tavily API | |
| os.environ["TAVILY_API_KEY"] = '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' | |
| # Mistral client API keys | |
| client_1 = Mistral(api_key='eLES5HrVqduOE1OSWG6C5XyEUeR7qpXQ') | |
| client_2 = Mistral(api_key='VPqG8sCy3JX5zFkpdiZ7bRSnTLKwngFJ') | |
| client_3 = Mistral(api_key='cvyu5Rdk2lS026epqL4VB6BMPUcUMSgt') | |
| api_key_4 = 'aYls8aj48SOEov8AY1dwp4hr07MsCRFb' | |
| client_4 = ChatMistralAI(api_key=api_key_4, model="pixtral-12b-2409") | |
| # Function to encode images in base64 | |
| def encode_image_bytes(image_bytes): | |
| return base64.b64encode(image_bytes).decode('utf-8') | |
| # Function to decode base64 images | |
| def decode_base64_image(base64_str): | |
| image_data = base64.b64decode(base64_str) | |
| return Image.open(io.BytesIO(image_data)) | |
| # Process text and images provided by the user | |
| def process_input(text_input, images_base64): | |
| images = [] | |
| if images_base64: | |
| for img_data in images_base64: | |
| try: | |
| img = decode_base64_image(img_data) | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="JPEG") | |
| image_base64 = encode_image_bytes(buffered.getvalue()) | |
| images.append({"type": "image_url", "image_url": f"data:image/jpeg;base64,{image_base64}"}) | |
| except Exception as e: | |
| print(f"Error decoding image: {e}") | |
| return text_input, images | |
| # Search setup function | |
| def setup_search(question): | |
| try: | |
| tavily_tool = TavilySearchResults(max_results=20) | |
| results = tavily_tool.invoke({"query": f"{question}"}) | |
| if isinstance(results, list): | |
| return results, 'tavily_tool' | |
| except Exception as e: | |
| print("Error with TavilySearchResults:", e) | |
| try: | |
| jina_tool = JinaSearch() | |
| results = json.loads(str(jina_tool.invoke({"query": f"{question}"}))) | |
| if isinstance(results, list): | |
| return results, 'jina_tool' | |
| except Exception as e: | |
| print("Error with JinaSearch:", e) | |
| return [], '' | |
| # Function to extract key topics | |
| def extract_key_topics(content, images=[]): | |
| prompt = f""" | |
| Extract the primary themes from the text below. List each theme in as few words as possible, focusing on essential concepts only. Format as a concise, unordered list with no extraneous words. | |
| ```{content}``` | |
| LIST IN ENGLISH: | |
| - | |
| """ | |
| message_content = [{"type": "text", "text": prompt}] + images | |
| response = client_1.chat.complete( | |
| model="pixtral-12b-2409", | |
| messages=[{"role": "user", "content": message_content}] | |
| ) | |
| return response.choices[0].message.content | |
| def extract_key_topics_with_large_text(content, images=[]): | |
| # Map prompt template for extracting key themes | |
| map_template = f""" | |
| Текст: {{docs}} | |
| Изображения: {{images}} | |
| Extract the primary themes from the text below. List each theme in as few words as possible, focusing on essential concepts only. Format as a concise, unordered list with no extraneous words. | |
| LIST IN ENGLISH: | |
| - | |
| :""" | |
| map_prompt = PromptTemplate.from_template(map_template) | |
| map_chain = LLMChain(llm=client_4, prompt=map_prompt) | |
| # Reduce prompt template to further refine and extract key themes | |
| reduce_template = f"""Следующий текст состоит из нескольких кратких итогов: | |
| {{docs}} | |
| Extract the primary themes from the text below. List each theme in as few words as possible, focusing on essential concepts only. Format as a concise, unordered list with no extraneous words. | |
| LIST IN ENGLISH: | |
| - | |
| :""" | |
| reduce_prompt = PromptTemplate.from_template(reduce_template) | |
| reduce_chain = LLMChain(llm=client_4, prompt=reduce_prompt) | |
| # Combine documents chain for Reduce step | |
| combine_documents_chain = StuffDocumentsChain( | |
| llm_chain=reduce_chain, document_variable_name="docs" | |
| ) | |
| # ReduceDocumentsChain configuration | |
| reduce_documents_chain = ReduceDocumentsChain( | |
| combine_documents_chain=combine_documents_chain, | |
| collapse_documents_chain=combine_documents_chain, | |
| token_max=128000, | |
| ) | |
| # MapReduceDocumentsChain combining Map and Reduce | |
| map_reduce_chain = MapReduceDocumentsChain( | |
| llm_chain=map_chain, | |
| reduce_documents_chain=reduce_documents_chain, | |
| document_variable_name="docs", | |
| return_intermediate_steps=False, | |
| ) | |
| # Text splitter configuration | |
| text_splitter = CharacterTextSplitter.from_huggingface_tokenizer( | |
| tokenizer, | |
| chunk_size=100000, | |
| chunk_overlap=14000, | |
| ) | |
| # Split the text into documents | |
| split_docs = text_splitter.create_documents([content]) | |
| # Include image descriptions (optional, if required by the prompt) | |
| image_descriptions = "\n".join( | |
| [f"Изображение {i+1}: {img['image_url']}" for i, img in enumerate(images)] | |
| ) | |
| # Run the summarization chain to extract key themes | |
| key_topics = map_reduce_chain.run({"input_documents": split_docs, "images": image_descriptions}) | |
| return key_topics | |
| def search_relevant_articles_arxiv(key_topics, max_articles=100): | |
| articles_by_topic = {} | |
| final_topics = [] | |
| def fetch_articles_for_topic(topic): | |
| topic_articles = [] | |
| try: | |
| # Fetch articles using arxiv.py based on the topic | |
| search = arxiv.Search( | |
| query=topic, | |
| max_results=max_articles, | |
| sort_by=arxiv.SortCriterion.Relevance | |
| ) | |
| for result in search.results(): | |
| article_data = { | |
| "title": result.title, | |
| "doi": result.doi, | |
| "summary": result.summary, | |
| "url": result.entry_id, | |
| "pdf_url": result.pdf_url | |
| } | |
| topic_articles.append(article_data) | |
| final_topics.append(topic) | |
| except Exception as e: | |
| print(f"Error fetching articles for topic '{topic}': {e}") | |
| return topic, topic_articles | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: | |
| # Use threads to fetch articles for each topic | |
| futures = {executor.submit(fetch_articles_for_topic, topic): topic for topic in key_topics} | |
| for future in concurrent.futures.as_completed(futures): | |
| topic, articles = future.result() | |
| if articles: | |
| articles_by_topic[topic] = articles | |
| return articles_by_topic, list(set(final_topics)) | |
| def init(content, images=[]): | |
| if count_tokens_in_text(text=content) < 128_000: | |
| key_topics = extract_key_topics(content, images) | |
| key_topics = [topic.strip("- ") for topic in key_topics.split("\n") if topic] | |
| articles_by_topic, final_topics = search_relevant_articles_arxiv(key_topics) | |
| result_json = json.dumps(articles_by_topic, indent=4) | |
| return final_topics, result_json | |
| else: | |
| key_topics = extract_key_topics_with_large_text(content, images) | |
| key_topics = [topic.strip("- ") for topic in key_topics.split("\n") if topic] | |
| articles_by_topic, final_topics = search_relevant_articles_arxiv(key_topics) | |
| result_json = json.dumps(articles_by_topic, indent=4) | |
| return final_topics, result_json | |
| # Summarization function | |
| def process_article_for_summary(text, images=[], compression_percentage=30): | |
| prompt = f""" | |
| You are a commentator. | |
| # article: | |
| {text} | |
| # Instructions: | |
| ## Summarize IN RUSSIAN: | |
| In clear and concise language, summarize the key points and themes presented in the article by cutting it by {compression_percentage} percent in the markdown format. | |
| """ | |
| if len(images) >= 8 : | |
| images = images[:7] | |
| message_content = [{"type": "text", "text": prompt}] + images | |
| response = client_3.chat.complete( | |
| model="pixtral-12b-2409", | |
| messages=[{"role": "user", "content": message_content}] | |
| ) | |
| return response.choices[0].message.content | |
| def process_large_article_for_summary(text, images=[], compression_percentage=30): | |
| # Map prompt template | |
| map_template = f"""Следующий текст состоит из текста и изображений: | |
| Текст: {{docs}} | |
| Изображения: {{images}} | |
| На основе приведенного материала, выполните сжатие текста, выделяя основные темы и важные моменты. | |
| Уровень сжатия: {compression_percentage}%. | |
| Ответ предоставьте на русском языке в формате Markdown. | |
| Полезный ответ:""" | |
| map_prompt = PromptTemplate.from_template(map_template) | |
| map_chain = LLMChain(llm=client_4, prompt=map_prompt) | |
| # Reduce prompt template | |
| reduce_template = f"""Следующий текст состоит из нескольких кратких итогов: | |
| {{docs}} | |
| На основе этих кратких итогов, выполните финальное сжатие текста, объединяя основные темы и ключевые моменты. | |
| Уровень сжатия: {compression_percentage}%. | |
| Результат предоставьте на русском языке в формате Markdown. | |
| Полезный ответ:""" | |
| reduce_prompt = PromptTemplate.from_template(reduce_template) | |
| reduce_chain = LLMChain(llm=client_4, prompt=reduce_prompt) | |
| # Combine documents chain for Reduce step | |
| combine_documents_chain = StuffDocumentsChain( | |
| llm_chain=reduce_chain, document_variable_name="docs" | |
| ) | |
| # ReduceDocumentsChain configuration | |
| reduce_documents_chain = ReduceDocumentsChain( | |
| combine_documents_chain=combine_documents_chain, | |
| collapse_documents_chain=combine_documents_chain, | |
| token_max=128000, | |
| ) | |
| # MapReduceDocumentsChain combining Map and Reduce | |
| map_reduce_chain = MapReduceDocumentsChain( | |
| llm_chain=map_chain, | |
| reduce_documents_chain=reduce_documents_chain, | |
| document_variable_name="docs", | |
| return_intermediate_steps=False, | |
| ) | |
| # Text splitter configuration | |
| text_splitter = CharacterTextSplitter.from_huggingface_tokenizer( | |
| tokenizer, | |
| chunk_size=100000, | |
| chunk_overlap=14000, | |
| ) | |
| # Split the text into documents | |
| split_docs = text_splitter.create_documents([text]) | |
| # Include image descriptions | |
| image_descriptions = "\n".join( | |
| [f"Изображение {i+1}: {img['image_url']}" for i, img in enumerate(images)] | |
| ) | |
| # Run the summarization chain | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| extract_future = executor.submit(init, text, images) | |
| summary = map_reduce_chain.run({"input_documents": split_docs, "images": image_descriptions}) | |
| key_topics , result_article_json = extract_future.result() | |
| return summary, key_topics, result_article_json | |
| # Question answering function | |
| def ask_question_to_mistral(text, question, images=[]): | |
| prompt = f"Answer the following question without mentioning it or repeating the original text on which the question is asked in style markdown.IN RUSSIAN:\nQuestion: {question}\n\nText:\n{text}" | |
| if len(images) >= 8 : | |
| images = images[:7] | |
| message_content = [{"type": "text", "text": prompt}] + images | |
| search_tool, tool = setup_search(question) | |
| context = '' | |
| if search_tool: | |
| if tool == 'tavily_tool': | |
| for result in search_tool: | |
| context += f"{result.get('url', 'N/A')} : {result.get('content', 'No content')} \n" | |
| elif tool == 'jina_tool': | |
| for result in search_tool: | |
| context += f"{result.get('link', 'N/A')} : {result.get('snippet', 'No snippet')} : {result.get('content', 'No content')} \n" | |
| response = client_2.chat.complete( | |
| model="pixtral-12b-2409", | |
| messages=[{"role": "user", "content": f'{message_content}\n\nAdditional Context from Web Search:\n{context}'}] | |
| ) | |
| return response.choices[0].message.content | |
| def ask_question_to_mistral_with_large_text(text, question, images=[]): | |
| # Prompts for QA | |
| map_template = """Следующий текст содержит статью/произведение: | |
| Текст: {{docs}} | |
| Изображения: {{images}} | |
| На основе приведенного текста, ответьте на следующий вопрос: | |
| Вопрос: {question} | |
| Ответ должен быть точным. Пожалуйста, ответьте на русском языке в формате Markdown. | |
| Полезный ответ:""" | |
| reduce_template = """Следующий текст содержит несколько кратких ответов на вопрос: | |
| {{docs}} | |
| Объедините их в финальный ответ. Ответ предоставьте на русском языке в формате Markdown. | |
| Полезный ответ:""" | |
| map_prompt = PromptTemplate.from_template(map_template) | |
| map_chain = LLMChain(llm=client_4, prompt=map_prompt) | |
| reduce_prompt = PromptTemplate.from_template(reduce_template) | |
| reduce_chain = LLMChain(llm=client_4, prompt=reduce_prompt) | |
| # Combine documents chain for Reduce step | |
| combine_documents_chain = StuffDocumentsChain( | |
| llm_chain=reduce_chain, document_variable_name="docs" | |
| ) | |
| # ReduceDocumentsChain configuration | |
| reduce_documents_chain = ReduceDocumentsChain( | |
| combine_documents_chain=combine_documents_chain, | |
| collapse_documents_chain=combine_documents_chain, | |
| token_max=128000, | |
| ) | |
| # MapReduceDocumentsChain combining Map and Reduce | |
| map_reduce_chain = MapReduceDocumentsChain( | |
| llm_chain=map_chain, | |
| reduce_documents_chain=reduce_documents_chain, | |
| document_variable_name="docs", | |
| return_intermediate_steps=False, | |
| ) | |
| # Text splitter configuration | |
| text_splitter = CharacterTextSplitter.from_huggingface_tokenizer( | |
| tokenizer, | |
| chunk_size=100000, | |
| chunk_overlap=14000, | |
| ) | |
| # Split the text into documents | |
| split_docs = text_splitter.create_documents([text]) | |
| # Include image descriptions | |
| image_descriptions = "\n".join( | |
| [f"Изображение {i+1}: {img['image_url']}" for i, img in enumerate(images)] | |
| ) | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| extract_future = executor.submit(init, text, images) | |
| summary = map_reduce_chain.run({"input_documents": split_docs, "question": question , "images": image_descriptions}) | |
| key_topics , result_article_json = extract_future.result() | |
| return summary, key_topics, result_article_json | |
| # Gradio interface | |
| def gradio_interface(text_input, images_base64, task, question, compression_percentage): | |
| text, images = process_input(text_input, images_base64) | |
| if task == "Summarization": | |
| if count_tokens_in_text(text=text) < 128_000: | |
| topics, articles_json = init(text, images) | |
| summary = process_article_for_summary(text, images, compression_percentage) | |
| return {"Topics": topics, "Summary": summary, "Articles": articles_json} | |
| else: | |
| summary , key_topics, result_article_json = process_large_article_for_summary(text, images, compression_percentage) | |
| return {"Topics": key_topics, "Summary": summary, "Articles": result_article_json} | |
| elif task == "Question Answering": | |
| if question: | |
| if count_tokens_in_text(text=text) < 128_000: | |
| topics, articles_json = init(text, images) | |
| answer = ask_question_to_mistral(text, question, images) | |
| return {"Topics": topics, "Answer": answer, "Articles": articles_json} | |
| else: | |
| summary , key_topics, result_article_json = ask_question_to_mistral_with_large_text(text, question, images) | |
| return {"Topics": key_topics, "Answer": answer, "Articles": result_article_json} | |
| else: | |
| return {"Topics": topics, "Answer": "No question provided.", "Articles": articles_json} | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Text Analysis: Summarization or Question Answering") | |
| with gr.Row(): | |
| text_input = gr.Textbox(label="Input Text") | |
| images_base64 = gr.Textbox(label="Base64 Images (comma-separated, if any)", placeholder="data:image/jpeg;base64,...", lines=2) | |
| task_choice = gr.Radio(["Summarization", "Question Answering"], label="Select Task") | |
| question_input = gr.Textbox(label="Question (for Question Answering)", visible=False) | |
| compression_input = gr.Slider(label="Compression Percentage (for Summarization)", minimum=10, maximum=90, value=30, visible=False) | |
| task_choice.change(lambda choice: (gr.update(visible=choice == "Question Answering"), | |
| gr.update(visible=choice == "Summarization")), | |
| inputs=task_choice, outputs=[question_input, compression_input]) | |
| with gr.Row(): | |
| result_output = gr.JSON(label="Results") | |
| submit_button = gr.Button("Submit") | |
| submit_button.click(gradio_interface, [text_input, images_base64, task_choice, question_input, compression_input], result_output) | |
| demo.launch(show_error=True) |