Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer | |
| from interface import GemmaLLMInterface | |
| from llama_index.core.node_parser import SentenceSplitter | |
| from llama_index.embeddings.instructor import InstructorEmbedding | |
| import gradio as gr | |
| from llama_index.core import ChatPromptTemplate | |
| from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage | |
| from llama_index.core.node_parser import SentenceSplitter | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| import spaces | |
| from huggingface_hub import login | |
| from llama_index.core.memory import ChatMemoryBuffer | |
| huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
| login(huggingface_token) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model_id = "google/gemma-2-2b-it" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", ## change this back to auto!!! | |
| torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| token=True) | |
| model.eval() | |
| #from accelerate import disk_offload | |
| #disk_offload(model=model, offload_dir="offload") | |
| # what models will be used by LlamaIndex: | |
| """Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base") | |
| Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)""" | |
| Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base") | |
| Settings.llm = GemmaLLMInterface(model_id="google/gemma-2-2b-it") | |
| ############################--------------------------------- | |
| # Get the parser | |
| parser = SentenceSplitter.from_defaults( | |
| chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n" | |
| ) | |
| def build_index(): | |
| # Load documents from a file | |
| documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data() | |
| # Parse the documents into nodes | |
| nodes = parser.get_nodes_from_documents(documents) | |
| # Build the vector store index from the nodes | |
| index = VectorStoreIndex(nodes) | |
| return index | |
| def handle_query(query_str, chathistory): | |
| index = build_index() | |
| qa_prompt_str = ( | |
| "Context information is below.\n" | |
| "---------------------\n" | |
| "{context_str}\n" | |
| "---------------------\n" | |
| "Given the context information and not prior knowledge, " | |
| "answer the question: {query_str}\n" | |
| ) | |
| # Text QA Prompt | |
| chat_text_qa_msgs = [ | |
| ( | |
| "system", | |
| "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ", | |
| ), | |
| ("user", qa_prompt_str), | |
| ] | |
| text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs) | |
| try: | |
| # Create a streaming query engine | |
| """query_engine = index.as_query_engine(text_qa_template=text_qa_template, streaming=False, similarity_top_k=1) | |
| # Execute the query | |
| streaming_response = query_engine.query(query_str) | |
| r = streaming_response.response | |
| cleaned_result = r.replace("<end_of_turn>", "").strip() | |
| yield cleaned_result""" | |
| # Stream the response | |
| """outputs = [] | |
| for text in streaming_response.response_gen: | |
| outputs.append(str(text)) | |
| yield "".join(outputs)""" | |
| memory = ChatMemoryBuffer.from_defaults(token_limit=1500) | |
| chat_engine = index.as_chat_engine( | |
| chat_mode="context", | |
| memory=memory, | |
| system_prompt=( | |
| "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. " | |
| ), | |
| ) | |
| response = chat_engine.stream_chat(query_str) | |
| #response = chat_engine.chat(query_str) | |
| for token in response.response_gen: | |
| yield token | |
| except Exception as e: | |
| yield f"Error processing query: {str(e)}" | |