欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

文心雕龙 ERNIE-Bot 4.0 模型流和非流 API 调用(SpringBoot+OkHttp3+SSE+WebSocket)

最编程 2024-04-17 17:01:24
...

博客地址:https://www.yuque.com/autunomy/emwi09

摘要
本文的主要内容是针对于文心一言的ERNIE-Bot 4.0模型的API调用,使用到的技术有JDK1.8 , OkHttp3 , WebSocket , SSE , SpringBoot。API有流式和非流式两种,这里都进行了详细的代码编写,并且针对于流式API给出了两种协议的写法并且是前后端交互式的写法,看完本文后几乎所有的大模型的流式、非流式API都可以轻松调用了。

准备工作

第一步申请资格:千帆服务体验申请
第二步打开控制台:百度智能云千帆大模型平台

点击模型广场并选择自己需要模型,打开API文档
image

打开文档之后会有两个部分,一个是API在线调试,一个是AccessToken的获取,那接下来我们先获取AccessToken
image

接下来就可以看到获取API Key的方式
image

点击创建应用来创建API Key,应用的名字可以随便写
image

最后就可以获取一个API Key 和 Secret Key,这两个就可以帮助我们获取AccessToken,接下来返回之前的AccessToken的获取页面,里面就会有详细的请求入参和响应,可以根据官方的文档获取AccessToken
AccessToken的作用就是一个通行证,每次我们与文心一言对话的时候都需要带着这个通行证才能访问,接下来我用java代码使用springboot+okHttp3的方式来演示AccessToken的获取方式

使用java代码获取AccessToken

第一步肯定是引入依赖

<dependency>
    <groupId>com.squareup.okhttp3</groupId>
    <artifactId>okhttp</artifactId>
    <version>4.9.3</version>
</dependency>
<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>fastjson</artifactId>
    <version>1.2.79</version>
</dependency>

第二步是配置okHttp

import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import javax.net.ssl.*;
import java.security.*;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.concurrent.TimeUnit;


@Configuration
public class OkHttpConfiguration {

    @Value("${ok.http.connect-timeout}")
    private Integer connectTimeout;

    @Value("${ok.http.read-timeout}")
    private Integer readTimeout;

    @Value("${ok.http.write-timeout}")
    private Integer writeTimeout;

    @Value("${ok.http.max-idle-connections}")
    private Integer maxIdleConnections;

    @Value("${ok.http.keep-alive-duration}")
    private Long keepAliveDuration;

    @Bean
    public OkHttpClient okHttpClient() {
        return new OkHttpClient.Builder()
                .sslSocketFactory(sslSocketFactory(), x509TrustManager())
                // 是否开启缓存
                .retryOnConnectionFailure(false)
                .connectionPool(pool())
                .connectTimeout(connectTimeout, TimeUnit.SECONDS)
                .readTimeout(readTimeout, TimeUnit.SECONDS)
                .writeTimeout(writeTimeout,TimeUnit.SECONDS)
                .hostnameVerifier((hostname, session) -> true)
                // 设置代理
//            	.proxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 8888)))
                // 拦截器
//                .addInterceptor()
                .build();
    }

    @Bean
    public X509TrustManager x509TrustManager() {
        return new X509TrustManager() {
            @Override
            public void checkClientTrusted(X509Certificate[] chain, String authType)
                    throws CertificateException {
            }
            @Override
            public void checkServerTrusted(X509Certificate[] chain, String authType)
                    throws CertificateException {
            }
            @Override
            public X509Certificate[] getAcceptedIssuers() {
                return new X509Certificate[0];
            }
        };
    }

    @Bean
    public SSLSocketFactory sslSocketFactory() {
        try {
            // 信任任何链接
            SSLContext sslContext = SSLContext.getInstance("TLS");
            sslContext.init(null, new TrustManager[]{x509TrustManager()}, new SecureRandom());
            return sslContext.getSocketFactory();
        } catch (NoSuchAlgorithmException | KeyManagementException e) {
            e.printStackTrace();
        }
        return null;
    }

    @Bean
    public ConnectionPool pool() {
        return new ConnectionPool(maxIdleConnections, keepAliveDuration, TimeUnit.SECONDS);
    }
}

同时需要在application.properties配置文件中进行配置

# okhttp3配置
ok.http.connect-timeout=30
ok.http.read-timeout=30
ok.http.write-timeout=30
# 连接池中整体的空闲连接的最大数量
ok.http.max-idle-connections=200
# 连接空闲时间最多为 300 秒
ok.http.keep-alive-duration=300

