Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer | |
| from interface import GemmaLLMInterface | |
| from llama_index.core.node_parser import SentenceSplitter | |
| from llama_index.embeddings.instructor import InstructorEmbedding | |
| import gradio as gr | |
| from llama_index.core import ChatPromptTemplate | |
| from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage | |
| from llama_index.core.node_parser import SentenceSplitter | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| import spaces | |
| from huggingface_hub import login | |
| from llama_index.core.memory import ChatMemoryBuffer | |
| from typing import Iterator, List, Any | |
| from llama_index.core.chat_engine import CondensePlusContextChatEngine | |
| from llama_index.core.llms import ChatMessage, MessageRole | |
| huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
| login(huggingface_token) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model_id = "google/gemma-2-2b-it" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| token=True) | |
| model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it") | |
| model.eval() | |
| # what models will be used by LlamaIndex: | |
| Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base") | |
| Settings.llm = GemmaLLMInterface() | |
| documents_paths = { | |
| 'blockchain': 'data/blockchainprova.txt', | |
| 'metaverse': 'data/metaverso', | |
| 'payment': 'data/payment' | |
| } | |
| session_state = {"documents_loaded": False, | |
| "document_db": None, | |
| "original_message": None, | |
| "clarification": False} | |
| ############################--------------------------------- | |
| # Get the parser | |
| parser = SentenceSplitter.from_defaults( | |
| chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n" | |
| ) | |
| def build_index(path: str): | |
| # Load documents from a file | |
| documents = SimpleDirectoryReader(input_files=[path]).load_data() | |
| # Parse the documents into nodes | |
| nodes = parser.get_nodes_from_documents(documents) | |
| # Build the vector store index from the nodes | |
| index = VectorStoreIndex(nodes) | |
| return index | |
| def handle_query(query_str: str, | |
| chat_history: list[tuple[str, str]], | |
| session: dict[str, Any]) -> Iterator[str]: | |
| global index | |
| if not session["index"]: | |
| matched_path = None | |
| words = query_str.lower() | |
| for key, path in documents_paths.items(): | |
| if key in words: | |
| matched_path = path | |
| break | |
| if matched_path: | |
| index = build_index(matched_path) | |
| session["index"] = True | |
| else: ## CHIEDI CHIARIMENTO | |
| conversation: List[ChatMessage] = [] | |
| for user, assistant in chat_history: | |
| conversation.extend( | |
| [ | |
| ChatMessage(role=MessageRole.USER, content=user), | |
| ChatMessage(role=MessageRole.ASSISTANT, content=assistant), | |
| ] | |
| ) | |
| index = build_index("data/chiarimento.txt") | |
| else: | |
| # The index is already built, no need to rebuild it. | |
| conversation: List[ChatMessage] = [] | |
| for user, assistant in chat_history: | |
| conversation.extend( | |
| [ | |
| ChatMessage(role=MessageRole.USER, content=user), | |
| ChatMessage(role=MessageRole.ASSISTANT, content=assistant), | |
| ] | |
| ) | |
| #conversation.append( ChatMessage(role=MessageRole.USER, content=query_str)) | |
| #pass | |
| try: | |
| memory = ChatMemoryBuffer.from_defaults(token_limit=None) | |
| chat_engine = index.as_chat_engine( | |
| chat_mode="condense_plus_context", | |
| memory=memory, | |
| similarity_top_k=4, | |
| response_mode="tree_summarize", #Good for summarization purposes | |
| context_prompt = ( | |
| "Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso." | |
| " Quando un utente ti chiede informazioni su di te o sul tuo creatore puoi dire che sei un assistente ricercatore creato dagli Osservatori Digitali e fornire gli argomenti di cui sei esperto." | |
| " Ecco i documenti rilevanti per il contesto:\n" | |
| "{context_str}" | |
| "\nIstruzione: Usa la cronologia delle chat precedenti, o il contesto sopra, per interagire e aiutare l'utente a rispondere alla sua domanda." | |
| ), | |
| verbose=False, | |
| ) | |
| outputs = [] | |
| response = chat_engine.stream_chat(query_str, conversation) | |
| #response = chat_engine.chat(query_str) | |
| for token in response.response_gen: | |
| #if not token.startswith("system:") and not token.startswith("user:"): | |
| outputs.append(token) | |
| #print(f"Generated token: {token}") | |
| yield "".join(outputs) | |
| except Exception as e: | |
| yield f"Error processing query: {str(e)}" | |