


















































































































import { getChatGenerate, getListWorker } from '@/utils/multiRound';
import { Vue, Component, Prop } from 'vue-property-decorator';
import ModelDetail from '../ModelDetail.vue';
import { cloneDeep, split } from 'lodash';
import axios from 'axios';

interface IModel {
    model_id: number;
    temperature: number;
    top_k: number;
    top_p: number;
    frepetition_penalty: number;
    max_tokens: string;
    prompt: string;
    prefix?: string;
    suffix?: string;
    defaultDialogueList?: any[];
}

@Component({
    components: {
        ModelDetail,
    },
})
export default class MultiRoundModelArena extends Vue {
    @Prop() models!: IModel[];
    @Prop() dialogueHistoryData?: any;

    dialogueList = [[], []];
    originModelDialogList = [];
    allMessage = '';
    messageList = ['', ''];
    beBeingSentList = [false, false];
    workerId = {};
    randomModel = {
        index: [],
        modelList: [],
    };

    isMeanwhileSend = true;
    modelEvaluate = null;

    modelSelectList = Object.freeze([
        {
            label: '👈A更好',
            type: 1,
        },
        {
            label: '👉B更好',
            type: 2,
        },
        {
            label: '🤝都好',
            type: 3,
        },
        {
            label: '👎都不好',
            type: 4,
        },
    ]);

    $refs: {
        dialogueListWrapper: HTMLDivElement[];
    };

    get allModelStatus() {
        return Object.values(this.workerId).length && Object.values(this.workerId).every(Boolean);
    }

    hasBeBeingSent() {
        return this.beBeingSentList.some(Boolean);
    }

    generateStatus(index) {
        const lastData = this.dialogueList[index][this.dialogueList[index].length - 1];
        if (lastData.role === 'USER') {
            return false;
        }
        return !!lastData.content;
    }

    transSaveData() {
        return {
            dialogueList: this.dialogueList,
            randomModel: this.randomModel,
            modelEvaluate: this.modelSelectList.find((m) => m.type === this.modelEvaluate),
            originModelDialogList: this.originModelDialogList,
        };
    }

    handleEvaluate(type) {
        if (this.modelEvaluate || this.hasBeBeingSent()) {
            return false;
        }
        this.modelEvaluate = type;
        const selectEvaluate = this.modelSelectList.find((m) => m.type === this.modelEvaluate);
        this.originModelDialogList.push([...cloneDeep(this.dialogueList), { ...selectEvaluate }]);
        switch (type) {
            case 1:
                this.dialogueList[1] = cloneDeep(this.dialogueList[0]);
                this.isMeanwhileSend = true;
                this.$message.success('模型A +1');
                break;
            case 2:
                this.dialogueList[0] = cloneDeep(this.dialogueList[1]);
                this.isMeanwhileSend = true;
                this.$message.success('模型B +1');
                break;
            case 3:
                this.$message.success('模型A +1   模型B +1');
                this.isMeanwhileSend = false;
                break;
            case 4:
                this.isMeanwhileSend = false;
                this.$message.success('模型A +0   模型B +0');
                break;
            default:
                break;
        }
        this.$emit('saveCache');
    }

    async handleCarryOut(j, index) {
        try {
            this.beBeingSentList[j] = true;
            this.dialogueList[j] = this.dialogueList[j].slice(0, index + 1);
            this.$forceUpdate();
            const model = this.randomModel.modelList[j];
            const prompts = this.dialogueList[j]
                .map((d, i) => {
                    const prefix = model.prefix.replace('{round}', `${Math.floor(i / 2)}`);
                    const suffix = model.suffix.replace('{round}', `${Math.floor(i / 2)}`);
                    return d.role === 'USER'
                        ? {
                              role: d.role,
                              content: `${prefix}${d.content}${suffix}`,
                          }
                        : {
                              role: d.role,
                              content: d.content,
                          };
                })
                .filter((d) => d.role !== 'ERROR' && d.content);
            const params = {
                workerId: this.workerId[model.model_id],
                parameter: {
                    topP: model.top_p,
                    topK: model.top_k,
                    temperature: model.temperature,
                    maxOutputLength: model.max_tokens,
                    repetitionPenalty: model.frepetition_penalty,
                },
                prompts,
                promptPrefix: model.prefix,
                promptSuffix: model.suffix,
            };
            await this.chat(params, j);
        } catch (error) {
            console.log(error);
        }
    }

    relaunch(index) {
        if (this.randomModel.modelList[index].prompt) {
            this.dialogueList[index] = [
                {
                    content: this.randomModel.modelList[index].prompt,
                    role: 'BOT',
                },
            ];
        } else {
            this.dialogueList[index] = [];
        }
        this.$forceUpdate();
        this.modelEvaluate = null;
    }

    handleEnter(e, j) {
        if ([8, 13].includes(e.keyCode)) {
            if (!this.messageList[j]) {
                return false;
            }
            e.preventDefault();
            this.handleSendMessage(j);
        }
    }

    handleAllEnter(e) {
        if ([8, 13].includes(e.keyCode)) {
            if (!this.allMessage) {
                return false;
            }
            e.preventDefault();
            this.handleSendAllMessage();
        }
    }

    async handleSendMessage(j) {
        try {
            if (!this.messageList[j].trim()) {
                this.$message.warning('请输入内容！');
                return false;
            }

            if (this.beBeingSentList[j]) {
                this.$message.warning('请稍后');
                return false;
            }
            this.beBeingSentList[j] = true;
            this.dialogueList[j].push({
                content: this.messageList[j],
                role: 'USER',
            });
            this.messageList[j] = '';
            this.$forceUpdate();
            this.scrollBottom();
            await this.handleCarryOut(j, this.dialogueList[j].length);
        } catch (error) {
            console.log(error);
        } finally {
            this.beBeingSentList[j] = false;
            this.$forceUpdate();
        }
    }

