Spaces:
Runtime error
Runtime error
import spaces | |
import transformers | |
import re | |
from transformers import AutoTokenizer, pipeline | |
import torch | |
import html | |
import gradio as gr | |
import tempfile | |
import os | |
import pandas as pd | |
# Define the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load models | |
editorial_model = "LLMDH/Estienne" | |
bibliography_model = "PleIAs/Bibliography-Formatter" | |
bibliography_style = "PleIAs/Bibliography-Classifier" | |
tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512) | |
editorial_classifier = pipeline( | |
"token-classification", model=editorial_model, aggregation_strategy="simple", device=device | |
) | |
bibliography_classifier = pipeline( | |
"token-classification", model=bibliography_model, aggregation_strategy="simple", device=device | |
) | |
style_classifier = pipeline("text-classification", model=bibliography_style, tokenizer=tokenizer, device=device) | |
# Helper functions | |
def preprocess_text(text): | |
text = re.sub(r'<[^>]+>', '', text) | |
text = re.sub(r'\n', ' ', text) | |
text = re.sub(r'\s+', ' ', text) | |
return text.strip() | |
def split_text(text, max_tokens=500): | |
parts = text.split("\n") | |
chunks = [] | |
current_chunk = "" | |
for part in parts: | |
temp_chunk = current_chunk + "\n" + part if current_chunk else part | |
num_tokens = len(tokenizer.tokenize(temp_chunk)) | |
if num_tokens <= max_tokens: | |
current_chunk = temp_chunk | |
else: | |
if current_chunk: | |
chunks.append(current_chunk) | |
current_chunk = part | |
if current_chunk: | |
chunks.append(current_chunk) | |
if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: | |
long_text = chunks[0] | |
chunks = [] | |
while len(tokenizer.tokenize(long_text)) > max_tokens: | |
split_point = len(long_text) // 2 | |
while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): | |
split_point += 1 | |
if split_point >= len(long_text): | |
split_point = len(long_text) - 1 | |
chunks.append(long_text[:split_point].strip()) | |
long_text = long_text[split_point:].strip() | |
if long_text: | |
chunks.append(long_text) | |
return chunks | |
def disambiguate_bibtex_ids(bibtex_entries): | |
id_count = {} | |
disambiguated_entries = [] | |
for entry in bibtex_entries: | |
# Extract the current ID | |
match = re.search(r'@\w+{(\w+),', entry) | |
if not match: | |
disambiguated_entries.append(entry) | |
continue | |
original_id = match.group(1) | |
# Check if this ID has been seen before | |
if original_id in id_count: | |
id_count[original_id] += 1 | |
new_id = f"{original_id}{chr(96 + id_count[original_id])}" # 'a', 'b', 'c', etc. | |
new_entry = re.sub(r'(@\w+{)(\w+)(,)', f'\\1{new_id}\\3', entry, 1) | |
disambiguated_entries.append(new_entry) | |
else: | |
id_count[original_id] = 0 | |
disambiguated_entries.append(entry) | |
return disambiguated_entries | |
def remove_punctuation(text): | |
return re.sub(r'[^\w\s]', '', text) | |
def extract_year(text): | |
year_match = re.search(r'\b(\d{4})\b', text) | |
return year_match.group(1) if year_match else None | |
def create_bibtex_entry(data): | |
if 'journal' in data: | |
entry_type = 'article' | |
elif 'booktitle' in data: | |
entry_type = 'inproceedings' | |
else: | |
entry_type = 'book' | |
none_content = data.pop('none', '') | |
year = extract_year(none_content) | |
if year and 'year' not in data: | |
data['year'] = year | |
if "year" in data: | |
match_year = re.search(r'(\d{4})', data['year']) | |
if match_year: | |
data['year'] = match_year.group(1) | |
year = data['year'] | |
else: | |
data.pop('year', '') | |
#Pages conformity. | |
if 'pages' in data: | |
match = re.search(r'(\d+(-\d+)?)', data['pages']) | |
if match: | |
data['pages'] = match.group(1) | |
else: | |
data.pop('pages', '') | |
author_words = data.get('author', '').split() | |
first_author = author_words[0] if author_words else 'unknown' | |
bibtex_id = f"{first_author}{year}" if year else first_author | |
bibtex_id = remove_punctuation(bibtex_id.lower()) | |
bibtex = f"@{entry_type}{{{bibtex_id},\n" | |
for key, value in data.items(): | |
if value.strip(): | |
if key in ['volume', 'year']: | |
value = remove_punctuation(value) | |
if key == 'pages': | |
value = value.replace('p. ', '') | |
if key != "separator": | |
bibtex += f" {key.lower()} = {{{value.strip()}}},\n" | |
bibtex = bibtex.rstrip(',\n') + "\n}" | |
return bibtex | |
def save_bibtex(bibtex_content): | |
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.bib') as temp_file: | |
temp_file.write(bibtex_content) | |
return temp_file.name | |
class CombinedProcessor: | |
def process(self, user_message): | |
#Precaution to reinforce bibliography detection. | |
editorial_text = "Bibliography\n" + user_message | |
#Our fix for the lack of newline in deberta | |
editorial_text = re.sub("\n", " ¶ ", editorial_text) | |
print(editorial_text) | |
num_tokens = len(tokenizer.tokenize(editorial_text)) | |
batch_prompts = split_text(editorial_text, max_tokens=500) if num_tokens > 500 else [editorial_text] | |
editorial_out = editorial_classifier(batch_prompts) | |
editorial_df = pd.concat([pd.DataFrame(classification) for classification in editorial_out]) | |
# Filter out only bibliography entries | |
bibliography_entries = editorial_df[editorial_df['entity_group'] == 'bibliography']['word'].tolist() | |
bibtex_entries = [] | |
list_style = [] | |
for entry in bibliography_entries: | |
print(entry) | |
entry = re.sub(r'- ?[\n¶] ?', r'', entry) | |
entry = re.sub(r' ?[\n¶] ?', r' ', entry) | |
style = pd.DataFrame(style_classifier(entry, truncation=True, padding=True, top_k=1)) | |
list_style.append(style) | |
entry = re.sub(r'\s*([;:,\.])\s*', r' \1 ', entry) | |
#print(entry) | |
bib_out = bibliography_classifier(entry) | |
bib_df = pd.DataFrame(bib_out) | |
bibtex_data = {} | |
current_entity = None | |
for _, row in bib_df.iterrows(): | |
entity_group = row['entity_group'] | |
word = row['word'] | |
if entity_group != 'None': | |
if entity_group in bibtex_data: | |
bibtex_data[entity_group] += ' ' + word | |
else: | |
bibtex_data[entity_group] = word | |
current_entity = entity_group | |
else: | |
if current_entity: | |
bibtex_data[current_entity] += ' ' + word | |
else: | |
bibtex_data['None'] = bibtex_data.get('None', '') + ' ' + word | |
bibtex_entry = create_bibtex_entry(bibtex_data) | |
bibtex_entries.append(bibtex_entry) | |
list_style = pd.concat(list_style) | |
list_style = list_style.groupby('label')['score'].mean().sort_values(ascending=False).reset_index() | |
top_style = list_style.iloc[0]['label'] | |
top_style_score = list_style.iloc[0]['score'] | |
# Create the style information string | |
style_info = f"Estimated bibliographic style: {top_style} (Mean score: {top_style_score:.6f})" | |
# Join BibTeX entries | |
bibtex_content = "\n\n".join(bibtex_entries) | |
return style_info, bibtex_content | |
# Create the processor instance | |
processor = CombinedProcessor() | |
# Define the Gradio interface | |
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: | |
gr.HTML("""<h1 style="text-align:center">Reversed Zotero</h1>""") | |
text_input = gr.Textbox(label="Your unstructured bibliography", type="text", lines=10) | |
text_button = gr.Button("Process Text") | |
style_output = gr.Textbox(label="Bibliographic Style", lines=2) | |
bibtex_output = gr.Textbox(label="BibTeX Entries", lines=15) | |
export_button = gr.Button("Export BibTeX") | |
export_output = gr.File(label="Exported BibTeX File") | |
text_button.click(processor.process, inputs=text_input, outputs=[style_output, bibtex_output]) | |
export_button.click(save_bibtex, inputs=[bibtex_output], outputs=[export_output]) | |
if __name__ == "__main__": | |
demo.queue().launch() |