1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
| import * as tf from '@tensorflow/tfjs'; import * as tfvis from '@tensorflow/tfjs-vis'; import { getInputs } from './data'; import { img2x, file2img } from './utils';
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json'; const NUM_CLASSES = 3; const BRAND_CLASSES = ['android', 'apple', 'windows'];
window.onload = async () => { const { inputs, labels } = await getInputs(); const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } }); inputs.forEach(img => { surface.drawArea.appendChild(img); });
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH); mobilenet.summary(); // 模型结构/* */ const layer = mobilenet.getLayer('conv_pw_13_relu'); // 获取模型中间层 const truncatedMobilenet = tf.model({ // 截断模型 inputs: mobilenet.inputs, // 输入层 outputs: layer.output // 输出层 });
// 创建新模型 const model = tf.sequential();
console.log(layer.outputShape) model.add(tf.layers.flatten({ // 展平层, 一维的向量,无训练参数 inputShape: layer.outputShape.slice(1) // 获取中间层输出的形状 })); model.add(tf.layers.dense({ // 全连接层 units: 10, activation: 'relu' })); model.add(tf.layers.dense({ // 分类层 units: NUM_CLASSES, activation: 'softmax' }));
// 定义损失函数和优化器 model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });
//训练数据经过截断模型,转为可以用于新模型训练的数据 const { xs, ys } = tf.tidy(() => { const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl)))); const ys = tf.tensor(labels); return { xs, ys }; });
await model.fit(xs, ys, { epochs: 20, callbacks: tfvis.show.fitCallbacks( { name: '训练效果' }, ['loss'], { callbacks: ['onEpochEnd'] } ) });
window.predict = async (file) => { const img = await file2img(file); document.body.appendChild(img); const pred = tf.tidy(() => { const x = img2x(img); const input = truncatedMobilenet.predict(x); return model.predict(input); });
const index = pred.argMax(1).dataSync()[0]; setTimeout(() => { alert(`预测结果:${BRAND_CLASSES[index]}`); }, 0); };
window.download = async () => { await model.save('downloads://model'); }; };
|