0x70DA commited on
Commit
633e625
1 Parent(s): a464a99

Add lots of code

Browse files
Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +133 -0
  3. requirements.txt +8 -0
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
- title: Abstractive Qa Demo
3
- emoji:
4
- colorFrom: purple
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.23.0
 
1
  ---
2
+ title: Abstractive QA Demo
3
+ emoji:
4
+ colorFrom: pink
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.23.0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import faiss
4
+ import numpy as np
5
+ import gradio as gr
6
+ import requests
7
+ import torch
8
+ from bs4 import BeautifulSoup
9
+ from datasets import Dataset
10
+ from sentence_transformers import SentenceTransformer
11
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
+
13
+ # Load retriever model
14
+ torch.set_grad_enabled(False) # Disable gradients
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device)
17
+
18
+ # Load generation model
19
+ tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5")
20
+ model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to(device)
21
+
22
+
23
+ def scrape(urls: List[str]) -> Dataset:
24
+ data = []
25
+ chunk_size = 100
26
+ # Extract the text inside all the <p> tags for each search result
27
+ for url in urls:
28
+ # Send the request and get the response
29
+ response = requests.get(url)
30
+
31
+ # Parse the response HTML with BeautifulSoup
32
+ soup = BeautifulSoup(response.text, "html.parser")
33
+
34
+ # Find all the <p> tags in the HTML and extract their text
35
+ for string in soup.stripped_strings:
36
+ text = repr(string).split()
37
+ contexts = [
38
+ " ".join(text[i : i + chunk_size])
39
+ for i in range(0, len(text), chunk_size)
40
+ ]
41
+ for context in contexts:
42
+ if len(context.split()) >= 15:
43
+ data.append({"context": context, "url": url})
44
+
45
+ return Dataset.from_list(data)
46
+
47
+
48
+ def search_web(query: str) -> List[str]:
49
+ url = f"https://www.google.com/search?q={query}"
50
+
51
+ # Set the user agent to avoid being blocked by Google
52
+ headers = {
53
+ "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
54
+ }
55
+
56
+ # Send the search request and get the response
57
+ response = requests.get(url, headers=headers)
58
+
59
+ # Parse the response HTML with BeautifulSoup
60
+ soup = BeautifulSoup(response.content, "html.parser")
61
+
62
+ # Find the search results in the HTML
63
+ search_results = soup.find_all("div", class_="g")
64
+
65
+ # Extract the title and URL of the top search results
66
+ urls = set()
67
+ for result in search_results[:10]:
68
+ url = result.find("a")["href"]
69
+ if url.startswith("http"):
70
+ urls.add(url)
71
+
72
+ return urls
73
+
74
+
75
+ def generate_answer(question_doc: str) -> str:
76
+ q_toks = tokenizer.batch_encode_plus(
77
+ [question_doc], max_length=1024, pad_to_max_length=True
78
+ )
79
+ q_ids, q_mask = (
80
+ torch.LongTensor(q_toks["input_ids"]).to(device),
81
+ torch.LongTensor(q_toks["attention_mask"]).to(device),
82
+ )
83
+ model_output = model.generate(
84
+ input_ids=q_ids,
85
+ attention_mask=q_mask,
86
+ min_new_tokens=32,
87
+ max_new_tokens=256,
88
+ no_repeat_ngram_size=3,
89
+ num_beams=2,
90
+ do_sample=True,
91
+ length_penalty=1.5,
92
+ )
93
+ answer = tokenizer.batch_decode(model_output, skip_special_tokens=True)[0]
94
+ return answer.strip()
95
+
96
+
97
+ def predict(question: str) -> str:
98
+ urls = search_web(question)
99
+ data = scrape(urls)
100
+ # Create vector embeddings and add Faiss index
101
+ data_with_embeds = data.map(
102
+ lambda batch: {"embeddings": retriever.encode(batch["context"])}, batched=True
103
+ )
104
+ data_with_embeds.add_faiss_index(
105
+ column="embeddings", metric_type=faiss.METRIC_INNER_PRODUCT
106
+ )
107
+ # Get the most relevant examples
108
+ scores, relevant_examples = data_with_embeds.get_nearest_examples(
109
+ "embeddings", retriever.encode([question]), k=20
110
+ )
111
+ doc = "<P> " + " <P> ".join(
112
+ relevant_examples["context"]
113
+ ) # The support document for the model
114
+
115
+ # Generate answer
116
+ question_doc = f"question: {question} context: {doc}"
117
+ return generate_answer(question_doc)
118
+
119
+
120
+ input_box = gr.Textbox(label="Question")
121
+ output_box = gr.Textbox(label="Answer")
122
+ description = """
123
+ <div style="text-align: center;">
124
+ <p style="font-style: italic;"> Disclaimer: This is just a stupid demo and it craches a lot. Don't take it too seriously.</p>
125
+ ✌😎
126
+ </div>
127
+ """
128
+
129
+
130
+ demo = gr.Interface(
131
+ fn=predict, inputs=input_box, outputs=output_box, description=description
132
+ ).queue()
133
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ sentence-transformers
3
+ datasets
4
+ torch
5
+ beautifulsoup4
6
+ requests
7
+ numpy
8
+ faiss-cpu