|
from fastapi import FastAPI |
|
from fastapi.responses import HTMLResponse |
|
from transformers import AutoTokenizer |
|
from pydantic import BaseModel |
|
from llama_cpp import Llama |
|
import time |
|
|
|
class Message(BaseModel): |
|
content: str |
|
token: int |
|
|
|
class System(BaseModel): |
|
sys_prompt: str |
|
|
|
|
|
app = FastAPI() |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
def greet_json(): |
|
return '''<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>FastAPI Chatbot</title> |
|
<style> |
|
body { |
|
font-family: Arial, sans-serif; |
|
margin: 0; |
|
padding: 0; |
|
background-color: #f4f4f9; |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
} |
|
.container { |
|
margin-top: 60px; |
|
width: 90%; |
|
margin-bottom: 20px; |
|
} |
|
.system-prompt { |
|
display: flex; |
|
justify-content: space-between; |
|
margin-bottom: 20px; |
|
} |
|
.system-prompt input { |
|
width: 70%; |
|
padding: 10px; |
|
border: 1px solid #ccc; |
|
border-radius: 4px; |
|
} |
|
.system-prompt button { |
|
padding: 10px 20px; |
|
border: none; |
|
background-color: #007bff; |
|
color: white; |
|
border-radius: 4px; |
|
cursor: pointer; |
|
} |
|
.system-prompt button:hover { |
|
background-color: #0056b3; |
|
} |
|
.chatbox { |
|
background-color: #fff; |
|
border-radius: 8px; |
|
box-shadow: 0 2px 5px rgba(0,0,0,0.1); |
|
padding: 20px; |
|
height: 400px; |
|
overflow-y: auto; |
|
} |
|
.message { |
|
margin-bottom: 10px; |
|
} |
|
.user { |
|
text-align: right; |
|
color: #007bff; |
|
} |
|
.assistant { |
|
text-align: left; |
|
color: #333; |
|
} |
|
.input-section { |
|
display: flex; |
|
width: 100%; |
|
margin-top: 20px; |
|
} |
|
.input-section input { |
|
flex: 1; |
|
padding: 10px; |
|
border: 1px solid #ccc; |
|
border-radius: 4px; |
|
margin-right: 10px; |
|
} |
|
.input-section input:focus { |
|
outline: none; |
|
border-color: #007bff; |
|
} |
|
.input-section button { |
|
padding: 10px 20px; |
|
border: none; |
|
background-color: #28a745; |
|
color: white; |
|
border-radius: 4px; |
|
cursor: pointer; |
|
} |
|
.input-section button:hover { |
|
background-color: #218838; |
|
} |
|
.token-input { |
|
width: 100px; |
|
margin-left: 10px; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<div class="system-prompt"> |
|
<input type="text" id="systemPrompt" placeholder="Enter System Prompt"> |
|
<button onclick="setSystemPromptAndClearHistory()">Set prompt and clear history</button> |
|
</div> |
|
<div class="chatbox" id="chatbox"></div> |
|
<div class="input-section"> |
|
<input type="text" id="userInput" placeholder="Type your message here..."> |
|
<input type="number" id="tokenLength" class="token-input" value="50" placeholder="Tokens"> |
|
<button onclick="sendMessage()">Send</button> |
|
</div> |
|
</div> |
|
|
|
<script> |
|
async function setSystemPromptAndClearHistory() { |
|
const systemPrompt = document.getElementById('systemPrompt').value; |
|
const response = await fetch('/setSystemPrompt', { |
|
method: 'POST', |
|
headers: { |
|
'Content-Type': 'application/json' |
|
}, |
|
body: JSON.stringify({ sys_prompt: systemPrompt }) |
|
}); |
|
if (response.ok) { |
|
document.getElementById('chatbox').innerHTML = ''; |
|
alert('System prompt set and history cleared.'); |
|
} else { |
|
alert('Failed to set system prompt.'); |
|
} |
|
} |
|
|
|
async function sendMessage() { |
|
const userInput = document.getElementById('userInput').value; |
|
const tokenLength = parseInt(document.getElementById('tokenLength').value); |
|
if (!userInput || isNaN(tokenLength)) { |
|
alert('Please enter a valid message and token length.'); |
|
return; |
|
} |
|
|
|
const chatbox = document.getElementById('chatbox'); |
|
const userMessage = document.createElement('div'); |
|
userMessage.className = 'message user'; |
|
userMessage.textContent = userInput; |
|
chatbox.appendChild(userMessage); |
|
|
|
const response = await fetch('/chat', { |
|
method: 'POST', |
|
headers: { |
|
'Content-Type': 'application/json' |
|
}, |
|
body: JSON.stringify({ content: userInput, token: tokenLength }) |
|
}); |
|
|
|
if (response.ok) { |
|
const data = await response.json(); |
|
const assistantMessage = document.createElement('div'); |
|
assistantMessage.className = 'message assistant'; |
|
assistantMessage.textContent = data.response; |
|
chatbox.appendChild(assistantMessage); |
|
document.getElementById('userInput').value = ''; |
|
} else { |
|
alert('Failed to get response from server.'); |
|
} |
|
|
|
chatbox.scrollTop = chatbox.scrollHeight; |
|
} |
|
</script> |
|
</body> |
|
</html>''' |
|
|
|
llm = Llama.from_pretrained( |
|
repo_id="Qwen/Qwen2.5-1.5B-Instruct-GGUF", |
|
filename="qwen2.5-1.5b-instruct-q8_0.gguf", |
|
verbose=False |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct") |
|
|
|
messages = [] |
|
|
|
@app.post("/chat") |
|
def chat(req: Message): |
|
a = time.time() |
|
messages.append({"role": "user", "content": req.content}) |
|
text = tokenizer.apply_chat_template( |
|
messages, tokenize=False, add_generation_prompt=True |
|
) |
|
output = llm(text,max_tokens=req.token,echo=False) |
|
response = output['choices'][0]['text'] |
|
messages.append({"role": "assistant", "content": response}) |
|
b = time.time() |
|
return {"response": response, "time": b-a} |
|
|
|
|
|
@app.post("/setSystemPrompt") |
|
def chat(req: System): |
|
messages.append({"role": "user", "content": req.sys_prompt}) |
|
return {"response": "System has been set"} |
|
|
|
@app.post("/clear_chat") |
|
def clear_chat(): |
|
global conversation_history |
|
conversation_history = [] |
|
return {"message": "Chat history cleared"} |
|
|