Spaces:
Sleeping
Sleeping
| import time | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Load model | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") | |
| model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") | |
| # Inference function | |
| def chat_completion(messages, model_name="mock-gpt-model", max_tokens=512, temperature=0.1, stream=False): | |
| if not messages: | |
| return { | |
| "error": "No messages provided." | |
| } | |
| # Rebuild prompt | |
| prompt = "" | |
| for msg in messages: | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role == "user": | |
| prompt += f"User: {content}\n" | |
| elif role == "assistant": | |
| prompt += f"Assistant: {content}\n" | |
| prompt += "Assistant:" | |
| # Generate output | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract assistant reply | |
| assistant_reply = generated_text[len(prompt):].strip() | |
| return { | |
| "id": "1337", | |
| "object": "chat.completion", | |
| "created": time.time(), | |
| "model": model_name, | |
| "choices": [{ | |
| "message": { | |
| "role": "assistant", | |
| "content": assistant_reply | |
| } | |
| }] | |
| } | |
| # Gradio API endpoint setup | |
| demo = gr.Interface( | |
| fn=chat_completion, | |
| inputs=[ | |
| gr.JSON(label="messages"), # List[{"role":..., "content":...}] | |
| gr.Textbox(label="model", value="mock-gpt-model"), | |
| gr.Slider(minimum=1, maximum=1024, value=512, label="max_tokens"), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.1, label="temperature"), | |
| gr.Checkbox(label="stream", value=False) | |
| ], | |
| outputs=gr.JSON(label="response"), | |
| title="OpenAI-compatible Chat API (Gradio + Transformers)", | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |