Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| from collections import OrderedDict | |
| class RiasecPredictor: | |
| def __init__(self, regressor_path='riasec_regressor_v1.pkl', | |
| scaler_path='riasec_scaler.pkl', | |
| embedding_model_path='all-MiniLM-L6-v2'): | |
| """ | |
| Load saved models for RIASEC prediction | |
| """ | |
| print("Loading models...") | |
| self.embedding_model = SentenceTransformer(embedding_model_path) | |
| self.regressor = joblib.load(regressor_path) | |
| try: | |
| self.scaler = joblib.load(scaler_path) # 👈 Load scaler | |
| except FileNotFoundError: | |
| raise FileNotFoundError(f"Scaler file not found at {scaler_path}. " | |
| "Did you save it during training?") | |
| self.riasec_labels = ['R', 'I', 'A', 'S', 'E', 'C'] | |
| self.code_to_name = { | |
| 'R': 'Realistic', 'I': 'Investigative', 'A': 'Artistic', | |
| 'S': 'Social', 'E': 'Enterprising', 'C': 'Conventional' | |
| } | |
| print("✅ Models and scaler loaded successfully!") | |
| def predict(self, job_title=None, job_description=None, full_text=None, sort_by_score=True): | |
| """ | |
| Predict RIASEC scores for a job (in original 1-7 scale) | |
| """ | |
| # Handle input | |
| if full_text is not None: | |
| text = full_text | |
| elif job_title is not None and job_description is not None: | |
| text = f"{job_title} {job_description}" | |
| else: | |
| raise ValueError("Provide either full_text OR both job_title and job_description") | |
| # Generate embedding | |
| embedding = self.embedding_model.encode([text], convert_to_numpy=True) | |
| # Make prediction in scaled space | |
| prediction_scaled = self.regressor.predict(embedding)[0] | |
| # Convert back to original scale | |
| prediction = self.scaler.inverse_transform(prediction_scaled.reshape(1, -1))[0] | |
| prediction = np.clip(prediction, 1.0, 7.0) # Enforce valid range | |
| # Create dictionary | |
| riasec_dict = dict(zip(self.riasec_labels, prediction.tolist())) | |
| if sort_by_score: | |
| return OrderedDict(sorted(riasec_dict.items(), key=lambda x: x[1], reverse=True)) | |
| else: | |
| return riasec_dict | |
| def predict_with_names(self, job_title=None, job_description=None, full_text=None): | |
| """Predict with full names in R-I-A-S-E-C order""" | |
| results = self.predict(job_title, job_description, full_text, sort_by_score=False) | |
| ordered_with_names = OrderedDict() | |
| for code in ['R', 'I', 'A', 'S', 'E', 'C']: | |
| ordered_with_names[self.code_to_name[code]] = results[code] | |
| return ordered_with_names | |
| # Initialize predictor | |
| predictor = RiasecPredictor() | |
| def predict_riasec(job_title, job_description): | |
| """Wrapper for Gradio""" | |
| try: | |
| if not job_title.strip() or not job_description.strip(): | |
| return None, "Please provide both job title and job description." | |
| result = predictor.predict( | |
| job_title=job_title, | |
| job_description=job_description, | |
| sort_by_score=False # Don't sort by score, maintain R-I-A-S-E-C order | |
| ) | |
| # Prepare bar chart data in R-I-A-S-E-C order with abbreviations | |
| riasec_order = ['R', 'I', 'A', 'S', 'E', 'C'] | |
| ordered_labels = [] | |
| ordered_scores = [] | |
| for code in riasec_order: | |
| ordered_labels.append(code) # Use abbreviations | |
| ordered_scores.append(result[code]) | |
| bar_data = pd.DataFrame({ | |
| "Category": ordered_labels, | |
| "Score": ordered_scores | |
| }) | |
| # Prepare top 3 (sorted by score) | |
| sorted_result = sorted(result.items(), key=lambda x: x[1], reverse=True) | |
| top_3_result = "### Top 3 RIASEC Types\n\n" | |
| for key, _ in sorted_result[:3]: | |
| top_3_result += f"<div style='font-size: 1.5em; font-weight: bold; margin: 5px 0; padding: 10px; background-color: #f0f0f0; color: #000000; border-radius: 5px; text-align: center; border: 1px solid #cccccc;'>{key}</div>\n" | |
| return bar_data, top_3_result | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| # Updated Gradio UI | |
| with gr.Blocks(title="RIASEC Predictor") as demo: | |
| gr.Markdown("# RIASEC Predictor") | |
| gr.Markdown("Predict RIASEC personality type scores for job descriptions") | |
| with gr.Row(): | |
| with gr.Column(): | |
| job_title = gr.Textbox(label="Job Title", placeholder="e.g., Data Scientist") | |
| job_description = gr.Textbox(label="Job Description", placeholder="e.g., Analyze large datasets...", lines=4) | |
| submit_btn = gr.Button("Predict RIASEC Scores", variant="primary") | |
| with gr.Column(): | |
| output_chart = gr.BarPlot( | |
| x="Category", | |
| y="Score", | |
| title="RIASEC Scores", | |
| vertical=False, # Horizontal bars | |
| tooltip=["Category", "Score"], | |
| show_legend=False, | |
| height=400 | |
| ) | |
| with gr.Column(): | |
| top_3_output = gr.Markdown(label="Top 3 RIASEC", elem_classes="top-3-riasec") | |
| gr.Markdown("Note: Please provide both job title and job description.") | |
| submit_btn.click( | |
| fn=predict_riasec, | |
| inputs=[job_title, job_description], | |
| outputs=[output_chart, top_3_output], | |
| show_progress=True | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Data Scientist", "Analyze large datasets and build machine learning models"], | |
| ["Graphic Designer", "Create visual content and design marketing materials"], | |
| ["Software Engineer", "Develop and maintain software applications"] | |
| ], | |
| inputs=[job_title, job_description], | |
| outputs=[output_chart, top_3_output], | |
| fn=predict_riasec, | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(share=True) | |