第三步是配置一些全局变量

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okhttp3.MediaType;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;

import java.io.IOException;
import java.util.Date;
/**
 * @author hty
 * @date 2023-11-18 9:42
 * @email 1156388927@qq.com
 * @description
 */

@Configuration
@Slf4j
public class WenXinConfig {

    @Value("${wenxin.apiKey}")
    public String API_KEY;
    @Value("${wenxin.secretKey}")
    public String SECRET_KEY;
    @Value("${wenxin.accessTokenUrl}")
    public String ACCESS_TOKEN_URL;
    @Value("${wenxin.ERNIE-Bot4.0URL}")
    public String ERNIE_Bot_4_0_URL;

    //过期时间为30天
    public String ACCESS_TOKEN = null;
    public String REFRESH_TOKEN = null;

    public Date CREATE_TIME = null;//accessToken创建时间

    public Date EXPIRATION_TIME = null;//accessToken到期时间

    /**
     * 获取accessToken
     * @return true表示成功 false表示失败
     */
    public synchronized String flushAccessToken(){
        //判断当前AccessToken是否为空且判断是否过期
        if(ACCESS_TOKEN != null && EXPIRATION_TIME.getTime() > CREATE_TIME.getTime()) return ACCESS_TOKEN;

        //构造请求体 包含请求参数和请求头等信息
        RequestBody body = RequestBody.create(MediaType.parse("application/json"),"");
        Request request = new Request.Builder()
                .url(ACCESS_TOKEN_URL+"?client_id="+API_KEY+"&client_secret="+SECRET_KEY+"&grant_type=client_credentials")
                .method("POST", body)
                .addHeader("Content-Type", "application/json")
                .addHeader("Accept", "application/json")
                .build();
        OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
        String response = null;
        try {
            //请求
            response = HTTP_CLIENT.newCall(request).execute().body().string();
        } catch (IOException e) {
            log.error("ACCESS_TOKEN获取失败");
            return null;
        }

        //刷新令牌以及更新令牌创建时间和过期时间
        JSONObject jsonObject = JSON.parseObject(response);
        ACCESS_TOKEN = jsonObject.getString("access_token");
        REFRESH_TOKEN = jsonObject.getString("refresh_token");
        CREATE_TIME = new Date();
        EXPIRATION_TIME = new Date(Long.parseLong(jsonObject.getString("expires_in")) + CREATE_TIME.getTime());

        return ACCESS_TOKEN;
    }
}

可以看到代码中有@Value注解,表示需要注入属性,所以继续在application.properties配置文件中进行配置

wenxin.apiKey=你的apiKey
wenxin.secretKey=你的secretKey
#获取AccessToken的url地址
wenxin.accessTokenUrl=https://aip.baidubce.com/oauth/2.0/token
#文心ERNIE-Bot4.0模型访问地址
wenxin.ERNIE-Bot4.0URL=https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro

非流式回答

至此配置完成,接下来就是编写代码来进行访问,我使用的是springboot框架,所以就在controller中编写一个方法来实现。

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.hty.config.WenXinConfig;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;

import javax.annotation.Resource;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@RestController
@Slf4j
public class TestController {

    @Resource
    private WenXinConfig wenXinConfig;

    //历史对话,需要按照user,assistant
    List<Map<String,String>> messages = new ArrayList<>();

    /**
     * 非流式问答
     * @param question 用户的问题
     * @return
     * @throws IOException
     */
    @PostMapping("/test1")
    public String test1(String question) throws IOException {
        if(question == null || question.equals("")){
            return "请输入问题";
        }
        String responseJson = null;
        //先获取令牌然后才能访问api
        if (wenXinConfig.flushAccessToken() != null) {
            HashMap<String, String> user = new HashMap<>();
            user.put("role","user");
            user.put("content",question);
            messages.add(user);
            String requestJson = constructRequestJson(1,0.95,0.8,1.0,false,messages);
            RequestBody body = RequestBody.create(MediaType.parse("application/json"), requestJson);
            Request request = new Request.Builder()
                    .url(wenXinConfig.ERNIE_Bot_4_0_URL + "?access_token=" + wenXinConfig.flushAccessToken())
                    .method("POST", body)
                    .addHeader("Content-Type", "application/json")
                    .build();
            OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
            try {
                responseJson = HTTP_CLIENT.newCall(request).execute().body().string();
                //将回复的内容转为一个JSONObject
                JSONObject responseObject = JSON.parseObject(responseJson);
                //将回复的内容添加到消息中
                HashMap<String, String> assistant = new HashMap<>();
                assistant.put("role","assistant");
                assistant.put("content",responseObject.getString("result"));
                messages.add(assistant);
            } catch (IOException e) {
                log.error("网络有问题");
                return "网络有问题,请稍后重试";
            }
        }
        return responseJson;
    }

