Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Create app.py
Browse files
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,77 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import gradio as gr
         | 
| 4 | 
            +
            import spaces
         | 
| 5 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from format.format_output import format_output
         | 
| 8 | 
            +
            from validate.validate_ingredients import validate_ingredients
         | 
| 9 | 
            +
            from device.get_device_id import get_device_id
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            tokenizer = AutoTokenizer.from_pretrained("Ashikan/dut-recipe-generator")
         | 
| 12 | 
            +
            model = AutoModelForCausalLM.from_pretrained("Ashikan/dut-recipe-generator")
         | 
| 13 | 
            +
            pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=get_device_id())
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            @spaces.GPU
         | 
| 17 | 
            +
            def perform_model_inference(ingredients_list):
         | 
| 18 | 
            +
                for ingredient_index in range(len(ingredients_list)):
         | 
| 19 | 
            +
                    ingredients_list[ingredient_index] = ingredients_list[ingredient_index].strip()
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                input_text = '{"prompt": ' + json.dumps(ingredients_list)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                output = pipe(input_text, max_length=1024, temperature=0.1, do_sample=True, truncation=True)[0]["generated_text"]
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                return format_output(output)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def generate_recipe(ingredients):
         | 
| 29 | 
            +
                ingredients_list = ingredients.lower().split(',')
         | 
| 30 | 
            +
                is_ingredients_valid = validate_ingredients(ingredients_list)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                if is_ingredients_valid:
         | 
| 33 | 
            +
                    generated_text = perform_model_inference(ingredients_list)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    return {
         | 
| 36 | 
            +
                        generated_recipe: gr.Markdown(value=generated_text, label="Generated Recipe",
         | 
| 37 | 
            +
                                                      elem_id="recipe-container", visible=True)
         | 
| 38 | 
            +
                    }
         | 
| 39 | 
            +
                else:
         | 
| 40 | 
            +
                    error_text = "## Invalid ingredients. Please include at least 2 ingredients in a comma separated list. e.g. brown rice, onions, garlic"
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    return {
         | 
| 43 | 
            +
                        generated_recipe: gr.Markdown(value=error_text, elem_id="recipe-container", visible=True)
         | 
| 44 | 
            +
                    }
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            with gr.Blocks(css="./css/styles.css") as recipegen:
         | 
| 48 | 
            +
                #gr.Image("./assets/dut.png", interactive=False, show_share_button=False, show_download_button=False,
         | 
| 49 | 
            +
                 #        show_fullscreen_button=False, show_label=False, elem_id="dut-logo", height=256)
         | 
| 50 | 
            +
                gr.Markdown("#  Recipe Generator", elem_id="header")
         | 
| 51 | 
            +
                gr.Markdown("### An AI Model Attempting To Produce Healthier, Diabetic-Friendly Recipes",
         | 
| 52 | 
            +
                            elem_id="header-sub-heading")
         | 
| 53 | 
            +
                gr.Markdown("Start by entering a comma-separated list of ingredients below.", elem_id="header-instructions")
         | 
| 54 | 
            +
                with gr.Column() as column:
         | 
| 55 | 
            +
                    user_ingredients = gr.Textbox(label="Ingredients", autofocus=True, max_lines=1, elem_id="ingredients-input")
         | 
| 56 | 
            +
                    generate_button = gr.Button(value="Generate")
         | 
| 57 | 
            +
                with gr.Column():
         | 
| 58 | 
            +
                    generated_recipe = gr.Markdown(visible=True)
         | 
| 59 | 
            +
                examples = gr.Examples(
         | 
| 60 | 
            +
                    elem_id="examples",
         | 
| 61 | 
            +
                    examples=[
         | 
| 62 | 
            +
                        "sweet potato, mushrooms, cheese, garlic",
         | 
| 63 | 
            +
                        "chicken breast, chili, onion, tomato, parmesan cheese",
         | 
| 64 | 
            +
                        "strawberries, vanilla, honey, rolled oats, almonds, butter",
         | 
| 65 | 
            +
                        "hake, spring onion, lemon"
         | 
| 66 | 
            +
                    ],
         | 
| 67 | 
            +
                    inputs=[user_ingredients]
         | 
| 68 | 
            +
                )
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                generate_button.click(
         | 
| 71 | 
            +
                    fn=generate_recipe,
         | 
| 72 | 
            +
                    inputs=[user_ingredients],
         | 
| 73 | 
            +
                    outputs=[generated_recipe]
         | 
| 74 | 
            +
                )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            recipegen.launch(share=True)
         | 
| 77 | 
            +
             |