mistralai / app.py
R-TA's picture
Update app.py
3651383 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "mistralai/Devstral-Small-2505"
# Use slow tokenizer to avoid sentencepiece errors
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
# Model loading with dtype and device fallback
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
def chat(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
do_sample=True,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
gr.Interface(
fn=chat,
inputs=gr.Textbox(lines=2, placeholder="Talk with Devstral..."),
outputs="text",
title="Devstral-Small Chat",
theme="compact"
).launch()