使用gradio部署微调后的模型
import gradio as gr
import numpy as np
import tensorflow as tf
# 加载预训练的模型和图片分类标签
model = tf.keras.applications.MobileNetV2()
with open("imagenet_class_index.json") as f:
CLASS_INDEX = json.load(f)
# 定义一个函数来对输入的图片进行预处理和预测
def predict(image):
if image.mode != "RGB":
image = image.convert("RGB")
image = np.asarray(image.resize((224, 224)))[None, ...]
image = image / 255.0
prediction = model.predict(image)
return np.argmax(prediction)
# 定义一个函数来将预测的类别索引转换为标签
def get_class_name(prediction):
return CLASS_INDEX[str(prediction)][0]
# 创建一个UI组件,用于选择和上传图片
image_input = gr.Image(label="Image")
# 创建一个UI组件,用于显示预测的类别
class_output = gr.Textbox(label="Class")
# 定义一个自定义的交互功能
def custom_function(image):
prediction = predict(image)
class_name = get_class_name(prediction)
return class_name
# 将UI组件和自定义的交互功能组合在一起
gr.Interface(fn=custom_function, inputs=image_input, outputs=class_output, live=True).launch()
这段代码使用了gradio库来创建一个用户界面,允许用户上传图片,并实时显示图片的分类结果。它展示了如何加载预训练的模型,如何对输入图片进行预处理,以及如何使用gradio的API来创建一个简单的用户界面。
评论已关闭