|
import gradio as gr |
|
import requests |
|
import json |
|
import logging |
|
import time |
|
from requests.exceptions import RequestException |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
API_URL = "https://api-inference.huggingface.co/models/mattshumer/Reflection-Llama-3.1-70B" |
|
|
|
def query(payload, max_retries=3, delay=5): |
|
for attempt in range(max_retries): |
|
try: |
|
logger.info(f"Sending payload to API (attempt {attempt + 1}/{max_retries}): {payload}") |
|
response = requests.post(API_URL, json=payload, timeout=60) |
|
logger.info(f"Received response with status code: {response.status_code}") |
|
|
|
if response.status_code == 200: |
|
return response.json() |
|
elif response.status_code == 503: |
|
logger.warning("Model is loading. Retrying...") |
|
time.sleep(delay) |
|
else: |
|
logger.error(f"API request failed with status code {response.status_code}: {response.text}") |
|
return {"error": f"API request failed with status code {response.status_code}"} |
|
|
|
except RequestException as e: |
|
logger.error(f"Request failed: {str(e)}") |
|
if attempt < max_retries - 1: |
|
logger.info(f"Retrying in {delay} seconds...") |
|
time.sleep(delay) |
|
else: |
|
return {"error": f"Failed to connect after {max_retries} attempts: {str(e)}"} |
|
|
|
return {"error": "Maximum retries reached"} |
|
|
|
def generate_text(prompt): |
|
logger.info(f"Received prompt: {prompt}") |
|
|
|
try: |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"max_new_tokens": 100, |
|
"temperature": 0.7, |
|
"top_p": 0.95, |
|
"do_sample": True |
|
} |
|
} |
|
logger.info("Calling Hugging Face Inference API for text generation...") |
|
response = query(payload) |
|
|
|
logger.info(f"Raw response from API: {json.dumps(response, indent=2)}") |
|
|
|
if "error" in response: |
|
return f"Error: {response['error']}" |
|
|
|
if isinstance(response, list) and len(response) > 0: |
|
generated_text = response[0].get('generated_text', '') |
|
logger.info(f"Processed response: {generated_text[:100]}...") |
|
return generated_text |
|
else: |
|
error_msg = f"Unexpected response format: {response}" |
|
logger.error(error_msg) |
|
return error_msg |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating text: {str(e)}" |
|
logger.exception(error_msg) |
|
return error_msg |
|
|
|
iface = gr.Interface( |
|
fn=generate_text, |
|
inputs=gr.Textbox(lines=5, label="Enter your prompt"), |
|
outputs=gr.Textbox(label="Generated Response"), |
|
title="Reflection Llama 3.1 70B Demo", |
|
description="Enter a prompt to generate text using the Reflection Llama 3.1 70B model. Please be patient, as the model may take some time to respond." |
|
) |
|
|
|
if __name__ == "__main__": |
|
logger.info("Starting Gradio interface...") |
|
iface.launch() |
|
logger.info("Gradio interface launched.") |