subhuatharva's picture
Update app.py
cb1234d verified
raw
history blame contribute delete
No virus
4.93 kB
# -*- coding: utf-8 -*-
"""satellite_app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
"""
import gradio as gr
from safetensors.torch import load_model
from timm import create_model
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import torch
import torchvision.transforms as T
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_fireworks import ChatFireworks
from langchain_core.prompts import ChatPromptTemplate
from transformers import AutoModelForImageClassification, AutoImageProcessor
safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
model_name = 'swin_s3_base_224'
# intialize the model
model = create_model(
model_name,
num_classes=17
)
load_model(model,safe_tensors)
def one_hot_decoding(labels):
class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze']
id2label = {idx:c for idx,c in enumerate(class_names)}
id_list = []
for idx,i in enumerate(labels):
if i == 1:
id_list.append(idx)
true_labels = []
for i in id_list:
true_labels.append(id2label[i])
return true_labels
def ragChain():
"""
function: creates a rag chain
output: rag chain
"""
loader = TextLoader("document.txt")
docs = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(docs)
vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True)
retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})
api_key = os.getenv("FIREWORKS_API_KEY")
llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a knowledgeable landscape deforestation analyst.
"""
),
(
"human",
"""First mention the detected labels only with short description.
Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation.
Don't include conversational messages.
""",
),
("human", "{context}, {question}"),
]
)
rag_chain = (
{
"context": retriever,
"question": RunnablePassthrough()
}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
def model_output(image):
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
img_size = (224,224)
test_tfms = T.Compose([
T.Resize(img_size),
T.ToTensor(),
])
img = test_tfms(PIL_image)
with torch.no_grad():
logits = model(img.unsqueeze(0))
predictions = logits.sigmoid() > 0.5
predictions = predictions.float().numpy().flatten()
pred_labels = one_hot_decoding(predictions)
output_text = " ".join(pred_labels)
query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels."
return query
def generate_response(rag_chain, query):
"""
input: rag chain, query
function: generates response using llm and knowledge base
output: generated response by the llm
"""
return rag_chain.invoke(f"{query}")
def main(image):
query = model_output(image)
chain = ragChain()
output = generate_response(chain, query)
return output
title = "Satellite Image Landscape Analysis for Deforestation"
description = "This bot will take any satellite image and analyze the factors which lead to deforestation by identify the landscape based on forest areas, roads, habitation, water etc."
app = gr.Interface(fn=main, inputs="image", outputs="text", title=title,
description=description,
examples=[["sample_images/train_142.jpg"], ["sample_images/train_32.jpg"],["sample_images/random_satellite3.png"],["sample_images/random_satellite2.png"],["sample_images/train_75.jpg"],["sample_images/train_92.jpg"],["sample_images/random_satellite.png"]])
app.launch(share = True)