AWS-CHATBOOT-SUPER / calculate_elo.py
jordonpeter01's picture
Duplicate from jordonpeter01/rlhf-arena-aws
d9076ee
import logging
import os
from datetime import datetime
from decimal import Decimal
from typing import List
import boto3
from boto3.dynamodb.conditions import Attr, Key
from datasets import Dataset
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
# Create a DynamoDB client
dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
def _create_arena_table():
dynamodb.create_table(
TableName='oaaic_chatbot_arena',
KeySchema=[
{
'AttributeName': 'arena_battle_id',
'KeyType': 'HASH'
},
],
AttributeDefinitions=[
{
'AttributeName': 'arena_battle_id',
'AttributeType': 'S'
},
{
'AttributeName': 'timestamp',
'AttributeType': 'S'
},
],
ProvisionedThroughput={
'ReadCapacityUnits': 5,
'WriteCapacityUnits': 5
},
GlobalSecondaryIndexes=[
{
'IndexName': 'TimestampIndex',
'KeySchema': [
{
'AttributeName': 'arena_battle_id',
'KeyType': 'HASH'
},
{
'AttributeName': 'timestamp',
'KeyType': 'RANGE'
},
],
'Projection': {
'ProjectionType': 'ALL',
},
'ProvisionedThroughput': {
'ReadCapacityUnits': 5,
'WriteCapacityUnits': 5,
}
},
]
)
def _create_elo_scores_table():
dynamodb.create_table(
TableName='elo_scores',
KeySchema=[
{
'AttributeName': 'chatbot_name',
'KeyType': 'HASH' # Partition key
},
],
AttributeDefinitions=[
{
'AttributeName': 'chatbot_name',
'AttributeType': 'S'
},
],
ProvisionedThroughput={
'ReadCapacityUnits': 5,
'WriteCapacityUnits': 5
}
)
def _create_elo_logs_table():
dynamodb.create_table(
TableName='elo_logs',
KeySchema=[
{
'AttributeName': 'arena_battle_id',
'KeyType': 'HASH' # Partition key
},
{
'AttributeName': 'battle_timestamp',
'KeyType': 'RANGE' # Sort key
},
],
AttributeDefinitions=[
{
'AttributeName': 'arena_battle_id',
'AttributeType': 'S'
},
{
'AttributeName': 'battle_timestamp',
'AttributeType': 'S'
},
{
'AttributeName': 'all',
'AttributeType': 'S'
}
],
ProvisionedThroughput={
'ReadCapacityUnits': 10,
'WriteCapacityUnits': 10
},
GlobalSecondaryIndexes=[
{
'IndexName': 'AllTimestampIndex',
'KeySchema': [
{
'AttributeName': 'all',
'KeyType': 'HASH' # Partition key for the GSI
},
{
'AttributeName': 'battle_timestamp',
'KeyType': 'RANGE' # Sort key for the GSI
}
],
'Projection': {
'ProjectionType': 'ALL'
},
'ProvisionedThroughput': {
'ReadCapacityUnits': 10,
'WriteCapacityUnits': 10
}
},
]
)
def get_unprocessed_battles(last_processed_timestamp):
# Use boto3 to create a DynamoDB resource and reference the table
table = dynamodb.Table('oaaic_chatbot_arena')
# Use a query to retrieve unprocessed battles in temporal order
response = table.scan(
FilterExpression=Attr('timestamp').gt(last_processed_timestamp),
# ScanIndexForward=True
)
return response['Items']
def calculate_elo(rating1, rating2, result, K=32):
# Convert ratings to float
rating1 = float(rating1)
rating2 = float(rating2)
# Calculate the expected outcomes
expected_outcome1 = 1.0 / (1.0 + 10.0 ** ((rating2 - rating1) / 400.0))
expected_outcome2 = 1.0 - expected_outcome1
# Calculate the new Elo ratings
new_rating1 = rating1 + K * (result - expected_outcome1)
new_rating2 = rating2 + K * ((1.0 - result) - expected_outcome2)
return Decimal(new_rating1).quantize(Decimal('0.00')), Decimal(new_rating2).quantize(Decimal('0.00'))
def get_last_processed_timestamp():
table = dynamodb.Table('elo_logs')
# Scan the table sorted by timestamp in descending order
response = table.query(
IndexName='AllTimestampIndex',
KeyConditionExpression=Key('all').eq('ALL'),
ScanIndexForward=False,
Limit=1
)
# If there are no items in the table, return a default timestamp
if not response['Items']:
return '1970-01-01T00:00:00'
# Otherwise, return the timestamp of the latest item
return response['Items'][0]['battle_timestamp']
def log_elo_update(arena_battle_id, battle_timestamp, new_rating1, new_rating2):
# Reference the elo_logs table
table = dynamodb.Table('elo_logs')
# Update the table
table.put_item(
Item={
'arena_battle_id': arena_battle_id,
'battle_timestamp': battle_timestamp, # Use the timestamp of the battle
'log_timestamp': datetime.now().isoformat(), # Also store the timestamp of the log for completeness
'new_rating1': new_rating1,
'new_rating2': new_rating2,
'all': 'ALL',
}
)
def get_elo_score(chatbot_name, elo_scores):
if chatbot_name in elo_scores:
return elo_scores[chatbot_name]
table = dynamodb.Table('elo_scores')
response = table.get_item(Key={'chatbot_name': chatbot_name})
# If there is no item in the table, return a default score
if 'Item' not in response:
return 1500
return response['Item']['elo_score']
def update_elo_score(chatbot_name, new_elo_score):
table = dynamodb.Table('elo_scores')
# This will create a new item if it doesn't exist
table.put_item(
Item={
'chatbot_name': chatbot_name,
'elo_score': Decimal(str(new_elo_score)),
}
)
def get_elo_scores():
table = dynamodb.Table('elo_scores')
response = table.scan()
data = response['Items']
return data
def _backfill_logs():
table = dynamodb.Table('elo_logs')
# Initialize the scan operation
response = table.scan()
for item in response['Items']:
table.update_item(
Key={
'arena_battle_id': item['arena_battle_id'],
'battle_timestamp': item['battle_timestamp']
},
UpdateExpression="SET #all = :value",
ExpressionAttributeNames={
'#all': 'all'
},
ExpressionAttributeValues={
':value': 'ALL'
}
)
def main():
last_processed_timestamp = get_last_processed_timestamp()
battles: List[dict] = get_unprocessed_battles(last_processed_timestamp)
battles = sorted(battles, key=lambda x: x['timestamp'])
elo_scores = {}
for battle in battles:
print(repr(battle))
if battle['label'] in {-1, 0, 1, 2}:
outcome = battle['label']
for chatbot_name in [battle['choice1_name'], battle['choice2_name']]:
if chatbot_name not in elo_scores:
elo_scores[chatbot_name] = get_elo_score(chatbot_name, elo_scores)
# 1: This means that the first player (or team) won the match.
# 0.5: This means that the match ended in a draw.
# 0: This means that the first player (or team) lost the match.
if outcome == 0 or outcome == -1:
elo_result = 0.5
elif outcome == 1:
elo_result = 1
else:
elo_result = 0
new_rating1, new_rating2 = calculate_elo(elo_scores[battle['choice1_name']], elo_scores[battle['choice2_name']], elo_result)
logging.info(f"{battle['choice1_name']}: {elo_scores[battle['choice1_name']]} -> {new_rating1} | {battle['choice2_name']}: {elo_scores[battle['choice2_name']]} -> {new_rating2}")
elo_scores[battle['choice1_name']] = new_rating1
elo_scores[battle['choice2_name']] = new_rating2
log_elo_update(battle['arena_battle_id'], battle['timestamp'], new_rating1, new_rating2)
update_elo_score(battle['choice1_name'], new_rating1)
update_elo_score(battle['choice2_name'], new_rating2)
elo_scores[battle['choice1_name']] = new_rating1
elo_scores[battle['choice2_name']] = new_rating2
elo_scores = get_elo_scores()
for i, j in enumerate(elo_scores):
j["elo_score"] = float(j["elo_score"])
elo_scores[i] = j
print(elo_scores)
if battles:
# Convert the data into a format suitable for Hugging Face Dataset
elo_dataset = Dataset.from_list(elo_scores)
elo_dataset.push_to_hub("openaccess-ai-collective/chatbot-arena-elo-scores", private=False)
if __name__ == "__main__":
main()