import argparse import os import gradio as gr from loguru import logger from similarities import BertSimilarity, BM25Similarity from chatpdf import Rag pwd_path = os.path.abspath(os.path.dirname(__file__)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--sim_model_name", type=str, default="sentence-transformers/all-mpnet-base-v2") parser.add_argument("--gen_model_type", type=str, default="auto") parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct") parser.add_argument("--lora_model", type=str, default=None) parser.add_argument("--rerank_model_name", type=str, default="") parser.add_argument("--corpus_files", type=str, default="Acuerdo009.pdf") parser.add_argument("--device", type=str, default=None) #parser.add_argument("--int4", action='store_true', help="use int4 quantization") #parser.add_argument("--int8", action='store_true', help="use int8 quantization") parser.add_argument("--chunk_size", type=int, default=220) parser.add_argument("--chunk_overlap", type=int, default=0) parser.add_argument("--num_expand_context_chunk", type=int, default=1) parser.add_argument("--server_name", type=str, default="0.0.0.0") parser.add_argument("--server_port", type=int, default=8082) parser.add_argument("--share", action='store_true', default=True, help="share model") args = parser.parse_args() logger.info(args) # Inicializar el modelo sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device) model = Rag( similarity_model=sim_model, generate_model_type=args.gen_model_type, generate_model_name_or_path=args.gen_model_name, lora_model_name_or_path=args.lora_model, corpus_files=args.corpus_files.split(','), device=args.device, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap, num_expand_context_chunk=args.num_expand_context_chunk, rerank_model_name_or_path=args.rerank_model_name, ) logger.info(f"chatpdf model: {model}") def predict_stream(message, history): history_format = [] for human, assistant in history: history_format.append([human, assistant]) model.history = history_format for chunk in model.predict_stream(message): yield chunk def predict(message, history): logger.debug(message) response, reference_results = model.predict(message) r = response + "\n\n" + '\n'.join(reference_results) logger.debug(r) return r chatbot_stream = gr.Chatbot( height=600, avatar_images=( os.path.join(pwd_path, "assets/user.png"), os.path.join(pwd_path, "assets/Logo1.png"), ), bubble_full_width=False) # Actualizar el título y la descripción title = " 🤖ChatPDF Zonia🤖 " # description = "Enlace en Github: [shibing624/ChatPDF](https://github.com/shibing624/ChatPDF)" css = """.toast-wrap { display: none !importante } """ examples = ['Puede hablarme del PNL?', 'Introducción a la PNL'] chat_interface_stream = gr.ChatInterface( predict, textbox=gr.Textbox(lines=4, placeholder="Ask me question", scale=7), # Añadir submit=True title=title, # description=description, chatbot=chatbot_stream, css=css, examples=examples, theme='soft', ) # Lanzar la aplicación sin `server_name` ni `server_port` chat_interface_stream.launch()