gaur3009 commited on
Commit
12fdd3d
1 Parent(s): 7231690

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertTokenizer, BertModel, GPT2LMHeadModel, GPT2Tokenizer
3
+ import numpy as np
4
+ import pandas as pd
5
+ import os
6
+ import json
7
+ from fastapi import FastAPI
8
+ from pydantic import BaseModel
9
+
10
+ app = FastAPI()
11
+
12
+ data = {
13
+ "questions": [
14
+ "What is Rookus?",
15
+ "How does Rookus use AI in its designs?",
16
+ "What products does Rookus offer?",
17
+ "Can I see samples of Rookus' designs?",
18
+ "How can I join the waitlist for Rookus?",
19
+ "How does Rookus ensure the quality of its AI-generated designs?",
20
+ "Is there a custom design option available at Rookus?",
21
+ "How long does it take to receive a product from Rookus?"
22
+ ],
23
+ "answers": [
24
+ "Rookus is a startup that leverages AI to create unique designs for various products such as clothes, posters, and different arts and crafts.",
25
+ "Rookus uses advanced AI algorithms to generate innovative and aesthetically pleasing designs. These AI models are trained on vast datasets of art and design to produce high-quality mockups.",
26
+ "Rookus offers a variety of products, including clothing, posters, and a range of arts and crafts items, all featuring AI-generated designs.",
27
+ "Yes, Rookus provides samples of its designs on its website. You can view a gallery of products showcasing the AI-generated artwork.",
28
+ "To join the waitlist for Rookus, visit our website and sign up with your email. You'll receive updates on our launch and exclusive early access opportunities.",
29
+ "Rookus ensures the quality of its AI-generated designs through rigorous testing and refinement. Each design goes through multiple review stages to ensure it meets our high standards.",
30
+ "Yes, Rookus offers custom design options. You can submit your preferences, and our AI will generate a design tailored to your specifications.",
31
+ "The delivery time for products from Rookus varies based on the product type and location. Typically, it takes 2-4 weeks for production and delivery."
32
+ ],
33
+ "default_answers": "I'm sorry, I cannot answer this right now. Your question has been saved, and we will get back to you with a response soon."
34
+ }
35
+
36
+ bert_model_name = 'models/bert'
37
+ bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
38
+ bert_model = BertModel.from_pretrained(bert_model_name)
39
+
40
+ gpt2_model_name = 'models/gpt2'
41
+ gpt2_tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
42
+ gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name)
43
+
44
+ def get_bert_embeddings(texts):
45
+ inputs = bert_tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
46
+ with torch.no_grad():
47
+ outputs = bert_model(**inputs)
48
+ return outputs.last_hidden_state[:, 0, :].numpy()
49
+
50
+ def get_closest_question(user_query, questions, threshold=0.95):
51
+ all_texts = questions + [user_query]
52
+ embeddings = get_bert_embeddings(all_texts)
53
+ cosine_similarities = np.dot(embeddings[-1], embeddings[:-1].T) / (
54
+ np.linalg.norm(embeddings[-1]) * np.linalg.norm(embeddings[:-1], axis=1)
55
+ )
56
+ max_similarity = np.max(cosine_similarities)
57
+
58
+ if max_similarity >= threshold:
59
+ most_similar_index = np.argmax(cosine_similarities)
60
+ return questions[most_similar_index], max_similarity
61
+ else:
62
+ return None, max_similarity
63
+
64
+ def generate_gpt2_response(prompt, model, tokenizer, max_length=100):
65
+ inputs = tokenizer.encode(prompt, return_tensors='pt')
66
+ outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1)
67
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
68
+
69
+ class QueryRequest(BaseModel):
70
+ query: str
71
+
72
+ @app.post("/query/")
73
+ def answer_query(request: QueryRequest):
74
+ user_query = request.query
75
+ closest_question, similarity = get_closest_question(user_query, data['questions'], threshold=0.95)
76
+ if closest_question and similarity >= 0.95:
77
+ answer_index = data['questions'].index(closest_question)
78
+ answer = data['answers'][answer_index]
79
+ else:
80
+ excel_file = 'new_questions1.xlsx'
81
+ if not os.path.isfile(excel_file):
82
+ df = pd.DataFrame(columns=['question'])
83
+ df.to_excel(excel_file, index=False)
84
+
85
+ new_data = pd.DataFrame({'questions': [user_query]})
86
+ df = pd.read_excel(excel_file)
87
+ df = pd.concat([df, new_data], ignore_index=True)
88
+ with pd.ExcelWriter(excel_file, engine='openpyxl', mode='w') as writer:
89
+ df.to_excel(writer, index=False)
90
+ answer = data['default_answers']
91
+
92
+ return {"query": user_query, "answer": answer}
93
+
94
+ if __name__ == "__main__":
95
+ import uvicorn
96
+ uvicorn.run(app, host="0.0.0.0", port=8000)