gaur3009 commited on
Commit
6685714
1 Parent(s): 12fdd3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -21
app.py CHANGED
@@ -3,12 +3,18 @@ from transformers import BertTokenizer, BertModel, GPT2LMHeadModel, GPT2Tokenize
3
  import numpy as np
4
  import pandas as pd
5
  import os
6
- import json
7
- from fastapi import FastAPI
8
- from pydantic import BaseModel
9
 
10
- app = FastAPI()
 
 
 
11
 
 
 
 
 
 
12
  data = {
13
  "questions": [
14
  "What is Rookus?",
@@ -33,14 +39,6 @@ data = {
33
  "default_answers": "I'm sorry, I cannot answer this right now. Your question has been saved, and we will get back to you with a response soon."
34
  }
35
 
36
- bert_model_name = 'models/bert'
37
- bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
38
- bert_model = BertModel.from_pretrained(bert_model_name)
39
-
40
- gpt2_model_name = 'models/gpt2'
41
- gpt2_tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
42
- gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name)
43
-
44
  def get_bert_embeddings(texts):
45
  inputs = bert_tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
46
  with torch.no_grad():
@@ -66,12 +64,7 @@ def generate_gpt2_response(prompt, model, tokenizer, max_length=100):
66
  outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1)
67
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
68
 
69
- class QueryRequest(BaseModel):
70
- query: str
71
-
72
- @app.post("/query/")
73
- def answer_query(request: QueryRequest):
74
- user_query = request.query
75
  closest_question, similarity = get_closest_question(user_query, data['questions'], threshold=0.95)
76
  if closest_question and similarity >= 0.95:
77
  answer_index = data['questions'].index(closest_question)
@@ -89,8 +82,15 @@ def answer_query(request: QueryRequest):
89
  df.to_excel(writer, index=False)
90
  answer = data['default_answers']
91
 
92
- return {"query": user_query, "answer": answer}
 
 
 
 
 
 
 
 
93
 
94
  if __name__ == "__main__":
95
- import uvicorn
96
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
3
  import numpy as np
4
  import pandas as pd
5
  import os
6
+ import gradio as gr
 
 
7
 
8
+ # Load the models and tokenizers
9
+ bert_model_name = 'bert-base-uncased'
10
+ bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
11
+ bert_model = BertModel.from_pretrained(bert_model_name)
12
 
13
+ gpt2_model_name = 'gpt2'
14
+ gpt2_tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
15
+ gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name)
16
+
17
+ # Load the data
18
  data = {
19
  "questions": [
20
  "What is Rookus?",
 
39
  "default_answers": "I'm sorry, I cannot answer this right now. Your question has been saved, and we will get back to you with a response soon."
40
  }
41
 
 
 
 
 
 
 
 
 
42
  def get_bert_embeddings(texts):
43
  inputs = bert_tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
44
  with torch.no_grad():
 
64
  outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1)
65
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
66
 
67
+ def answer_query(user_query):
 
 
 
 
 
68
  closest_question, similarity = get_closest_question(user_query, data['questions'], threshold=0.95)
69
  if closest_question and similarity >= 0.95:
70
  answer_index = data['questions'].index(closest_question)
 
82
  df.to_excel(writer, index=False)
83
  answer = data['default_answers']
84
 
85
+ return answer
86
+
87
+ iface = gr.Interface(
88
+ fn=answer_query,
89
+ inputs="text",
90
+ outputs="text",
91
+ title="Rookus AI Query Interface",
92
+ description="Ask questions about Rookus and get answers generated by AI."
93
+ )
94
 
95
  if __name__ == "__main__":
96
+ iface.launch()