# -*- coding: utf-8 -*- """satellite_app.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27 """ import gradio as gr from safetensors.torch import load_model from timm import create_model from huggingface_hub import hf_hub_download from datasets import load_dataset import torch import torchvision.transforms as T import cv2 import matplotlib.pyplot as plt import numpy as np from PIL import Image import os from langchain_community.document_loaders import TextLoader from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain_fireworks import ChatFireworks from langchain_core.prompts import ChatPromptTemplate from transformers import AutoModelForImageClassification, AutoImageProcessor safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors") model_name = 'swin_s3_base_224' # intialize the model model = create_model( model_name, num_classes=17 ) load_model(model,safe_tensors) def one_hot_decoding(labels): class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze'] id2label = {idx:c for idx,c in enumerate(class_names)} id_list = [] for idx,i in enumerate(labels): if i == 1: id_list.append(idx) true_labels = [] for i in id_list: true_labels.append(id2label[i]) return true_labels def ragChain(): """ function: creates a rag chain output: rag chain """ loader = TextLoader("document.txt") docs = loader.load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) docs = text_splitter.split_documents(docs) vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True) retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5}) api_key = os.getenv("FIREWORKS_API_KEY") llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key) prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are a knowledgeable landscape deforestation analyst. """ ), ( "human", """First mention the detected labels only with short description. Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation. Don't include conversational messages. """, ), ("human", "{context}, {question}"), ] ) rag_chain = ( { "context": retriever, "question": RunnablePassthrough() } | prompt | llm | StrOutputParser() ) return rag_chain def model_output(image): PIL_image = Image.fromarray(image.astype('uint8'), 'RGB') img_size = (224,224) test_tfms = T.Compose([ T.Resize(img_size), T.ToTensor(), ]) img = test_tfms(PIL_image) with torch.no_grad(): logits = model(img.unsqueeze(0)) predictions = logits.sigmoid() > 0.5 predictions = predictions.float().numpy().flatten() pred_labels = one_hot_decoding(predictions) output_text = " ".join(pred_labels) query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels." return query def generate_response(rag_chain, query): """ input: rag chain, query function: generates response using llm and knowledge base output: generated response by the llm """ return rag_chain.invoke(f"{query}") def main(image): query = model_output(image) chain = ragChain() output = generate_response(chain, query) return output title = "Satellite Image Landscape Analysis for Deforestation" description = "This bot will take any satellite image and analyze the factors which lead to deforestation by identify the landscape based on forest areas, roads, habitation, water etc." app = gr.Interface(fn=main, inputs="image", outputs="text", title=title, description=description, examples=[["sample_images/train_142.jpg"], ["sample_images/train_32.jpg"],["sample_images/random_satellite3.png"],["sample_images/random_satellite2.png"],["sample_images/train_75.jpg"],["sample_images/train_92.jpg"],["sample_images/random_satellite.png"]]) app.launch(share = True)