import gradio as gr from transformers import AutoTokenizer from unsloth import FastLanguageModel from transformers import TextStreamer import torch # Define the Alpaca-style prompt template alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: {} ### Input: {} ### Response: {} """ # Load the model and tokenizer model_name = "kenzykhaled/model" device = "cuda" if torch.cuda.is_available() else "cpu" model, tokenizer = FastLanguageModel.from_pretrained(model_name) FastLanguageModel.for_inference(model) # Define the generation function def generate(instruction, passage, max_tokens=160): try: inputs = tokenizer( [ alpaca_prompt.format( instruction, passage, "" ) ], return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(device) text_streamer = TextStreamer(tokenizer, skip_prompt=True) outputs = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, streamer=text_streamer, max_new_tokens=max_tokens, pad_token_id=tokenizer.eos_token_id ) return tokenizer.decode(outputs[0], skip_special_tokens=True).strip() except Exception as e: return f"Error: {str(e)}" # Create Gradio interface interface = gr.Interface( fn=generate, inputs=[ gr.Textbox(label="Instruction", placeholder="Enter the task instruction here"), gr.Textbox(label="Passage", placeholder="Enter the input passage here"), gr.Slider(50, 512, step=10, value=160, label="Max Tokens") ], outputs=gr.Textbox(label="Generated Response"), title="Fine-Tuned Model", description="Generate responses based on input instructions and passages. Use the slider to control the maximum tokens generated.", examples=[ [ "Generate a multiple-choice question (MCQ) based on the passage, provide options, and indicate the correct option.", "Photosynthesis is the process by which green plants, algae, and some bacteria convert sunlight into energy. This process primarily takes place in the chloroplasts of plant cells, where chlorophyll absorbs sunlight.", 160 ] ] ) # Launch the app if __name__ == "__main__": interface.launch(server_name="0.0.0.0", server_port=7860)