cahya commited on
Commit
e5f8262
2 Parent(s): 3354e39 cc1e48d

Merge branch 'gpt'

Browse files
app/{web_socket.py → api.py} RENAMED
@@ -1,7 +1,12 @@
1
  from fastapi import FastAPI, WebSocket
2
  from fastapi.responses import HTMLResponse
 
 
 
3
  import os
4
-
 
 
5
 
6
  app = FastAPI()
7
 
@@ -45,6 +50,7 @@ html = """
45
  async def get():
46
  return HTMLResponse(html)
47
 
 
48
  @app.get("/env")
49
  async def env():
50
  environment_variables = "<h3>Environment Variables</h3>"
@@ -52,6 +58,7 @@ async def env():
52
  environment_variables += f"{name}: {value}<br>"
53
  return HTMLResponse(environment_variables)
54
 
 
55
  @app.websocket("/ws")
56
  async def websocket_endpoint(websocket: WebSocket):
57
  await websocket.accept()
@@ -59,3 +66,72 @@ async def websocket_endpoint(websocket: WebSocket):
59
  data = await websocket.receive_text()
60
  await websocket.send_text(f"Message text was: {data}")
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, WebSocket
2
  from fastapi.responses import HTMLResponse
3
+ from fastapi import Form, Depends, HTTPException, status
4
+ from transformers import pipeline, set_seed, AutoConfig, AutoTokenizer, GPT2LMHeadModel
5
+ import torch
6
  import os
7
+ import time
8
+ import re
9
+ import json
10
 
11
  app = FastAPI()
12
 
 
50
  async def get():
51
  return HTMLResponse(html)
52
 
53
+
54
  @app.get("/env")
55
  async def env():
56
  environment_variables = "<h3>Environment Variables</h3>"
 
58
  environment_variables += f"{name}: {value}<br>"
59
  return HTMLResponse(environment_variables)
60
 
61
+
62
  @app.websocket("/ws")
63
  async def websocket_endpoint(websocket: WebSocket):
64
  await websocket.accept()
 
66
  data = await websocket.receive_text()
67
  await websocket.send_text(f"Message text was: {data}")
68
 
69
+
70
+ @app.post("/api/indochat/v1")
71
+ async def indochat(
72
+ text: str = Form(default="", description="The Prompt"),
73
+ max_length: int = Form(default=250, description="Maximal length of the generated text"),
74
+ do_sample: bool = Form(default=True, description="Whether to use sampling; use greedy decoding otherwise"),
75
+ top_k: int = Form(default=50, description="The number of highest probability vocabulary tokens to keep "
76
+ "for top-k-filtering"),
77
+ top_p: float = Form(default=0.95, description="If set to float < 1, only the most probable tokens with "
78
+ "probabilities that add up to top_p or higher are kept "
79
+ "for generation"),
80
+ temperature: float = Form(default=1.0, description="The Temperature of the softmax distribution"),
81
+ penalty_alpha: float = Form(default=0.6, description="Penalty alpha"),
82
+ repetition_penalty: float = Form(default=1.0, description="Repetition penalty"),
83
+ seed: int = Form(default=42, description="Random Seed"),
84
+ max_time: float = Form(default=60.0, description="Maximal time in seconds to generate the text")
85
+ ):
86
+ set_seed(seed)
87
+ if repetition_penalty == 0.0:
88
+ min_penalty = 1.05
89
+ max_penalty = 1.5
90
+ repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
91
+ prompt = f"User: {text}\nAssistant: "
92
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
93
+ model.eval()
94
+ print("Generating text...")
95
+ print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, "
96
+ f"temperature: {temperature}, repetition_penalty: {repetition_penalty}, penalty_alpha: {penalty_alpha}")
97
+ time_start = time.time()
98
+ sample_outputs = model.generate(input_ids,
99
+ penalty_alpha=penalty_alpha,
100
+ do_sample=do_sample,
101
+ min_length=200,
102
+ max_length=max_length,
103
+ top_k=top_k,
104
+ top_p=top_p,
105
+ temperature=temperature,
106
+ repetition_penalty=repetition_penalty,
107
+ num_return_sequences=1,
108
+ max_time=max_time
109
+ )
110
+ result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
111
+ # result = result[len(prompt) + 1:]
112
+ time_end = time.time()
113
+ time_diff = time_end - time_start
114
+ print(f"result:\n{result}")
115
+ generated_text = result
116
+ return {"generated_text": generated_text, "processing_time": time_diff}
117
+
118
+
119
+ def get_text_generator(model_name: str, device: str = "cpu"):
120
+ hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
121
+ print(f"hf_auth_token: {hf_auth_token}")
122
+ print(f"Loading model with device: {device}...")
123
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
124
+ model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
125
+ use_auth_token=hf_auth_token)
126
+ model.to(device)
127
+ print("Model loaded")
128
+ return model, tokenizer
129
+
130
+
131
+ def get_config():
132
+ return json.load(open("config.json", "r"))
133
+
134
+
135
+ config = get_config()
136
+ device = "cuda" if torch.cuda.is_available() else "cpu"
137
+ model, tokenizer = get_text_generator(model_name=config["model_name"], device=device)
app/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "model_name": "cahya/indochat-tiny"
3
+ }
app/start.sh CHANGED
@@ -3,21 +3,16 @@ set -e
3
 
4
  cd /home/user/app
5
 
6
- id
7
- ls -ld /var/log/nginx/ /var/lib/nginx/ /run
8
- ls -la /
9
- ls -la ~
10
-
11
  nginx
12
 
13
  python whisper.py&
14
 
15
  if [ "$DEBUG" = true ] ; then
16
  echo 'Debugging - ON'
17
- uvicorn web_socket:app --host 0.0.0.0 --port 7880 --reload
18
  else
19
  echo 'Debugging - OFF'
20
- uvicorn web_socket:app --host 0.0.0.0 --port 7880
21
  echo $?
22
  echo END
23
  fi
 
3
 
4
  cd /home/user/app
5
 
 
 
 
 
 
6
  nginx
7
 
8
  python whisper.py&
9
 
10
  if [ "$DEBUG" = true ] ; then
11
  echo 'Debugging - ON'
12
+ uvicorn api:app --host 0.0.0.0 --port 7880 --reload
13
  else
14
  echo 'Debugging - OFF'
15
+ uvicorn api:app --host 0.0.0.0 --port 7880
16
  echo $?
17
  echo END
18
  fi