File size: 2,938 Bytes
6acbc8e
70629ed
3164c55
6acbc8e
 
8ebabda
e9b2bc1
 
 
 
 
 
3164c55
e9b2bc1
cf6722a
3164c55
 
6acbc8e
 
 
f82fb89
cf6722a
6acbc8e
 
cf6722a
f82fb89
cf6722a
 
 
 
 
 
e9b2bc1
cf6722a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937738a
 
 
 
 
 
 
 
 
 
 
 
6acbc8e
 
 
 
f82fb89
 
 
cf6722a
 
 
 
 
 
 
6acbc8e
 
 
 
 
 
 
 
cf6722a
263c778
6acbc8e
 
 
 
3164c55
 
 
70629ed
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from fastapi import APIRouter, HTTPException
from .utils.PoeBot import SendMessage, GenerateImage
from .Schemas import BotRequest
from aiohttp import ClientSession
from pydantic import BaseModel
import asyncio
from ballyregan.models import Protocols, Anonymities
from ballyregan import ProxyFetcher

# Setting the debug mode to True, defaults to False


chat_router = APIRouter(tags=["Chat"])
proxy = ""
proxies = []


class InputData(BaseModel):
    input: dict
    version: str = "727e49a643e999d602a896c774a0658ffefea21465756a6ce24b7ea4165eba6a"
    proxies: list[str] = []
    is_proxied: bool = False


async def fetch_predictions(data, is_proxied=False):
    global proxy, proxies
    if is_proxied:
        proxy_set = proxy != ""
        async with ClientSession() as session:
            for p in proxies:
                if proxy_set:
                    if p != proxy:
                        continue
                try:
                    async with session.post(
                        "https://replicate.com/api/predictions",
                        json=data,
                        timeout=5,
                    ) as response:
                        if str(response.status).startswith("4"):
                            continue
                        proxy = str(p)
                        temp = await response.json()
                        print(temp)
                        return temp
                except Exception as e:
                    print("Error fetching", e)
                    pass
            proxy = ""
    else:
        async with ClientSession() as session:
            try:
                async with session.post(
                    "https://replicate.com/api/predictions",
                    json=data,
                    timeout=5,
                ) as response:
                    temp = await response.json()
                    return temp
            except Exception as e:
                print("Error fetching", e)
                pass


@chat_router.post("/predictions")
async def get_predictions(input_data: InputData):
    global proxies
    if input_data.proxies != []:
        proxies = input_data.proxies
    else:

        proxies = [
            "http://51.89.14.70:80",
            "http://52.151.210.204:9002",
            "http://38.180.36.19:80",
        ]
    data = {
        "input": input_data.input,
        "is_training": False,
        "create_model": "0",
        "stream": False,
        "version": input_data.version,
    }
    try:
        predictions = await fetch_predictions(data, input_data.is_proxied)
        return predictions
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")


@chat_router.post("/chat")
async def chat(req: BotRequest):
    return await SendMessage(req)


@chat_router.post("/generate_image")
async def chat(req: BotRequest):
    return await GenerateImage(req)