import gradio as gr import json from pyvis.network import Network from transformers import pipeline # Load the pipeline tourModel = pipeline(model="manhan/GPT-Tour") def vizTour(activityList): net = Network(directed=True) net.add_node(activityList[0]) for activity in activityList[1:]: if activity not in net.get_nodes(): net.add_node(activity) net.add_edge(activityList[0], activity) net.show('tour.html') return True def getTour(income,size,years,sex,edu,wrk): # person = a dict with person-level and hh-level attributes: person = {} hh_income = int(income) if hh_income<0: _skip = 1 # nvm, bad data elif hh_income<4: # $25,000 person['hh_inc'] = 'poor' elif hh_income<6: # $50,000 person['hh_inc'] = 'low' elif hh_income<7: # $75,000 person['hh_inc'] = 'medium' elif hh_income<9: # $125,000 person['hh_inc'] = 'high' else: # over person['hh_inc'] = 'affluent' hh_size = int(size) if hh_size == 1: person['hh_size'] = 'single' elif hh_size == 2: person['hh_size'] = 'couple' elif hh_size <= 4: person['hh_size'] = 'small' else: # more than four people person['hh_size'] = 'large' age = int(years) if age < 18: person['age_grp'] = 'child' elif age < 45: person['age_grp'] = 'younger' elif age < 65: person['age_grp'] = 'older' else: person['age_grp'] = 'senior' person['sex'] = sex person['edu'] = edu person['wrk'] = wrk activity_list = [] prompt = json.dumps(person)[:-1] + ", pattern: " print(person) while not activity_list: generated = tourModel(prompt, return_full_text=False, max_length=250, temperature=0.9)[0]['generated_text'] #print(f"{generated}") start_pos = generated.find('[') end_pos = generated.find(']')+1 activity_list_str = generated[start_pos:end_pos] print(f"Extracted: '{activity_list_str}'") # this check doesn't appear to work anyways if person['wrk']=='yes' and activity_list_str.find('Work')==-1: continue # try again if person['wrk']=='no' and activity_list_str.find('Work')>0: continue # try again if activity_list_str: try: activity_list = json.loads(activity_list_str) if activity_list[-1]!='Home': activity_list=[] continue break except Exception as e: print("Error parsing activity list") print(e) else: print("Nothing extracted!") success = vizTour(activity_list) return activity_list, success with gr.Interface(fn=getTour, inputs=[ gr.inputs.Textbox(label="Annual Household Income (in dollars)"), gr.inputs.Textbox(label="Household Size (number of people)"), gr.inputs.Textbox(label="Traveler Age (years)"), gr.inputs.Dropdown(["unknown", "male", "female"], label="Gender/sex"), gr.inputs.Dropdown(["unknown", "grade school","highschool", "associates", "bachelors", "graduate"], label="Educational attainment level"), gr.inputs.Dropdown(["unknown", "yes","no"], label="Worker status")], outputs=["json", gr.HTML(value=open('tour.html'))], title="GPT-Tour", description="Generate a tour for a person", allow_flagging=False, allow_screenshot=False, allow_embedding=False) as iface: iface.launch()