File size: 8,052 Bytes
5601c60
dca8b66
 
5601c60
dca8b66
5601c60
 
 
58966a1
5601c60
89183a0
d00f229
 
 
5343cd4
d00f229
7d7d860
5601c60
 
d00f229
5601c60
 
 
 
 
 
81d2ef5
7d7d860
 
 
 
 
 
dca8b66
73ab258
5601c60
dca8b66
 
 
 
 
 
 
 
 
 
 
 
 
a23c36a
c1073c4
d00f229
dca8b66
 
 
d00f229
81d2ef5
d00f229
5601c60
dca8b66
 
 
 
d00f229
 
 
 
 
 
 
 
 
 
89183a0
d00f229
 
 
89183a0
 
 
5601c60
 
dca8b66
5601c60
 
 
dca8b66
5601c60
 
 
dca8b66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a05ac69
dca8b66
 
 
 
 
d00f229
 
a05ac69
dca8b66
 
 
 
a05ac69
 
d00f229
0242952
d00f229
a05ac69
d00f229
20960a5
d00f229
20960a5
0242952
d00f229
 
 
 
 
 
 
 
 
 
 
 
0242952
 
dca8b66
d00f229
dca8b66
0242952
d00f229
 
dca8b66
d00f229
 
dca8b66
a05ac69
 
dca8b66
a05ac69
dca8b66
a05ac69
dca8b66
34826da
a05ac69
d00f229
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import uuid 
from typing import Dict, Optional
from fastapi import FastAPI, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig # Import BitsAndBytesConfig
import torch
from pydantic import BaseModel
import traceback
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain
from langchain.prompts import PromptTemplate
from starlette.responses import StreamingResponse
import asyncio
import json
from langchain_community.llms import HuggingFacePipeline 
import uvicorn
from huggingface_hub import login

app = FastAPI()

# Get the Hugging Face API token from environment variables (BEST PRACTICE)
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")

if HUGGINGFACEHUB_API_TOKEN is None:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")

# --- Explicitly log in to Hugging Face Hub ---
try:
    login(token=HUGGINGFACEHUB_API_TOKEN)
    print("Successfully logged into Hugging Face Hub.")
except Exception as e:
    print(f"Failed to log into Hugging Face Hub: {e}")

# --- Initialize tokenizer and model globally (heavy to load, shared across sessions) ---
model_id = "mistralai/Mistral-7B-Instruct-v0.3"

# --- NEW: Quantization configuration for 4-bit loading, optimized for T4 ---
# This configuration tells Hugging Face Transformers to load the model weights
# in 4-bit precision using the bitsandbytes library.
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, # Enable 4-bit quantization
    bnb_4bit_quant_type="nf4", # Specify the quantization type: "nf4" (NormalFloat 4-bit) is recommended for transformers
    # --- IMPORTANT CHANGE: Use float16 for compute dtype for T4 compatibility ---
    # T4 GPUs (Turing architecture) do not have native bfloat16 support.
    # Using float16 for computations is more efficient and prevents CPU offloading.
    bnb_4bit_compute_dtype=torch.float16, 
    bnb_4bit_use_double_quant=True, # Use double quantization for slightly better quality
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACEHUB_API_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto", # 'auto' handles device placement, including offloading to CPU if necessary (but quantization aims to prevent this)
    quantization_config=bnb_config, # Pass the quantization configuration here
    # torch_dtype=torch.bfloat16, # REMOVED: This is now handled by bnb_4bit_compute_dtype
    trust_remote_code=True,
    token=HUGGINGFACEHUB_API_TOKEN 
)

# Global dictionary to store active conversation chains, keyed by session_id.
# IMPORTANT: In a production environment, this in-memory dictionary will reset
# if the server restarts. For true persistence, you would use a database (e.g., Redis, Firestore).
active_conversations: Dict[str, ConversationChain] = {}

