imSleepy commited on
Commit
f82f890
1 Parent(s): db96dc8

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +2 -20
chatbot.py CHANGED
@@ -4,22 +4,16 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  from sentence_transformers import SentenceTransformer
5
  from pinecone import Pinecone
6
 
7
- device = 'cpu'
8
 
9
  # Initialize Pinecone instance
10
- pc = Pinecone(api_key='89eeb534-da10-4068-92f7-12eddeabe1e5')
11
-
12
- # Check if the index exists; if not, create it
13
- index_name = 'abstractive-question-answering'
14
- index = pc.Index(index_name)
15
 
16
  # Initialize FastAPI app
17
  app = FastAPI()
18
 
19
  # Initialize the models
20
  def load_models():
21
- print("Loading models...")
22
-
23
  retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
24
  tokenizer = T5Tokenizer.from_pretrained('t5-small')
25
  generator = T5ForConditionalGeneration.from_pretrained('t5-base').to(device)
@@ -38,29 +32,17 @@ def predict(query: QueryInput):
38
  xq = retriever.encode([query_text]).tolist()
39
  xc = index.query(vector=xq, top_k=1, include_metadata=True)
40
 
41
- # Check if 'matches' exists and is a list
42
  if 'matches' in xc and isinstance(xc['matches'], list):
43
  context = [m['metadata']['Output'] for m in xc['matches']]
44
  context_str = " ".join(context)
45
  formatted_query = f"answer the question: {query_text} context: {context_str}"
46
  else:
47
- # Handle the case where 'matches' isn't found or isn't in the expected format
48
  context_str = ""
49
  formatted_query = f"answer the question: {query_text} context: {context_str}"
50
 
51
  # Generate answer using T5 model
52
- output_text = context_str
53
- if len(output_text.splitlines()) > 5:
54
- return {"response": output_text}
55
-
56
- if output_text.lower() == "none":
57
- return {"response": "The topic is not covered in the student manual."}
58
-
59
  inputs = tokenizer.encode(formatted_query, return_tensors="pt", max_length=512, truncation=True).to(device)
60
  ids = generator.generate(inputs, num_beams=2, min_length=10, max_length=60, repetition_penalty=1.2)
61
  answer = tokenizer.decode(ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
62
 
63
  return {"response": answer}
64
-
65
- # To run the server (use uvicorn when deploying):
66
- # uvicorn chatbot:app --reload
 
4
  from sentence_transformers import SentenceTransformer
5
  from pinecone import Pinecone
6
 
7
+ device = 'cpu'
8
 
9
  # Initialize Pinecone instance
10
+ pc = Pinecone(api_key='your-pinecone-api-key')
 
 
 
 
11
 
12
  # Initialize FastAPI app
13
  app = FastAPI()
14
 
15
  # Initialize the models
16
  def load_models():
 
 
17
  retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
18
  tokenizer = T5Tokenizer.from_pretrained('t5-small')
19
  generator = T5ForConditionalGeneration.from_pretrained('t5-base').to(device)
 
32
  xq = retriever.encode([query_text]).tolist()
33
  xc = index.query(vector=xq, top_k=1, include_metadata=True)
34
 
 
35
  if 'matches' in xc and isinstance(xc['matches'], list):
36
  context = [m['metadata']['Output'] for m in xc['matches']]
37
  context_str = " ".join(context)
38
  formatted_query = f"answer the question: {query_text} context: {context_str}"
39
  else:
 
40
  context_str = ""
41
  formatted_query = f"answer the question: {query_text} context: {context_str}"
42
 
43
  # Generate answer using T5 model
 
 
 
 
 
 
 
44
  inputs = tokenizer.encode(formatted_query, return_tensors="pt", max_length=512, truncation=True).to(device)
45
  ids = generator.generate(inputs, num_beams=2, min_length=10, max_length=60, repetition_penalty=1.2)
46
  answer = tokenizer.decode(ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
47
 
48
  return {"response": answer}