import os import gradio as gr from transformers import pipeline models = [ 'borabuyukbas/distilbert-base-turkish-cased-product-names-1e-labels-exp', 'borabuyukbas/distilbert-base-turkish-cased-product-names-1e-labels', 'borabuyukbas/distilbert-base-turkish-cased-product-names-5e-labels', 'borabuyukbas/distilbert-base-turkish-cased-product-names-1e', 'borabuyukbas/distilbert-base-turkish-cased-product-names-2e', 'borabuyukbas/distilbert-base-turkish-cased-product-names-10e', 'borabuyukbas/distilbert-base-turkish-cased-product-names', 'borabuyukbas/bert-base-turkish-cased-product-names-v4', 'borabuyukbas/bert-base-turkish-cased-product-names-v3', 'borabuyukbas/bert-base-turkish-cased-product-names-v2', 'borabuyukbas/bert-base-turkish-cased-product-names', ] staple_check = pipeline( "text-classification", model="borabuyukbas/distilbert-base-turkish-cased-product-names-filter_v2", use_auth_token=os.environ['HF_TOKEN'], ) pipelines = [pipeline( "zero-shot-classification", model=model, use_auth_token=os.environ['HF_TOKEN'], ) for model in models] default_categories = ['antep fıstığı', 'armut', 'ay çekirdeği', 'ayçiçek yağı', 'ayran', 'ayva', 'badem i̇çi', 'baharat', 'baklava', 'bal', 'balık', 'bebek sütü (toz karışım)', 'beyaz lahana', 'beyaz peynir', 'bisküvi', 'buğday unu', 'bulgur', 'çarliston biber', 'çay', 'ceviz i̇çi', 'çikolata krem', 'çikolata tablet', 'çilek', 'cipsler', 'dana eti', 'dereotu', 'barbunya', 'dolmalık biber', 'domates', 'dondurma', 'ekmek', 'ekmek hamuru (yufka)', 'elma', 'erik', 'fındık ezmesi', 'fındık i̇çi', 'gazoz meyveli', 'gofret', 'havuç', 'hazır çorbalar', 'hazır et yemekleri', 'hazır kahve', 'kahve', 'hazır pakette toz tatlılar (puding)', 'hazır sütlü tatlılar', 'ispanak', 'kabak', 'kabak çekirdeği', 'kabartma maddeleri', 'kağıtlı şeker', 'kakao', 'kakaolu toz i̇çecekler', 'karnabahar', 'karpuz', 'kaşar peyniri', 'kavun', 'kayısı', 'kek', 'kesme şeker', 'ketçap', 'kiraz', 'kırmızı lahana', 'kırmızı turp', 'kivi', 'kola', 'konserve balık', 'konserveler', 'kraker', 'krem peynir', 'kuru fasulye', 'kuru kayısı', 'kuru soğan', 'kuzu eti', 'leblebi', 'limon', 'lokum', 'maden suyu ve sodası', 'makarna', 'mandalina', 'mantar', 'margarin', 'marul', 'maydanoz', 'mayonez', 'mercimek', 'meyve suyu', 'mısırözü yağı', 'muz', 'nane', 'nar', 'nohut', 'pasta', 'patates', 'patlıcan', 'pekmez', 'pırasa', 'pirinç', 'portakal', 'bitki ve meyve çayı (poşet)', 'reçel', 'roka', 'sakatat', 'sakız', 'salam', 'salatalık', 'salça', 'sarımsak', 'şeftali', 'şehriye', 'sirke', 'sivri biber', 'soğuk çay', 'sosis', 'su', 'sucuk', 'süt', 'tahıl gevreği', 'tahin', 'tahin helvası', 'tavuk eti', 'taze fasulye', 'tere', 'tereyağı (kahvaltılık)', 'toz şeker', 'tulum peyniri', 'tuz', 'üzüm', 'yer fıstığı', 'yeşil soğan', 'yoğurt', 'yumurta', 'zeytin', 'zeytinyağı', 'kuru üzüm'] def predict(model_choice, text, candidate_labels): staple_result = staple_check(text.strip().lower()) if (staple_result[0]["label"] == "GIDA"): pipeline = pipelines[model_choice] predictions = pipeline(text.strip().lower(), candidate_labels.strip().lower().split("\n")) return [{ staple_result[0]["label"]: staple_result[0]["score"] }, { label: predictions['scores'][i] for (i, label) in enumerate(predictions['labels']) }] else: return [{ staple_result[0]["label"]: staple_result[0]["score"] }, { }] gr.Interface( predict, inputs=[ gr.Dropdown(label="Select Model", choices=models, value=models[0], type="index", interactive=True), gr.Textbox(label="Product name"), gr.Textbox( label="Categories", value="\n".join(default_categories), ) ], outputs=[gr.outputs.Label(label="Staple Food Check"), gr.outputs.Label(label="Staple Food Classification", num_top_classes=10)], title="Zero-shot Turkish Product Categorization", ).launch()