stroke_predict / utils.py
jagadeesh72's picture
initial backend
4975e29
raw
history blame contribute delete
852 Bytes
# Utility functions can go here
import tensorflow as tf
import numpy as np
from PIL import Image
CLASS_NAMES = ['hemorrhagic_stroke', 'ischemic_stroke', 'no_stroke']
def preprocess_image(image_file, target_size=(224, 224)):
"""
Preprocess uploaded image for prediction
"""
img = Image.open(image_file).convert("RGB")
img = img.resize(target_size)
img_array = tf.keras.utils.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)
return img_array
def predict_image(model, image_file):
"""
Predict stroke type
"""
processed = preprocess_image(image_file)
predictions = model.predict(processed)
index = np.argmax(predictions[0])
confidence = float(np.max(predictions[0]) * 100)
return {
"prediction": CLASS_NAMES[index],
"confidence": round(confidence, 2)
}