    async handleSendAllMessage() {
        try {
            if (!this.allMessage.trim()) {
                this.$message.warning('请输入内容！');
                return false;
            }

            if (this.beBeingSentList.some(Boolean)) {
                this.$message.warning('请稍后');
                return false;
            }

            this.dialogueList.forEach((item) => {
                item.push({
                    content: this.allMessage,
                    role: 'USER',
                });
            });
            this.scrollBottom();
            const paramsList = this.randomModel.modelList.map((model, i) => {
                const prompts = this.dialogueList[i]
                    .map((d, i) => {
                        const prefix = model.prefix.replace('{round}', `${Math.floor(i / 2)}`);
                        const suffix = model.suffix.replace('{round}', `${Math.floor(i / 2)}`);
                        return d.role === 'USER'
                            ? {
                                  role: d.role,
                                  content: `${prefix}${d.content}${suffix}`,
                              }
                            : {
                                  role: d.role,
                                  content: d.content,
                              };
                    })
                    .filter((d) => d.role !== 'ERROR' && d.content);
                return {
                    workerId: this.workerId[model.model_id],
                    parameter: {
                        topP: model.top_p,
                        topK: model.top_k,
                        temperature: model.temperature,
                        maxOutputLength: model.max_tokens,
                        repetitionPenalty: model.frepetition_penalty,
                    },
                    prompts,
                    promptPrefix: model.prefix,
                    promptSuffix: model.suffix,
                };
            });
            this.allMessage = '';
            const pAll = paramsList.map((p, i) => this.chat(p, i));
            await Promise.all(pAll);
        } catch (error) {
            console.log('获取失败');
        }
    }

    cancelTokenSource = [];
    async chat(params, index) {
        this.modelEvaluate = null;
        try {
            this.beBeingSentList[index] = true;
            this.cancelTokenSource[index] = axios.CancelToken.source();
            this.dialogueList[index].push({
                content: '',
                role: 'BOT',
            });
            await getChatGenerate(
                params,
                (progressEvent) => {
                    const responseMessageList = progressEvent.currentTarget.responseText.split('data:');
                    let data = null;
                    try {
                        data = JSON.parse(responseMessageList.at(-1).replace(/\n\n$/, ''));
                    } catch (e1) {
                        try {
                            data = JSON.parse(responseMessageList.at(-2).replace(/\n\n$/, ''));
                        } catch (e2) {
                            console.log('JSON解析失败', e1, e2, responseMessageList);
                        }
                    }
                    if (data) {
                        this.dialogueList[index].splice(this.dialogueList[index].length - 1, 1, {
                            content: data.result,
                            role: 'BOT',
                        });
                        this.$forceUpdate();
                        this.scrollBottom();
                    }
                },
                { cancelToken: this.cancelTokenSource[index].token }
            );
        } catch (error) {
            if (axios.isCancel(error)) {
                return false;
            }
            console.log(error);
            this.$message.error(error);
        } finally {
            this.$emit('saveCache');
            this.beBeingSentList[index] = false;
            this.$forceUpdate();
        }
    }

    handleStopGenerate(index) {
        this.cancelTokenSource[index].cancel();
        this.beBeingSentList[index] = false;
        this.$forceUpdate();
    }

    scrollBottom() {
        this.$refs.dialogueListWrapper.forEach((item) => {
            setTimeout(() => {
                item.scrollTop = item.scrollHeight;
            });
        });
    }

    getRandomModel() {
        this.randomModel.index[0] = Math.floor(Math.random() * this.models.length);
        do {
            this.randomModel.index[1] = Math.floor(Math.random() * this.models.length);
        } while (this.randomModel.index[0] === this.randomModel.index[1]);
        this.randomModel.index.forEach((item) => {
            this.randomModel.modelList.push({
                ...this.models[item],
            });
        });
    }

    async handleGetWorkerId() {
        try {
            const requestList = this.randomModel.modelList.map((model) => getListWorker(model.model_id));
            const res = await Promise.all(requestList);
            res.forEach((item, index) => {
                const runningWorkerList = item.data.filter((r) => r.status === 'RUNNING');
                if (runningWorkerList.length) {
                    const randomWorkerIndex = Math.floor(Math.random() * runningWorkerList.length);
                    this.$set(
                        this.workerId,
                        this.randomModel.modelList[index].model_id,
                        runningWorkerList[randomWorkerIndex]?.workerId || null
                    );
                }
            });
            Object.keys(this.workerId).forEach((key) => {
                if (!this.workerId[key]) {
                    this.$message.warning(`模型Id为${key}的模型未启动成功。`);
                    throw new Error('模型未启动');
                }
            });
        } catch (error) {
            console.log(error, '获取失败');
        }
    }

    initDialogList() {
        this.randomModel.modelList.forEach((model, index) => {
            if (model.prompt) {
                this.dialogueList[index] = [
                    {
                        content: model.prompt,
                        role: 'BOT',
                    },
                ];
            } else {
                this.dialogueList[index] = [];
            }
        });
    }

    init() {
        if (this.dialogueHistoryData) {
            this.dialogueList = this.dialogueHistoryData.dialogueList;
            this.randomModel = this.dialogueHistoryData.randomModel;
            this.modelEvaluate = this.dialogueHistoryData.modelEvaluate?.type;
            this.originModelDialogList = this.dialogueHistoryData.originModelDialogList || [];
            if ([3, 4].includes(this.modelEvaluate)) {
                this.isMeanwhileSend = false;
            }
            this.handleGetWorkerId();
            return false;
        }
        this.getRandomModel();
        this.handleGetWorkerId();
        this.initDialogList();
    }

    mounted() {
        this.init();
    }
}
