import os import logging import openai import tiktoken import re import anthropic import cohere import google.generativeai as genai import time from file_utils import read_file from openai import OpenAI class Paper: def __init__(self, arxiv_id, tex_file): self.arxiv_id = arxiv_id self.tex_file = tex_file class PaperProcessor: MAX_TOKENS = 127192 encoding = tiktoken.encoding_for_model("gpt-4-0125-preview") def __init__(self, prompt_dir, model, openai_api_key, claude_api_key, gemini_api_key, commandr_api_key): self.prompt_dir = prompt_dir self.model = model self.openai_api_key = openai_api_key self.claude_api_key = claude_api_key self.gemini_api_key = gemini_api_key self.commandr_api_key = commandr_api_key def count_tokens(self, text): return len(self.encoding.encode(text)) def truncate_content(self, content): token_count = self.count_tokens(content) logging.debug(f"Token count before truncation: {token_count}") if token_count > self.MAX_TOKENS: tokens = self.encoding.encode(content) truncated_tokens = tokens[:self.MAX_TOKENS] truncated_content = self.encoding.decode(truncated_tokens) logging.debug(f"Content truncated. Token count after truncation: {self.count_tokens(truncated_content)}") return truncated_content return content def prepare_base_prompt(self, paper): return paper.tex_file def call_model(self, prompt, model_type): system_role_file_path = os.path.join(self.prompt_dir, "systemrole.txt") if not os.path.exists(system_role_file_path): logging.error(f"System role file not found: {system_role_file_path}") return None system_role = read_file(system_role_file_path) logging.debug(f"Token count of full prompt: {self.count_tokens(prompt)}") logging.info(f"Sending the following prompt to {model_type}: {prompt}") try: if model_type == 'gpt-4-turbo-2024-04-09': client = OpenAI(api_key=self.openai_api_key) messages = [{"role": "system", "content": system_role}, {"role": "user", "content": prompt}] completion = client.chat.completions.create( model="gpt-4-turbo-2024-04-09", messages=messages, temperature=1 ) print(completion) return completion.choices[0].message.content.strip() elif model_type == 'gpt-4o': client = OpenAI(api_key=self.openai_api_key) messages = [{"role": "system", "content": system_role}, {"role": "user", "content": prompt}] completion = client.chat.completions.create( model="gpt-4o", messages=messages, temperature=1 ) print(completion) return completion.choices[0].message.content.strip() elif model_type == 'claude-3-opus-20240229': client = anthropic.Anthropic(api_key=self.claude_api_key) response = client.messages.create( model='claude-3-opus-20240229', max_tokens=4096, system=system_role, temperature=0.5, messages=[{"role": "user", "content": prompt}] ) print(response) return response.content[0].text elif model_type == 'command-r-plus': co = cohere.Client(self.commandr_api_key) response = co.chat( model="command-r-plus", message=prompt, preamble=system_role ) print(response) return response.text elif model_type == 'gemini-pro': genai.configure(api_key=self.gemini_api_key) model = genai.GenerativeModel('gemini-pro') response = model.generate_content(prompt) print(response) return response.candidates[0].content.parts[0].text except Exception as e: logging.error(f"Exception occurred: {e}") print(e) return None def is_content_appropriate(self, content): try: response = openai.moderations.create(input=content) return not response["results"][0]["flagged"] except Exception as e: logging.error(f"Exception occurred while checking content appropriateness: {e}") return True # In case of an error, default to content being appropriate def get_prompt_files(self, prompt_dir): return [f for f in os.listdir(prompt_dir) if f.endswith('.txt') and f.startswith('question')] def process_paper(self, paper): openai.api_key = self.openai_api_key start_time = time.time() base_prompt = self.prepare_base_prompt(paper) if base_prompt is None: return "Error: Base prompt could not be prepared." moderation_response = openai.moderations.create(input=base_prompt) if moderation_response.results[0].flagged: return ["Desk Rejected", "The paper contains inappropriate or harmful content."] review_output = [] previous_responses = [] header = ['Summary:', 'Soundness:', 'Presentation:', 'Contribution:', 'Strengths:', 'Weaknesses:', 'Questions:', 'Flag For Ethics Review:', 'Rating:', 'Confidence:', 'Code Of Conduct:'] for i in range(1, 12): question_file = os.path.join(self.prompt_dir, f"question{i}.txt") question_text = read_file(question_file) if i == 1: prompt = f"{question_text}\n\n####\n{base_prompt}\n####" else: prompt = f"\nHere is your review so far:\n{' '.join(previous_responses)}\n\nHere are your reviewer instructions. Please answer the following question:\n{question_text}" truncated_prompt = self.truncate_content(prompt) logging.info(f"Processing prompt for question {i}") response = self.call_model(truncated_prompt, self.model) if response is None: response = "N/A" if i in [2, 3, 4, 10]: number_match = re.search(r'\b\d+\b', response) if number_match: number = int(number_match.group(0)) response = '5/5' if number > 5 else number_match.group(0) + '/5' elif i == 9: number_match = re.search(r'\b\d+\b', response) if number_match: response = number_match.group(0) + '/10' response_with_header = f"{header[i-1]} {response}" review_output.append(response_with_header) previous_responses.append(response) end_time = time.time() elapsed_time = end_time - start_time print(f"Time taken to process paper: {elapsed_time:.2f} seconds") return review_output