saifeddinemk commited on
Commit
7d0de60
1 Parent(s): 5aaa320

Fixed app v2

Browse files
Files changed (1) hide show
  1. app.py +21 -56
app.py CHANGED
@@ -1,78 +1,43 @@
1
- import torch
2
- import json
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from typing import List
7
 
8
  # Initialize the FastAPI app
9
  app = FastAPI()
10
 
11
- # Model and tokenizer paths and loading
12
- model_path = "WhiteRabbitNeo/WhiteRabbitNeo-2.5-Qwen-2.5-Coder-7B"
13
- output_file_path = "/home/user/conversations.jsonl"
14
-
15
- model = AutoModelForCausalLM.from_pretrained(
16
- model_path,
17
- torch_dtype=torch.float16,
18
- device_map="auto",
19
- load_in_4bit=False,
20
- trust_remote_code=False,
21
- )
22
-
23
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
24
-
25
- # Function to generate text
26
- def generate_text(instruction):
27
- tokens = tokenizer.encode(instruction)
28
- tokens = torch.LongTensor(tokens).unsqueeze(0)
29
- tokens = tokens.to("cuda")
30
-
31
- instance = {
32
- "input_ids": tokens,
33
- "top_p": 1.0,
34
- "temperature": 0.75,
35
- "generate_len": 2048,
36
- "top_k": 50,
37
- }
38
-
39
- length = len(tokens[0])
40
- with torch.no_grad():
41
- rest = model.generate(
42
- input_ids=tokens,
43
- max_length=length + instance["generate_len"],
44
- use_cache=True,
45
- do_sample=True,
46
- top_p=instance["top_p"],
47
- temperature=instance["temperature"],
48
- top_k=instance["top_k"],
49
- num_return_sequences=1,
50
- pad_token_id=tokenizer.eos_token_id,
51
- )
52
- output = rest[0][length:]
53
- string = tokenizer.decode(output, skip_special_tokens=True)
54
- return f"{string}"
55
 
56
  # Data model for FastAPI input
57
  class UserInput(BaseModel):
58
  conversation: str
59
  user_input: str
 
 
60
 
61
  @app.post("/generate/")
62
  async def generate_response(user_input: UserInput):
63
  try:
64
- # Construct the prompt
65
- conversation = user_input.conversation
66
- llm_prompt = f"{conversation}{user_input.user_input}<|im_end|>\n<|im_start|>assistant\nSure! Let me provide a complete and a thorough answer to your question, with functional and production-ready code.\n"
67
-
68
  # Generate response
69
- answer = generate_text(llm_prompt)
70
-
71
- # Update conversation for future requests
72
- updated_conversation = f"{llm_prompt}{answer}<|im_end|>\n<|im_start|>user\n"
 
 
 
 
73
 
 
 
 
74
  return {
75
- "response": answer,
76
  "updated_conversation": updated_conversation
77
  }
78
  except Exception as e:
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import pipeline, set_seed
4
  from typing import List
5
 
6
  # Initialize the FastAPI app
7
  app = FastAPI()
8
 
9
+ # Initialize the generator pipeline
10
+ generator = pipeline('text-generation', model='gpt2-medium')
11
+ set_seed(42)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Data model for FastAPI input
14
  class UserInput(BaseModel):
15
  conversation: str
16
  user_input: str
17
+ max_length: int = 50 # default length
18
+ num_return_sequences: int = 1 # default number of sequences
19
 
20
  @app.post("/generate/")
21
  async def generate_response(user_input: UserInput):
22
  try:
23
+ # Construct the prompt from the conversation and user input
24
+ prompt = f"{user_input.conversation}{user_input.user_input}"
25
+
 
26
  # Generate response
27
+ responses = generator(
28
+ prompt,
29
+ max_length=user_input.max_length,
30
+ num_return_sequences=user_input.num_return_sequences
31
+ )
32
+
33
+ # Extract text from each generated sequence
34
+ generated_texts = [response["generated_text"] for response in responses]
35
 
36
+ # Update conversation with the last generated text
37
+ updated_conversation = f"{prompt}\n{generated_texts[-1]}"
38
+
39
  return {
40
+ "responses": generated_texts,
41
  "updated_conversation": updated_conversation
42
  }
43
  except Exception as e: