|
import os |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_huggingface import HuggingFacePipeline |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.chains import RetrievalQA |
|
import gradio as gr |
|
import spaces |
|
|
|
|
|
|
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) |
|
|
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.95, |
|
top_k=40, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=pipe) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
db_FAISS = FAISS.load_local("/home/user/app/", embeddings, allow_dangerous_deserialization=True) |
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=db_FAISS.as_retriever(search_kwargs={"k": 3}), |
|
return_source_documents=True |
|
) |
|
|
|
print("fuck14") |
|
@spaces.GPU |
|
def query_documents(query): |
|
result = qa_chain({"query": query}) |
|
answer = result['result'] |
|
sources = [doc.metadata for doc in result['source_documents']] |
|
return answer, sources |
|
|
|
|
|
def gradio_interface(query): |
|
answer, sources = query_documents(query) |
|
source_text = "\n\nSources:\n" + "\n".join([f"Source: {s.get('source', 'Unknown')}, Page: {s.get('page', 'Unknown')}" for s in sources]) |
|
return answer + source_text |
|
|
|
iface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs="text", |
|
outputs="text", |
|
title="Document Q&A with TinyLlama", |
|
description="Ask questions about your documents" |
|
) |
|
|
|
|
|
iface.launch() |