Spaces:
Paused
Paused
import os | |
import uuid | |
from typing import Dict, Optional | |
from fastapi import FastAPI, HTTPException | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig # Import BitsAndBytesConfig | |
import torch | |
from pydantic import BaseModel | |
import traceback | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.chains import ConversationChain | |
from langchain.prompts import PromptTemplate | |
from starlette.responses import StreamingResponse | |
import asyncio | |
import json | |
from langchain_community.llms import HuggingFacePipeline | |
import uvicorn | |
from huggingface_hub import login | |
app = FastAPI() | |
# Get the Hugging Face API token from environment variables (BEST PRACTICE) | |
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
if HUGGINGFACEHUB_API_TOKEN is None: | |
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.") | |
# --- Explicitly log in to Hugging Face Hub --- | |
try: | |
login(token=HUGGINGFACEHUB_API_TOKEN) | |
print("Successfully logged into Hugging Face Hub.") | |
except Exception as e: | |
print(f"Failed to log into Hugging Face Hub: {e}") | |
# --- Initialize tokenizer and model globally (heavy to load, shared across sessions) --- | |
model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
# --- NEW: Quantization configuration for 4-bit loading, optimized for T4 --- | |
# This configuration tells Hugging Face Transformers to load the model weights | |
# in 4-bit precision using the bitsandbytes library. | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, # Enable 4-bit quantization | |
bnb_4bit_quant_type="nf4", # Specify the quantization type: "nf4" (NormalFloat 4-bit) is recommended for transformers | |
# --- IMPORTANT CHANGE: Use float16 for compute dtype for T4 compatibility --- | |
# T4 GPUs (Turing architecture) do not have native bfloat16 support. | |
# Using float16 for computations is more efficient and prevents CPU offloading. | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, # Use double quantization for slightly better quality | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACEHUB_API_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", # 'auto' handles device placement, including offloading to CPU if necessary (but quantization aims to prevent this) | |
quantization_config=bnb_config, # Pass the quantization configuration here | |
# torch_dtype=torch.bfloat16, # REMOVED: This is now handled by bnb_4bit_compute_dtype | |
trust_remote_code=True, | |
token=HUGGINGFACEHUB_API_TOKEN | |
) | |
# Global dictionary to store active conversation chains, keyed by session_id. | |
# IMPORTANT: In a production environment, this in-memory dictionary will reset | |
# if the server restarts. For true persistence, you would use a database (e.g., Redis, Firestore). | |
active_conversations: Dict[str, ConversationChain] = {} | |
# --- UPDATED PROMPT TEMPLATE --- | |
template = """<|im_start|>system | |
You are a concise and direct AI assistant named Siddhi. | |
You strictly avoid asking any follow-up questions. | |
You do not generate any additional conversational turns (e.g., "Human: ..."). | |
If asked for your name, you respond with "I am Siddhi." | |
If you do not know the answer to a question, you truthfully state that it does not know. | |
<|im_end|> | |
<|im_start|>user | |
{history} | |
{input}<|im_end|> | |
<|im_start|>assistant | |
""" | |
PROMPT = PromptTemplate(input_variables=["history", "input"], template=template) | |
class QuestionRequest(BaseModel): | |
question: str | |
session_id: Optional[str] = None # Optional session ID for continuing conversations | |
class ChatResponse(BaseModel): | |
response: str | |
session_id: str # Include session_id in the response for client to track | |
async def generate_text(request: QuestionRequest): | |
""" | |
Handles text generation requests, maintaining conversation history per session. | |
""" | |
session_id = request.session_id | |
# If no session_id is provided, generate a new one. | |
# This signifies the start of a new conversation. | |
if session_id is None: | |
session_id = str(uuid.uuid4()) | |
print(f"Starting new conversation with session_id: {session_id}") | |
# Retrieve or create a ConversationChain for this session_id | |
if session_id not in active_conversations: | |
print(f"Creating new ConversationChain for session_id: {session_id}") | |
# Initialize Langchain HuggingFacePipeline for this session | |
llm = HuggingFacePipeline(pipeline=pipeline( | |
"text-generation", | |
model=model, # Use the globally loaded model | |
tokenizer=tokenizer, # Use the globally loaded tokenizer | |
max_new_tokens=512, | |
return_full_text=True, | |
temperature=0.2, | |
do_sample=True, | |
)) | |
# Initialize memory for this specific session | |
memory = ConversationBufferWindowMemory(k=5) # Remembers the last 5 human-AI interaction pairs | |
conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True) | |
active_conversations[session_id] = conversation | |
else: | |
print(f"Continuing conversation for session_id: {session_id}") | |
conversation = active_conversations[session_id] | |
async def generate_stream(): | |
""" | |
An asynchronous generator function to stream text responses token-by-token. | |
Each yielded item will be a JSON string representing a part of the stream. | |
""" | |
# Flag to indicate when we've started streaming the AI's actual response | |
started_streaming_ai_response = False | |
try: | |
# First, send a JSON object containing the session_id. | |
# This allows the client to immediately get the session ID. | |
yield json.dumps({"type": "session_info", "session_id": session_id}) + "\n" | |
response_stream = conversation.stream({"input": request.question}) | |
stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"] | |
assistant_start_marker = "<|im_start|>assistant\n" | |
for chunk in response_stream: | |
full_text_chunk = "" | |
if 'response' in chunk: | |
full_text_chunk = chunk['response'] | |
else: | |
full_text_chunk = str(chunk) | |
if not started_streaming_ai_response: | |
if assistant_start_marker in full_text_chunk: | |
token_content = full_text_chunk.split(assistant_start_marker, 1)[1] | |
started_streaming_ai_response = True | |
else: | |
token_content = "" | |
else: | |
token_content = full_text_chunk | |
for stop_seq in stop_sequences_to_check: | |
if stop_seq in token_content: | |
token_content = token_content.split(stop_seq, 1)[0] | |
if token_content: | |
yield json.dumps({"type": "token", "content": token_content}) + "\n" | |
await asyncio.sleep(0.01) | |
yield json.dumps({"type": "end", "status": "completed", "session_id": session_id}) + "\n" | |
return | |
if token_content: | |
yield json.dumps({"type": "token", "content": token_content}) + "\n" | |
await asyncio.sleep(0.01) | |
yield json.dumps({"type": "end", "status": "completed", "session_id": session_id}) + "\n" | |
except Exception as e: | |
print(f"Error during streaming generation for session {session_id}:") | |
traceback.print_exc() | |
yield json.dumps({"type": "error", "message": str(e), "session_id": session_id}) + "\n" | |
# Return a StreamingResponse with application/json media type | |
return StreamingResponse(generate_stream(), media_type="application/json") | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) | |