| from fastai.vision import * | |
| from duckduckgo_search import ddg_images | |
| from fastcore.all import * | |
| from fastbook import * | |
| from fastdownload import download_url | |
| from PIL import Image | |
| import os | |
| def search_images(term, max_images=60): | |
| print(f"Searching for '{term}'") | |
| return L(ddg_images(term, max_results=max_images)).itemgot('image') | |
| def download_images(dest, urls): | |
| print(f"Downloading {len(urls)} images") | |
| for url in urls: | |
| if not Path(dest/url).exists(): | |
| try: | |
| download_url(url, dest, show_progress=False) | |
| except Exception as e: | |
| print(f"Error while downloading photo: {e}") | |
| def resize_image(fpath, max_size): | |
| ext = Path(fpath).suffix | |
| new_fpath = fpath.split(".") | |
| if len(new_fpath) > 1: | |
| ext = new_fpath[-1] | |
| new_fpath = ".".join(new_fpath[:-1]) | |
| new_fpath = f"{new_fpath}_{max_size}.{ext}" | |
| else: | |
| new_fpath = fpath | |
| print(new_fpath) | |
| if not Path(new_fpath).exists(): | |
| try: | |
| img = Image.open(fpath) | |
| img = img.resize((max_size, max_size)) | |
| img.save(f"{new_fpath}") | |
| except: | |
| pass | |
| os.remove(fpath) | |
| def resize_images(path, max_size): | |
| print(f"resizing images") | |
| for fpath in os.listdir(path=path): | |
| fpath = f"{path}/{fpath}" | |
| if "_400" in fpath: | |
| continue | |
| resize_image(fpath, max_size) | |
| searches = 'football ball', 'basketball ball', 'tennis ball' | |
| path = Path('.') | |
| for o in searches: | |
| dest = (path/o) | |
| if not dest.exists(): | |
| dest.mkdir(exist_ok=True, parents=True) | |
| download_images(dest, urls=search_images(f'{o} photo')) | |
| resize_images(dest, max_size=400) | |
| dls = DataBlock( | |
| blocks=(ImageBlock, CategoryBlock), | |
| get_items=get_image_files, | |
| splitter=RandomSplitter(valid_pct=0.2, seed=42), | |
| get_y=parent_label, | |
| item_tfms=[Resize(192, method='squish')] | |
| ).dataloaders(path, bs=32) | |
| dls.show_batch(max_n=20) | |
| learn = vision_learner(dls, resnet18, metrics=error_rate) | |
| learn.fine_tune(6) | |
| learn.export() |