gbif-tools / app.py
cboettig's picture
here we go!
4dcc7af
# +
import streamlit as st
import leafmap.maplibregl as leafmap
import numpy as np
from matplotlib import cm
import pandas as pd
import ibis
from ibis import _
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import ConfigurableField
from langchain_core.tools import tool
from langchain.agents import create_tool_calling_agent, AgentExecutor
h3_parquet = "/home/rstudio/huggingface/spaces/boettiger-lab/gbif/gbif_ca.geoparquet"
con = ibis.duckdb.connect("duck.db")
gbif = con.read_parquet(h3_parquet, "gbif")
@tool
def sql_query(sql):
"""Execute the SQL query 'sql' on the current database"""
return gbif.sql(sql).limit(100).execute()
# +
st.set_page_config(page_title="GBIF Observations Explorer", layout="wide")
st.header("GBIF Observations Explorer", divider="rainbow")
# st.set_page_config(page_title="H3 in Streamlit", layout="wide")
# -
def get_h3point_df(_df, zoom: float) -> pd.DataFrame:
column = "h" + str(zoom)
df = (_df
.rename(hex = column)
.group_by(_.hex)
.agg(n = _.count())
.mutate(v = _.n.log())
.mutate(normalized_values = _.v / _.v.max())
# .mutate(wkt = h3_cell_to_boundary_wkt(_.hex))
.to_pandas()
)
rgb = cm.viridis(df.normalized_values)
rgb_array = np.round( rgb * 255 ).astype(int).clip(0,255).tolist()
df['rgb'] = rgb_array
# df['viridis_hex'] = rgb_to_hex(rgb)
return df
import pydeck as pdk
def hex_layer(m, df: pd.DataFrame, v_scale = 1):
layer = pdk.Layer(
"H3HexagonLayer",
df,
get_hexagon="hex",
get_fill_color="rgb",
extruded=True,
get_elevation="normalized_values",
elevation_scale= v_scale,
elevation_range = [0,1]
)
return m.add_deck_layers([layer])
@tool
def maplibre_plot(sql, zoom = 4, vertical_exaggeration=0):
"""plot the result of a sql_query() at the specified zoom"""
df = gbif.sql(sql)
df = get_h3point_df(df, zoom = zoom)
return hex_layer(m, df, vertical_exaggeration)
llm = ChatOllama(
model="llama3-groq-tool-use:70b",
temperature=0,
)
schema = con.table("gbif").schema()
prompt = ChatPromptTemplate.from_messages([
("system", '''
You're a helpful assistant working with data from the Global Biodiversity Information Facility, GBIF. You are an expert in duckdb SQL.
Your job is to provide answers about GBIF data by using the tools I provide. Assume any question about species, animals, plants, etc can be answered
using the data in the gbif table provided.
The schema of the table is {schema}
You will construct appropriate SQL queries that can be run against this table to answer the user's questions. If the user asks you to "show" the
results on a map, you should call the maplibre_plot() tool with your query. If the user asks for information that would not require a map, such
as "the number of unique species" or other summary statistic, use the sql_query() tool.
For example, if the user says "show all mammals on the map", you will produce the SQL query to filter the data to include only mammals:
SELECT * FROM gbif WHERE class='Mammalia'
Because the user asked for a map, you will call the tool, maplibre_plot() with that sql query. If the user just names a species or taxonomic group without
saying what they want you to do, just
write the SQL query that will filter for that specific group. Think carefully, you will need to determine the appropriate taxonomic rank or ranks to
use (kingdom, phylum, class, order), and provide the scientific name of the group, e.g. class='Mammalia'. For specific species, use the binomial scientific
name. For example if the user just writes "wolves", run:
SELECT * FROM gbif WHERE scientificname = 'Canis lupus'
using the maplibre_plot() tool.
If you do not have enough information, please ask the user clarifying questions.
When using the maplibre_plot() tool, you can control the "zoom" (or resolution at which the data is shown) and the vertical_exaggeration of the results.
If the user does not specify a zoom, please choose zoom 6. If not specified, the vertical exaggeration should be 1.
You can call the sql_query() tool to help you explore the data better before you try and answer the question. For instance, you may want to use the
tool to call sql_query("select * from gbif limit 1;") to examine the table schema before formulating your query.
If not specified, the vertical_exaggeration should be 1 and the zoom should be 6.
If not specified, the table name is 'gbif'. Include the table name in the SQL query, it is not a parameter.
'''
),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
])
tools = [sql_query, maplibre_plot]
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
example = "filter the gbif table to show only bird species. visualize results at zoom 10 and set the vertical exaggeration to 5000."
m = leafmap.Map(style="positron", center=(-121.4, 37.50), zoom=8,)
with st.container():
st.markdown("🦜 Or try our chat-based query:")
if prompt := st.chat_input(example, key="chain"):
st.chat_message("user").write(prompt)
with st.chat_message("assistant"):
out = agent_executor.invoke({"input": prompt, "schema": schema})
st.write(out["output"])
m.to_streamlit()
st.divider()
'''
## Credits
DRAFT. Open Source Software developed at UC Berkeley.
'''