Tonic commited on
Commit
3a69f52
·
verified ·
1 Parent(s): bb75784

Create app.py

Browse files

initial commit

Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ import json
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ # Load model and tokenizer
8
+ model_name = "Salesforce/xLAM-7b-r"
9
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto", trust_remote_code=True)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ # Set random seed for reproducibility
13
+ torch.random.manual_seed(0)
14
+
15
+ # Task and format instructions
16
+ task_instruction = """
17
+ Based on the previous context and API request history, generate an API request or a response as an AI assistant.""".strip()
18
+
19
+ format_instruction = """
20
+ The output should be of the JSON format, which specifies a list of generated function calls. The example format is as follows, please make sure the parameter type is correct. If no function call is needed, please make
21
+ tool_calls an empty list "[]".
22
+ ```
23
+ {"thought": "the thought process, or an empty string", "tool_calls": [{"name": "api_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}]}
24
+ ```
25
+ """.strip()
26
+
27
+ def convert_to_xlam_tool(tools):
28
+ if isinstance(tools, dict):
29
+ return {
30
+ "name": tools["name"],
31
+ "description": tools["description"],
32
+ "parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
33
+ }
34
+ elif isinstance(tools, list):
35
+ return [convert_to_xlam_tool(tool) for tool in tools]
36
+ else:
37
+ return tools
38
+
39
+ def build_conversation_history_prompt(conversation_history: str):
40
+ parsed_history = []
41
+ for step_data in conversation_history:
42
+ parsed_history.append({
43
+ "step_id": step_data["step_id"],
44
+ "thought": step_data["thought"],
45
+ "tool_calls": step_data["tool_calls"],
46
+ "next_observation": step_data["next_observation"],
47
+ "user_input": step_data['user_input']
48
+ })
49
+
50
+ history_string = json.dumps(parsed_history)
51
+ return f"\n[BEGIN OF HISTORY STEPS]\n{history_string}\n[END OF HISTORY STEPS]\n"
52
+
53
+ def build_prompt(task_instruction: str, format_instruction: str, tools: list, query: str, conversation_history: list):
54
+ prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
55
+ prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n"
56
+ prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
57
+ prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
58
+
59
+ if len(conversation_history) > 0:
60
+ prompt += build_conversation_history_prompt(conversation_history)
61
+ return prompt
62
+
63
+ def generate_response(tools_input, query):
64
+ try:
65
+ tools = json.loads(tools_input)
66
+ except json.JSONDecodeError:
67
+ return "Error: Invalid JSON format for tools input."
68
+
69
+ xlam_format_tools = convert_to_xlam_tool(tools)
70
+ conversation_history = []
71
+ content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query, conversation_history)
72
+
73
+ messages = [
74
+ {'role': 'user', 'content': content}
75
+ ]
76
+
77
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
78
+ outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
79
+ agent_action = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
80
+
81
+ return agent_action
82
+
83
+ # Gradio interface
84
+ iface = gr.Interface(
85
+ fn=generate_response,
86
+ inputs=[
87
+ gr.Textbox(
88
+ label="Available Tools (JSON format)",
89
+ lines=10,
90
+ value=json.dumps([
91
+ {
92
+ "name": "get_weather",
93
+ "description": "Get the current weather for a location",
94
+ "parameters": {
95
+ "type": "object",
96
+ "properties": {
97
+ "location": {
98
+ "type": "string",
99
+ "description": "The city and state, e.g. San Francisco, New York"
100
+ },
101
+ "unit": {
102
+ "type": "string",
103
+ "enum": ["celsius", "fahrenheit"],
104
+ "description": "The unit of temperature to return"
105
+ }
106
+ },
107
+ "required": ["location"]
108
+ }
109
+ },
110
+ {
111
+ "name": "search",
112
+ "description": "Search for information on the internet",
113
+ "parameters": {
114
+ "type": "object",
115
+ "properties": {
116
+ "query": {
117
+ "type": "string",
118
+ "description": "The search query, e.g. 'latest news on AI'"
119
+ }
120
+ },
121
+ "required": ["query"]
122
+ }
123
+ }
124
+ ], indent=2)
125
+ ),
126
+ gr.Textbox(label="User Query", lines=2, value="What's the weather like in New York in fahrenheit?")
127
+ ],
128
+ outputs=gr.Textbox(label="Generated Response", lines=10),
129
+ title="xLAM-7b-r API Request Generator",
130
+ description="Enter available tools in JSON format and a user query to generate an API request or response.",
131
+ )
132
+
133
+ if __name__ == "__main__":
134
+ iface.launch()