navaneethkrishnan commited on
Commit
ec1fb2e
·
verified ·
1 Parent(s): 2c49fd3

Create core/providers.py

Browse files
Files changed (1) hide show
  1. core/providers.py +105 -0
core/providers.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, re
2
+ from dataclasses import dataclass
3
+ from typing import Dict, Any, Optional, Tuple
4
+
5
+ # OpenAI
6
+ try:
7
+ from openai import OpenAI
8
+ except Exception:
9
+ OpenAI = None
10
+
11
+ # Anthropic
12
+ try:
13
+ import anthropic
14
+ except Exception:
15
+ anthropic = None
16
+
17
+ class ProviderKind:
18
+ OPENAI = "openai"
19
+ ANTHROPIC = "anthropic"
20
+
21
+ @dataclass
22
+ class Provider:
23
+ kind: str
24
+ model: str
25
+ label: str
26
+
27
+ def judge(self, system_prompt: str, user_prompt: str) -> Tuple[Dict[str, Any], Dict[str, int]]:
28
+ raise NotImplementedError
29
+
30
+ class OpenAIProvider(Provider):
31
+ def __init__(self, model: str):
32
+ super().__init__(ProviderKind.OPENAI, model, f"OpenAI/{model}")
33
+ if OpenAI is None:
34
+ raise RuntimeError("openai package not available. Add to requirements.txt")
35
+ self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
36
+
37
+ def judge(self, system_prompt: str, user_prompt: str):
38
+ # Use chat.completions for broader compatibility
39
+ resp = self.client.chat.completions.create(
40
+ model=self.model,
41
+ messages=[
42
+ {"role": "system", "content": system_prompt},
43
+ {"role": "user", "content": user_prompt}
44
+ ],
45
+ temperature=0
46
+ )
47
+ content = resp.choices[0].message.content
48
+ usage = resp.usage
49
+ usage_dict = {
50
+ "prompt": getattr(usage, "prompt_tokens", 0) or getattr(usage, "input_tokens", 0) or 0,
51
+ "completion": getattr(usage, "completion_tokens", 0) or getattr(usage, "output_tokens", 0) or 0,
52
+ "total": getattr(usage, "total_tokens", 0) or 0,
53
+ }
54
+ return _safe_json(content), usage_dict
55
+
56
+ class AnthropicProvider(Provider):
57
+ def __init__(self, model: str):
58
+ super().__init__(ProviderKind.ANTHROPIC, model, f"Anthropic/{model}")
59
+ if anthropic is None:
60
+ raise RuntimeError("anthropic package not available. Add to requirements.txt")
61
+ self.client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
62
+
63
+ def judge(self, system_prompt: str, user_prompt: str):
64
+ resp = self.client.messages.create(
65
+ model=self.model,
66
+ max_tokens=1000,
67
+ temperature=0,
68
+ system=system_prompt,
69
+ messages=[{"role": "user", "content": user_prompt}],
70
+ )
71
+ # Anthropic returns a list of content blocks
72
+ content = "".join([b.text for b in resp.content if hasattr(b, "text")])
73
+ usage = getattr(resp, "usage", None)
74
+ usage_dict = {
75
+ "prompt": getattr(usage, "input_tokens", 0) if usage else 0,
76
+ "completion": getattr(usage, "output_tokens", 0) if usage else 0,
77
+ "total": (getattr(usage, "input_tokens", 0) + getattr(usage, "output_tokens", 0)) if usage else 0,
78
+ }
79
+ return _safe_json(content), usage_dict
80
+
81
+
82
+ def get_provider(kind: str, model: str) -> Provider:
83
+ if kind == ProviderKind.OPENAI:
84
+ return OpenAIProvider(model)
85
+ elif kind == ProviderKind.ANTHROPIC:
86
+ return AnthropicProvider(model)
87
+ else:
88
+ raise ValueError(f"Unknown provider kind: {kind}")
89
+
90
+ # ----------------------
91
+ # Helpers
92
+ # ----------------------
93
+ JSON_RE = re.compile(r"\{[\s\S]*\}")
94
+
95
+ def _safe_json(s: str) -> dict:
96
+ try:
97
+ return json.loads(s)
98
+ except Exception:
99
+ m = JSON_RE.search(s or "")
100
+ if m:
101
+ try:
102
+ return json.loads(m.group(0))
103
+ except Exception:
104
+ pass
105
+ return {}