Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Thu Aug 17 12:11:26 2023 | |
@author: crodrig1 | |
""" | |
from optparse import OptionParser | |
import sys, re, os | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from pymongo import MongoClient | |
from pprint import pprint | |
import torch | |
import warnings | |
import re, string | |
from dotenv import load_dotenv | |
load_dotenv() | |
MONGO_URI = os.environ.get("MONGO_URI") | |
warnings.filterwarnings("ignore") | |
tokenizer = AutoTokenizer.from_pretrained("crodri/bloom1.3_meteo") | |
from transformers import BitsAndBytesConfig | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
llm_int8_enable_fp32_cpu_offload=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) | |
model_4bit = AutoModelForCausalLM.from_pretrained( | |
"crodri/bloom1.3_meteo", | |
model_type="BloomForCausalLM", | |
device_map="cpu", | |
# verbose=False, | |
# quantization_config=quantization_config, | |
trust_remote_code=True) | |
# #tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# tokenizer = AutoTokenizer.from_pretrained(model_4bit) | |
llm_pipeline = pipeline( | |
"text-generation", | |
model=model_4bit, | |
tokenizer=tokenizer, | |
use_cache=True, | |
device_map="auto", | |
#max_length=800, | |
do_sample=True, | |
top_k=10, | |
num_return_sequences=1, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
def retrieveFor(result): | |
def retrievehighest(key,result): | |
try: | |
candidates = [x for x in result if (x["entity_group"] == key)] | |
topone = max(candidates, key=lambda x:x['score']) | |
return topone['word'].strip() | |
except ValueError: | |
return [] | |
def removeend(frase): | |
frase = re.sub("\?","",frase) | |
frase = re.sub(",","",frase) | |
return frase | |
intervalo = [x["word"].strip() for x in result if (x["entity_group"] == "interval")] | |
client = MongoClient(MONGO_URI) | |
db = client['aina'] | |
collection = db['new_ccma_meteo'] | |
try: | |
location = removeend(retrievehighest("location",result)) | |
day = removeend(retrievehighest("day",result)) | |
except TypeError: | |
print("No hem trobat el lloc o la data. Torna a provar") | |
return None | |
record = collection.find({"location":location.strip(),"day":day.lower().strip()}) | |
try: | |
j = record.next() | |
if intervalo: | |
return (j,intervalo[0]) | |
return (j,'tot el dia') | |
except: | |
print("No hem trobat el lloc o la data. Torna a provar") | |
return None | |
#context": "Day: dilluns | Location: Sant Salvador de Guardiola | mati: la nuvolositat anirà en augment | tarda: els núvols alts taparan el cel | nit: cel clar | temp: Lleugera pujada de les temperatures" | |
pipe = pipeline("token-classification", model="crodri/ccma_ner",aggregation_strategy='first') | |
intent = pipeline("text-classification", model="projecte-aina/roberta-large-ca-v2-massive") | |
def pipeIt(jresponse): | |
regex = re.compile('[%s]' % re.escape(string.punctuation)) | |
d = jresponse[0] | |
#i = jresponse[-1] | |
#i = regex.sub('', i) | |
#context = i +" del "+ db["day"]+" a "+db["location"]+" al mati "+db["mati"]+", "+"a la tarda "+db["tarde"]+", a la nit "+db["nit"] +", i "+db["temperature"] | |
#context = d["day"]+" a "+d["location"]+": al mati "+d["mati"]+", "+"a la tarda "+d["tarde"]+", a la nit "+d["nit"] +", i "+d["temperature"] | |
return d["context"] | |
#question = "Quin temps farà a la tarda a Algete dijous?" | |
def givePrediction(question, context,temperature,repetition): | |
instruction = question | |
text = f"### Instruction\n{{instruction}}\n\n### Context\n{{context}}\n\n### Answer\n" | |
response = llm_pipeline(text.format(instruction=instruction, context=context),temperature=temperature,repetition_penalty=repetition, max_new_tokens=100)[0]["generated_text"] | |
answer = response.split("###")[-1][8:] | |
return answer | |
def assistant(question): | |
is_intent = intent(question)[0] | |
if is_intent['label'] == 'weather_query': | |
result = pipe(question) | |
jresponse = retrieveFor(result) | |
if jresponse: | |
context = jresponse[0]['context']#pipeIt(jresponse) | |
#jresponse[0]['context'] = context | |
print("Context: ",context) | |
print() | |
return jresponse | |
elif is_intent['label'] in ["general_greet","general_quirky"]: | |
print("Hola, quina es la teva consulta meteorològica?") | |
#sys.exit(0) | |
else: | |
print(is_intent['label']) | |
print("Ho sento. Jo només puc respondre a preguntes sobre el temps a alguna localitat en concret ...") | |
#sys.exit(0) | |
return None | |
def generate(question,temperature,repetition): | |
jresponse = assistant(question) | |
#print(jresponse) | |
if jresponse: | |
codes = jresponse[0]['codis'] | |
interval = jresponse[1] | |
context = {"codis": codes, "interval": interval} | |
# context = jresponse[0]['context'] | |
ccma_response = jresponse[0]['response'] | |
answer = givePrediction(question, context,temperature,repetition) | |
print("CCMA generated: ",ccma_response) | |
print("="*16) | |
print("LLM answer: ",answer) | |
print() | |
return {"context": context, "ccma_response": ccma_response, "model_answer": answer} | |
else: | |
print("No response") | |
return None | |
def main(): | |
parser = OptionParser() | |
parser.add_option("-q", "--question", dest="question", type="string", | |
help="question to test", default="Quin temps farà a la tarda a Algete dijous?") | |
parser.add_option("-t", "--temperature", dest="temperature", type="float", | |
help="temperature generation", default=1.0) | |
parser.add_option("-r", "--repetition", dest="repetition", type="float", | |
help="repetition penalty", default=1.0) | |
(options, args) = parser.parse_args(sys.argv) | |
print(options) | |
#question = options.question | |
#print(question) | |
answer = generate(options.question,options.temperature,options.repetition) | |
#print(answer) | |
if __name__ == "__main__": | |
main() | |