|
import logging |
|
|
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed |
|
import torch |
|
import spaces |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
level=logging.getLevelName("INFO"), |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
) |
|
|
|
|
|
EXAMPLES = [ |
|
["images/ingredients_1.jpg", "24.36% chocolat noir 63% origine non UE (cacao, sucre, beurre de cacao, émulsifiant léci - thine de colza, vanille bourbon gousse), œuf, farine de blé, beurre, sucre, miel, sucre perlé, levure chimique, zeste de citron."], |
|
["images/ingredients_2.jpg", "farine de froment, œufs, lait entier pasteurisé Aprigine: France), sucre, sel, extrait de vanille naturelle Conditi( 35."], |
|
|
|
["images/ingredients_4.jpg", "Eau de noix de coco 93.9%, Arôme natutel de fruit"], |
|
["images/ingredients_5.jpg", "Sucre, pâte de cacao, beurre de cacao, émulsifiant: léci - thines (soja). Peut contenir des traces de lait. Chocolat noir: cacao: 50% minimum. À conserver à l'abri de la chaleur et de l'humidité. Élaboré en France."], |
|
] |
|
|
|
MODEL_ID = "openfoodfacts/spellcheck-mistral-7b" |
|
|
|
PRESENTATION = """# 🍊 Ingredients Spellcheck - Open Food Facts |
|
|
|
Open Food Facts is a non-profit organization building the largest open food database in the world. 🌎 |
|
|
|
When a product is added to the database, all its details, such as allergens, additives, or nutritional values, are either wrote down by the contributor, |
|
or automatically extracted from the product pictures using OCR. |
|
|
|
However, it often happens the information extracted by OCR contains typos and errors due to bad quality pictures: low-definition, curved product, light reflection, etc... |
|
|
|
To solve this problem, we developed an 🍊 **Ingredient Spellcheck** 🍊, a model capable of correcting typos in a list of ingredients following a defined guideline. |
|
The model, based on Mistral-7B-v0.3, was fine-tuned on thousand of corrected lists of ingredients extracted from the database. More information in the model card. |
|
|
|
## Project in progress |
|
|
|
## 👇 Links |
|
|
|
* Open Food Facts website: https://world.openfoodfacts.org/discover |
|
* Open Food Facts Github: https://github.com/openfoodfacts |
|
* Spellcheck project: https://github.com/openfoodfacts/openfoodfacts-ai/tree/develop/spellcheck |
|
* Model card: https://huggingface.co/openfoodfacts/spellcheck-mistral-7b |
|
""" |
|
|
|
|
|
zero = torch.Tensor([0]).cuda() |
|
|
|
|
|
set_seed(42) |
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info(f"Load tokenizer from {MODEL_ID}.") |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
logging.info(f"Load model from {MODEL_ID}.") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
device_map="auto", |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def process(text: str) -> str: |
|
"""Take the text, the tokenizer and the causal model and generate the correction.""" |
|
prompt = prepare_instruction(text) |
|
input_ids = tokenizer( |
|
prompt, |
|
add_special_tokens=True, |
|
return_tensors="pt" |
|
).input_ids |
|
with torch.no_grad(): |
|
output = model.generate( |
|
input_ids.to(zero.device), |
|
do_sample=False, |
|
max_new_tokens=512, |
|
) |
|
return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt):].strip() |
|
|
|
|
|
def prepare_instruction(text: str) -> str: |
|
"""Prepare instruction prompt for fine-tuning and inference. |
|
Identical to instruction during training. |
|
|
|
Args: |
|
text (str): List of ingredients |
|
|
|
Returns: |
|
str: Instruction. |
|
""" |
|
instruction = ( |
|
"###Correct the list of ingredients:\n" |
|
+ text |
|
+ "\n\n###Correction:\n" |
|
) |
|
return instruction |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown(PRESENTATION) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image = gr.Image(type="pil", label="image_input", interactive=False) |
|
|
|
with gr.Column(): |
|
ingredients = gr.Textbox(label="List of ingredients") |
|
spellcheck_button = gr.Button(value='Run spellcheck') |
|
correction = gr.Textbox(label="Correction", interactive=False) |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
fn=process, |
|
examples=EXAMPLES, |
|
inputs=[ |
|
image, |
|
ingredients, |
|
], |
|
) |
|
spellcheck_button.click( |
|
fn=process, |
|
inputs=[ingredients], |
|
outputs=[correction] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo.launch() |
|
|