kitrofimov commited on
Commit
645bc1a
·
1 Parent(s): 512f7dc

Fix label names

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -2,10 +2,11 @@ import gradio as gr
2
  from transformers import pipeline
3
 
4
  pipe = pipeline("text-classification", model="kitrofimov/news-clf", top_k=3)
 
5
 
6
  def classify(text):
7
  preds = pipe(text)[0]
8
- return {p["label"]: float(p["score"]) for p in preds}
9
 
10
  with gr.Blocks() as demo:
11
  gr.Markdown("# News Classifier")
 
2
  from transformers import pipeline
3
 
4
  pipe = pipeline("text-classification", model="kitrofimov/news-clf", top_k=3)
5
+ label_names = dict(zip([f"LABEL_{i}" for i in range(4)], ["World", "Sports", "Business", "Sci/Tech"]))
6
 
7
  def classify(text):
8
  preds = pipe(text)[0]
9
+ return {label_names[p["label"]]: float(p["score"]) for p in preds}
10
 
11
  with gr.Blocks() as demo:
12
  gr.Markdown("# News Classifier")