llmaindex-poc / app.py
atharva-nlp's picture
Update app.py
9f04a4a verified
import gradio as gr
from transformers import pipeline
from transformers.utils import logging
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import torch
from llama_index.core import VectorStoreIndex
from llama_index.core import Document
from llama_index.core import Settings
from llama_index.llms.huggingface import (
HuggingFaceInferenceAPI,
HuggingFaceLLM,
)
# logging.set_verbosity_error()
# llm = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
# llm = pipeline("question-answering", model="sshleifer/distilbart-cnn-12-6")
# pipe = pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-beta", torch_dtype=torch.bfloat16, device_map="auto")
#Settings.llm = pipeline("text2text-generation", model="facebook/blenderbot-400M-distill")
Settings.llm = HuggingFaceLLM(model_name="facebook/blenderbot-400M-distill",
device_map="cpu",
context_window=128,
tokenizer_name="facebook/blenderbot-400M-distill"
)
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
documents = [Document(text="Indian parliament elections happened in April-May 2024. BJP Party won.")]
index = VectorStoreIndex.from_documents(
documents,
)
query_engine = index.as_query_engine()
# def get_embeddings(mytext):
# embeddings = embed_model.get_text_embedding("Hello World!")
# subarray = embeddings[:5]
# out = [str(i) for i in subarray]
# return '::'.join(out)
# def summarize(mytext):
# out = llm(mytext, max_length=130, min_length=30, do_sample=False)
# out = llm(mytext)
# return str(out[0])
# def chat_completion(question):
# messages = [
# {
# "role": "system",
# "content": "You are a friendly chatbot.",
# },
# {"role": "user", "content": question},
# ]
# prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# outputs = pipe(prompt, max_new_tokens=1024, do_sample=False, temperature=0.3, top_k=75, top_p=0.97)
# return outputs[0]["generated_text"]
# def generate_response(input_text):
# result = pipe(input_text)
# return result[0]['generated_text']
def rag(input_text, file):
return query_engine.query(
input_text
)
iface = gr.Interface(fn=rag, inputs=[gr.Textbox(label="Question", lines=6), gr.File()],
outputs=[gr.Textbox(label="Result", lines=6)],
title="Answer my question",
description= "CoolChatBot"
)
iface.launch()