Ask-ANRG / main.py
FloraJ's picture
update demo
c12d231
raw
history blame
1.23 kB
from fastapi import FastAPI, Depends, HTTPException, Query
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import openai
from helper import get_response_from_model
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
class InputData(BaseModel):
user_input: str
api_key: str
@app.get("/", response_class=HTMLResponse)
async def read_root():
with open("static/index.html", "r") as f:
content = f.read()
return HTMLResponse(content=content)
# Initialize model and tokenizer
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat-int4")
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat-int4").eval()
@app.post("/chat/")
def chat(input_data: InputData):
print("input_data: ", input_data)
user_input = input_data.user_input
api_key = input_data.api_key
openai.api_key = api_key
response = get_response_from_model(user_input)
return {"response": response}
# return {"response": f"user input: {input_data.user_input}, api_key: {input_data.api_key}"}