|
import re |
|
import pandas as pd |
|
from tqdm.auto import tqdm |
|
from transformers import pipeline |
|
from transformers import AutoTokenizer |
|
|
|
model_checkpoint = "Pclanglais/French-TV-transcript-NER" |
|
token_classifier = pipeline( |
|
"token-classification", model=model_checkpoint, aggregation_strategy="simple" |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) |
|
|
|
def split_text(text, max_tokens=500): |
|
|
|
parts = text.split("\n") |
|
chunks = [] |
|
current_chunk = "" |
|
|
|
for part in parts: |
|
|
|
if current_chunk: |
|
temp_chunk = current_chunk + "\n" + part |
|
else: |
|
temp_chunk = part |
|
|
|
|
|
num_tokens = len(tokenizer.tokenize(temp_chunk)) |
|
|
|
if num_tokens <= max_tokens: |
|
current_chunk = temp_chunk |
|
else: |
|
if current_chunk: |
|
chunks.append(current_chunk) |
|
current_chunk = part |
|
|
|
if current_chunk: |
|
chunks.append(current_chunk) |
|
|
|
|
|
if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: |
|
long_text = chunks[0] |
|
chunks = [] |
|
while len(tokenizer.tokenize(long_text)) > max_tokens: |
|
split_point = len(long_text) // 2 |
|
while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): |
|
split_point += 1 |
|
|
|
if split_point >= len(long_text): |
|
split_point = len(long_text) - 1 |
|
chunks.append(long_text[:split_point].strip()) |
|
long_text = long_text[split_point:].strip() |
|
if long_text: |
|
chunks.append(long_text) |
|
|
|
return chunks |
|
|
|
|
|
complete_data = pd.read_parquet("[file with transcripts]") |
|
|
|
print(complete_data) |
|
|
|
classified_list = [] |
|
|
|
list_prompt = [] |
|
list_page = [] |
|
list_file = [] |
|
list_id = [] |
|
text_id = 1 |
|
for index, row in complete_data.iterrows(): |
|
prompt, current_file = str(row["corrected_text"]), row["identifier"] |
|
prompt = re.sub("\n", " ¶ ", prompt) |
|
|
|
|
|
num_tokens = len(tokenizer.tokenize(prompt)) |
|
|
|
if num_tokens > 500: |
|
|
|
chunks = split_text(prompt, max_tokens=500) |
|
for chunk in chunks: |
|
list_file.append(current_file) |
|
list_prompt.append(chunk) |
|
list_id.append(text_id) |
|
else: |
|
list_file.append(current_file) |
|
list_prompt.append(prompt) |
|
list_id.append(text_id) |
|
|
|
text_id = text_id + 1 |
|
|
|
full_classification = [] |
|
batch_size = 4 |
|
for out in tqdm(token_classifier(list_prompt, batch_size=batch_size), total=len(list_prompt)/batch_size): |
|
full_classification.append(out) |
|
|
|
id_row = 0 |
|
for classification in full_classification: |
|
try: |
|
df = pd.DataFrame(classification) |
|
|
|
df["identifier"] = list_file[id_row] |
|
df["text_id"] = list_id[id_row] |
|
|
|
df['word'] = df['word'].replace(' ¶ ', ' \n ', regex=True) |
|
|
|
print(df) |
|
|
|
classified_list.append(df) |
|
|
|
except: |
|
pass |
|
id_row = id_row + 1 |
|
|
|
classified_list = pd.concat(classified_list) |
|
|
|
|
|
print(classified_list) |
|
|
|
classified_list.to_csv("result_transcripts.tsv", sep = "\t") |
|
|