zting commited on
Commit
2ba8a65
1 Parent(s): 137ec6d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from langchain.agents import initialize_agent,AgentType
4
+ from langchain.chat_models import AzureChatOpenAI
5
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory
6
+ import torch
7
+ from transformers import BlipProcessor,BlipForConditionalGeneration
8
+ import requests
9
+ from PIL import Image
10
+ from langchain.tools import BaseTool
11
+ from langchain.chains import LLMChain
12
+ from langchain import PromptTemplate, FewShotPromptTemplate
13
+
14
+ OPENAI_API_KEY=os.getenv("OPENAI_API_KEY")
15
+ OPENAI_API_BASE=os.getenv("OPENAI_API_BASE")
16
+ DEP_NAME=os.getenv("deployment_name")
17
+ llm=AzureChatOpenAI(deployment_name=DEP_NAME,openai_api_base=OPENAI_API_BASE,openai_api_key=OPENAI_API_KEY,openai_api_version="2023-03-15-preview",model_name="gpt-3.5-turbo")
18
+
19
+ image_to_text_model="Salesforce/blip-image-captioning-large"
20
+ device= 'cuda' if torch.cuda.is_available() else 'cpu'
21
+
22
+ processor=BlipProcessor.from_pretrained(image_to_text_model)
23
+
24
+ model=BlipForConditionalGeneration.from_pretrained(image_to_text_model).to(device)
25
+
26
+ def descImage(image_url):
27
+ image_obj=Image.open(image_url).convert('RGB')
28
+ inputs=processor(image_obj,return_tensors='pt').to(device)
29
+ outputs=model.generate(**inputs)
30
+ return processor.decode(outputs[0],skip_special_tokens=True)
31
+
32
+ def toChinese(en:str):
33
+ pp="翻译下面语句到中文\n{en}"
34
+ prompt = PromptTemplate(
35
+ input_variables=["en"],
36
+ template=pp
37
+ )
38
+ llchain=LLMChain(llm=llm,prompt=prompt)
39
+ return llchain.run(en)
40
+
41
+ class DescTool(BaseTool):
42
+ name="Describe Image Tool"
43
+ description="use this tool to describe an image"
44
+
45
+ def _run(self,url:str):
46
+ description=descImage(url)
47
+ return description
48
+ def _arun(
49
+ self,query:str):
50
+ raise NotImplementedError('未实现')
51
+
52
+ tools=[DescTool()]
53
+ memory=ConversationBufferWindowMemory(
54
+ memory_key='chat_history',
55
+ k=5,
56
+ return_messages=True
57
+ )
58
+
59
+ agent=initialize_agent(
60
+ agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
61
+ tools=tools,
62
+ llm=llm,
63
+ verbose=False,
64
+ max_iterations=3,
65
+ early_stopping_method='generate',
66
+ memory=memory
67
+ )
68
+ def reset_user_input():
69
+ return gr.update(value='')
70
+ def reset_state():
71
+ return [], []
72
+
73
+ def predict(file,input, chatbot,history):
74
+ input1=f""+input+"\n"+file
75
+ out=agent(input1)
76
+ anws=toChinese(out['output'])
77
+ chatbot.append(input)
78
+ chatbot[-1] = (input, anws)
79
+ yield chatbot, history
80
+ return
81
+
82
+ with gr.Blocks(css=".chat-blocks{height:calc(100vh - 332px);} .mychat{flex:1} .mychat .block{min-height:100%} .mychat .block .wrap{max-height: calc(100vh - 330px);} .myinput{flex:initial !important;min-height:180px}") as demo:
83
+ title = '图像识别'
84
+ demo.title=title
85
+ with gr.Column(elem_classes="chat-blocks"):
86
+ with gr.Row(elem_classes="mychat"):
87
+ file = gr.Image(type="filepath")
88
+ chatbot = gr.Chatbot(label="图像识别", show_label=False)
89
+ with gr.Column(elem_classes="myinput"):
90
+ user_input = gr.Textbox(show_label=False, placeholder="请输入...", lines=1).style(
91
+ container=False)
92
+ submitBtn = gr.Button("提交", variant="primary", elem_classes="btn1")
93
+ emptyBtn = gr.Button("清除历史").style(container=False)
94
+
95
+ history = gr.State([])
96
+ submitBtn.click(predict, [file,user_input, chatbot,history], [chatbot, history],
97
+ show_progress=True)
98
+ submitBtn.click(reset_user_input, [], [user_input])
99
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
100
+
101
+
102
+ demo.queue(api_open=False,concurrency_count=20).launch()