ChatLab / app_MVP.py
peterciank's picture
Rename app.py to app_MVP.py
a4028fd verified
import os
import streamlit as st
import requests
# Set up Hugging Face API details
token = os.getenv("HF_TOKEN", None)
headers = {"Authorization": f"Bearer {token}"}
API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn"
# Title and description for this particular project
st.title("Large Language Model using Inference API")
st.write("This project will show how Inference API and Bart LLM uses text summarization.")
st.write("It is very simple implementation, and other models can be used.")
# Function to query the Hugging Face model
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
# Input textbox to introduce prompt
user_input = st.text_input("You:", "")
# Submit button to run the inference API
if st.button("Send"):
if user_input.strip() != "":
# Query Hugging Face model
data = query({"inputs": user_input, "parameters": {"do_sample": False}})
# Display response
if data and "summary_text" in data[0]:
st.text_area("Bot:", value=data[0]["summary_text"], height=150)
else:
st.error("No response from the model")
# Model selection
model = st.radio(
"Model",
[
"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"OpenAssistant/oasst-sft-1-pythia-12b",
"google/flan-t5-xxl",
"google/flan-ul2",
"bigscience/bloom",
"bigscience/bloomz",
"EleutherAI/gpt-neox-20b",
]
)
# Input textbox
input_text = st.text_input(label="Type an input and press Enter", placeholder="What is Deep Learning?")
# Parameters
with st.expander("Parameters", expanded=False):
typical_p = st.slider("Typical P mass", min_value=0.0, max_value=1.0, value=0.2, step=0.05)
top_p = st.slider("Top-p (nucleus sampling)", min_value=0.0, max_value=1.0, value=0.25, step=0.05)
temperature = st.slider("Temperature", min_value=0.0, max_value=5.0, value=0.6, step=0.1)
top_k = st.slider("Top-k", min_value=1, max_value=50, value=50, step=1)
repetition_penalty = st.slider("Repetition Penalty", min_value=0.1, max_value=3.0, value=1.03, step=0.01)
watermark = st.checkbox("Text watermarking", value=False)
# Submit button
if st.button("Submit"):
# Perform prediction
predict(model, input_text, typical_p, top_p, temperature, top_k, repetition_penalty, watermark)
'''
# Reset button
if st.button("Reset"):
input_text = ""
'''