sibimani commited on
Commit
376da64
·
verified ·
1 Parent(s): 0ce2647

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +80 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ from peft import PeftModel
3
+ import torch
4
+ import os
5
+ import re
6
+ import json
7
+ from flask import Flask, request, jsonify
8
+
9
+ app = Flask(__name__)
10
+
11
+ script_dir = os.path.dirname(os.path.abspath(__file__))
12
+ adapter_path = os.path.join(script_dir, "lora-playwright-adapter")
13
+ model_name = "mistralai/Mistral-7B-Instruct-v0.1"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+
16
+ # Ensure padding token
17
+ if tokenizer.pad_token is None:
18
+ tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
19
+
20
+ base_model = AutoModelForCausalLM.from_pretrained(
21
+ model_name,
22
+ torch_dtype=torch.float16,
23
+ device_map="auto"
24
+ )
25
+
26
+ model = PeftModel.from_pretrained(base_model, adapter_path)
27
+ model.eval()
28
+
29
+ # Example test goals (you can extend this to load from Excel/CSV)
30
+
31
+ def generate_action_sequence(test_goals):
32
+ full_response = []
33
+ for goal in test_goals:
34
+
35
+ prompt = f"Goal: {goal}\nReturn only one valid JSON array, no explanation.\nOutput:"
36
+
37
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
38
+
39
+ with torch.no_grad():
40
+ outputs = model.generate(
41
+ input_ids=inputs["input_ids"],
42
+ attention_mask=inputs["attention_mask"],
43
+ max_new_tokens=150,
44
+ pad_token_id=tokenizer.pad_token_id,
45
+ top_p=1.0,
46
+ repetition_penalty=1.2,
47
+ do_sample=False
48
+ )
49
+
50
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+ # Extract JSON array part
52
+ match = re.search(r'\[.*\]', response, re.DOTALL)
53
+ if match:
54
+ response_text = match.group(0)
55
+ try:
56
+ response_json = json.loads(response_text)
57
+ full_response.extend(response_json)
58
+ except json.JSONDecodeError:
59
+ print(f"Invalid JSON for goal: {goal}")
60
+ else:
61
+ print(f"No JSON found for goal: {goal}")
62
+
63
+ return full_response
64
+
65
+ @app.route("/")
66
+ def health():
67
+ return "OK", 200
68
+
69
+ @app.route("/generate", methods=["POST"])
70
+ def generate():
71
+ data = request.get_json()
72
+ test_goals = data.get("goals", [])
73
+ result = generate_action_sequence(test_goals)
74
+ return jsonify({"result": result})
75
+
76
+
77
+ if __name__ == "__main__":
78
+ port = int(os.environ.get("PORT", 5000))
79
+ app.run(host="0.0.0.0", port=port)
80
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openai
2
+ langchain
3
+ chromadb
4
+ python-dotenv
5
+ rich
6
+ typer