Spaces:
Runtime error
Runtime error
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import ConversationChain | |
import os | |
import runpod | |
from dotenv import load_dotenv | |
from langchain.llms import HuggingFaceTextGenInference | |
from langchain.schema import BaseOutputParser | |
import re | |
import re | |
from typing import List | |
from langchain.schema import BaseOutputParser | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
StoppingCriteria, | |
) | |
# Load the .env file | |
load_dotenv() | |
# Get the API key from the environment variable | |
runpod.api_key = os.getenv("RUNPOD_API_KEY") | |
os.environ["LANGCHAIN_WANDB_TRACING"] = "true" | |
os.environ["WANDB_PROJECT"] = "falcon_hackathon" | |
os.environ["WANDB_API_KEY"] = os.getenv("WANDB_API_KEY") | |
pod_id = os.getenv("POD_ID") | |
class CleanupOutputParser(BaseOutputParser): | |
def parse(self, text: str) -> str: | |
user_pattern = r"\nUser" | |
text = re.sub(user_pattern, "", text) | |
human_pattern = r"\nHuman:" | |
text = re.sub(human_pattern, "", text) | |
ai_pattern = r"\nAI:" | |
return re.sub(ai_pattern, "", text).strip() | |
def _type(self) -> str: | |
return "output_parser" | |
class StopGenerationCriteria(StoppingCriteria): | |
def __init__( | |
self, tokens: List[List[str]], tokenizer: AutoTokenizer, device: torch.device | |
): | |
stop_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] | |
self.stop_token_ids = [ | |
torch.tensor(x, dtype=torch.long, device=device) for x in stop_token_ids | |
] | |
def __call__( | |
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs | |
) -> bool: | |
for stop_ids in self.stop_token_ids: | |
if torch.eq(input_ids[0][-len(stop_ids) :], stop_ids).all(): | |
return True | |
return False | |
class Falcon_7b_llm(): | |
def __init__(self): | |
inference_server_url_cloud = f"https://{pod_id}-80.proxy.runpod.net" | |
template = """You are a chatbot called 'Falcon Barista' working at a coffee shop. | |
Your primary function is to take orders from customers. | |
Start with a greeting. | |
You have the following menu with prices. Dont mention the price unless asked. Do not take order for anything other than in menu. | |
- cappucino-5$ | |
- latte-3$ | |
- frappucino-8$ | |
- juice-3$ | |
If user orders something else, apologise that you dont have that item. | |
Take the order politely and in a frienldy way. After that confirm the order, tell the order price and say "Goodbye have a nice day". | |
{chat_history} | |
Human: {human_input} | |
AI:""" | |
prompt = PromptTemplate( | |
input_variables=["chat_history", "human_input"], template=template | |
) | |
memory = ConversationBufferMemory(memory_key="chat_history") | |
llm_cloud = HuggingFaceTextGenInference( | |
inference_server_url=inference_server_url_cloud, | |
max_new_tokens=200, | |
top_k=10, | |
top_p=0.95, | |
typical_p=0.95, | |
temperature=0.01, | |
repetition_penalty=1.0, | |
stop_sequences = ['Mini', 'AI', 'Human', ':'] | |
) | |
self.llm_chain_cloud = ConversationChain( | |
prompt=prompt, | |
llm=llm_cloud, | |
verbose=True, | |
memory=memory, | |
output_parser=CleanupOutputParser(), | |
input_key='human_input' | |
) | |
def restart_state(self): | |
inference_server_url_cloud = f"https://{pod_id}-80.proxy.runpod.net" | |
template = """You are a chatbot called 'Falcon Barista' working at a coffee shop. | |
Your primary function is to take orders from customers. | |
Start with a greeting. | |
You have the following menu with prices. Dont mention the price unless asked. Do not take order for anything other than in menu. | |
- cappucino-5$ | |
- latte-3$ | |
- frappucino-8$ | |
- juice-3$ | |
If user orders something else, apologise that you dont have that item. | |
Take the order politely and in a frienldy way. After that confirm the order, tell the order price and say "Goodbye have a nice day". | |
{chat_history} | |
Human: {human_input} | |
AI:""" | |
prompt = PromptTemplate( | |
input_variables=["chat_history", "human_input"], template=template | |
) | |
memory = ConversationBufferMemory(memory_key="chat_history") | |
llm_cloud = HuggingFaceTextGenInference( | |
inference_server_url=inference_server_url_cloud, | |
max_new_tokens=200, | |
top_k=10, | |
top_p=0.95, | |
typical_p=0.95, | |
temperature=0.01, | |
repetition_penalty=1.0, | |
stop_sequences = ['Mini', 'AI', 'Human', ':'] | |
) | |
self.llm_chain_cloud = ConversationChain( | |
prompt=prompt, | |
llm=llm_cloud, | |
verbose=True, | |
memory=memory, | |
output_parser=CleanupOutputParser(), | |
input_key='human_input' | |
) | |
def get_llm_response(self, human_input): | |
completion = self.llm_chain_cloud.predict(human_input=human_input) | |
return completion |