KITT / llama2.py
Abigail
llama2 with a little bit of context
863d316
raw
history blame
2.27 kB
# -*- coding: utf-8 -*-
"""llama2
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/15UK6iHd1y0pMMQc-DbZIYhSteoTCUsMH
"""
#install and then restart execution
!pip install accelerate
!pip install bitsandbytes
!pip install optimum
!pip install auto-gptq
!pip install transformers
from transformers import AutoModelForCausalLM,AutoTokenizer
import torch
!pip install transformers huggingface_hub
from huggingface_hub import notebook_login
notebook_login()
mn = 'stabilityai/StableBeluga-7B'
#mn = "TheBloke/Llama-2-7b-Chat-GPTQ"
model = AutoModelForCausalLM.from_pretrained(mn, device_map=0, load_in_8bit=True)
#model = AutoModelForCausalLM.from_pretrained(mn, device_map=0, torch_dtype=torch.float16)
sb_sys = "### System:\nYou are a AI driving assistant in my car, that follows instructions extremely well. Help as much as you can.\n\n"
def gen(p, maxlen=15, sample=True):
toks = tokr(p, return_tensors="pt")
res = model.generate(**toks.to("cuda"), max_new_tokens=maxlen, do_sample=sample).to('cpu')
return tokr.batch_decode(res)
tokr = AutoTokenizer.from_pretrained(mn)
#to have a prompt corresponding to the specific format required by the fine-tuned model Stable Beluga
def mk_prompt(user, syst=sb_sys): return f"{syst}### User: {user}\n\n### Assistant:\n"
complete_answer= ''
#attempt to get user location
import requests
response = requests.get("http://ip-api.com/json/")
data = response.json()
print(data['city'], data['lat'], data['lon'])
city= data['city']
lat = data['lat']
lon = data['lon']
import re
model_answer= ''
general_context= f'I am in my car in {city}, latitude {lat}, longitude {lon}, I can move with my car to reach a destination'
pattern = r"Assistant:\\n(.*?)</s>"
ques = "I hate pizzas"
ques_ctx = f"""Answer the question with the help of the provided context.
## Context
{general_context} .
## Question
{ques}"""
complete_answer = str(gen(mk_prompt(ques_ctx), 150))
match = re.search(pattern, complete_answer, re.DOTALL)
if match:
# Extracting the text
model_answer = match.group(1)
else:
model_answer = "There has been an error with the generated response."
general_context += model_answer
print(model_answer)
print(complete_answer)