Atharv Subhekar commited on
Commit
592981e
1 Parent(s): f76d03b

Application update

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -7,15 +7,10 @@ Original file is located at
7
  https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
8
  """
9
 
10
- #!pip install gradio --quiet
11
- #!pip install -Uq transformers datasets timm accelerate evaluate
12
-
13
- import subprocess
14
- # subprocess.run('pip3 install datasets timm cv2 huggingface_hub torch pillow matplotlib' ,shell=True)
15
-
16
  import gradio as gr
17
- from huggingface_hub import hf_hub_download
18
  from safetensors.torch import load_model
 
 
19
  from datasets import load_dataset
20
  import torch
21
  import torchvision.transforms as T
@@ -23,8 +18,17 @@ import cv2
23
  import matplotlib.pyplot as plt
24
  import numpy as np
25
  from PIL import Image
26
- from timm import create_model
27
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
@@ -52,8 +56,56 @@ def one_hot_decoding(labels):
52
  true_labels.append(id2label[i])
53
  return true_labels
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def model_output(image):
56
- image = cv2.imread(image)
57
  PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
58
 
59
  img_size = (224,224)
@@ -72,8 +124,27 @@ def model_output(image):
72
  pred_labels = one_hot_decoding(predictions)
73
  output_text = " ".join(pred_labels)
74
 
75
- return output_text
76
-
77
- app = gr.Interface(fn=model_output, inputs="image", outputs="text")
78
- app.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
 
7
  https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
8
  """
9
 
 
 
 
 
 
 
10
  import gradio as gr
 
11
  from safetensors.torch import load_model
12
+ from timm import create_model
13
+ from huggingface_hub import hf_hub_download
14
  from datasets import load_dataset
15
  import torch
16
  import torchvision.transforms as T
 
18
  import matplotlib.pyplot as plt
19
  import numpy as np
20
  from PIL import Image
21
+ import os
22
 
23
+ from langchain_community.document_loaders import TextLoader
24
+ from langchain_community.vectorstores import FAISS
25
+ from langchain_community.embeddings import HuggingFaceEmbeddings
26
+ from langchain.text_splitter import CharacterTextSplitter
27
+ from langchain_core.output_parsers import StrOutputParser
28
+ from langchain_core.runnables import RunnablePassthrough
29
+ from langchain_fireworks import ChatFireworks
30
+ from langchain_core.prompts import ChatPromptTemplate
31
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
32
 
33
 
34
  safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
 
56
  true_labels.append(id2label[i])
57
  return true_labels
58
 
59
+ def ragChain():
60
+ """
61
+ function: creates a rag chain
62
+ output: rag chain
63
+ """
64
+ loader = TextLoader("document.txt")
65
+ docs = loader.load()
66
+
67
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
68
+ docs = text_splitter.split_documents(docs)
69
+
70
+ vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True)
71
+ retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})
72
+
73
+ api_key = os.getenv("FIREWORKS_API_KEY")
74
+ llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key)
75
+
76
+ prompt = ChatPromptTemplate.from_messages(
77
+ [
78
+ (
79
+ "system",
80
+ """You are a knowledgeable landscape deforestation analyst.
81
+ """
82
+ ),
83
+ (
84
+ "human",
85
+ """First mention the detected labels only with short description.
86
+ Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation.
87
+ Don't include conversational messages.
88
+ """,
89
+ ),
90
+
91
+ ("human", "{context}, {question}"),
92
+ ]
93
+ )
94
+
95
+ rag_chain = (
96
+ {
97
+ "context": retriever,
98
+ "question": RunnablePassthrough()
99
+ }
100
+ | prompt
101
+ | llm
102
+ | StrOutputParser()
103
+ )
104
+
105
+ return rag_chain
106
+
107
  def model_output(image):
108
+
109
  PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
110
 
111
  img_size = (224,224)
 
124
  pred_labels = one_hot_decoding(predictions)
125
  output_text = " ".join(pred_labels)
126
 
127
+ query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels."
128
+
129
+ return query
130
+
131
+ def generate_response(rag_chain, query):
132
+ """
133
+ input: rag chain, query
134
+ function: generates response using llm and knowledge base
135
+ output: generated response by the llm
136
+ """
137
+ return rag_chain.invoke(f"{query}")
138
+
139
+ def main(image):
140
+ query = model_output(image)
141
+ chain = ragChain()
142
+ output = generate_response(chain, query)
143
+ return output
144
+ title = "Satellite Image Landscape Analysis for Deforestation"
145
+ description = "This bot will take any satellite image and analyze the factors which lead to deforestation by identify the landscape based on forest areas, roads, habitation, water etc."
146
+ app = gr.Interface(fn=main, inputs="image", outputs="text", title=title,
147
+ description=description,
148
+ examples=[["sampleimages/train_142.jpg"], ["sampleimages/train_32.jpg"],["sampleimages/train_59.jpg"], ["sampleimages/train_67.jpg"],["sampleimages/train_75.jpg"],["sampleimages/train_92.jpg"],["sampleimages/random_satellite.jpg"]])
149
+ app.launch(share = True)
150
 
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  transformers
2
  datasets
3
- timm
4
  langchain-fireworks
5
  langchain_core
6
  langchain_community
@@ -10,4 +10,4 @@ safetensors
10
  torch
11
  torchvision
12
  opencv-python
13
- pillow
 
1
  transformers
2
  datasets
3
+ Time
4
  langchain-fireworks
5
  langchain_core
6
  langchain_community
 
10
  torch
11
  torchvision
12
  opencv-python
13
+ pillow
sample_images/Screenshot 2024-06-28 at 1.35.57/342/200/257PM.png ADDED
~$cumentation.docx DELETED
Binary file (162 Bytes)