Petr Tsvetkov
Add edit distance and edit time metrics; add GPT-based metric
f5faae7
raw
history blame
No virus
1.82 kB
import pickle
import time
from grazie.api.client.chat.prompt import ChatPrompt
from grazie.api.client.endpoints import GrazieApiGatewayUrls
from grazie.api.client.gateway import GrazieApiGatewayClient, GrazieAgent, AuthType
from grazie.api.client.profiles import LLMProfile
import config
client = GrazieApiGatewayClient(
grazie_agent=GrazieAgent(name="commit-rewriting-synthetic-end-to-start", version="dev"),
url=GrazieApiGatewayUrls.STAGING,
auth_type=AuthType.USER,
grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
)
LLM_CACHE_FILE = config.CACHE_DIR / f"{config.LLM_MODEL}.cache.pkl"
LLM_CACHE = {}
LLM_CACHE_USED = {}
if not LLM_CACHE_FILE.exists():
with open(LLM_CACHE_FILE, "wb") as file:
pickle.dump(obj=LLM_CACHE, file=file)
with open(LLM_CACHE_FILE, "rb") as file:
LLM_CACHE = pickle.load(file=file)
def llm_request(prompt):
output = None
while output is None:
try:
output = output = client.chat(
chat=ChatPrompt()
.add_system("You are a helpful assistant.")
.add_user(prompt),
profile=LLMProfile(config.LLM_MODEL)
).content
except:
time.sleep(config.GRAZIE_TIMEOUT_SEC)
assert output is not None
return output
def generate_for_prompt(prompt):
if prompt not in LLM_CACHE:
LLM_CACHE[prompt] = []
if prompt not in LLM_CACHE_USED:
LLM_CACHE_USED[prompt] = 0
while LLM_CACHE_USED[prompt] >= len(LLM_CACHE[prompt]):
new_response = llm_request(prompt)
LLM_CACHE[prompt].append(new_response)
with open(LLM_CACHE_FILE, "wb") as file:
pickle.dump(obj=LLM_CACHE, file=file)
result = LLM_CACHE[prompt][LLM_CACHE_USED[prompt]]
LLM_CACHE_USED[prompt] += 1
return result