--- license: apache-2.0 language: - en pipeline_tag: image-to-text tags: - mplug-owl --- # Usage ## Get the latest codebase from Github ```Bash git clone https://github.com/X-PLUG/mPLUG-Owl.git ``` ## Model initialization ```Python from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor pretrained_ckpt = 'MAGAer13/mplug-owl-llama-7b' model = MplugOwlForConditionalGeneration.from_pretrained( pretrained_ckpt, torch_dtype=torch.bfloat16, ) image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt) tokenizer = MplugOwlTokenizer.from_pretrained(pretrained_ckpt) processor = MplugOwlProcessor(image_processor, tokenizer) ``` ## Model inference Prepare model inputs. ```Python # We use a human/AI template to organize the context as a multi-turn conversation. # denotes an image placehold. prompts = [ '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. Human: Human: Explain why this meme is funny. AI: '''] # The image paths should be placed in the image_list and kept in the same order as in the prompts. # We support urls, local file paths and base64 string. You can custom the pre-process of images by modifying the mplug_owl.modeling_mplug_owl.ImageProcessor image_list = ['https://xxx.com/image.jpg'] ``` Get response. ```Python # generate kwargs (the same in transformers) can be passed in the do_generate() generate_kwargs = { 'do_sample': True, 'top_k': 5, 'max_length': 512 } from PIL import Image images = [Image.open(_) for _ in image_list] inputs = processor(text=prompts, images=images, return_tensors='pt') inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): res = model.generate(**inputs, **generate_kwargs) sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) print(sentence) ```