imSleepy commited on
Commit
b5e1894
·
verified ·
1 Parent(s): 6636126

reverted back to original (chatbot.py)

Browse files
Files changed (1) hide show
  1. chatbot.py +24 -17
chatbot.py CHANGED
@@ -1,19 +1,19 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
  from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  from sentence_transformers import SentenceTransformer
5
  from pinecone import Pinecone
6
 
7
- device = 'cpu'
8
 
9
  # Initialize Pinecone instance
10
- pc = Pinecone(api_key='your-pinecone-api-key')
11
 
12
- # Initialize FastAPI app
13
- app = FastAPI()
 
14
 
15
- # Initialize the models
16
  def load_models():
 
 
17
  retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
18
  tokenizer = T5Tokenizer.from_pretrained('t5-small')
19
  generator = T5ForConditionalGeneration.from_pretrained('t5-base').to(device)
@@ -22,27 +22,34 @@ def load_models():
22
 
23
  retriever, generator, tokenizer = load_models()
24
 
25
- class QueryInput(BaseModel):
26
- input: str
27
-
28
- @app.post("/predict")
29
- def predict(query: QueryInput):
30
- query_text = query.input
31
  # Query Pinecone
32
- xq = retriever.encode([query_text]).tolist()
33
  xc = index.query(vector=xq, top_k=1, include_metadata=True)
 
 
 
34
 
 
35
  if 'matches' in xc and isinstance(xc['matches'], list):
36
  context = [m['metadata']['Output'] for m in xc['matches']]
37
  context_str = " ".join(context)
38
- formatted_query = f"answer the question: {query_text} context: {context_str}"
39
  else:
 
40
  context_str = ""
41
- formatted_query = f"answer the question: {query_text} context: {context_str}"
42
 
43
  # Generate answer using T5 model
 
 
 
 
 
 
 
44
  inputs = tokenizer.encode(formatted_query, return_tensors="pt", max_length=512, truncation=True).to(device)
45
  ids = generator.generate(inputs, num_beams=2, min_length=10, max_length=60, repetition_penalty=1.2)
46
  answer = tokenizer.decode(ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
47
 
48
- return {"response": answer}
 
 
 
1
  from transformers import T5Tokenizer, T5ForConditionalGeneration
2
  from sentence_transformers import SentenceTransformer
3
  from pinecone import Pinecone
4
 
5
+ device = 'cpu'
6
 
7
  # Initialize Pinecone instance
8
+ pc = Pinecone(api_key='89eeb534-da10-4068-92f7-12eddeabe1e5')
9
 
10
+ # Check if the index exists; if not, create it
11
+ index_name = 'abstractive-question-answering'
12
+ index = pc.Index(index_name)
13
 
 
14
  def load_models():
15
+ print("Loading models...")
16
+
17
  retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
18
  tokenizer = T5Tokenizer.from_pretrained('t5-small')
19
  generator = T5ForConditionalGeneration.from_pretrained('t5-base').to(device)
 
22
 
23
  retriever, generator, tokenizer = load_models()
24
 
25
+ def process_query(query):
 
 
 
 
 
26
  # Query Pinecone
27
+ xq = retriever.encode([query]).tolist()
28
  xc = index.query(vector=xq, top_k=1, include_metadata=True)
29
+
30
+ # Print the response to check the structure
31
+ print("Pinecone response:", xc)
32
 
33
+ # Check if 'matches' exists and is a list
34
  if 'matches' in xc and isinstance(xc['matches'], list):
35
  context = [m['metadata']['Output'] for m in xc['matches']]
36
  context_str = " ".join(context)
37
+ formatted_query = f"answer the question: {query} context: {context_str}"
38
  else:
39
+ # Handle the case where 'matches' isn't found or isn't in the expected format
40
  context_str = ""
41
+ formatted_query = f"answer the question: {query} context: {context_str}"
42
 
43
  # Generate answer using T5 model
44
+ output_text = context_str
45
+ if len(output_text.splitlines()) > 5:
46
+ return output_text
47
+
48
+ if output_text.lower() == "none":
49
+ return "The topic is not covered in the student manual."
50
+
51
  inputs = tokenizer.encode(formatted_query, return_tensors="pt", max_length=512, truncation=True).to(device)
52
  ids = generator.generate(inputs, num_beams=2, min_length=10, max_length=60, repetition_penalty=1.2)
53
  answer = tokenizer.decode(ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
54
 
55
+ return answer