Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request | |
from fastapi.responses import HTMLResponse | |
from fastapi.templating import Jinja2Templates | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import asyncio | |
import logging | |
import httpx | |
import torch | |
# Initialize FastAPI | |
app = FastAPI() | |
# Logging setup | |
logging.basicConfig(level=logging.INFO) | |
# Telegram Token and Chat ID (Replace with your actual values) | |
TELEGRAM_TOKEN = "7437859619:AAGeGG3ZkLM0OVaw-Exx1uMRE55JtBCZZCY" | |
CHAT_ID = "-1002228627548" | |
# Templating setup | |
templates = Jinja2Templates(directory="templates") | |
# WebSocket Manager | |
class WebSocketManager: | |
def __init__(self): | |
self.active_connection: WebSocket = None | |
async def connect(self, websocket: WebSocket): | |
"""Connects the WebSocket""" | |
await websocket.accept() | |
self.active_connection = websocket | |
logging.info("WebSocket connected.") | |
async def disconnect(self): | |
"""Disconnects the WebSocket""" | |
if self.active_connection: | |
await self.active_connection.close() | |
self.active_connection = None | |
logging.info("WebSocket disconnected.") | |
async def send_message(self, message: str): | |
"""Sends a message through WebSocket""" | |
if self.active_connection: | |
await self.active_connection.send_text(message) | |
logging.info(f"Sent via WebSocket: {message}") | |
websocket_manager = WebSocketManager() | |
# BLOOM Model Manager | |
class BloomAI: | |
def __init__(self): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.pipeline = None | |
def load_model(self): | |
"""Loads BLOOM AI Model""" | |
logging.info("Loading BLOOM model...") | |
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") | |
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") | |
self.pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
logging.info("BLOOM model loaded successfully.") | |
async def generate_response(self, prompt: str) -> str: | |
"""Generates a response using BLOOM""" | |
if not prompt.strip(): | |
return "⚠️ Please send a valid message." | |
logging.info(f"Generating response for prompt: {prompt}") | |
outputs = self.pipeline( | |
prompt, | |
max_length=100, | |
do_sample=True, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.9, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2 | |
) | |
response = outputs[0]["generated_text"] | |
return response.strip() | |
# Initialize BLOOM | |
bloom_ai = BloomAI() | |
bloom_ai.load_model() | |
# Telegram Message Handling | |
async def send_telegram_message(text: str): | |
"""Sends a message to Telegram""" | |
async with httpx.AsyncClient() as client: | |
url = f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/sendMessage" | |
payload = {"chat_id": CHAT_ID, "text": text} | |
response = await client.post(url, json=payload) | |
if response.status_code == 200: | |
logging.info(f"Sent to Telegram: {text}") | |
else: | |
logging.error(f"Failed to send message to Telegram: {response.text}") | |
async def telegram_webhook(update: dict): | |
"""Handles Telegram Webhook messages""" | |
if "message" in update: | |
chat_id = str(update["message"]["chat"]["id"]) | |
if chat_id != CHAT_ID: | |
return {"status": "Unauthorized"} | |
user_message = update["message"]["text"] | |
logging.info(f"Received from Telegram: {user_message}") | |
# Process the message | |
response = await bloom_ai.generate_response(user_message) | |
await send_telegram_message(response) | |
return {"status": "ok"} | |
# WebSocket Endpoint | |
async def websocket_endpoint(websocket: WebSocket): | |
"""WebSocket communication for real-time interaction""" | |
await websocket_manager.connect(websocket) | |
try: | |
while True: | |
# Receive message from WebSocket | |
data = await websocket.receive_text() | |
logging.info(f"Received from WebSocket: {data}") | |
# Process the message | |
response = await bloom_ai.generate_response(data) | |
# Send response back through WebSocket | |
await websocket_manager.send_message(response) | |
except WebSocketDisconnect: | |
# Handle WebSocket disconnection | |
await websocket_manager.disconnect() | |
# HTML Test UI | |
async def get_ui(request: Request): | |
"""Displays the WebSocket HTML UI""" | |
return templates.TemplateResponse("index.html", {"request": request}) | |
# Simple UI (fallback in case templates folder is not available) | |
async def simple_ui(): | |
"""Fallback HTML for WebSocket Test""" | |
return HTMLResponse(content=""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>WebSocket Test</title> | |
<script> | |
let ws = new WebSocket("ws://localhost:8000/ws"); | |
ws.onopen = () => { | |
console.log("WebSocket connection opened."); | |
}; | |
ws.onmessage = (event) => { | |
console.log("Message from server:", event.data); | |
const msgContainer = document.getElementById("messages"); | |
const msg = document.createElement("div"); | |
msg.innerText = event.data; | |
msgContainer.appendChild(msg); | |
}; | |
ws.onclose = () => { | |
console.log("WebSocket connection closed."); | |
}; | |
function sendMessage() { | |
const input = document.getElementById("messageInput"); | |
const message = input.value; | |
ws.send(message); | |
input.value = ""; | |
} | |
</script> | |
</head> | |
<body> | |
<h1>WebSocket Test</h1> | |
<div id="messages" style="border: 1px solid black; height: 200px; overflow-y: scroll;"></div> | |
<input id="messageInput" type="text" /> | |
<button onclick="sendMessage()">Send</button> | |
</body> | |
</html> | |
""") | |