File size: 4,499 Bytes
640d768 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from typing import Optional, Dict, List, Any
import os
import requests
import json
from dotenv import load_dotenv
from dataclasses import dataclass
load_dotenv()
@dataclass
class GeminiResponse:
content: str
class GeminiProvider:
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent"
def chat(self, messages: List[Dict[str, Any]]) -> GeminiResponse:
# Convert messages to Gemini format
gemini_messages = []
for msg in messages:
# Handle both dict and LangChain message objects
if isinstance(msg, BaseMessage):
role = "user" if isinstance(msg, HumanMessage) else "model"
content = msg.content
else:
role = "user" if msg["role"] == "human" else "model"
content = msg["content"]
gemini_messages.append({
"role": role,
"parts": [{"text": content}]
})
# Prepare the request
headers = {
"Content-Type": "application/json"
}
params = {
"key": self.api_key
}
data = {
"contents": gemini_messages,
"generationConfig": {
"temperature": 0.7,
"topP": 0.8,
"topK": 40,
"maxOutputTokens": 2048,
}
}
try:
response = requests.post(
self.base_url,
headers=headers,
params=params,
json=data,
verify='C:\\ProgramData\\NetFree\\CA\\netfree-ca-bundle-curl.crt'
)
response.raise_for_status()
result = response.json()
if "candidates" in result and len(result["candidates"]) > 0:
return GeminiResponse(content=result["candidates"][0]["content"]["parts"][0]["text"])
else:
raise Exception("No response generated")
except Exception as e:
raise Exception(f"Error calling Gemini API: {str(e)}")
def invoke(self, messages: List[BaseMessage], **kwargs) -> GeminiResponse:
return self.chat(messages)
def generate(self, prompts, **kwargs) -> GeminiResponse:
if isinstance(prompts, str):
return self.invoke([HumanMessage(content=prompts)])
elif isinstance(prompts, list):
return self.invoke([HumanMessage(content=prompts[0])])
raise ValueError("Unsupported prompt format")
class LLMProvider:
def __init__(self):
self.providers: Dict[str, Any] = {}
self._setup_providers()
def _setup_providers(self):
os.environ['REQUESTS_CA_BUNDLE'] = 'C:\\ProgramData\\NetFree\\CA\\netfree-ca-bundle-curl.crt'
# Google Gemini
if google_key := os.getenv('GOOGLE_API_KEY'):
self.providers['Gemini'] = GeminiProvider(api_key=google_key)
# Anthropic
if anthropic_key := os.getenv('ANTHROPIC_API_KEY'):
self.providers['Claude'] = ChatAnthropic(
api_key=anthropic_key,
model_name="claude-3-5-sonnet-20241022",
)
# OpenAI
if openai_key := os.getenv('OPENAI_API_KEY'):
self.providers['ChatGPT'] = ChatOpenAI(
api_key=openai_key,
model_name="gpt-4o-2024-11-20"
)
# Ollama (local)
try:
self.providers['Ollama-dictalm2.0'] = ChatOllama(model="dictaLM")
except Exception:
pass # Ollama not available
def get_available_providers(self) -> list[str]:
"""Return list of available provider names"""
return list(self.providers.keys())
def get_provider(self, name: str) -> Optional[Any]:
"""Get LLM provider by name"""
return self.providers.get(name)
|