    /**
     * 构造请求的请求参数
     * @param userId
     * @param temperature
     * @param topP
     * @param penaltyScore
     * @param messages
     * @return
     */
    public String constructRequestJson(Integer userId,
                                       Double temperature,
                                       Double topP,
                                       Double penaltyScore,
                                       boolean stream,
                                       List<Map<String, String>> messages) {
        Map<String,Object> request = new HashMap<>();
        request.put("user_id",userId.toString());
        request.put("temperature",temperature);
        request.put("top_p",topP);
        request.put("penalty_score",penaltyScore);
        request.put("stream",stream);
        request.put("messages",messages);
        System.out.println(JSON.toJSONString(request));
        return JSON.toJSONString(request);
    }
}

使用PostMan调用接口即可获取到回复

流式回答-输出到控制台

流式回答就是在原来非流式回答的基础上将stream这个参数从false改为true即可,但是还要注意的一个事情就是响应中的内容是会有一点变化的,会在原来响应的内容的前面加上data这个字符,需要在转jsonObject的时候注意

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.hty.config.WenXinConfig;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;

import javax.annotation.Resource;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@RestController
@Slf4j
public class TestController {

    @Resource
    private WenXinConfig wenXinConfig;

    //历史对话,需要按照user,assistant
    List<Map<String,String>> messages = new ArrayList<>();

    /**
     * 流式回答
     * @return
     */
    @PostMapping("/test2")
    public String test2(String question){
        OkHttpClient client = new OkHttpClient();

        HashMap<String, String> user = new HashMap<>();
        user.put("role","user");
        user.put("content",question);
        messages.add(user);
        String requestJson = constructRequestJson(1,0.95,0.8,1.0,true,messages);
        RequestBody body = RequestBody.create(MediaType.parse("application/json"), requestJson);
        Request request = new Request.Builder()
                .url(wenXinConfig.ERNIE_Bot_4_0_URL + "?access_token=" + wenXinConfig.flushAccessToken())
                .method("POST", body)
                .addHeader("Content-Type", "application/json")
                .build();

        StringBuilder answer = new StringBuilder();
        // 发起异步请求
        try {
            Response response = client.newCall(request).execute();
            // 检查响应是否成功
            if (response.isSuccessful()) {
                // 获取响应流
                try (ResponseBody responseBody = response.body()) {
                    if (responseBody != null) {
                        InputStream inputStream = responseBody.byteStream();
                        // 以流的方式处理响应内容,输出到控制台
                        byte[] buffer = new byte[1024];
                        int bytesRead;
                        while ((bytesRead = inputStream.read(buffer)) != -1) {
                            // 在控制台输出每个数据块
                            System.out.write(buffer, 0, bytesRead);
                            //将结果汇总起来
                            answer.append(new String(buffer, 0, bytesRead));
                        }
                    }
                }
            } else {
                System.out.println("Unexpected code " + response);
            }

        } catch (IOException e) {
            log.error("流式请求出错");
            throw new RuntimeException(e);
        }
        //将回复的内容添加到消息中
        HashMap<String, String> assistant = new HashMap<>();
        assistant.put("role","assistant");
        assistant.put("content","");
        //取出我们需要的内容,也就是result部分
        String[] answerArray = answer.toString().split("data: ");
        for (int i=1;i<answerArray.length;++i) {
            answerArray[i] = answerArray[i].substring(0,answerArray[i].length() - 2);
            assistant.put("content",assistant.get("content") + JSON.parseObject(answerArray[i]).get("result"));
        }
        messages.add(assistant);
        return assistant.get("content");
    }

    /**
     * 构造请求的请求参数
     * @param userId
     * @param temperature
     * @param topP
     * @param penaltyScore
     * @param messages
     * @return
     */
    public String constructRequestJson(Integer userId,
                                       Double temperature,
                                       Double topP,
                                       Double penaltyScore,
                                       boolean stream,
                                       List<Map<String, String>> messages) {
        Map<String,Object> request = new HashMap<>();
        request.put("user_id",userId.toString());
        request.put("temperature",temperature);
        request.put("top_p",topP);
        request.put("penalty_score",penaltyScore);
        request.put("stream",stream);
        request.put("messages",messages);
        System.out.println(JSON.toJSONString(request));
        return JSON.toJSONString(request);
    }
}

