|
|
import os |
|
|
import pickle |
|
|
import tensorflow as tf |
|
|
import numpy as np |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
|
|
|
|
|
|
|
tokenizer_file_path = 'tokenizer.pkl' |
|
|
model_dir_path = 'keras_model_savedmodel' |
|
|
|
|
|
|
|
|
max_sequence_length = 12 |
|
|
|
|
|
|
|
|
|
|
|
with open(tokenizer_file_path, 'rb') as handle: |
|
|
tokenizer = pickle.load(handle) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = tf.keras.layers.TFSMLayer(model_dir_path, call_endpoint='serve') |
|
|
|
|
|
|
|
|
def generate_first_word(input_text, model, tokenizer, max_sequence_length): |
|
|
|
|
|
input_sequence = tokenizer.texts_to_sequences([input_text]) |
|
|
|
|
|
|
|
|
padded_input = pad_sequences(input_sequence, maxlen=max_sequence_length, padding='post') |
|
|
|
|
|
|
|
|
padded_input_tensor = tf.constant(padded_input, dtype=tf.float32) |
|
|
|
|
|
|
|
|
predicted_outputs = model(padded_input_tensor) |
|
|
|
|
|
|
|
|
if isinstance(predicted_outputs, dict): |
|
|
predicted_probabilities = list(predicted_outputs.values())[0][0] |
|
|
elif isinstance(predicted_outputs, (tuple, list)): |
|
|
predicted_probabilities = predicted_outputs[0][0] |
|
|
else: |
|
|
predicted_probabilities = predicted_outputs[0] |
|
|
|
|
|
|
|
|
predicted_token_id = np.argmax(predicted_probabilities) |
|
|
|
|
|
|
|
|
if predicted_token_id in tokenizer.index_word: |
|
|
predicted_word = tokenizer.index_word[predicted_token_id] |
|
|
else: |
|
|
predicted_word = "<UNKNOWN_TOKEN>" |
|
|
return predicted_word |
|
|
|
|
|
|
|
|
class InputText(BaseModel): |
|
|
text: str |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def read_root(): |
|
|
return {"message": "Welcome to the Generative Text Model API!"} |
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict_first_word(input_data: InputText): |
|
|
predicted_word = generate_first_word(input_data.text, model, tokenizer, max_sequence_length) |
|
|
return {"predicted_first_word": predicted_word} |
|
|
|
|
|
print("API code consolidated into app.py.") |
|
|
|