VagoX1 commited on
Commit
417b39a
1 Parent(s): 6daaa3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -179
app.py CHANGED
@@ -1,134 +1,21 @@
1
- import joblib
2
- import gradio as gr
3
- from collections import Counter
4
- from typing import TypedDict
5
- from abc import ABC, abstractmethod
6
- from typing import Any, Dict, Type
7
- from scipy.sparse._csc import csc_matrix
8
- from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
9
- import pickle
10
  from dataclasses import dataclass
 
 
 
 
 
11
  import tqdm
12
  import re
13
- import os
14
  import nltk
15
  nltk.download("stopwords", quiet=True)
16
  from nltk.corpus import stopwords as nltk_stopwords
17
- import math
18
- from dataclasses import dataclass
19
- from typing import Optional
20
- from datasets import load_dataset
21
- from enum import Enum
22
- import numpy as np
23
-
24
- @dataclass
25
- class Document:
26
- collection_id: str
27
- text: str
28
-
29
-
30
- @dataclass
31
- class Query:
32
- query_id: str
33
- text: str
34
-
35
-
36
- @dataclass
37
- class QRel:
38
- query_id: str
39
- collection_id: str
40
- relevance: int
41
- answer: Optional[str] = None
42
-
43
- class Split(str, Enum):
44
- train = "train"
45
- dev = "dev"
46
- test = "test"
47
-
48
- @dataclass
49
- class IRDataset:
50
- corpus: List[Document]
51
- queries: List[Query]
52
- split2qrels: Dict[Split, List[QRel]]
53
-
54
- def get_stats(self) -> Dict[str, int]:
55
- stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)}
56
- for split, qrels in self.split2qrels.items():
57
- stats[f"|qrels-{split}|"] = len(qrels)
58
- return stats
59
-
60
- def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]:
61
- qrels_dict = {}
62
- for qrel in self.split2qrels[split]:
63
- qrels_dict.setdefault(qrel.query_id, {})
64
- qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance
65
- return qrels_dict
66
-
67
- def get_split_queries(self, split: Split) -> List[Query]:
68
- qrels = self.split2qrels[split]
69
- qids = {qrel.query_id for qrel in qrels}
70
- return list(filter(lambda query: query.query_id in qids, self.queries))
71
-
72
-
73
-
74
- @(joblib.Memory(".cache").cache)
75
- def load_sciq(verbose: bool = False) -> IRDataset:
76
- train = load_dataset("allenai/sciq", split="train")
77
- validation = load_dataset("allenai/sciq", split="validation")
78
- test = load_dataset("allenai/sciq", split="test")
79
- data = {Split.train: train, Split.dev: validation, Split.test: test}
80
-
81
- # Each duplicated record is the same to each other:
82
- df = train.to_pandas() + validation.to_pandas() + test.to_pandas()
83
- for question, group in df.groupby("question"):
84
- assert len(set(group["support"].tolist())) == len(group)
85
- assert len(set(group["correct_answer"].tolist())) == len(group)
86
-
87
- # Build:
88
- corpus = []
89
- queries = []
90
- split2qrels: Dict[str, List[dict]] = {}
91
- question2id = {}
92
- support2id = {}
93
- for split, rows in data.items():
94
- if verbose:
95
- print(f"|raw_{split}|", len(rows))
96
- split2qrels[split] = []
97
- for i, row in enumerate(rows):
98
- example_id = f"{split}-{i}"
99
- support: str = row["support"]
100
- if len(support.strip()) == 0:
101
- continue
102
- question = row["question"]
103
- if len(support.strip()) == 0:
104
- continue
105
- if support in support2id:
106
- continue
107
- else:
108
- support2id[support] = example_id
109
- if question in question2id:
110
- continue
111
- else:
112
- question2id[question] = example_id
113
- doc = {"collection_id": example_id, "text": support}
114
- query = {"query_id": example_id, "text": row["question"]}
115
- qrel = {
116
- "query_id": example_id,
117
- "collection_id": example_id,
118
- "relevance": 1,
119
- "answer": row["correct_answer"],
120
- }
121
- corpus.append(Document(**doc))
122
- queries.append(Query(**query))
123
- split2qrels[split].append(QRel(**qrel))
124
-
125
- # Assembly and return:
126
- return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels)
127
 