流式回答-输出到前端

这个地方分两种方式,一种是使用SSE(Server-Sent Events),一种是使用websocket

SSE方式

引入依赖

<dependency>
    <groupId>com.squareup.okhttp3</groupId>
    <artifactId>okhttp</artifactId>
    <version>4.9.3</version>
</dependency>
<dependency>
    <groupId>com.squareup.okhttp3</groupId>
    <artifactId>okhttp-sse</artifactId>
    <version>4.9.3</version>
</dependency>

配置跨域

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

@Configuration
public class CorsConfig {
    @Bean
    public WebMvcConfigurer corsConfigurer() {
        return new WebMvcConfigurer() {
            @Override
            public void addCorsMappings(CorsRegistry registry) {
                registry.addMapping("/**")
                        .allowedOrigins("*");//允许域名访问,如果*,代表所有域名
            }
        };
    }
}

前端代码

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>SSE Example</title>
</head>
<body>
    <div id="sse-data" v-html="message"></div>

    <script src="https://code.jquery.com/jquery-3.6.4.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/vue@2"></script>

    <script>
        const app = new Vue({
            el: '#sse-data',
            data: {
                message: '',
            },
            created() {
                this.createSseConnect();
            },
            methods: {
				// 建立连接
				createSseConnect(){
					let clientId = 1;
					if(window.EventSource){
						const eventSource = new EventSource('http://localhost:8080/sse/connect?clientId=' + clientId);
						console.log(eventSource)
						
						eventSource.onmessage = (event) =>{
							console.log("onmessage: "+event.data)
							this.message = this.message + event.data
						};
						
						eventSource.onopen = (event) =>{
							console.log("onopen:"+event)
						};
						
						eventSource.onerror = (event) =>{
							console.log("onerror:"+event)
						};
						
						eventSource.close = (event) =>{
							console.log("close :"+event)
						};
				 
					}else{
						console.log("你的浏览器不支持SSE~")
					}
					console.log(" 测试 打印")
				},
            },
        });
    </script>
</body>
</html

后端代码部分的写法和WebSocket非常的像,都是向创建连接,之后再发送消息

/**
 * 这个用来保存用户与服务器之间的连接信息
 */
private static final Map<Long, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();

//创建连接
@GetMapping(value = "/sse/connect", produces="text/event-stream;charset=UTF-8")
public SseEmitter sseConnect(Long clientId){
    //已经连接过,直接返回连接
    if (sseEmitterMap.containsKey(clientId)) {
        return sseEmitterMap.get(clientId);
    }

    try {
        // 设置超时时间,0表示不过期。默认30秒
        SseEmitter sseEmitter = new SseEmitter(30 * 1000L);

        // 注册回调
        sseEmitter.onCompletion(completionCallBack(clientId));
        sseEmitter.onTimeout(timeoutCallBack(clientId));
        sseEmitterMap.put(clientId, sseEmitter);
        log.info("创建sse连接完成,当前客户端:{}", clientId);
        return sseEmitter;
    } catch (Exception e) {
        log.info("创建sse连接异常,当前客户端:{}", clientId);
    }
    return null;
}

//发送消息,这里采用异步的方式来进行发送
/**
 * 用来异步发送消息
 */
private final ExecutorService executorService = Executors.newCachedThreadPool();
/**
 * SSE方式向前端发送消息
 * @param clientId
 * @param question
 */
