Spaces:
Sleeping
Sleeping
Refactor, introduce CommaFixerInterface and remove duplication
Browse files
commafixer/routers/baseline.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
from fastapi import APIRouter
|
2 |
import logging
|
3 |
|
4 |
from commafixer.src.baseline import BaselineCommaFixer
|
5 |
-
|
6 |
|
7 |
logger = logging.Logger(__name__)
|
8 |
logging.basicConfig(level=logging.INFO)
|
@@ -16,10 +16,4 @@ router.model = BaselineCommaFixer()
|
|
16 |
@router.post('/fix-commas/')
|
17 |
async def fix_commas_with_baseline(data: dict):
|
18 |
json_field_name = 's'
|
19 |
-
|
20 |
-
logger.debug('Fixing commas.')
|
21 |
-
return {json_field_name: router.model.fix_commas(data['s'])}
|
22 |
-
else:
|
23 |
-
msg = f"Text '{json_field_name}' missing"
|
24 |
-
logger.debug(msg)
|
25 |
-
raise HTTPException(status_code=400, detail=msg)
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
import logging
|
3 |
|
4 |
from commafixer.src.baseline import BaselineCommaFixer
|
5 |
+
from common import fix_commas_request_handler
|
6 |
|
7 |
logger = logging.Logger(__name__)
|
8 |
logging.basicConfig(level=logging.INFO)
|
|
|
16 |
@router.post('/fix-commas/')
|
17 |
async def fix_commas_with_baseline(data: dict):
|
18 |
json_field_name = 's'
|
19 |
+
return fix_commas_request_handler(json_field_name, data, logger, router.model)
|
|
|
|
|
|
|
|
|
|
|
|
commafixer/routers/common.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import HTTPException
|
2 |
+
from logging import Logger
|
3 |
+
|
4 |
+
from comma_fixer_interface import CommaFixerInterface
|
5 |
+
|
6 |
+
|
7 |
+
def fix_commas_request_handler(
|
8 |
+
json_field_name: str,
|
9 |
+
data: dict[str, str],
|
10 |
+
logger: Logger,
|
11 |
+
model: CommaFixerInterface
|
12 |
+
) -> dict[str, str]:
|
13 |
+
if json_field_name in data:
|
14 |
+
logger.debug('Fixing commas.')
|
15 |
+
return {json_field_name: model.fix_commas(data['s'])}
|
16 |
+
else:
|
17 |
+
msg = f"Text '{json_field_name}' missing"
|
18 |
+
logger.debug(msg)
|
19 |
+
raise HTTPException(status_code=400, detail=msg)
|
commafixer/routers/fixer.py
CHANGED
@@ -2,6 +2,7 @@ from fastapi import APIRouter, HTTPException
|
|
2 |
import logging
|
3 |
|
4 |
from commafixer.src.fixer import CommaFixer
|
|
|
5 |
|
6 |
|
7 |
logger = logging.Logger(__name__)
|
@@ -16,10 +17,4 @@ router.model = CommaFixer()
|
|
16 |
@router.post('/')
|
17 |
async def fix_commas(data: dict):
|
18 |
json_field_name = 's'
|
19 |
-
|
20 |
-
logger.debug('Fixing commas.')
|
21 |
-
return {json_field_name: router.model.fix_commas(data['s'])}
|
22 |
-
else:
|
23 |
-
msg = f"Text '{json_field_name}' missing"
|
24 |
-
logger.debug(msg)
|
25 |
-
raise HTTPException(status_code=400, detail=msg)
|
|
|
2 |
import logging
|
3 |
|
4 |
from commafixer.src.fixer import CommaFixer
|
5 |
+
from commafixer.routers.common import fix_commas_request_handler
|
6 |
|
7 |
|
8 |
logger = logging.Logger(__name__)
|
|
|
17 |
@router.post('/')
|
18 |
async def fix_commas(data: dict):
|
19 |
json_field_name = 's'
|
20 |
+
return fix_commas_request_handler(json_field_name, data, logger, router.model)
|
|
|
|
|
|
|
|
|
|
|
|
commafixer/src/baseline.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
|
2 |
import re
|
3 |
|
|
|
4 |
|
5 |
-
|
|
|
6 |
"""
|
7 |
A wrapper class for the oliverguhr/fullstop-punctuation-multilang-large baseline punctuation restoration model.
|
8 |
It adapts the model to perform comma fixing instead of full punctuation restoration, that is, removes the
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
|
2 |
import re
|
3 |
|
4 |
+
from commafixer.src.comma_fixer_interface import CommaFixerInterface
|
5 |
|
6 |
+
|
7 |
+
class BaselineCommaFixer(CommaFixerInterface):
|
8 |
"""
|
9 |
A wrapper class for the oliverguhr/fullstop-punctuation-multilang-large baseline punctuation restoration model.
|
10 |
It adapts the model to perform comma fixing instead of full punctuation restoration, that is, removes the
|
commafixer/src/comma_fixer_interface.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class CommaFixerInterface(ABC):
|
5 |
+
@abstractmethod
|
6 |
+
def fix_commas(self, s: str) -> str:
|
7 |
+
pass
|
commafixer/src/fixer.py
CHANGED
@@ -3,8 +3,10 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipelin
|
|
3 |
import nltk
|
4 |
import re
|
5 |
|
|
|
6 |
|
7 |
-
|
|
|
8 |
"""
|
9 |
A wrapper class for the fine-tuned comma fixer model.
|
10 |
"""
|
@@ -84,7 +86,7 @@ def _fix_commas_based_on_labels_and_offsets(
|
|
84 |
|
85 |
def _should_insert_comma(label, result, current_offset) -> bool:
|
86 |
# Only insert commas for the final token of a word, that is, if next word starts with a space.
|
87 |
-
# TODO
|
88 |
return label == 'B-COMMA' and result[current_offset].isspace()
|
89 |
|
90 |
|
|
|
3 |
import nltk
|
4 |
import re
|
5 |
|
6 |
+
from commafixer.src.comma_fixer_interface import CommaFixerInterface
|
7 |
|
8 |
+
|
9 |
+
class CommaFixer(CommaFixerInterface):
|
10 |
"""
|
11 |
A wrapper class for the fine-tuned comma fixer model.
|
12 |
"""
|
|
|
86 |
|
87 |
def _should_insert_comma(label, result, current_offset) -> bool:
|
88 |
# Only insert commas for the final token of a word, that is, if next word starts with a space.
|
89 |
+
# TODO perhaps for low confidence tokens, we should use the original decision of the user in the input?
|
90 |
return label == 'B-COMMA' and result[current_offset].isspace()
|
91 |
|
92 |
|