# --- UPDATED PROMPT TEMPLATE ---
template = """<|im_start|>system
You are a concise and direct AI assistant named Siddhi.
You strictly avoid asking any follow-up questions.
You do not generate any additional conversational turns (e.g., "Human: ...").
If asked for your name, you respond with "I am Siddhi."
If you do not know the answer to a question, you truthfully state that it does not know.
<|im_end|>
<|im_start|>user
{history}
{input}<|im_end|>
<|im_start|>assistant
"""

PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)

class QuestionRequest(BaseModel):
    question: str
    session_id: Optional[str] = None # Optional session ID for continuing conversations

class ChatResponse(BaseModel):
    response: str
    session_id: str # Include session_id in the response for client to track

@app.post("/api/generate")
async def generate_text(request: QuestionRequest):
    """
    Handles text generation requests, maintaining conversation history per session.
    """
    session_id = request.session_id

    # If no session_id is provided, generate a new one.
    # This signifies the start of a new conversation.
    if session_id is None:
        session_id = str(uuid.uuid4())
        print(f"Starting new conversation with session_id: {session_id}")

    # Retrieve or create a ConversationChain for this session_id
    if session_id not in active_conversations:
        print(f"Creating new ConversationChain for session_id: {session_id}")
        # Initialize Langchain HuggingFacePipeline for this session
        llm = HuggingFacePipeline(pipeline=pipeline(
            "text-generation",
            model=model, # Use the globally loaded model
            tokenizer=tokenizer, # Use the globally loaded tokenizer
            max_new_tokens=512,  
            return_full_text=True, 
            temperature=0.2,      
            do_sample=True,        
        ))
        # Initialize memory for this specific session
        memory = ConversationBufferWindowMemory(k=5) # Remembers the last 5 human-AI interaction pairs
        conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True)
        active_conversations[session_id] = conversation
    else:
        print(f"Continuing conversation for session_id: {session_id}")
        conversation = active_conversations[session_id]

    async def generate_stream():
        """
        An asynchronous generator function to stream text responses token-by-token.
        Each yielded item will be a JSON string representing a part of the stream.
        """
        # Flag to indicate when we've started streaming the AI's actual response
        started_streaming_ai_response = False
        
        try:
            # First, send a JSON object containing the session_id.
            # This allows the client to immediately get the session ID.
            yield json.dumps({"type": "session_info", "session_id": session_id}) + "\n"

            response_stream = conversation.stream({"input": request.question})

            stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"]
            assistant_start_marker = "<|im_start|>assistant\n" 

            for chunk in response_stream:
                full_text_chunk = ""
                if 'response' in chunk:
                    full_text_chunk = chunk['response']
                else:
                    full_text_chunk = str(chunk) 

                if not started_streaming_ai_response:
                    if assistant_start_marker in full_text_chunk:
                        token_content = full_text_chunk.split(assistant_start_marker, 1)[1]
                        started_streaming_ai_response = True
                    else:
                        token_content = ""
                else:
                    token_content = full_text_chunk

                for stop_seq in stop_sequences_to_check:
                    if stop_seq in token_content:
                        token_content = token_content.split(stop_seq, 1)[0] 
                        if token_content: 
                            yield json.dumps({"type": "token", "content": token_content}) + "\n"
                            await asyncio.sleep(0.01)
                        yield json.dumps({"type": "end", "status": "completed", "session_id": session_id}) + "\n" 
                        return 

                if token_content:
                    yield json.dumps({"type": "token", "content": token_content}) + "\n"
                    await asyncio.sleep(0.01)

            yield json.dumps({"type": "end", "status": "completed", "session_id": session_id}) + "\n"

        except Exception as e:
            print(f"Error during streaming generation for session {session_id}:")
            traceback.print_exc()
            yield json.dumps({"type": "error", "message": str(e), "session_id": session_id}) + "\n"

    # Return a StreamingResponse with application/json media type
    return StreamingResponse(generate_stream(), media_type="application/json")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))