@PostMapping(value = "/sse/chat")
public void streamOutputToPage(Long clientId,String question){
//异步发送消息
    executorService.execute(() -> {
        SseEmitter sseEmitter = sseEmitterMap.get(clientId);
        if(sseEmitter == null){
            sseEmitter = sseChat(clientId);
        }

        OkHttpClient client = new OkHttpClient();

        String requestJson = wenxinUtils.constructRequestJson(1,0.95,1.0,true,messages);
        RequestBody body = RequestBody.create(MediaType.parse("application/json"), requestJson);
        Request request = new Request.Builder()
                .url(wenXinConfig.ERNIE_Bot_4_0_URL + "?access_token=" + wenXinConfig.flushAccessToken())
                .method("POST", body)
                .addHeader("Content-Type", "application/json")
                .build();

        //将回复的内容添加到消息中
        HashMap<String, String> assistant = new HashMap<>();
        assistant.put("role","assistant");
        assistant.put("content","");

        // 发起异步请求
        try {
            Response response = client.newCall(request).execute();
            // 检查响应是否成功
            if (response.isSuccessful()) {
                // 获取响应流
                try (ResponseBody responseBody = response.body()) {
                    if (responseBody != null) {
                        InputStream inputStream = responseBody.byteStream();
                        // 以流的方式处理响应内容,输出到控制台 这里的数组大小一定不能太小,否则会导致接收中文字符的时候产生乱码
                        byte[] buffer = new byte[2048];
                        int bytesRead;
                        StringBuilder temp = new StringBuilder();
                        while ((bytesRead = inputStream.read(buffer)) != -1) {
                            //TODO:这部分不需要使用\n\n来进行分割了,只需要将缓冲区开的尽可能大即可
                            //消息分割采用标识符 \n\n 来分割 并且需要从后向前找\n\n,因为每条消息分割点的最后才是\n\n
                            temp.append(new String(buffer, 0, bytesRead));
                            String result = "";
                            if(temp.lastIndexOf("\n\n") != -1){
                                //从6开始 因为有 data: 这个前缀 占了6个字符所以 0 + 6 = 6
                                JSONObject jsonObject = JSON.parseObject(temp.substring(6, temp.lastIndexOf("\n\n")));
                                temp = new StringBuilder(temp.substring(temp.lastIndexOf("\n\n") + 2));
                                if(jsonObject != null && jsonObject.getString("result") != null){
                                    result = jsonObject.getString("result");
                                }
                            }
                            if(!result.equals("")){
                                //SSE协议默认是以两个\n换行符为结束标志 需要在进行一次转义才能成功发送给前端
                                sseEmitter.send(SseEmitter.event().data(result.replace("\n","\\n")));
                                //将结果汇总起来
                                assistant.put("content",assistant.get("content") + result);
                            }
                        }
                        messages.add(assistant);
                    }
                }
            } else {
                System.out.println("Unexpected code " + response);
            }

        } catch (IOException e) {
            log.error("流式请求出错,断开与{}的连接",clientId);
            //移除当前的连接
            sseEmitterMap.remove(clientId);
            //移除本次对话的内容
            messages.remove(user);
        }
    });
}

至此代码编写完毕,调用的方式也是非常的简单,首先将前端页面打开就可以在控制台的network部分看到连接了,之后我们想要发送问题的时候直接使用PostMan之类的API调试工具发送请求即可,具体操作看下图
image

question部分就是你的问题,clientId就是刚才在前端页面的clientId,表示服务器推送问题答案的服务器,最后在页面中即可看到显示出来的信息。
但是页面中的内容出现的非常突兀,没有一个字一个字的弹出,这个就不是后端的技术问题了,是前端的CSS样式,自行搜索即可。

问题1:如何向前端使用流式返回?

利用到的是SSE中的SseEmitter对象来向前端发送消息,其实流式只是一个称呼而已,最终的实际情况也不过是将消息拆分开来进行发送,这样做的好处就是让用户使用过程中更加的舒服,防止用户没有目的的等待。

问题2:如何分隔两条流式消息?

我们通过对返回数据的观察发现,所有的data之间都使用两个换行符来进行分割的,那我们就可以利用\n\n来进行消息的分割'

问题3:如何保证\n\n一定是两条消息之间的分割而不是内容中的\n\n?

让缓冲区尽量大,保证每次都能将一条data数据读取完毕,这样就能保证\n\n一定是两条消息之间的分隔符,目前也只有这一种方法可行,根据我的测试来看,只要设置1024个字节以上的缓冲区,基本不会存在一次性读取不完数据的情况

问题3补充

由于我们缓冲区设置的足够大,所以我们甚至不需要再对消息进行分割,我们每次缓冲区读取到的数据就一定是独立的一条消息

问题4:发送给前端数据时数据产生截断问题?

SSE协议默认是以两个\n换行符为结束标志,所以在数据中如果有\n\n存在的话就会导致后面的数据被截断,无法收到,解决办法就是再对\n进行一次转义即可

WebSocket方式

待补充

注意事项

本文并没有针对于token进行统计,需要用户自行进行代码来实现token的统计。
上述代码并没有进行抽取合并,大家可以自行抽取,后面我也会将我抽取的代码放在Github上供大家参考。