Spaces:
Build error
Build error
update app
Browse files- app.py +4 -4
- weakly_supervised_parser/utils/populate_chart.py +16 -15
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import gradio
|
2 |
import benepar
|
3 |
import spacy
|
4 |
-
from IPython.display import display
|
5 |
import nltk
|
6 |
from nltk.tree import Tree
|
7 |
nltk.download('stopwords')
|
@@ -19,9 +18,10 @@ benepar.download('benepar_en3')
|
|
19 |
nlp = spacy.load("en_core_web_md")
|
20 |
nlp.add_pipe("benepar", config={"model": "benepar_en3"})
|
21 |
|
22 |
-
|
23 |
-
fetch_url_inside_model = hf_hub_url(repo_id="nickil/weakly-supervised-parsing", filename="inside_model.
|
24 |
-
inside_model = LightningModel.load_from_checkpoint(checkpoint_path=cached_download(fetch_url_inside_model))
|
|
|
25 |
|
26 |
# outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=64)
|
27 |
# outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "outside_model.onnx")
|
|
|
1 |
import gradio
|
2 |
import benepar
|
3 |
import spacy
|
|
|
4 |
import nltk
|
5 |
from nltk.tree import Tree
|
6 |
nltk.download('stopwords')
|
|
|
18 |
nlp = spacy.load("en_core_web_md")
|
19 |
nlp.add_pipe("benepar", config={"model": "benepar_en3"})
|
20 |
|
21 |
+
inside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
|
22 |
+
fetch_url_inside_model = hf_hub_url(repo_id="nickil/weakly-supervised-parsing", filename="inside_model.onnx", revision="main")
|
23 |
+
# inside_model = LightningModel.load_from_checkpoint(checkpoint_path=cached_download(fetch_url_inside_model))
|
24 |
+
inside_model.load_model(pre_trained_model_path=cached_download(fetch_url_inside_model))
|
25 |
|
26 |
# outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=64)
|
27 |
# outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "outside_model.onnx")
|
weakly_supervised_parser/utils/populate_chart.py
CHANGED
@@ -26,9 +26,9 @@ ptb_top_100_common = ['this', 'myself', 'shouldn', 'not', 'analysts', 'same', 'm
|
|
26 |
# ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower()
|
27 |
ptb_most_common_first_token = "the"
|
28 |
|
29 |
-
from pytorch_lightning import Trainer
|
30 |
|
31 |
-
trainer = Trainer(accelerator="auto", enable_progress_bar=False, max_epochs=-1)
|
32 |
|
33 |
|
34 |
class PopulateCKYChart:
|
@@ -54,19 +54,20 @@ class PopulateCKYChart:
|
|
54 |
|
55 |
if predict_type == "inside":
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
70 |
|
71 |
data["inside_scores"] = inside_scores
|
72 |
data.loc[
|
|
|
26 |
# ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower()
|
27 |
ptb_most_common_first_token = "the"
|
28 |
|
29 |
+
# from pytorch_lightning import Trainer
|
30 |
|
31 |
+
# trainer = Trainer(accelerator="auto", enable_progress_bar=False, max_epochs=-1)
|
32 |
|
33 |
|
34 |
class PopulateCKYChart:
|
|
|
54 |
|
55 |
if predict_type == "inside":
|
56 |
|
57 |
+
if data.shape[0] > chunks:
|
58 |
+
data_chunks = np.array_split(data, data.shape[0] // chunks)
|
59 |
+
for data_chunk in data_chunks:
|
60 |
+
inside_scores.extend(model.predict_proba(spans=data_chunk.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
|
61 |
+
scale_axis=scale_axis,
|
62 |
+
predict_batch_size=predict_batch_size)[:, 1])
|
63 |
+
else:
|
64 |
+
inside_scores.extend(model.predict_proba(spans=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
|
65 |
+
scale_axis=scale_axis,
|
66 |
+
predict_batch_size=predict_batch_size)[:, 1])
|
67 |
+
|
68 |
+
# test_dataloader = DataModule(model_name_or_path="roberta-base", train_df=None, eval_df=None,
|
69 |
+
# test_df=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]])
|
70 |
+
# inside_scores.extend(trainer.predict(model, dataloaders=test_dataloader)[0])
|
71 |
|
72 |
data["inside_scores"] = inside_scores
|
73 |
data.loc[
|