128
  LANGUAGE = "english"
129
  word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
130
  stopwords = set(nltk_stopwords.words(LANGUAGE))
131
 
 
132
  def word_splitting(text: str) -> List[str]:
133
  return word_splitter(text.lower())
134
 
@@ -149,6 +36,7 @@ class PostingList:
149
  docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
150
  tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
151
 
 
152
  @dataclass
153
  class InvertedIndex:
154
  posting_lists: List[PostingList] # docid -> posting_list
@@ -171,24 +59,8 @@ class InvertedIndex:
171
  index = pickle.load(f)
172
  return index
173
 
174
- class BaseRetriever(ABC):
175
-
176
- @property
177
- @abstractmethod
178
- def index_class(self) -> Type[Any]:
179
- pass
180
-
181
- def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
182
- raise NotImplementedError
183
-
184
- @abstractmethod
185
- def score(self, query: str, cid: str) -> float:
186
- pass
187
-
188
- @abstractmethod
189
- def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
190
- pass
191
 
 
192
  @dataclass
193
  class Counting:
194
  posting_lists: List[PostingList]
@@ -260,6 +132,17 @@ def run_counting(
260
  doc_texts=doc_texts,
261
  )
262
 
 
 
 
 
 
 
 
 
 
 
 
263
  @dataclass
264
  class BM25Index(InvertedIndex):
265
 
@@ -305,7 +188,7 @@ class BM25Index(InvertedIndex):
305
 
306
  @classmethod
307
  def build_from_documents(
308
- cls: Type["BM25Index"],
309
  documents: Iterable[Document],
310
  store_raw: bool = True,
311
  output_dir: Optional[str] = None,
@@ -313,7 +196,7 @@ class BM25Index(InvertedIndex):
313
  show_progress_bar: bool = True,
314
  k1: float = 0.9,
315
  b: float = 0.4,
316
- ) -> "BM25Index":
317
  # Counting TFs, DFs, doc_lengths, etc.:
318
  counting = run_counting(
319
  documents=documents,
@@ -346,6 +229,147 @@ class BM25Index(InvertedIndex):
346
  )
347
  return index
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
  @dataclass
351
  class CSCInvertedIndex:
@@ -369,6 +393,7 @@ class CSCInvertedIndex:
369
  index = pickle.load(f)
370
  return index
371
 
 
372
  @dataclass
373
  class CSCBM25Index(CSCInvertedIndex):
374
 
@@ -388,7 +413,6 @@ class CSCBM25Index(CSCInvertedIndex):
388
  ) -> csc_matrix:
389
  """Compute term weights and caching"""
390
 
391
- ## YOUR_CODE_STARTS_HERE
392
  data = []
393
  indices = []
394
  indptr = [0]
@@ -431,7 +455,7 @@ class CSCBM25Index(CSCInvertedIndex):
431
 
432
  @classmethod
433
  def build_from_documents(
434
- cls: Type["CSCBM25Index"],
435
  documents: Iterable[Document],
436
  store_raw: bool = True,
437
  output_dir: Optional[str] = None,
@@ -439,7 +463,7 @@ class CSCBM25Index(CSCInvertedIndex):
439
  show_progress_bar: bool = True,
440
  k1: float = 0.9,
441
  b: float = 0.4,
442
- ) -> "CSCBM25Index":
443
  # Counting TFs, DFs, doc_lengths, etc.:
444
  counting = run_counting(
445
  documents=documents,
@@ -472,6 +496,15 @@ class CSCBM25Index(CSCInvertedIndex):
472
  )
473
  return index
474
 
 
 
 
 
 
 
 
 
 
475
  class BaseCSCInvertedIndexRetriever(BaseRetriever):
476
 
477
  @property
@@ -503,28 +536,29 @@ class BaseCSCInvertedIndexRetriever(BaseRetriever):
503
 
504
  def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
505
  ## YOUR_CODE_STARTS_HERE
 
 
506
  toks = self.index.tokenize(query)
507
  docid2score: Dict[int, float] = {}
508
-
509
  for tok in toks:
510
- if tok not in self.index.vocab:
511
- continue
512
- tid = self.index.vocab[tok]
513
- # Get weights for all documents for the current term
514
- weights_for_term = self.index.posting_lists_matrix.getcol(tid).toarray()[:, 0]
515
-
516
- for docid, weight in enumerate(weights_for_term):
517
- docid2score.setdefault(docid, 0)
518
- docid2score[docid] += weight # Accumulate scores for each document
519
 
