Really-amin's picture
Update app.py
56a6340 verified
raw
history blame
6.32 kB
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import asyncio
import logging
import httpx
import torch
# Initialize FastAPI
app = FastAPI()
# Logging setup
logging.basicConfig(level=logging.INFO)
# Telegram Token and Chat ID (Replace with your actual values)
TELEGRAM_TOKEN = "7437859619:AAGeGG3ZkLM0OVaw-Exx1uMRE55JtBCZZCY"
CHAT_ID = "-1002228627548"
# Templating setup
templates = Jinja2Templates(directory="templates")
# WebSocket Manager
class WebSocketManager:
def __init__(self):
self.active_connection: WebSocket = None
async def connect(self, websocket: WebSocket):
"""Connects the WebSocket"""
await websocket.accept()
self.active_connection = websocket
logging.info("WebSocket connected.")
async def disconnect(self):
"""Disconnects the WebSocket"""
if self.active_connection:
await self.active_connection.close()
self.active_connection = None
logging.info("WebSocket disconnected.")
async def send_message(self, message: str):
"""Sends a message through WebSocket"""
if self.active_connection:
await self.active_connection.send_text(message)
logging.info(f"Sent via WebSocket: {message}")
websocket_manager = WebSocketManager()
# BLOOM Model Manager
class BloomAI:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.pipeline = None
def load_model(self):
"""Loads BLOOM AI Model"""
logging.info("Loading BLOOM model...")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
self.pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
logging.info("BLOOM model loaded successfully.")
async def generate_response(self, prompt: str) -> str:
"""Generates a response using BLOOM"""
if not prompt.strip():
return "⚠️ Please send a valid message."
logging.info(f"Generating response for prompt: {prompt}")
outputs = self.pipeline(
prompt,
max_length=100,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.9,
num_return_sequences=1,
no_repeat_ngram_size=2
)
response = outputs[0]["generated_text"]
return response.strip()
# Initialize BLOOM
bloom_ai = BloomAI()
bloom_ai.load_model()
# Telegram Message Handling
async def send_telegram_message(text: str):
"""Sends a message to Telegram"""
async with httpx.AsyncClient() as client:
url = f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/sendMessage"
payload = {"chat_id": CHAT_ID, "text": text}
response = await client.post(url, json=payload)
if response.status_code == 200:
logging.info(f"Sent to Telegram: {text}")
else:
logging.error(f"Failed to send message to Telegram: {response.text}")
@app.post("/telegram")
async def telegram_webhook(update: dict):
"""Handles Telegram Webhook messages"""
if "message" in update:
chat_id = str(update["message"]["chat"]["id"])
if chat_id != CHAT_ID:
return {"status": "Unauthorized"}
user_message = update["message"]["text"]
logging.info(f"Received from Telegram: {user_message}")
# Process the message
response = await bloom_ai.generate_response(user_message)
await send_telegram_message(response)
return {"status": "ok"}
# WebSocket Endpoint
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket communication for real-time interaction"""
await websocket_manager.connect(websocket)
try:
while True:
# Receive message from WebSocket
data = await websocket.receive_text()
logging.info(f"Received from WebSocket: {data}")
# Process the message
response = await bloom_ai.generate_response(data)
# Send response back through WebSocket
await websocket_manager.send_message(response)
except WebSocketDisconnect:
# Handle WebSocket disconnection
await websocket_manager.disconnect()
# HTML Test UI
@app.get("/")
async def get_ui(request: Request):
"""Displays the WebSocket HTML UI"""
return templates.TemplateResponse("index.html", {"request": request})
# Simple UI (fallback in case templates folder is not available)
@app.get("/simple-ui")
async def simple_ui():
"""Fallback HTML for WebSocket Test"""
return HTMLResponse(content="""
<!DOCTYPE html>
<html>
<head>
<title>WebSocket Test</title>
<script>
let ws = new WebSocket("ws://localhost:8000/ws");
ws.onopen = () => {
console.log("WebSocket connection opened.");
};
ws.onmessage = (event) => {
console.log("Message from server:", event.data);
const msgContainer = document.getElementById("messages");
const msg = document.createElement("div");
msg.innerText = event.data;
msgContainer.appendChild(msg);
};
ws.onclose = () => {
console.log("WebSocket connection closed.");
};
function sendMessage() {
const input = document.getElementById("messageInput");
const message = input.value;
ws.send(message);
input.value = "";
}
</script>
</head>
<body>
<h1>WebSocket Test</h1>
<div id="messages" style="border: 1px solid black; height: 200px; overflow-y: scroll;"></div>
<input id="messageInput" type="text" />
<button onclick="sendMessage()">Send</button>
</body>
</html>
""")