ggruiz-amd's picture
add all files for qwen
0d87b57
import os
import logging
import gradio as gr
from typing import Iterator
from gateway import request_generation
# Setup logging
logging.basicConfig(level=logging.INFO)
# Validate environment variables
CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT")
if not CLOUD_GATEWAY_API:
raise EnvironmentError("API_ENDPOINT is not set.")
MODEL_NAME: str = os.getenv("MODEL_NAME")
if not MODEL_NAME:
raise EnvironmentError("MODEL_NAME is not set.")
# Get API Key
API_KEY = os.getenv("API_KEY")
if not API_KEY: # simple check to validate API Key
raise Exception("API Key not valid.")
# Create a header, avoid declaring multiple times
HEADER = {"x-api-key": f"{API_KEY}"}
def generate(
message: str,
chat_history: list,
system_prompt: str,
temperature: float = 0.6,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
) -> Iterator[str]:
"""Send a request to backend, fetch the streaming responses and emit to the UI.
Args:
message (str): input message from the user
chat_history (list[tuple[str, str]]): entire chat history of the session
system_prompt (str): system prompt
temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6.
top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities
that add up to top_p or higher are kept for generation. Defaults to 0.9.
top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering.
Defaults to 50.
repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty.
Defaults to 1.2.
Yields:
Iterator[str]: Streaming responses to the UI
"""
# sample method to yield responses from the llm model
outputs = []
for text in request_generation(
header=HEADER,
message=message,
system_prompt=system_prompt,
temperature=temperature,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
cloud_gateway_api=CLOUD_GATEWAY_API,
model_name=MODEL_NAME,
):
outputs.append(text)
yield "".join(outputs)
description = """
This Space is an Alpha release that demonstrates the [Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B) model running on AMD MI300 infrastructure. The space is built with Qwen 3 [License](https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/LICENSE). Feel free to play with it!
"""
demo = gr.ChatInterface(
fn=generate,
type="messages",
chatbot=gr.Chatbot(
type="messages",
scale=2,
allow_tags=True,
),
stop_btn=None,
additional_inputs=[
gr.Textbox(
label="System prompt",
value="You are a highly capable AI assistant. Provide accurate, concise, and fact-based responses that are directly relevant to the user's query. Avoid speculation, ensure logical consistency, and maintain clarity in longer outputs. Keep answers well-structured and under 1200 tokens unless explicitly requested otherwise.",
lines=3,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.3,
),
gr.Slider(
label="Frequency penalty",
minimum=-2.0,
maximum=2.0,
step=0.1,
value=0.0,
),
gr.Slider(
label="Presence penalty",
minimum=-2.0,
maximum=2.0,
step=0.1,
value=0.0,
),
],
examples=[
["Plan a three-day trip to Washington DC for Cherry Blossom Festival."],
[
"Compose a short, joyful musical piece for kids celebrating spring sunshine and blossom."
],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'."],
],
cache_examples=False,
title="Qwen3-30B-A3B",
description=description,
)
if __name__ == "__main__":
demo.queue(
max_size=int(os.getenv("QUEUE")),
default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")),
).launch()