520
- # Sort and get topk documents
521
  docid2score = dict(
522
- sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
523
- )
524
- return {
525
- self.index.collection_ids[docid]: score
526
- for docid, score in docid2score.items()
527
- }
 
 
 
528
  ## YOUR_CODE_ENDS_HERE
529
 
530
  class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
@@ -533,6 +567,9 @@ class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
533
  def index_class(self) -> Type[CSCBM25Index]:
534
  return CSCBM25Index
535
 
 
 
 
536
  class Hit(TypedDict):
537
  cid: str
538
  score: float
@@ -542,28 +579,26 @@ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
542
  return_type = List[Hit]
543
 
544
  ## YOUR_CODE_STARTS_HERE
545
- def search(query: str) -> List[Hit]:
546
- bm25_index = BM25Index.build_from_documents(
547
- documents=iter(sciq.corpus),
548
- ndocs=12160,
549
- show_progress_bar=True
550
- )
551
- bm25_index.save("output/bm25_index")
552
- bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
553
- ranking = bm25_retriever.retrieve(query=query)
554
- hits = []
555
- for cid, score in ranking.items():
556
- doc = next((doc for doc in sciq.corpus if doc.collection_id == cid), None)
557
- if doc:
558
- hits.append({"cid": cid, "score": score, "text": doc.text})
559
- return hits
560
 
561
  demo = gr.Interface(
562
  fn=search,
563
- inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
564
- outputs=gr.JSON(label="Search Results"),
565
- title="SciQ Search Engine",
566
- description="Enter a query to search the SciQ dataset using BM25.",
567
  )
 
568
  ## YOUR_CODE_ENDS_HERE
569
- demo.launch()
 
1
+ from __future__ import annotations
 
 
 
 
 
 
 
 
2
  from dataclasses import dataclass
3
+ import pickle
4
+ import os
5
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
6
+ from nlp4web_codebase.ir.data_loaders.dm import Document
7
+ from collections import Counter
8
  import tqdm
9
  import re
 
10
  import nltk
11
  nltk.download("stopwords", quiet=True)
12
  from nltk.corpus import stopwords as nltk_stopwords
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  LANGUAGE = "english"
15
  word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
16
  stopwords = set(nltk_stopwords.words(LANGUAGE))
17
 
18
+
19
  def word_splitting(text: str) -> List[str]:
20
  return word_splitter(text.lower())
21
 
 
36
  docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
37
  tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
38
 
39
+
40
  @dataclass
41
  class InvertedIndex:
42
  posting_lists: List[PostingList] # docid -> posting_list
 
59
  index = pickle.load(f)
60
  return index
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # The output of the counting function:
64
  @dataclass
65
  class Counting:
66
  posting_lists: List[PostingList]
 
132
  doc_texts=doc_texts,
133
  )
134
 
135
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
136
+ sciq = load_sciq()
137
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
138
+
139
+ from dataclasses import asdict, dataclass
140
+ import math
141
+ import os
142
+ from typing import Iterable, List, Optional, Type
143
+ import tqdm
144
+ from nlp4web_codebase.ir.data_loaders.dm import Document
145
+
146
  @dataclass
147
  class BM25Index(InvertedIndex):
148
 
 
188
 
189
  @classmethod
190
  def build_from_documents(
191
+ cls: Type[BM25Index],
192
  documents: Iterable[Document],
193
  store_raw: bool = True,
194
  output_dir: Optional[str] = None,
 
196
  show_progress_bar: bool = True,
197
  k1: float = 0.9,
198
  b: float = 0.4,
199
+ ) -> BM25Index:
200
  # Counting TFs, DFs, doc_lengths, etc.:
201
  counting = run_counting(
202
  documents=documents,
 
229
  )
230
  return index
231
 
232
+ bm25_index = BM25Index.build_from_documents(
233
+ documents=iter(sciq.corpus),
234
+ ndocs=12160,
235
+ show_progress_bar=True,
236
+ )
237
+ bm25_index.save("output/bm25_index")
238
+
239
+ from nlp4web_codebase.ir.models import BaseRetriever
240
+ from typing import Type
241
+ from abc import abstractmethod
242
+
243
+ class BaseInvertedIndexRetriever(BaseRetriever):
244
+
245
+ @property
246
+ @abstractmethod
247
+ def index_class(self) -> Type[InvertedIndex]:
248
+ pass
249
+
250
+ def __init__(self, index_dir: str) -> None:
251
+ self.index = self.index_class.from_saved(index_dir)
252
+
253
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
254
+ toks = self.index.tokenize(query)
255
+ target_docid = self.index.cid2docid[cid]
256
+ term_weights = {}
257
+ for tok in toks:
258
+ if tok not in self.index.vocab:
259
+ continue
260
+ tid = self.index.vocab[tok]
261
+ posting_list = self.index.posting_lists[tid]
262
+ for docid, tweight in zip(
263
+ posting_list.docid_postings, posting_list.tweight_postings
264
+ ):
265
+ if docid == target_docid:
266
+ term_weights[tok] = tweight
267
+ break
268
+ return term_weights
269
+
270
+ def score(self, query: str, cid: str) -> float:
271
+ return sum(self.get_term_weights(query=query, cid=cid).values())
272
+
273
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
274
+ toks = self.index.tokenize(query)
275
+ docid2score: Dict[int, float] = {}
276
+ for tok in toks:
277
+ if tok not in self.index.vocab:
278
+ continue
279
+ tid = self.index.vocab[tok]
280
+ posting_list = self.index.posting_lists[tid]
281
+ for docid, tweight in zip(
282
+ posting_list.docid_postings, posting_list.tweight_postings
283
+ ):
284
+ docid2score.setdefault(docid, 0)
285
+ docid2score[docid] += tweight
286
+ docid2score = dict(
287
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
288
+ )
289
+ return {
290
+ self.index.collection_ids[docid]: score
291
+ for docid, score in docid2score.items()
292
+ }
293
+
294
+ class BM25Retriever(BaseInvertedIndexRetriever):
295
+
296
+ @property
297
+ def index_class(self) -> Type[BM25Index]:
298
+ return BM25Index
299
+
300
+ from nlp4web_codebase.ir.data_loaders import Split
301
+ import pytrec_eval
302
+ import numpy as np
303
+
304
+ def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float:
305
+ metric = "map_cut_10"
306
+ qrels = sciq.get_qrels_dict(split)
307
+ evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,))
308
+ qps = evaluator.evaluate(rankings)
309
+ return float(np.mean([qp[metric] for qp in qps.values()]))
310
+
311
+ # Loading dataset:
312
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
313
+ sciq = load_sciq()
314
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
315
+
316
+ # Building BM25 index and save:
317
+ bm25_index = BM25Index.build_from_documents(
318
+ documents=iter(sciq.corpus),
319
+ ndocs=12160,
320
+ show_progress_bar=True
321
+ )
322
+ bm25_index.save("output/bm25_index")
323
+
324
+
325
+ plots_b: Dict[str, List[float]] = {
326
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
327
+ "Y": []
328
+ }
329
+ plots_k1: Dict[str, List[float]] = {
330
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
331
+ "Y": []
332
+ }
333
+
334
+ ## YOUR_CODE_STARTS_HERE
335
+ # Step 1: Tune b (with fixed k1=0.9)
336
+ for b_val in plots_b["X"]:
337
+ bm25_index = BM25Index.build_from_documents(
338
+ documents=iter(sciq.corpus),
339
+ ndocs=12160,
340
+ show_progress_bar=True,
341
+ k1=0.9, # Fix k1
342
+ b=b_val
343
+ )
344
+ bm25_index.save("output/bm25_index")
345
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
346
+ rankings = {}
347
+ for query in sciq.get_split_queries(Split.dev):
348
+ ranking = bm25_retriever.retrieve(query=query.text)
349
+ rankings[query.query_id] = ranking
350
+ map_score = evaluate_map(rankings)
351
+ plots_b["Y"].append(map_score)
352
+
353
+ # Step 2: Tune k1 (with the best b from step 1)
354
+ best_b = plots_b["X"][np.argmax(plots_b["Y"])] # Get best b
355
+ for k1_val in plots_k1["X"]:
356
+ bm25_index = BM25Index.build_from_documents(
357
+ documents=iter(sciq.corpus),
358
+ ndocs=12160,
359
+ show_progress_bar=True,
360
+ k1=k1_val,
361
+ b=best_b # Use best b
362
+ )
363
+ bm25_index.save("output/bm25_index")
364
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
365
+ rankings = {}
366
+ for query in sciq.get_split_queries(Split.dev):
367
+ ranking = bm25_retriever.retrieve(query=query.text)
368
+ rankings[query.query_id] = ranking
369
+ map_score = evaluate_map(rankings)
370
+ plots_k1["Y"].append(map_score)
371
+
372
+ from scipy.sparse._csc import csc_matrix
373
 
374
  @dataclass
375
  class CSCInvertedIndex:
 
393
  index = pickle.load(f)
394
  return index
395
 
396
+
397
  @dataclass
398
  class CSCBM25Index(CSCInvertedIndex):
399
 
 
413
  ) -> csc_matrix:
414
  """Compute term weights and caching"""
415
 
 
416
  data = []
417
  indices = []
418
  indptr = [0]
 
455
 
456
  @classmethod
457
  def build_from_documents(
458
+ cls: Type[CSCBM25Index],
459
  documents: Iterable[Document],
460
  store_raw: bool = True,
461
  output_dir: Optional[str] = None,
 
463
  show_progress_bar: bool = True,
464
  k1: float = 0.9,
465
  b: float = 0.4,
466
+ ) -> CSCBM25Index:
467
  # Counting TFs, DFs, doc_lengths, etc.:
468
  counting = run_counting(
469
  documents=documents,
 
496
  )
497
  return index
498
 
499
+ csc_bm25_index = CSCBM25Index.build_from_documents(
500
+ documents=iter(sciq.corpus),
501
+ ndocs=12160,
502
+ show_progress_bar=True,
503
+ k1=best_k1,
504
+ b=best_b
505
+ )
506
+ csc_bm25_index.save("output/csc_bm25_index")
507
+
508
  class BaseCSCInvertedIndexRetriever(BaseRetriever):
509
 
510
  @property
 
536
 
537
  def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
538
  ## YOUR_CODE_STARTS_HERE
539
+
540
+ ranking: Dict[str, float] = {}
541
  toks = self.index.tokenize(query)
542
  docid2score: Dict[int, float] = {}
 
543
  for tok in toks:
544
+ if tok not in self.index.vocab:
545
+ continue
546
+ tid = self.index.vocab[tok]
547
+ tid2documents = self.index.posting_lists_matrix.getcol(tid)
548
+ for docid, tweight in zip(tid2documents.indices, tid2documents.data):
549
+ docid2score.setdefault(docid, 0)
550
+ docid2score[docid] += tweight
 
 
551
 
 
552
  docid2score = dict(
553
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
554
+ )
555
+ ranking = {
556
+ self.index.collection_ids[docid]: score
557
+ for docid, score in docid2score.items()
558
+ }
559
+ return ranking
560
+
561
+
562
  ## YOUR_CODE_ENDS_HERE
563
 
564
  class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
 
567
  def index_class(self) -> Type[CSCBM25Index]:
568
  return CSCBM25Index
569
 
570
+ import gradio as gr
571
+ from typing import TypedDict
572
+
573
  class Hit(TypedDict):
574
  cid: str
575
  score: float
 
579
  return_type = List[Hit]
580
 
581
  ## YOUR_CODE_STARTS_HERE
582
+
583
+ def search(query) -> List[Hit]:
584
+ return_type: List[Hit] = []
585
+ bm_25_retriever = BM25Retriever(index_dir="output/bm25_index")
586
+ ranking = bm_25_retriever.retrieve(query)
587
+ for rank in ranking:
588
+ hit = {
589
+ "cid": rank,
590
+ "score": ranking[rank],
591
+ "text": bm_25_retriever.index.doc_texts[bm_25_retriever.index.cid2docid[rank]]
592
+ }
593
+ return_type.append(hit)
594
+
595
+ return return_type
 
596
 
597
  demo = gr.Interface(
598
  fn=search,
599
+ inputs=["text"],
600
+ outputs=gr.Textbox()
 
 
601
  )
602
+
603
  ## YOUR_CODE_ENDS_HERE
604
+ demo.launch(share=True)